diff --git a/flows/openrag_ingest_docling.json b/flows/openrag_ingest_docling.json
index 889f8425..cce73398 100644
--- a/flows/openrag_ingest_docling.json
+++ b/flows/openrag_ingest_docling.json
@@ -95,7 +95,7 @@
"data": {
"sourceHandle": {
"dataType": "EmbeddingModel",
- "id": "EmbeddingModel-cxG9r",
+ "id": "EmbeddingModel-eZ6bT",
"name": "embeddings",
"output_types": [
"Embeddings"
@@ -110,10 +110,10 @@
"type": "other"
}
},
- "id": "xy-edge__EmbeddingModel-cxG9r{œdataTypeœ:œEmbeddingModelœ,œidœ:œEmbeddingModel-cxG9rœ,œnameœ:œembeddingsœ,œoutput_typesœ:[œEmbeddingsœ]}-OpenSearchHybrid-XtKoA{œfieldNameœ:œembeddingœ,œidœ:œOpenSearchHybrid-XtKoAœ,œinputTypesœ:[œEmbeddingsœ],œtypeœ:œotherœ}",
+ "id": "xy-edge__EmbeddingModel-eZ6bT{œdataTypeœ:œEmbeddingModelœ,œidœ:œEmbeddingModel-eZ6bTœ,œnameœ:œembeddingsœ,œoutput_typesœ:[œEmbeddingsœ]}-OpenSearchHybrid-XtKoA{œfieldNameœ:œembeddingœ,œidœ:œOpenSearchHybrid-XtKoAœ,œinputTypesœ:[œEmbeddingsœ],œtypeœ:œotherœ}",
"selected": false,
- "source": "EmbeddingModel-cxG9r",
- "sourceHandle": "{œdataTypeœ:œEmbeddingModelœ,œidœ:œEmbeddingModel-cxG9rœ,œnameœ:œembeddingsœ,œoutput_typesœ:[œEmbeddingsœ]}",
+ "source": "EmbeddingModel-eZ6bT",
+ "sourceHandle": "{œdataTypeœ:œEmbeddingModelœ,œidœ:œEmbeddingModel-eZ6bTœ,œnameœ:œembeddingsœ,œoutput_typesœ:[œEmbeddingsœ]}",
"target": "OpenSearchHybrid-XtKoA",
"targetHandle": "{œfieldNameœ:œembeddingœ,œidœ:œOpenSearchHybrid-XtKoAœ,œinputTypesœ:[œEmbeddingsœ],œtypeœ:œotherœ}"
}
@@ -1631,7 +1631,7 @@
},
{
"data": {
- "id": "EmbeddingModel-cxG9r",
+ "id": "EmbeddingModel-eZ6bT",
"node": {
"base_classes": [
"Embeddings"
@@ -1657,7 +1657,7 @@
],
"frozen": false,
"icon": "binary",
- "last_updated": "2025-09-24T16:02:07.998Z",
+ "last_updated": "2025-09-22T15:54:52.885Z",
"legacy": false,
"metadata": {
"code_hash": "93faf11517da",
@@ -1738,7 +1738,7 @@
"show": true,
"title_case": false,
"type": "str",
- "value": ""
+ "value": "OPENAI_API_KEY"
},
"chunk_size": {
"_input_type": "IntInput",
@@ -1926,16 +1926,16 @@
"type": "EmbeddingModel"
},
"dragging": false,
- "id": "EmbeddingModel-cxG9r",
+ "id": "EmbeddingModel-eZ6bT",
"measured": {
- "height": 366,
+ "height": 369,
"width": 320
},
"position": {
- "x": 1743.8608432729177,
- "y": 1808.780792406514
+ "x": 1726.6943524438122,
+ "y": 1800.5330404375484
},
- "selected": false,
+ "selected": true,
"type": "genericNode"
}
],
diff --git a/frontend/components/label-wrapper.tsx b/frontend/components/label-wrapper.tsx
index ab785c5c..691b7726 100644
--- a/frontend/components/label-wrapper.tsx
+++ b/frontend/components/label-wrapper.tsx
@@ -10,18 +10,25 @@ export function LabelWrapper({
id,
required,
flex,
+ start,
children,
}: {
label: string;
description?: string;
- helperText?: string;
+ helperText?: string | React.ReactNode;
id: string;
required?: boolean;
flex?: boolean;
+ start?: boolean;
children: React.ReactNode;
}) {
return (
-
+
{label}
{required && * }
@@ -39,7 +46,7 @@ export function LabelWrapper({
- {helperText}
+ {helperText}
)}
@@ -48,7 +55,7 @@ export function LabelWrapper({
{description}
)}
- {flex &&
{children}
}
+ {flex &&
{children}
}
);
}
diff --git a/frontend/components/ui/tooltip.tsx b/frontend/components/ui/tooltip.tsx
index d7ca7125..3d86b3b3 100644
--- a/frontend/components/ui/tooltip.tsx
+++ b/frontend/components/ui/tooltip.tsx
@@ -19,7 +19,7 @@ const TooltipContent = React.forwardRef<
ref={ref}
sideOffset={sideOffset}
className={cn(
- "z-50 overflow-hidden rounded-md border bg-popover px-3 py-1.5 text-sm text-popover-foreground shadow-md animate-in fade-in-0 zoom-in-95 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2 origin-[--radix-tooltip-content-transform-origin]",
+ "z-50 overflow-hidden rounded-md border bg-primary py-1 px-1.5 text-xs font-normal text-primary-foreground shadow-md animate-in fade-in-0 zoom-in-95 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2 origin-[--radix-tooltip-content-transform-origin]",
className,
)}
{...props}
diff --git a/frontend/src/app/connectors/page.tsx b/frontend/src/app/connectors/page.tsx
index 26e1a88a..ad70ec90 100644
--- a/frontend/src/app/connectors/page.tsx
+++ b/frontend/src/app/connectors/page.tsx
@@ -1,20 +1,14 @@
-"use client"
+"use client";
import React, { useState } from "react";
-import { GoogleDrivePicker } from "@/components/google-drive-picker"
-import { useTask } from "@/contexts/task-context"
+import { UnifiedCloudPicker, CloudFile } from "@/components/cloud-picker";
+import { useTask } from "@/contexts/task-context";
-interface GoogleDriveFile {
- id: string;
- name: string;
- mimeType: string;
- webViewLink?: string;
- iconLink?: string;
-}
+// CloudFile interface is now imported from the unified cloud picker
export default function ConnectorsPage() {
- const { addTask } = useTask()
- const [selectedFiles, setSelectedFiles] = useState
([]);
+ const { addTask } = useTask();
+ const [selectedFiles, setSelectedFiles] = useState([]);
const [isSyncing, setIsSyncing] = useState(false);
const [syncResult, setSyncResult] = useState<{
processed?: number;
@@ -25,16 +19,19 @@ export default function ConnectorsPage() {
errors?: number;
} | null>(null);
- const handleFileSelection = (files: GoogleDriveFile[]) => {
+ const handleFileSelection = (files: CloudFile[]) => {
setSelectedFiles(files);
};
- const handleSync = async (connector: { connectionId: string, type: string }) => {
- if (!connector.connectionId || selectedFiles.length === 0) return
-
- setIsSyncing(true)
- setSyncResult(null)
-
+ const handleSync = async (connector: {
+ connectionId: string;
+ type: string;
+ }) => {
+ if (!connector.connectionId || selectedFiles.length === 0) return;
+
+ setIsSyncing(true);
+ setSyncResult(null);
+
try {
const syncBody: {
connection_id: string;
@@ -42,54 +39,55 @@ export default function ConnectorsPage() {
selected_files?: string[];
} = {
connection_id: connector.connectionId,
- selected_files: selectedFiles.map(file => file.id)
- }
-
+ selected_files: selectedFiles.map(file => file.id),
+ };
+
const response = await fetch(`/api/connectors/${connector.type}/sync`, {
- method: 'POST',
+ method: "POST",
headers: {
- 'Content-Type': 'application/json',
+ "Content-Type": "application/json",
},
body: JSON.stringify(syncBody),
- })
-
- const result = await response.json()
-
+ });
+
+ const result = await response.json();
+
if (response.status === 201) {
- const taskId = result.task_id
+ const taskId = result.task_id;
if (taskId) {
- addTask(taskId)
- setSyncResult({
- processed: 0,
+ addTask(taskId);
+ setSyncResult({
+ processed: 0,
total: selectedFiles.length,
- status: 'started'
- })
+ status: "started",
+ });
}
} else if (response.ok) {
- setSyncResult(result)
+ setSyncResult(result);
} else {
- console.error('Sync failed:', result.error)
- setSyncResult({ error: result.error || 'Sync failed' })
+ console.error("Sync failed:", result.error);
+ setSyncResult({ error: result.error || "Sync failed" });
}
} catch (error) {
- console.error('Sync error:', error)
- setSyncResult({ error: 'Network error occurred' })
+ console.error("Sync error:", error);
+ setSyncResult({ error: "Network error occurred" });
} finally {
- setIsSyncing(false)
+ setIsSyncing(false);
}
};
return (
Connectors
-
+
- This is a demo page for the Google Drive picker component.
- For full connector functionality, visit the Settings page.
+ This is a demo page for the Google Drive picker component. For full
+ connector functionality, visit the Settings page.
-
-
0 && (
-
handleSync({ connectionId: "google-drive-connection-id", type: "google-drive" })}
+
+ handleSync({
+ connectionId: "google-drive-connection-id",
+ type: "google-drive",
+ })
+ }
disabled={isSyncing}
className="px-4 py-2 bg-blue-500 text-white rounded hover:bg-blue-600 disabled:opacity-50 disabled:cursor-not-allowed"
>
@@ -110,14 +113,15 @@ export default function ConnectorsPage() {
<>Sync {selectedFiles.length} Selected Items>
)}
-
+
{syncResult && (
{syncResult.error ? (
Error: {syncResult.error}
- ) : syncResult.status === 'started' ? (
+ ) : syncResult.status === "started" ? (
- Sync started for {syncResult.total} files. Check the task notification for progress.
+ Sync started for {syncResult.total} files. Check the task
+ notification for progress.
) : (
diff --git a/frontend/src/app/onboarding/components/openai-onboarding.tsx b/frontend/src/app/onboarding/components/openai-onboarding.tsx
index 236097a4..b202756d 100644
--- a/frontend/src/app/onboarding/components/openai-onboarding.tsx
+++ b/frontend/src/app/onboarding/components/openai-onboarding.tsx
@@ -2,7 +2,7 @@ import { useState } from "react";
import { LabelInput } from "@/components/label-input";
import { LabelWrapper } from "@/components/label-wrapper";
import OpenAILogo from "@/components/logo/openai-logo";
-import { Switch } from "@/components/ui/switch";
+import { Checkbox } from "@/components/ui/checkbox";
import { useDebouncedValue } from "@/lib/debounce";
import type { OnboardingVariables } from "../../api/mutations/useOnboardingMutation";
import { useGetOpenAIModelsQuery } from "../../api/queries/useGetModelsQuery";
@@ -72,11 +72,19 @@ export function OpenAIOnboarding({
<>
+ Reuse the key from your environment config.
+
+ Uncheck to enter a different key.
+ >
+ }
flex
+ start
>
-
@@ -86,6 +94,7 @@ export function OpenAIOnboarding({
)}
{modelsError && (
-
+
Invalid OpenAI API key. Verify or replace the key.
)}
diff --git a/frontend/src/app/upload/[provider]/page.tsx b/frontend/src/app/upload/[provider]/page.tsx
index ea2d319e..522a45b7 100644
--- a/frontend/src/app/upload/[provider]/page.tsx
+++ b/frontend/src/app/upload/[provider]/page.tsx
@@ -4,29 +4,12 @@ import { useState, useEffect } from "react";
import { useParams, useRouter } from "next/navigation";
import { Button } from "@/components/ui/button";
import { ArrowLeft, AlertCircle } from "lucide-react";
-import { GoogleDrivePicker } from "@/components/google-drive-picker";
-import { OneDrivePicker } from "@/components/onedrive-picker";
+import { UnifiedCloudPicker, CloudFile } from "@/components/cloud-picker";
+import type { IngestSettings } from "@/components/cloud-picker/types";
import { useTask } from "@/contexts/task-context";
import { Toast } from "@/components/ui/toast";
-interface GoogleDriveFile {
- id: string;
- name: string;
- mimeType: string;
- webViewLink?: string;
- iconLink?: string;
-}
-
-interface OneDriveFile {
- id: string;
- name: string;
- mimeType?: string;
- webUrl?: string;
- driveItem?: {
- file?: { mimeType: string };
- folder?: unknown;
- };
-}
+// CloudFile interface is now imported from the unified cloud picker
interface CloudConnector {
id: string;
@@ -35,6 +18,7 @@ interface CloudConnector {
status: "not_connected" | "connecting" | "connected" | "error";
type: string;
connectionId?: string;
+ clientId: string;
hasAccessToken: boolean;
accessTokenError?: string;
}
@@ -49,14 +33,19 @@ export default function UploadProviderPage() {
const [isLoading, setIsLoading] = useState(true);
const [error, setError] = useState(null);
const [accessToken, setAccessToken] = useState(null);
- const [selectedFiles, setSelectedFiles] = useState<
- GoogleDriveFile[] | OneDriveFile[]
- >([]);
+ const [selectedFiles, setSelectedFiles] = useState([]);
const [isIngesting, setIsIngesting] = useState(false);
const [currentSyncTaskId, setCurrentSyncTaskId] = useState(
null
);
const [showSuccessToast, setShowSuccessToast] = useState(false);
+ const [ingestSettings, setIngestSettings] = useState({
+ chunkSize: 1000,
+ chunkOverlap: 200,
+ ocr: false,
+ pictureDescriptions: false,
+ embeddingModel: "text-embedding-3-small",
+ });
useEffect(() => {
const fetchConnectorInfo = async () => {
@@ -129,6 +118,7 @@ export default function UploadProviderPage() {
status: isConnected ? "connected" : "not_connected",
type: provider,
connectionId: activeConnection?.connection_id,
+ clientId: activeConnection?.client_id,
hasAccessToken,
accessTokenError,
});
@@ -159,13 +149,6 @@ export default function UploadProviderPage() {
// Task completed successfully, show toast and redirect
setIsIngesting(false);
setShowSuccessToast(true);
-
- // Dispatch knowledge updated event to refresh the knowledge table
- console.log(
- "Cloud provider task completed, dispatching knowledgeUpdated event"
- );
- window.dispatchEvent(new CustomEvent("knowledgeUpdated"));
-
setTimeout(() => {
router.push("/knowledge");
}, 2000); // 2 second delay to let user see toast
@@ -176,20 +159,12 @@ export default function UploadProviderPage() {
}
}, [tasks, currentSyncTaskId, router]);
- const handleFileSelected = (files: GoogleDriveFile[] | OneDriveFile[]) => {
+ const handleFileSelected = (files: CloudFile[]) => {
setSelectedFiles(files);
console.log(`Selected ${files.length} files from ${provider}:`, files);
// You can add additional handling here like triggering sync, etc.
};
- const handleGoogleDriveFileSelected = (files: GoogleDriveFile[]) => {
- handleFileSelected(files);
- };
-
- const handleOneDriveFileSelected = (files: OneDriveFile[]) => {
- handleFileSelected(files);
- };
-
const handleSync = async (connector: CloudConnector) => {
if (!connector.connectionId || selectedFiles.length === 0) return;
@@ -200,9 +175,11 @@ export default function UploadProviderPage() {
connection_id: string;
max_files?: number;
selected_files?: string[];
+ settings?: IngestSettings;
} = {
connection_id: connector.connectionId,
selected_files: selectedFiles.map(file => file.id),
+ settings: ingestSettings,
};
const response = await fetch(`/api/connectors/${connector.type}/sync`, {
@@ -353,48 +330,49 @@ export default function UploadProviderPage() {
router.back()}>
-
+
-
Add Cloud Knowledge
+
+ Add from {getProviderDisplayName()}
+
- {connector.type === "google_drive" && (
-
- )}
-
- {(connector.type === "onedrive" || connector.type === "sharepoint") && (
-
- )}
+
- {selectedFiles.length > 0 && (
-
-
- handleSync(connector)}
- disabled={selectedFiles.length === 0 || isIngesting}
- >
- {isIngesting ? (
- <>Ingesting {selectedFiles.length} Files...>
- ) : (
- <>Ingest Files ({selectedFiles.length})>
- )}
-
-
+
+
+ router.back()}
+ >
+ Back
+
+ handleSync(connector)}
+ disabled={selectedFiles.length === 0 || isIngesting}
+ >
+ {isIngesting ? (
+ <>Ingesting {selectedFiles.length} Files...>
+ ) : (
+ <>Start ingest>
+ )}
+
- )}
+
{/* Success toast notification */}
void
- onFileSelected?: (files: GoogleDriveFile[] | OneDriveFile[], connectorType: string) => void
+ isOpen: boolean;
+ onOpenChange: (open: boolean) => void;
+ onFileSelected?: (files: CloudFile[], connectorType: string) => void;
}
-export function CloudConnectorsDialog({
- isOpen,
+export function CloudConnectorsDialog({
+ isOpen,
onOpenChange,
- onFileSelected
+ onFileSelected,
}: CloudConnectorsDialogProps) {
- const [connectors, setConnectors] = useState([])
- const [isLoading, setIsLoading] = useState(true)
- const [selectedFiles, setSelectedFiles] = useState<{[connectorId: string]: GoogleDriveFile[] | OneDriveFile[]}>({})
- const [connectorAccessTokens, setConnectorAccessTokens] = useState<{[connectorType: string]: string}>({})
- const [activePickerType, setActivePickerType] = useState(null)
+ const [connectors, setConnectors] = useState([]);
+ const [isLoading, setIsLoading] = useState(true);
+ const [selectedFiles, setSelectedFiles] = useState<{
+ [connectorId: string]: CloudFile[];
+ }>({});
+ const [connectorAccessTokens, setConnectorAccessTokens] = useState<{
+ [connectorType: string]: string;
+ }>({});
+ const [activePickerType, setActivePickerType] = useState(null);
const getConnectorIcon = (iconName: string) => {
const iconMap: { [key: string]: React.ReactElement } = {
- 'google-drive': (
+ "google-drive": (
G
),
- 'sharepoint': (
+ sharepoint: (
SP
),
- 'onedrive': (
+ onedrive: (
OD
),
- }
- return iconMap[iconName] || (
-
- ?
-
- )
- }
+ };
+ return (
+ iconMap[iconName] || (
+
+ ?
+
+ )
+ );
+ };
const fetchConnectorStatuses = useCallback(async () => {
- if (!isOpen) return
-
- setIsLoading(true)
+ if (!isOpen) return;
+
+ setIsLoading(true);
try {
// Fetch available connectors from backend
- const connectorsResponse = await fetch('/api/connectors')
+ const connectorsResponse = await fetch("/api/connectors");
if (!connectorsResponse.ok) {
- throw new Error('Failed to load connectors')
+ throw new Error("Failed to load connectors");
}
-
- const connectorsResult = await connectorsResponse.json()
- const connectorTypes = Object.keys(connectorsResult.connectors)
-
+
+ const connectorsResult = await connectorsResponse.json();
+ const connectorTypes = Object.keys(connectorsResult.connectors);
+
// Filter to only cloud connectors
- const cloudConnectorTypes = connectorTypes.filter(type =>
- ['google_drive', 'onedrive', 'sharepoint'].includes(type) &&
- connectorsResult.connectors[type].available
- )
-
+ const cloudConnectorTypes = connectorTypes.filter(
+ type =>
+ ["google_drive", "onedrive", "sharepoint"].includes(type) &&
+ connectorsResult.connectors[type].available
+ );
+
// Initialize connectors list
const initialConnectors = cloudConnectorTypes.map(type => ({
id: type,
@@ -115,82 +105,95 @@ export function CloudConnectorsDialog({
status: "not_connected" as const,
type: type,
hasAccessToken: false,
- accessTokenError: undefined
- }))
-
- setConnectors(initialConnectors)
+ accessTokenError: undefined,
+ clientId: "",
+ }));
+
+ setConnectors(initialConnectors);
// Check status for each cloud connector type
for (const connectorType of cloudConnectorTypes) {
try {
- const response = await fetch(`/api/connectors/${connectorType}/status`)
+ const response = await fetch(
+ `/api/connectors/${connectorType}/status`
+ );
if (response.ok) {
- const data = await response.json()
- const connections = data.connections || []
- const activeConnection = connections.find((conn: { connection_id: string; is_active: boolean }) => conn.is_active)
- const isConnected = activeConnection !== undefined
-
- let hasAccessToken = false
- let accessTokenError: string | undefined = undefined
+ const data = await response.json();
+ const connections = data.connections || [];
+ const activeConnection = connections.find(
+ (conn: { connection_id: string; is_active: boolean }) =>
+ conn.is_active
+ );
+ const isConnected = activeConnection !== undefined;
+
+ let hasAccessToken = false;
+ let accessTokenError: string | undefined = undefined;
// Try to get access token for connected connectors
if (isConnected && activeConnection) {
try {
- const tokenResponse = await fetch(`/api/connectors/${connectorType}/token?connection_id=${activeConnection.connection_id}`)
+ const tokenResponse = await fetch(
+ `/api/connectors/${connectorType}/token?connection_id=${activeConnection.connection_id}`
+ );
if (tokenResponse.ok) {
- const tokenData = await tokenResponse.json()
+ const tokenData = await tokenResponse.json();
if (tokenData.access_token) {
- hasAccessToken = true
+ hasAccessToken = true;
setConnectorAccessTokens(prev => ({
...prev,
- [connectorType]: tokenData.access_token
- }))
+ [connectorType]: tokenData.access_token,
+ }));
}
} else {
- const errorData = await tokenResponse.json().catch(() => ({ error: 'Token unavailable' }))
- accessTokenError = errorData.error || 'Access token unavailable'
+ const errorData = await tokenResponse
+ .json()
+ .catch(() => ({ error: "Token unavailable" }));
+ accessTokenError =
+ errorData.error || "Access token unavailable";
}
} catch {
- accessTokenError = 'Failed to fetch access token'
+ accessTokenError = "Failed to fetch access token";
}
}
-
- setConnectors(prev => prev.map(c =>
- c.type === connectorType
- ? {
- ...c,
- status: isConnected ? "connected" : "not_connected",
- connectionId: activeConnection?.connection_id,
- hasAccessToken,
- accessTokenError
- }
- : c
- ))
+
+ setConnectors(prev =>
+ prev.map(c =>
+ c.type === connectorType
+ ? {
+ ...c,
+ status: isConnected ? "connected" : "not_connected",
+ connectionId: activeConnection?.connection_id,
+ clientId: activeConnection?.client_id,
+ hasAccessToken,
+ accessTokenError,
+ }
+ : c
+ )
+ );
}
} catch (error) {
- console.error(`Failed to check status for ${connectorType}:`, error)
+ console.error(`Failed to check status for ${connectorType}:`, error);
}
}
} catch (error) {
- console.error('Failed to load cloud connectors:', error)
+ console.error("Failed to load cloud connectors:", error);
} finally {
- setIsLoading(false)
+ setIsLoading(false);
}
- }, [isOpen])
+ }, [isOpen]);
- const handleFileSelection = (connectorId: string, files: GoogleDriveFile[] | OneDriveFile[]) => {
+ const handleFileSelection = (connectorId: string, files: CloudFile[]) => {
setSelectedFiles(prev => ({
...prev,
- [connectorId]: files
- }))
-
- onFileSelected?.(files, connectorId)
- }
+ [connectorId]: files,
+ }));
+
+ onFileSelected?.(files, connectorId);
+ };
useEffect(() => {
- fetchConnectorStatuses()
- }, [fetchConnectorStatuses])
-
+ fetchConnectorStatuses();
+ }, [fetchConnectorStatuses]);
return (
@@ -218,19 +221,24 @@ export function CloudConnectorsDialog({
{connectors
.filter(connector => connector.status === "connected")
- .map((connector) => (
+ .map(connector => (
{
- e.preventDefault()
- e.stopPropagation()
+ title={
+ !connector.hasAccessToken
+ ? connector.accessTokenError ||
+ "Access token required - try reconnecting your account"
+ : `Select files from ${connector.name}`
+ }
+ onClick={e => {
+ e.preventDefault();
+ e.stopPropagation();
if (connector.hasAccessToken) {
- setActivePickerType(connector.id)
+ setActivePickerType(connector.id);
}
}}
className="min-w-[120px]"
@@ -243,54 +251,46 @@ export function CloudConnectorsDialog({
{connectors.every(c => c.status !== "connected") && (
No connected cloud providers found.
-
Go to Settings to connect your cloud storage accounts.
+
+ Go to Settings to connect your cloud storage accounts.
+
)}
-
- {/* Render pickers inside dialog */}
- {activePickerType && connectors.find(c => c.id === activePickerType) && (() => {
- const connector = connectors.find(c => c.id === activePickerType)!
-
- if (connector.type === "google_drive") {
+
+ {/* Render unified picker inside dialog */}
+ {activePickerType &&
+ connectors.find(c => c.id === activePickerType) &&
+ (() => {
+ const connector = connectors.find(
+ c => c.id === activePickerType
+ )!;
+
return (
- {
- handleFileSelection(connector.id, files)
- setActivePickerType(null)
+ {
+ handleFileSelection(connector.id, files);
+ setActivePickerType(null);
}}
- selectedFiles={selectedFiles[connector.id] as GoogleDriveFile[] || []}
+ selectedFiles={selectedFiles[connector.id] || []}
isAuthenticated={connector.status === "connected"}
accessToken={connectorAccessTokens[connector.type]}
onPickerStateChange={() => {}}
+ clientId={connector.clientId}
/>
- )
- }
-
- if (connector.type === "onedrive" || connector.type === "sharepoint") {
- return (
-
- {
- handleFileSelection(connector.id, files)
- setActivePickerType(null)
- }}
- selectedFiles={selectedFiles[connector.id] as OneDriveFile[] || []}
- isAuthenticated={connector.status === "connected"}
- accessToken={connectorAccessTokens[connector.type]}
- connectorType={connector.type as "onedrive" | "sharepoint"}
- />
-
- )
- }
-
- return null
- })()}
+ );
+ })()}
)}
- )
-}
\ No newline at end of file
+ );
+}
diff --git a/frontend/src/components/cloud-picker/file-item.tsx b/frontend/src/components/cloud-picker/file-item.tsx
new file mode 100644
index 00000000..3f6b5ab5
--- /dev/null
+++ b/frontend/src/components/cloud-picker/file-item.tsx
@@ -0,0 +1,67 @@
+"use client";
+
+import { Badge } from "@/components/ui/badge";
+import { FileText, Folder, Trash } from "lucide-react";
+import { CloudFile } from "./types";
+
+interface FileItemProps {
+ file: CloudFile;
+ onRemove: (fileId: string) => void;
+}
+
+const getFileIcon = (mimeType: string) => {
+ if (mimeType.includes("folder")) {
+ return
;
+ }
+ return
;
+};
+
+const getMimeTypeLabel = (mimeType: string) => {
+ const typeMap: { [key: string]: string } = {
+ "application/vnd.google-apps.document": "Google Doc",
+ "application/vnd.google-apps.spreadsheet": "Google Sheet",
+ "application/vnd.google-apps.presentation": "Google Slides",
+ "application/vnd.google-apps.folder": "Folder",
+ "application/pdf": "PDF",
+ "text/plain": "Text",
+ "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
+ "Word Doc",
+ "application/vnd.openxmlformats-officedocument.presentationml.presentation":
+ "PowerPoint",
+ };
+
+ return typeMap[mimeType] || mimeType?.split("/").pop() || "Document";
+};
+
+const formatFileSize = (bytes?: number) => {
+ if (!bytes) return "";
+ const sizes = ["B", "KB", "MB", "GB", "TB"];
+ if (bytes === 0) return "0 B";
+ const i = Math.floor(Math.log(bytes) / Math.log(1024));
+ return `${(bytes / Math.pow(1024, i)).toFixed(1)} ${sizes[i]}`;
+};
+
+export const FileItem = ({ file, onRemove }: FileItemProps) => (
+
+
+ {getFileIcon(file.mimeType)}
+ {file.name}
+
+ {getMimeTypeLabel(file.mimeType)}
+
+
+
+
+ {formatFileSize(file.size) || "—"}
+
+
+ onRemove(file.id)}
+ />
+
+
+);
diff --git a/frontend/src/components/cloud-picker/file-list.tsx b/frontend/src/components/cloud-picker/file-list.tsx
new file mode 100644
index 00000000..775d78c4
--- /dev/null
+++ b/frontend/src/components/cloud-picker/file-list.tsx
@@ -0,0 +1,42 @@
+"use client";
+
+import { Button } from "@/components/ui/button";
+import { CloudFile } from "./types";
+import { FileItem } from "./file-item";
+
+interface FileListProps {
+ files: CloudFile[];
+ onClearAll: () => void;
+ onRemoveFile: (fileId: string) => void;
+}
+
+export const FileList = ({
+ files,
+ onClearAll,
+ onRemoveFile,
+}: FileListProps) => {
+ if (files.length === 0) {
+ return null;
+ }
+
+ return (
+
+
+
Added files
+
+ Clear all
+
+
+
+ {files.map(file => (
+
+ ))}
+
+
+ );
+};
diff --git a/frontend/src/components/cloud-picker/index.ts b/frontend/src/components/cloud-picker/index.ts
new file mode 100644
index 00000000..ef7aa74b
--- /dev/null
+++ b/frontend/src/components/cloud-picker/index.ts
@@ -0,0 +1,7 @@
+export { UnifiedCloudPicker } from "./unified-cloud-picker";
+export { PickerHeader } from "./picker-header";
+export { FileList } from "./file-list";
+export { FileItem } from "./file-item";
+export { IngestSettings } from "./ingest-settings";
+export * from "./types";
+export * from "./provider-handlers";
diff --git a/frontend/src/components/cloud-picker/ingest-settings.tsx b/frontend/src/components/cloud-picker/ingest-settings.tsx
new file mode 100644
index 00000000..d5843a2a
--- /dev/null
+++ b/frontend/src/components/cloud-picker/ingest-settings.tsx
@@ -0,0 +1,139 @@
+"use client";
+
+import { Input } from "@/components/ui/input";
+import { Switch } from "@/components/ui/switch";
+import {
+ Collapsible,
+ CollapsibleContent,
+ CollapsibleTrigger,
+} from "@/components/ui/collapsible";
+import { ChevronRight, Info } from "lucide-react";
+import { IngestSettings as IngestSettingsType } from "./types";
+
+interface IngestSettingsProps {
+ isOpen: boolean;
+ onOpenChange: (open: boolean) => void;
+ settings?: IngestSettingsType;
+ onSettingsChange?: (settings: IngestSettingsType) => void;
+}
+
+export const IngestSettings = ({
+ isOpen,
+ onOpenChange,
+ settings,
+ onSettingsChange,
+}: IngestSettingsProps) => {
+ // Default settings
+ const defaultSettings: IngestSettingsType = {
+ chunkSize: 1000,
+ chunkOverlap: 200,
+ ocr: false,
+ pictureDescriptions: false,
+ embeddingModel: "text-embedding-3-small",
+ };
+
+ // Use provided settings or defaults
+ const currentSettings = settings || defaultSettings;
+
+ const handleSettingsChange = (newSettings: Partial
) => {
+ const updatedSettings = { ...currentSettings, ...newSettings };
+ onSettingsChange?.(updatedSettings);
+ };
+
+ return (
+
+
+
+
+ Ingest settings
+
+
+
+
+
+
+
+
Chunk size
+
+ handleSettingsChange({
+ chunkSize: parseInt(e.target.value) || 0,
+ })
+ }
+ />
+
+
+
Chunk overlap
+
+ handleSettingsChange({
+ chunkOverlap: parseInt(e.target.value) || 0,
+ })
+ }
+ />
+
+
+
+
+
+
OCR
+
+ Extracts text from images/PDFs. Ingest is slower when enabled.
+
+
+
+ handleSettingsChange({ ocr: checked })
+ }
+ />
+
+
+
+
+
+ Picture descriptions
+
+
+ Adds captions for images. Ingest is more expensive when enabled.
+
+
+
+ handleSettingsChange({ pictureDescriptions: checked })
+ }
+ />
+
+
+
+
+ Embedding model
+
+
+
+ handleSettingsChange({ embeddingModel: e.target.value })
+ }
+ placeholder="text-embedding-3-small"
+ />
+
+
+
+
+ );
+};
diff --git a/frontend/src/components/cloud-picker/picker-header.tsx b/frontend/src/components/cloud-picker/picker-header.tsx
new file mode 100644
index 00000000..05dcaebd
--- /dev/null
+++ b/frontend/src/components/cloud-picker/picker-header.tsx
@@ -0,0 +1,70 @@
+"use client";
+
+import { Button } from "@/components/ui/button";
+import { Card, CardContent } from "@/components/ui/card";
+import { Plus } from "lucide-react";
+import { CloudProvider } from "./types";
+
+interface PickerHeaderProps {
+ provider: CloudProvider;
+ onAddFiles: () => void;
+ isPickerLoaded: boolean;
+ isPickerOpen: boolean;
+ accessToken?: string;
+ isAuthenticated: boolean;
+}
+
+const getProviderName = (provider: CloudProvider): string => {
+ switch (provider) {
+ case "google_drive":
+ return "Google Drive";
+ case "onedrive":
+ return "OneDrive";
+ case "sharepoint":
+ return "SharePoint";
+ default:
+ return "Cloud Storage";
+ }
+};
+
+export const PickerHeader = ({
+ provider,
+ onAddFiles,
+ isPickerLoaded,
+ isPickerOpen,
+ accessToken,
+ isAuthenticated,
+}: PickerHeaderProps) => {
+ if (!isAuthenticated) {
+ return (
+
+ Please connect to {getProviderName(provider)} first to select specific
+ files.
+
+ );
+ }
+
+ return (
+
+
+
+ Select files from {getProviderName(provider)} to ingest.
+
+
+
+ {isPickerOpen ? "Opening Picker..." : "Add Files"}
+
+
+ csv, json, pdf,{" "}
+
+16 more {" "}
+
150 MB max
+
+
+
+ );
+};
diff --git a/frontend/src/components/cloud-picker/provider-handlers.ts b/frontend/src/components/cloud-picker/provider-handlers.ts
new file mode 100644
index 00000000..4a39312f
--- /dev/null
+++ b/frontend/src/components/cloud-picker/provider-handlers.ts
@@ -0,0 +1,245 @@
+"use client";
+
+import {
+ CloudFile,
+ CloudProvider,
+ GooglePickerData,
+ GooglePickerDocument,
+} from "./types";
+
+export class GoogleDriveHandler {
+ private accessToken: string;
+ private onPickerStateChange?: (isOpen: boolean) => void;
+
+ constructor(
+ accessToken: string,
+ onPickerStateChange?: (isOpen: boolean) => void
+ ) {
+ this.accessToken = accessToken;
+ this.onPickerStateChange = onPickerStateChange;
+ }
+
+ async loadPickerApi(): Promise {
+ return new Promise(resolve => {
+ if (typeof window !== "undefined" && window.gapi) {
+ window.gapi.load("picker", {
+ callback: () => resolve(true),
+ onerror: () => resolve(false),
+ });
+ } else {
+ // Load Google API script
+ const script = document.createElement("script");
+ script.src = "https://apis.google.com/js/api.js";
+ script.async = true;
+ script.defer = true;
+ script.onload = () => {
+ window.gapi.load("picker", {
+ callback: () => resolve(true),
+ onerror: () => resolve(false),
+ });
+ };
+ script.onerror = () => resolve(false);
+ document.head.appendChild(script);
+ }
+ });
+ }
+
+ openPicker(onFileSelected: (files: CloudFile[]) => void): void {
+ if (!window.google?.picker) {
+ return;
+ }
+
+ try {
+ this.onPickerStateChange?.(true);
+
+ const picker = new window.google.picker.PickerBuilder()
+ .addView(window.google.picker.ViewId.DOCS)
+ .addView(window.google.picker.ViewId.FOLDERS)
+ .setOAuthToken(this.accessToken)
+ .enableFeature(window.google.picker.Feature.MULTISELECT_ENABLED)
+ .setTitle("Select files from Google Drive")
+ .setCallback(data => this.pickerCallback(data, onFileSelected))
+ .build();
+
+ picker.setVisible(true);
+
+ // Apply z-index fix
+ setTimeout(() => {
+ const pickerElements = document.querySelectorAll(
+ ".picker-dialog, .goog-modalpopup"
+ );
+ pickerElements.forEach(el => {
+ (el as HTMLElement).style.zIndex = "10000";
+ });
+ const bgElements = document.querySelectorAll(
+ ".picker-dialog-bg, .goog-modalpopup-bg"
+ );
+ bgElements.forEach(el => {
+ (el as HTMLElement).style.zIndex = "9999";
+ });
+ }, 100);
+ } catch (error) {
+ console.error("Error creating picker:", error);
+ this.onPickerStateChange?.(false);
+ }
+ }
+
+ private async pickerCallback(
+ data: GooglePickerData,
+ onFileSelected: (files: CloudFile[]) => void
+ ): Promise {
+ if (data.action === window.google.picker.Action.PICKED) {
+ const files: CloudFile[] = data.docs.map((doc: GooglePickerDocument) => ({
+ id: doc[window.google.picker.Document.ID],
+ name: doc[window.google.picker.Document.NAME],
+ mimeType: doc[window.google.picker.Document.MIME_TYPE],
+ webViewLink: doc[window.google.picker.Document.URL],
+ iconLink: doc[window.google.picker.Document.ICON_URL],
+ size: doc["sizeBytes"] ? parseInt(doc["sizeBytes"]) : undefined,
+ modifiedTime: doc["lastEditedUtc"],
+ isFolder:
+ doc[window.google.picker.Document.MIME_TYPE] ===
+ "application/vnd.google-apps.folder",
+ }));
+
+ // Enrich with additional file data if needed
+ if (files.some(f => !f.size && !f.isFolder)) {
+ try {
+ const enrichedFiles = await Promise.all(
+ files.map(async file => {
+ if (!file.size && !file.isFolder) {
+ try {
+ const response = await fetch(
+ `https://www.googleapis.com/drive/v3/files/${file.id}?fields=size,modifiedTime`,
+ {
+ headers: {
+ Authorization: `Bearer ${this.accessToken}`,
+ },
+ }
+ );
+ if (response.ok) {
+ const fileDetails = await response.json();
+ return {
+ ...file,
+ size: fileDetails.size
+ ? parseInt(fileDetails.size)
+ : undefined,
+ modifiedTime:
+ fileDetails.modifiedTime || file.modifiedTime,
+ };
+ }
+ } catch (error) {
+ console.warn("Failed to fetch file details:", error);
+ }
+ }
+ return file;
+ })
+ );
+ onFileSelected(enrichedFiles);
+ } catch (error) {
+ console.warn("Failed to enrich file data:", error);
+ onFileSelected(files);
+ }
+ } else {
+ onFileSelected(files);
+ }
+ }
+
+ this.onPickerStateChange?.(false);
+ }
+}
+
+export class OneDriveHandler {
+ private accessToken: string;
+ private clientId: string;
+ private provider: CloudProvider;
+ private baseUrl?: string;
+
+ constructor(
+ accessToken: string,
+ clientId: string,
+ provider: CloudProvider = "onedrive",
+ baseUrl?: string
+ ) {
+ this.accessToken = accessToken;
+ this.clientId = clientId;
+ this.provider = provider;
+ this.baseUrl = baseUrl;
+ }
+
+ async loadPickerApi(): Promise {
+ return new Promise(resolve => {
+ const script = document.createElement("script");
+ script.src = "https://js.live.net/v7.2/OneDrive.js";
+ script.onload = () => resolve(true);
+ script.onerror = () => resolve(false);
+ document.head.appendChild(script);
+ });
+ }
+
+ openPicker(onFileSelected: (files: CloudFile[]) => void): void {
+ if (!window.OneDrive) {
+ return;
+ }
+
+ window.OneDrive.open({
+ clientId: this.clientId,
+ action: "query",
+ multiSelect: true,
+ advanced: {
+ endpointHint: "api.onedrive.com",
+ accessToken: this.accessToken,
+ },
+ success: (response: any) => {
+ const newFiles: CloudFile[] =
+ response.value?.map((item: any, index: number) => ({
+ id: item.id,
+ name:
+ item.name ||
+ `${this.getProviderName()} File ${index + 1} (${item.id.slice(
+ -8
+ )})`,
+ mimeType: item.file?.mimeType || "application/octet-stream",
+ webUrl: item.webUrl || "",
+ downloadUrl: item["@microsoft.graph.downloadUrl"] || "",
+ size: item.size,
+ modifiedTime: item.lastModifiedDateTime,
+ isFolder: !!item.folder,
+ })) || [];
+
+ onFileSelected(newFiles);
+ },
+ cancel: () => {
+ console.log("Picker cancelled");
+ },
+ error: (error: any) => {
+ console.error("Picker error:", error);
+ },
+ });
+ }
+
+ private getProviderName(): string {
+ return this.provider === "sharepoint" ? "SharePoint" : "OneDrive";
+ }
+}
+
+export const createProviderHandler = (
+ provider: CloudProvider,
+ accessToken: string,
+ onPickerStateChange?: (isOpen: boolean) => void,
+ clientId?: string,
+ baseUrl?: string
+) => {
+ switch (provider) {
+ case "google_drive":
+ return new GoogleDriveHandler(accessToken, onPickerStateChange);
+ case "onedrive":
+ case "sharepoint":
+ if (!clientId) {
+ throw new Error("Client ID required for OneDrive/SharePoint");
+ }
+ return new OneDriveHandler(accessToken, clientId, provider, baseUrl);
+ default:
+ throw new Error(`Unsupported provider: ${provider}`);
+ }
+};
diff --git a/frontend/src/components/cloud-picker/types.ts b/frontend/src/components/cloud-picker/types.ts
new file mode 100644
index 00000000..ca346bf0
--- /dev/null
+++ b/frontend/src/components/cloud-picker/types.ts
@@ -0,0 +1,106 @@
+export interface CloudFile {
+ id: string;
+ name: string;
+ mimeType: string;
+ webViewLink?: string;
+ iconLink?: string;
+ size?: number;
+ modifiedTime?: string;
+ isFolder?: boolean;
+ webUrl?: string;
+ downloadUrl?: string;
+}
+
+export type CloudProvider = "google_drive" | "onedrive" | "sharepoint";
+
+export interface UnifiedCloudPickerProps {
+ provider: CloudProvider;
+ onFileSelected: (files: CloudFile[]) => void;
+ selectedFiles?: CloudFile[];
+ isAuthenticated: boolean;
+ accessToken?: string;
+ onPickerStateChange?: (isOpen: boolean) => void;
+ // OneDrive/SharePoint specific props
+ clientId?: string;
+ baseUrl?: string;
+ // Ingest settings
+ onSettingsChange?: (settings: IngestSettings) => void;
+}
+
+export interface GoogleAPI {
+ load: (
+ api: string,
+ options: { callback: () => void; onerror?: () => void }
+ ) => void;
+}
+
+export interface GooglePickerData {
+ action: string;
+ docs: GooglePickerDocument[];
+}
+
+export interface GooglePickerDocument {
+ [key: string]: string;
+}
+
+declare global {
+ interface Window {
+ gapi: GoogleAPI;
+ google: {
+ picker: {
+ api: {
+ load: (callback: () => void) => void;
+ };
+ PickerBuilder: new () => GooglePickerBuilder;
+ ViewId: {
+ DOCS: string;
+ FOLDERS: string;
+ DOCS_IMAGES_AND_VIDEOS: string;
+ DOCUMENTS: string;
+ PRESENTATIONS: string;
+ SPREADSHEETS: string;
+ };
+ Feature: {
+ MULTISELECT_ENABLED: string;
+ NAV_HIDDEN: string;
+ SIMPLE_UPLOAD_ENABLED: string;
+ };
+ Action: {
+ PICKED: string;
+ CANCEL: string;
+ };
+ Document: {
+ ID: string;
+ NAME: string;
+ MIME_TYPE: string;
+ URL: string;
+ ICON_URL: string;
+ };
+ };
+ };
+ OneDrive?: any;
+ }
+}
+
+export interface GooglePickerBuilder {
+ addView: (view: string) => GooglePickerBuilder;
+ setOAuthToken: (token: string) => GooglePickerBuilder;
+ setCallback: (
+ callback: (data: GooglePickerData) => void
+ ) => GooglePickerBuilder;
+ enableFeature: (feature: string) => GooglePickerBuilder;
+ setTitle: (title: string) => GooglePickerBuilder;
+ build: () => GooglePicker;
+}
+
+export interface GooglePicker {
+ setVisible: (visible: boolean) => void;
+}
+
+export interface IngestSettings {
+ chunkSize: number;
+ chunkOverlap: number;
+ ocr: boolean;
+ pictureDescriptions: boolean;
+ embeddingModel: string;
+}
diff --git a/frontend/src/components/cloud-picker/unified-cloud-picker.tsx b/frontend/src/components/cloud-picker/unified-cloud-picker.tsx
new file mode 100644
index 00000000..fd77698f
--- /dev/null
+++ b/frontend/src/components/cloud-picker/unified-cloud-picker.tsx
@@ -0,0 +1,195 @@
+"use client";
+
+import { useState, useEffect } from "react";
+import {
+ UnifiedCloudPickerProps,
+ CloudFile,
+ IngestSettings as IngestSettingsType,
+} from "./types";
+import { PickerHeader } from "./picker-header";
+import { FileList } from "./file-list";
+import { IngestSettings } from "./ingest-settings";
+import { createProviderHandler } from "./provider-handlers";
+
+export const UnifiedCloudPicker = ({
+ provider,
+ onFileSelected,
+ selectedFiles = [],
+ isAuthenticated,
+ accessToken,
+ onPickerStateChange,
+ clientId,
+ baseUrl,
+ onSettingsChange,
+}: UnifiedCloudPickerProps) => {
+ const [isPickerLoaded, setIsPickerLoaded] = useState(false);
+ const [isPickerOpen, setIsPickerOpen] = useState(false);
+ const [isIngestSettingsOpen, setIsIngestSettingsOpen] = useState(false);
+ const [isLoadingBaseUrl, setIsLoadingBaseUrl] = useState(false);
+ const [autoBaseUrl, setAutoBaseUrl] = useState(undefined);
+
+ // Settings state with defaults
+ const [ingestSettings, setIngestSettings] = useState({
+ chunkSize: 1000,
+ chunkOverlap: 200,
+ ocr: false,
+ pictureDescriptions: false,
+ embeddingModel: "text-embedding-3-small",
+ });
+
+ // Handle settings changes and notify parent
+ const handleSettingsChange = (newSettings: IngestSettingsType) => {
+ setIngestSettings(newSettings);
+ onSettingsChange?.(newSettings);
+ };
+
+ const effectiveBaseUrl = baseUrl || autoBaseUrl;
+
+ // Auto-detect base URL for OneDrive personal accounts
+ useEffect(() => {
+ if (
+ (provider === "onedrive" || provider === "sharepoint") &&
+ !baseUrl &&
+ accessToken &&
+ !autoBaseUrl
+ ) {
+ const getBaseUrl = async () => {
+ setIsLoadingBaseUrl(true);
+ try {
+ setAutoBaseUrl("https://onedrive.live.com/picker");
+ } catch (error) {
+ console.error("Auto-detect baseUrl failed:", error);
+ } finally {
+ setIsLoadingBaseUrl(false);
+ }
+ };
+
+ getBaseUrl();
+ }
+ }, [accessToken, baseUrl, autoBaseUrl, provider]);
+
+ // Load picker API
+ useEffect(() => {
+ if (!accessToken || !isAuthenticated) return;
+
+ const loadApi = async () => {
+ try {
+ const handler = createProviderHandler(
+ provider,
+ accessToken,
+ onPickerStateChange,
+ clientId,
+ effectiveBaseUrl
+ );
+ const loaded = await handler.loadPickerApi();
+ setIsPickerLoaded(loaded);
+ } catch (error) {
+ console.error("Failed to create provider handler:", error);
+ setIsPickerLoaded(false);
+ }
+ };
+
+ loadApi();
+ }, [
+ accessToken,
+ isAuthenticated,
+ provider,
+ clientId,
+ effectiveBaseUrl,
+ onPickerStateChange,
+ ]);
+
+ const handleAddFiles = () => {
+ if (!isPickerLoaded || !accessToken) {
+ return;
+ }
+
+ if ((provider === "onedrive" || provider === "sharepoint") && !clientId) {
+ console.error("Client ID required for OneDrive/SharePoint");
+ return;
+ }
+
+ try {
+ setIsPickerOpen(true);
+ onPickerStateChange?.(true);
+
+ const handler = createProviderHandler(
+ provider,
+ accessToken,
+ isOpen => {
+ setIsPickerOpen(isOpen);
+ onPickerStateChange?.(isOpen);
+ },
+ clientId,
+ effectiveBaseUrl
+ );
+
+ handler.openPicker((files: CloudFile[]) => {
+ // Merge new files with existing ones, avoiding duplicates
+ const existingIds = new Set(selectedFiles.map(f => f.id));
+ const newFiles = files.filter(f => !existingIds.has(f.id));
+ onFileSelected([...selectedFiles, ...newFiles]);
+ });
+ } catch (error) {
+ console.error("Error opening picker:", error);
+ setIsPickerOpen(false);
+ onPickerStateChange?.(false);
+ }
+ };
+
+ const handleRemoveFile = (fileId: string) => {
+ const updatedFiles = selectedFiles.filter(file => file.id !== fileId);
+ onFileSelected(updatedFiles);
+ };
+
+ const handleClearAll = () => {
+ onFileSelected([]);
+ };
+
+ if (isLoadingBaseUrl) {
+ return (
+
+ Loading...
+
+ );
+ }
+
+ if (
+ (provider === "onedrive" || provider === "sharepoint") &&
+ !clientId &&
+ isAuthenticated
+ ) {
+ return (
+
+ Configuration required: Client ID missing for{" "}
+ {provider === "sharepoint" ? "SharePoint" : "OneDrive"}.
+
+ );
+ }
+
+ return (
+
+ );
+};
diff --git a/frontend/src/components/google-drive-picker.tsx b/frontend/src/components/google-drive-picker.tsx
deleted file mode 100644
index c9dee19a..00000000
--- a/frontend/src/components/google-drive-picker.tsx
+++ /dev/null
@@ -1,341 +0,0 @@
-"use client"
-
-import { useState, useEffect } from "react"
-import { Button } from "@/components/ui/button"
-import { Badge } from "@/components/ui/badge"
-import { FileText, Folder, Plus, Trash2 } from "lucide-react"
-import { Card, CardContent } from "@/components/ui/card"
-
-interface GoogleDrivePickerProps {
- onFileSelected: (files: GoogleDriveFile[]) => void
- selectedFiles?: GoogleDriveFile[]
- isAuthenticated: boolean
- accessToken?: string
- onPickerStateChange?: (isOpen: boolean) => void
-}
-
-interface GoogleDriveFile {
- id: string
- name: string
- mimeType: string
- webViewLink?: string
- iconLink?: string
- size?: number
- modifiedTime?: string
- isFolder?: boolean
-}
-
-interface GoogleAPI {
- load: (api: string, options: { callback: () => void; onerror?: () => void }) => void
-}
-
-interface GooglePickerData {
- action: string
- docs: GooglePickerDocument[]
-}
-
-interface GooglePickerDocument {
- [key: string]: string
-}
-
-declare global {
- interface Window {
- gapi: GoogleAPI
- google: {
- picker: {
- api: {
- load: (callback: () => void) => void
- }
- PickerBuilder: new () => GooglePickerBuilder
- ViewId: {
- DOCS: string
- FOLDERS: string
- DOCS_IMAGES_AND_VIDEOS: string
- DOCUMENTS: string
- PRESENTATIONS: string
- SPREADSHEETS: string
- }
- Feature: {
- MULTISELECT_ENABLED: string
- NAV_HIDDEN: string
- SIMPLE_UPLOAD_ENABLED: string
- }
- Action: {
- PICKED: string
- CANCEL: string
- }
- Document: {
- ID: string
- NAME: string
- MIME_TYPE: string
- URL: string
- ICON_URL: string
- }
- }
- }
- }
-}
-
-interface GooglePickerBuilder {
- addView: (view: string) => GooglePickerBuilder
- setOAuthToken: (token: string) => GooglePickerBuilder
- setCallback: (callback: (data: GooglePickerData) => void) => GooglePickerBuilder
- enableFeature: (feature: string) => GooglePickerBuilder
- setTitle: (title: string) => GooglePickerBuilder
- build: () => GooglePicker
-}
-
-interface GooglePicker {
- setVisible: (visible: boolean) => void
-}
-
-export function GoogleDrivePicker({
- onFileSelected,
- selectedFiles = [],
- isAuthenticated,
- accessToken,
- onPickerStateChange
-}: GoogleDrivePickerProps) {
- const [isPickerLoaded, setIsPickerLoaded] = useState(false)
- const [isPickerOpen, setIsPickerOpen] = useState(false)
-
- useEffect(() => {
- const loadPickerApi = () => {
- if (typeof window !== 'undefined' && window.gapi) {
- window.gapi.load('picker', {
- callback: () => {
- setIsPickerLoaded(true)
- },
- onerror: () => {
- console.error('Failed to load Google Picker API')
- }
- })
- }
- }
-
- // Load Google API script if not already loaded
- if (typeof window !== 'undefined') {
- if (!window.gapi) {
- const script = document.createElement('script')
- script.src = 'https://apis.google.com/js/api.js'
- script.async = true
- script.defer = true
- script.onload = loadPickerApi
- script.onerror = () => {
- console.error('Failed to load Google API script')
- }
- document.head.appendChild(script)
-
- return () => {
- if (document.head.contains(script)) {
- document.head.removeChild(script)
- }
- }
- } else {
- loadPickerApi()
- }
- }
- }, [])
-
-
- const openPicker = () => {
- if (!isPickerLoaded || !accessToken || !window.google?.picker) {
- return
- }
-
- try {
- setIsPickerOpen(true)
- onPickerStateChange?.(true)
-
- // Create picker with higher z-index and focus handling
- const picker = new window.google.picker.PickerBuilder()
- .addView(window.google.picker.ViewId.DOCS)
- .addView(window.google.picker.ViewId.FOLDERS)
- .setOAuthToken(accessToken)
- .enableFeature(window.google.picker.Feature.MULTISELECT_ENABLED)
- .setTitle('Select files from Google Drive')
- .setCallback(pickerCallback)
- .build()
-
- picker.setVisible(true)
-
- // Apply z-index fix after a short delay to ensure picker is rendered
- setTimeout(() => {
- const pickerElements = document.querySelectorAll('.picker-dialog, .goog-modalpopup')
- pickerElements.forEach(el => {
- (el as HTMLElement).style.zIndex = '10000'
- })
- const bgElements = document.querySelectorAll('.picker-dialog-bg, .goog-modalpopup-bg')
- bgElements.forEach(el => {
- (el as HTMLElement).style.zIndex = '9999'
- })
- }, 100)
-
- } catch (error) {
- console.error('Error creating picker:', error)
- setIsPickerOpen(false)
- onPickerStateChange?.(false)
- }
- }
-
- const pickerCallback = async (data: GooglePickerData) => {
- if (data.action === window.google.picker.Action.PICKED) {
- const files: GoogleDriveFile[] = data.docs.map((doc: GooglePickerDocument) => ({
- id: doc[window.google.picker.Document.ID],
- name: doc[window.google.picker.Document.NAME],
- mimeType: doc[window.google.picker.Document.MIME_TYPE],
- webViewLink: doc[window.google.picker.Document.URL],
- iconLink: doc[window.google.picker.Document.ICON_URL],
- size: doc['sizeBytes'] ? parseInt(doc['sizeBytes']) : undefined,
- modifiedTime: doc['lastEditedUtc'],
- isFolder: doc[window.google.picker.Document.MIME_TYPE] === 'application/vnd.google-apps.folder'
- }))
-
- // If size is still missing, try to fetch it via Google Drive API
- if (accessToken && files.some(f => !f.size && !f.isFolder)) {
- try {
- const enrichedFiles = await Promise.all(files.map(async (file) => {
- if (!file.size && !file.isFolder) {
- try {
- const response = await fetch(`https://www.googleapis.com/drive/v3/files/${file.id}?fields=size,modifiedTime`, {
- headers: {
- 'Authorization': `Bearer ${accessToken}`
- }
- })
- if (response.ok) {
- const fileDetails = await response.json()
- return {
- ...file,
- size: fileDetails.size ? parseInt(fileDetails.size) : undefined,
- modifiedTime: fileDetails.modifiedTime || file.modifiedTime
- }
- }
- } catch (error) {
- console.warn('Failed to fetch file details:', error)
- }
- }
- return file
- }))
- onFileSelected(enrichedFiles)
- } catch (error) {
- console.warn('Failed to enrich file data:', error)
- onFileSelected(files)
- }
- } else {
- onFileSelected(files)
- }
- }
-
- setIsPickerOpen(false)
- onPickerStateChange?.(false)
- }
-
- const removeFile = (fileId: string) => {
- const updatedFiles = selectedFiles.filter(file => file.id !== fileId)
- onFileSelected(updatedFiles)
- }
-
- const getFileIcon = (mimeType: string) => {
- if (mimeType.includes('folder')) {
- return
- }
- return
- }
-
- const getMimeTypeLabel = (mimeType: string) => {
- const typeMap: { [key: string]: string } = {
- 'application/vnd.google-apps.document': 'Google Doc',
- 'application/vnd.google-apps.spreadsheet': 'Google Sheet',
- 'application/vnd.google-apps.presentation': 'Google Slides',
- 'application/vnd.google-apps.folder': 'Folder',
- 'application/pdf': 'PDF',
- 'text/plain': 'Text',
- 'application/vnd.openxmlformats-officedocument.wordprocessingml.document': 'Word Doc',
- 'application/vnd.openxmlformats-officedocument.presentationml.presentation': 'PowerPoint'
- }
-
- return typeMap[mimeType] || 'Document'
- }
-
- const formatFileSize = (bytes?: number) => {
- if (!bytes) return ''
- const sizes = ['B', 'KB', 'MB', 'GB', 'TB']
- if (bytes === 0) return '0 B'
- const i = Math.floor(Math.log(bytes) / Math.log(1024))
- return `${(bytes / Math.pow(1024, i)).toFixed(1)} ${sizes[i]}`
- }
-
- if (!isAuthenticated) {
- return (
-
- Please connect to Google Drive first to select specific files.
-
- )
- }
-
- return (
-
-
-
-
- Select files from Google Drive to ingest.
-
-
-
- {isPickerOpen ? 'Opening Picker...' : 'Add Files'}
-
-
-
-
- {selectedFiles.length > 0 && (
-
-
-
- Added files
-
-
onFileSelected([])}
- size="sm"
- variant="ghost"
- className="text-xs h-6"
- >
- Clear all
-
-
-
- {selectedFiles.map((file) => (
-
-
- {getFileIcon(file.mimeType)}
- {file.name}
-
- {getMimeTypeLabel(file.mimeType)}
-
-
-
- {formatFileSize(file.size)}
- removeFile(file.id)}
- size="sm"
- variant="ghost"
- className="h-6 w-6 p-0"
- >
-
-
-
-
- ))}
-
-
-
- )}
-
- )
-}
diff --git a/frontend/src/components/onedrive-picker.tsx b/frontend/src/components/onedrive-picker.tsx
deleted file mode 100644
index 739491c1..00000000
--- a/frontend/src/components/onedrive-picker.tsx
+++ /dev/null
@@ -1,320 +0,0 @@
-"use client"
-
-import { useState, useEffect } from "react"
-import { Button } from "@/components/ui/button"
-import { Badge } from "@/components/ui/badge"
-import { FileText, Folder, Trash2, X } from "lucide-react"
-
-interface OneDrivePickerProps {
- onFileSelected: (files: OneDriveFile[]) => void
- selectedFiles?: OneDriveFile[]
- isAuthenticated: boolean
- accessToken?: string
- connectorType?: "onedrive" | "sharepoint"
- onPickerStateChange?: (isOpen: boolean) => void
-}
-
-interface OneDriveFile {
- id: string
- name: string
- mimeType?: string
- webUrl?: string
- driveItem?: {
- file?: { mimeType: string }
- folder?: unknown
- }
-}
-
-interface GraphResponse {
- value: OneDriveFile[]
-}
-
-declare global {
- interface Window {
- mgt?: {
- Providers?: {
- globalProvider?: unknown
- }
- }
- }
-}
-
-export function OneDrivePicker({
- onFileSelected,
- selectedFiles = [],
- isAuthenticated,
- accessToken,
- connectorType = "onedrive",
- onPickerStateChange
-}: OneDrivePickerProps) {
- const [isLoading, setIsLoading] = useState(false)
- const [files, setFiles] = useState([])
- const [isPickerOpen, setIsPickerOpen] = useState(false)
- const [currentPath, setCurrentPath] = useState(
- connectorType === "sharepoint" ? 'sites?search=' : 'me/drive/root/children'
- )
- const [breadcrumbs, setBreadcrumbs] = useState<{id: string, name: string}[]>([
- {id: 'root', name: connectorType === "sharepoint" ? 'SharePoint' : 'OneDrive'}
- ])
-
- useEffect(() => {
- const loadMGT = async () => {
- if (typeof window !== 'undefined' && !window.mgt) {
- try {
- await import('@microsoft/mgt-components')
- await import('@microsoft/mgt-msal2-provider')
-
- // For simplicity, we'll use direct Graph API calls instead of MGT components
- // MGT provider initialization would go here if needed
- } catch {
- console.warn('MGT not available, falling back to direct API calls')
- }
- }
- }
-
- loadMGT()
- }, [accessToken])
-
-
- const fetchFiles = async (path: string = currentPath) => {
- if (!accessToken) return
-
- setIsLoading(true)
- try {
- const response = await fetch(`https://graph.microsoft.com/v1.0/${path}`, {
- headers: {
- 'Authorization': `Bearer ${accessToken}`,
- 'Content-Type': 'application/json'
- }
- })
-
- if (response.ok) {
- const data: GraphResponse = await response.json()
- setFiles(data.value || [])
- } else {
- console.error('Failed to fetch OneDrive files:', response.statusText)
- }
- } catch (error) {
- console.error('Error fetching OneDrive files:', error)
- } finally {
- setIsLoading(false)
- }
- }
-
- const openPicker = () => {
- if (!accessToken) return
-
- setIsPickerOpen(true)
- onPickerStateChange?.(true)
- fetchFiles()
- }
-
- const closePicker = () => {
- setIsPickerOpen(false)
- onPickerStateChange?.(false)
- setFiles([])
- setCurrentPath(
- connectorType === "sharepoint" ? 'sites?search=' : 'me/drive/root/children'
- )
- setBreadcrumbs([
- {id: 'root', name: connectorType === "sharepoint" ? 'SharePoint' : 'OneDrive'}
- ])
- }
-
- const handleFileClick = (file: OneDriveFile) => {
- if (file.driveItem?.folder) {
- // Navigate to folder
- const newPath = `me/drive/items/${file.id}/children`
- setCurrentPath(newPath)
- setBreadcrumbs([...breadcrumbs, {id: file.id, name: file.name}])
- fetchFiles(newPath)
- } else {
- // Select file
- const isAlreadySelected = selectedFiles.some(f => f.id === file.id)
- if (!isAlreadySelected) {
- onFileSelected([...selectedFiles, file])
- }
- }
- }
-
- const navigateToBreadcrumb = (index: number) => {
- if (index === 0) {
- setCurrentPath('me/drive/root/children')
- setBreadcrumbs([{id: 'root', name: 'OneDrive'}])
- fetchFiles('me/drive/root/children')
- } else {
- const targetCrumb = breadcrumbs[index]
- const newPath = `me/drive/items/${targetCrumb.id}/children`
- setCurrentPath(newPath)
- setBreadcrumbs(breadcrumbs.slice(0, index + 1))
- fetchFiles(newPath)
- }
- }
-
- const removeFile = (fileId: string) => {
- const updatedFiles = selectedFiles.filter(file => file.id !== fileId)
- onFileSelected(updatedFiles)
- }
-
- const getFileIcon = (file: OneDriveFile) => {
- if (file.driveItem?.folder) {
- return
- }
- return
- }
-
- const getMimeTypeLabel = (file: OneDriveFile) => {
- const mimeType = file.driveItem?.file?.mimeType || file.mimeType || ''
- const typeMap: { [key: string]: string } = {
- 'application/vnd.openxmlformats-officedocument.wordprocessingml.document': 'Word Doc',
- 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': 'Excel',
- 'application/vnd.openxmlformats-officedocument.presentationml.presentation': 'PowerPoint',
- 'application/pdf': 'PDF',
- 'text/plain': 'Text',
- 'image/jpeg': 'Image',
- 'image/png': 'Image',
- }
-
- if (file.driveItem?.folder) return 'Folder'
- return typeMap[mimeType] || 'Document'
- }
-
- const serviceName = connectorType === "sharepoint" ? "SharePoint" : "OneDrive"
-
- if (!isAuthenticated) {
- return (
-
- Please connect to {serviceName} first to select specific files.
-
- )
- }
-
- return (
-
-
-
-
{serviceName} File Selection
-
- Choose specific files to sync instead of syncing everything
-
-
-
- {!accessToken ? "No Access Token" : "Select Files"}
-
-
-
- {/* Status message when access token is missing */}
- {isAuthenticated && !accessToken && (
-
-
Access token unavailable
-
The file picker requires an access token. Try disconnecting and reconnecting your {serviceName} account.
-
- )}
-
- {/* File Picker Modal */}
- {isPickerOpen && (
-
-
-
-
Select Files from {serviceName}
-
-
-
-
-
- {/* Breadcrumbs */}
-
- {breadcrumbs.map((crumb, index) => (
-
- {index > 0 && / }
- navigateToBreadcrumb(index)}
- className="text-blue-600 hover:underline"
- >
- {crumb.name}
-
-
- ))}
-
-
- {/* File List */}
-
- {isLoading ? (
-
Loading...
- ) : files.length === 0 ? (
-
No files found
- ) : (
-
- {files.map((file) => (
-
handleFileClick(file)}
- >
-
- {getFileIcon(file)}
- {file.name}
-
- {getMimeTypeLabel(file)}
-
-
- {selectedFiles.some(f => f.id === file.id) && (
-
Selected
- )}
-
- ))}
-
- )}
-
-
-
- )}
-
- {selectedFiles.length > 0 && (
-
-
- Selected files ({selectedFiles.length}):
-
-
- {selectedFiles.map((file) => (
-
-
- {getFileIcon(file)}
- {file.name}
-
- {getMimeTypeLabel(file)}
-
-
-
removeFile(file.id)}
- size="sm"
- variant="ghost"
- className="h-6 w-6 p-0"
- >
-
-
-
- ))}
-
-
onFileSelected([])}
- size="sm"
- variant="ghost"
- className="text-xs h-6"
- >
- Clear all
-
-
- )}
-
- )
-}
\ No newline at end of file
diff --git a/src/api/connectors.py b/src/api/connectors.py
index 4c9eab49..5fbcec86 100644
--- a/src/api/connectors.py
+++ b/src/api/connectors.py
@@ -127,6 +127,23 @@ async def connector_status(request: Request, connector_service, session_manager)
user_id=user.user_id, connector_type=connector_type
)
+ # Get the connector for each connection
+ connection_client_ids = {}
+ for connection in connections:
+ try:
+ connector = await connector_service._get_connector(connection.connection_id)
+ if connector is not None:
+ connection_client_ids[connection.connection_id] = connector.get_client_id()
+ else:
+ connection_client_ids[connection.connection_id] = None
+ except Exception as e:
+ logger.warning(
+ "Could not get connector for connection",
+ connection_id=connection.connection_id,
+ error=str(e),
+ )
+ connection.connector = None
+
# Check if there are any active connections
active_connections = [conn for conn in connections if conn.is_active]
has_authenticated_connection = len(active_connections) > 0
@@ -140,6 +157,7 @@ async def connector_status(request: Request, connector_service, session_manager)
{
"connection_id": conn.connection_id,
"name": conn.name,
+ "client_id": connection_client_ids.get(conn.connection_id),
"is_active": conn.is_active,
"created_at": conn.created_at.isoformat(),
"last_sync": conn.last_sync.isoformat() if conn.last_sync else None,
@@ -323,8 +341,8 @@ async def connector_webhook(request: Request, connector_service, session_manager
)
async def connector_token(request: Request, connector_service, session_manager):
- """Get access token for connector API calls (e.g., Google Picker)"""
- connector_type = request.path_params.get("connector_type")
+ """Get access token for connector API calls (e.g., Pickers)."""
+ url_connector_type = request.path_params.get("connector_type")
connection_id = request.query_params.get("connection_id")
if not connection_id:
@@ -333,37 +351,81 @@ async def connector_token(request: Request, connector_service, session_manager):
user = request.state.user
try:
- # Get the connection and verify it belongs to the user
+ # 1) Load the connection and verify ownership
connection = await connector_service.connection_manager.get_connection(connection_id)
if not connection or connection.user_id != user.user_id:
return JSONResponse({"error": "Connection not found"}, status_code=404)
- # Get the connector instance
+ # 2) Get the ACTUAL connector instance/type for this connection_id
connector = await connector_service._get_connector(connection_id)
if not connector:
- return JSONResponse({"error": f"Connector not available - authentication may have failed for {connector_type}"}, status_code=404)
+ return JSONResponse(
+ {"error": f"Connector not available - authentication may have failed for {url_connector_type}"},
+ status_code=404,
+ )
- # For Google Drive, get the access token
- if connector_type == "google_drive" and hasattr(connector, 'oauth'):
+ real_type = getattr(connector, "type", None) or getattr(connection, "connector_type", None)
+ if real_type is None:
+ return JSONResponse({"error": "Unable to determine connector type"}, status_code=500)
+
+ # Optional: warn if URL path type disagrees with real type
+ if url_connector_type and url_connector_type != real_type:
+ # You can downgrade this to debug if you expect cross-routing.
+ return JSONResponse(
+ {
+ "error": "Connector type mismatch",
+ "detail": {
+ "requested_type": url_connector_type,
+ "actual_type": real_type,
+ "hint": "Call the token endpoint using the correct connector_type for this connection_id.",
+ },
+ },
+ status_code=400,
+ )
+
+ # 3) Branch by the actual connector type
+ # GOOGLE DRIVE (google-auth)
+ if real_type == "google_drive" and hasattr(connector, "oauth"):
await connector.oauth.load_credentials()
if connector.oauth.creds and connector.oauth.creds.valid:
- return JSONResponse({
- "access_token": connector.oauth.creds.token,
- "expires_in": (connector.oauth.creds.expiry.timestamp() -
- __import__('time').time()) if connector.oauth.creds.expiry else None
- })
- else:
- return JSONResponse({"error": "Invalid or expired credentials"}, status_code=401)
-
- # For OneDrive and SharePoint, get the access token
- elif connector_type in ["onedrive", "sharepoint"] and hasattr(connector, 'oauth'):
+ expires_in = None
+ try:
+ if connector.oauth.creds.expiry:
+ import time
+ expires_in = max(0, int(connector.oauth.creds.expiry.timestamp() - time.time()))
+ except Exception:
+ expires_in = None
+
+ return JSONResponse(
+ {
+ "access_token": connector.oauth.creds.token,
+ "expires_in": expires_in,
+ }
+ )
+ return JSONResponse({"error": "Invalid or expired credentials"}, status_code=401)
+
+ # ONEDRIVE / SHAREPOINT (MSAL or custom)
+ if real_type in ("onedrive", "sharepoint") and hasattr(connector, "oauth"):
+ # Ensure cache/credentials are loaded before trying to use them
try:
+ # Prefer a dedicated is_authenticated() that loads cache internally
+ if hasattr(connector.oauth, "is_authenticated"):
+ ok = await connector.oauth.is_authenticated()
+ else:
+ # Fallback: try to load credentials explicitly if available
+ ok = True
+ if hasattr(connector.oauth, "load_credentials"):
+ ok = await connector.oauth.load_credentials()
+
+ if not ok:
+ return JSONResponse({"error": "Not authenticated"}, status_code=401)
+
+ # Now safe to fetch access token
access_token = connector.oauth.get_access_token()
- return JSONResponse({
- "access_token": access_token,
- "expires_in": None # MSAL handles token expiry internally
- })
+ # MSAL result has expiry, but we’re returning a raw token; keep expires_in None for simplicity
+ return JSONResponse({"access_token": access_token, "expires_in": None})
except ValueError as e:
+ # Typical when acquire_token_silent fails (e.g., needs re-auth)
return JSONResponse({"error": f"Failed to get access token: {str(e)}"}, status_code=401)
except Exception as e:
return JSONResponse({"error": f"Authentication error: {str(e)}"}, status_code=500)
@@ -371,7 +433,5 @@ async def connector_token(request: Request, connector_service, session_manager):
return JSONResponse({"error": "Token not available for this connector type"}, status_code=400)
except Exception as e:
- logger.error("Error getting connector token", error=str(e))
+ logger.error("Error getting connector token", exc_info=True)
return JSONResponse({"error": str(e)}, status_code=500)
-
-
diff --git a/src/api/settings.py b/src/api/settings.py
index 3e242c4b..c2c7cbd0 100644
--- a/src/api/settings.py
+++ b/src/api/settings.py
@@ -556,6 +556,19 @@ async def onboarding(request, flows_service):
)
# Continue even if setting global variables fails
+ # Initialize the OpenSearch index now that we have the embedding model configured
+ try:
+ # Import here to avoid circular imports
+ from main import init_index
+
+ logger.info("Initializing OpenSearch index after onboarding configuration")
+ await init_index()
+ logger.info("OpenSearch index initialization completed successfully")
+ except Exception as e:
+ logger.error("Failed to initialize OpenSearch index after onboarding", error=str(e))
+ # Don't fail the entire onboarding process if index creation fails
+ # The application can still work, but document operations may fail
+
# Handle sample data ingestion if requested
if should_ingest_sample_data:
try:
diff --git a/src/config/config_manager.py b/src/config/config_manager.py
index 055d48a7..0b814470 100644
--- a/src/config/config_manager.py
+++ b/src/config/config_manager.py
@@ -16,6 +16,8 @@ class ProviderConfig:
model_provider: str = "openai" # openai, anthropic, etc.
api_key: str = ""
+ endpoint: str = "" # For providers like Watson/IBM that need custom endpoints
+ project_id: str = "" # For providers like Watson/IBM that need project IDs
@dataclass
@@ -129,6 +131,10 @@ class ConfigManager:
config_data["provider"]["model_provider"] = os.getenv("MODEL_PROVIDER")
if os.getenv("PROVIDER_API_KEY"):
config_data["provider"]["api_key"] = os.getenv("PROVIDER_API_KEY")
+ if os.getenv("PROVIDER_ENDPOINT"):
+ config_data["provider"]["endpoint"] = os.getenv("PROVIDER_ENDPOINT")
+ if os.getenv("PROVIDER_PROJECT_ID"):
+ config_data["provider"]["project_id"] = os.getenv("PROVIDER_PROJECT_ID")
# Backward compatibility for OpenAI
if os.getenv("OPENAI_API_KEY"):
config_data["provider"]["api_key"] = os.getenv("OPENAI_API_KEY")
diff --git a/src/config/settings.py b/src/config/settings.py
index 5f9b189d..3bf1e6cf 100644
--- a/src/config/settings.py
+++ b/src/config/settings.py
@@ -78,6 +78,31 @@ INDEX_NAME = "documents"
VECTOR_DIM = 1536
EMBED_MODEL = "text-embedding-3-small"
+OPENAI_EMBEDDING_DIMENSIONS = {
+ "text-embedding-3-small": 1536,
+ "text-embedding-3-large": 3072,
+ "text-embedding-ada-002": 1536,
+ }
+
+OLLAMA_EMBEDDING_DIMENSIONS = {
+ "nomic-embed-text": 768,
+ "all-minilm": 384,
+ "mxbai-embed-large": 1024,
+}
+
+WATSONX_EMBEDDING_DIMENSIONS = {
+# IBM Models
+"ibm/granite-embedding-107m-multilingual": 384,
+"ibm/granite-embedding-278m-multilingual": 1024,
+"ibm/slate-125m-english-rtrvr": 768,
+"ibm/slate-125m-english-rtrvr-v2": 768,
+"ibm/slate-30m-english-rtrvr": 384,
+"ibm/slate-30m-english-rtrvr-v2": 384,
+# Third Party Models
+"intfloat/multilingual-e5-large": 1024,
+"sentence-transformers/all-minilm-l6-v2": 384,
+}
+
INDEX_BODY = {
"settings": {
"index": {"knn": True},
diff --git a/src/connectors/connection_manager.py b/src/connectors/connection_manager.py
index 2e70ee1f..07ebd5ee 100644
--- a/src/connectors/connection_manager.py
+++ b/src/connectors/connection_manager.py
@@ -294,32 +294,39 @@ class ConnectionManager:
async def get_connector(self, connection_id: str) -> Optional[BaseConnector]:
"""Get an active connector instance"""
+ logger.debug(f"Getting connector for connection_id: {connection_id}")
+
# Return cached connector if available
if connection_id in self.active_connectors:
connector = self.active_connectors[connection_id]
if connector.is_authenticated:
+ logger.debug(f"Returning cached authenticated connector for {connection_id}")
return connector
else:
# Remove unauthenticated connector from cache
+ logger.debug(f"Removing unauthenticated connector from cache for {connection_id}")
del self.active_connectors[connection_id]
# Try to create and authenticate connector
connection_config = self.connections.get(connection_id)
if not connection_config or not connection_config.is_active:
+ logger.debug(f"No active connection config found for {connection_id}")
return None
+ logger.debug(f"Creating connector for {connection_config.connector_type}")
connector = self._create_connector(connection_config)
- if await connector.authenticate():
+
+ logger.debug(f"Attempting authentication for {connection_id}")
+ auth_result = await connector.authenticate()
+ logger.debug(f"Authentication result for {connection_id}: {auth_result}")
+
+ if auth_result:
self.active_connectors[connection_id] = connector
-
- # Setup webhook subscription if not already set up
- await self._setup_webhook_if_needed(
- connection_id, connection_config, connector
- )
-
+ # ... rest of the method
return connector
-
- return None
+ else:
+ logger.warning(f"Authentication failed for {connection_id}")
+ return None
def get_available_connector_types(self) -> Dict[str, Dict[str, Any]]:
"""Get available connector types with their metadata"""
@@ -363,20 +370,23 @@ class ConnectionManager:
def _create_connector(self, config: ConnectionConfig) -> BaseConnector:
"""Factory method to create connector instances"""
- if config.connector_type == "google_drive":
- return GoogleDriveConnector(config.config)
- elif config.connector_type == "sharepoint":
- return SharePointConnector(config.config)
- elif config.connector_type == "onedrive":
- return OneDriveConnector(config.config)
- elif config.connector_type == "box":
- # Future: BoxConnector(config.config)
- raise NotImplementedError("Box connector not implemented yet")
- elif config.connector_type == "dropbox":
- # Future: DropboxConnector(config.config)
- raise NotImplementedError("Dropbox connector not implemented yet")
- else:
- raise ValueError(f"Unknown connector type: {config.connector_type}")
+ try:
+ if config.connector_type == "google_drive":
+ return GoogleDriveConnector(config.config)
+ elif config.connector_type == "sharepoint":
+ return SharePointConnector(config.config)
+ elif config.connector_type == "onedrive":
+ return OneDriveConnector(config.config)
+ elif config.connector_type == "box":
+ raise NotImplementedError("Box connector not implemented yet")
+ elif config.connector_type == "dropbox":
+ raise NotImplementedError("Dropbox connector not implemented yet")
+ else:
+ raise ValueError(f"Unknown connector type: {config.connector_type}")
+ except Exception as e:
+ logger.error(f"Failed to create {config.connector_type} connector: {e}")
+ # Re-raise the exception so caller can handle appropriately
+ raise
async def update_last_sync(self, connection_id: str):
"""Update the last sync timestamp for a connection"""
diff --git a/src/connectors/google_drive/connector.py b/src/connectors/google_drive/connector.py
index 0aa4234a..37ebdc8a 100644
--- a/src/connectors/google_drive/connector.py
+++ b/src/connectors/google_drive/connector.py
@@ -477,7 +477,7 @@ class GoogleDriveConnector(BaseConnector):
"next_page_token": None, # no more pages
}
except Exception as e:
- # Optionally log error with your base class logger
+ # Log the error
try:
logger.error(f"GoogleDriveConnector.list_files failed: {e}")
except Exception:
@@ -495,7 +495,6 @@ class GoogleDriveConnector(BaseConnector):
try:
blob = self._download_file_bytes(meta)
except Exception as e:
- # Use your base class logger if available
try:
logger.error(f"Download failed for {file_id}: {e}")
except Exception:
@@ -562,7 +561,6 @@ class GoogleDriveConnector(BaseConnector):
if not self.cfg.changes_page_token:
self.cfg.changes_page_token = self.get_start_page_token()
except Exception as e:
- # Optional: use your base logger
try:
logger.error(f"Failed to get start page token: {e}")
except Exception:
@@ -593,7 +591,6 @@ class GoogleDriveConnector(BaseConnector):
expiration = result.get("expiration")
# Persist in-memory so cleanup can stop this channel later.
- # If your project has a persistence layer, save these values there.
self._active_channel = {
"channel_id": channel_id,
"resource_id": resource_id,
@@ -803,7 +800,7 @@ class GoogleDriveConnector(BaseConnector):
"""
Perform a one-shot sync of the currently selected scope and emit documents.
- Emits ConnectorDocument instances (adapt to your BaseConnector ingestion).
+ Emits ConnectorDocument instances
"""
items = self._iter_selected_items()
for meta in items:
diff --git a/src/connectors/onedrive/connector.py b/src/connectors/onedrive/connector.py
index 8b800b3d..3ef4bdaf 100644
--- a/src/connectors/onedrive/connector.py
+++ b/src/connectors/onedrive/connector.py
@@ -1,223 +1,494 @@
+import logging
+from pathlib import Path
+from typing import List, Dict, Any, Optional
+from datetime import datetime
import httpx
-import uuid
-from datetime import datetime, timedelta
-from typing import Dict, List, Any, Optional
from ..base import BaseConnector, ConnectorDocument, DocumentACL
from .oauth import OneDriveOAuth
+logger = logging.getLogger(__name__)
+
class OneDriveConnector(BaseConnector):
- """OneDrive connector using Microsoft Graph API"""
+ """OneDrive connector using MSAL-based OAuth for authentication."""
+ # Required BaseConnector class attributes
CLIENT_ID_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_ID"
CLIENT_SECRET_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET"
# Connector metadata
CONNECTOR_NAME = "OneDrive"
- CONNECTOR_DESCRIPTION = "Connect your personal OneDrive to sync documents"
+ CONNECTOR_DESCRIPTION = "Connect to OneDrive (personal) to sync documents and files"
CONNECTOR_ICON = "onedrive"
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
- self.oauth = OneDriveOAuth(
- client_id=self.get_client_id(),
- client_secret=self.get_client_secret(),
- token_file=config.get("token_file", "onedrive_token.json"),
- )
- self.subscription_id = config.get("subscription_id") or config.get(
- "webhook_channel_id"
- )
- self.base_url = "https://graph.microsoft.com/v1.0"
- async def authenticate(self) -> bool:
- if await self.oauth.is_authenticated():
- self._authenticated = True
- return True
- return False
+ logger.debug(f"OneDrive connector __init__ called with config type: {type(config)}")
+ logger.debug(f"OneDrive connector __init__ config value: {config}")
- async def setup_subscription(self) -> str:
- if not self._authenticated:
- raise ValueError("Not authenticated")
+ if config is None:
+ logger.debug("Config was None, using empty dict")
+ config = {}
- webhook_url = self.config.get("webhook_url")
- if not webhook_url:
- raise ValueError("webhook_url required in config for subscriptions")
+ try:
+ logger.debug("Calling super().__init__")
+ super().__init__(config)
+ logger.debug("super().__init__ completed successfully")
+ except Exception as e:
+ logger.error(f"super().__init__ failed: {e}")
+ raise
- expiration = (datetime.utcnow() + timedelta(days=2)).isoformat() + "Z"
- body = {
- "changeType": "created,updated,deleted",
- "notificationUrl": webhook_url,
- "resource": "/me/drive/root",
- "expirationDateTime": expiration,
- "clientState": str(uuid.uuid4()),
+ # Initialize with defaults that allow the connector to be listed
+ self.client_id = None
+ self.client_secret = None
+ self.redirect_uri = config.get("redirect_uri", "http://localhost")
+
+ # Try to get credentials, but don't fail if they're missing
+ try:
+ self.client_id = self.get_client_id()
+ logger.debug(f"Got client_id: {self.client_id is not None}")
+ except Exception as e:
+ logger.debug(f"Failed to get client_id: {e}")
+
+ try:
+ self.client_secret = self.get_client_secret()
+ logger.debug(f"Got client_secret: {self.client_secret is not None}")
+ except Exception as e:
+ logger.debug(f"Failed to get client_secret: {e}")
+
+ # Token file setup
+ project_root = Path(__file__).resolve().parent.parent.parent.parent
+ token_file = config.get("token_file") or str(project_root / "onedrive_token.json")
+ Path(token_file).parent.mkdir(parents=True, exist_ok=True)
+
+ # Only initialize OAuth if we have credentials
+ if self.client_id and self.client_secret:
+ connection_id = config.get("connection_id", "default")
+
+ # Use token_file from config if provided, otherwise generate one
+ if config.get("token_file"):
+ oauth_token_file = config["token_file"]
+ else:
+ # Use a per-connection cache file to avoid collisions with other connectors
+ oauth_token_file = f"onedrive_token_{connection_id}.json"
+
+ # MSA & org both work via /common for OneDrive personal testing
+ authority = "https://login.microsoftonline.com/common"
+
+ self.oauth = OneDriveOAuth(
+ client_id=self.client_id,
+ client_secret=self.client_secret,
+ token_file=oauth_token_file,
+ authority=authority,
+ allow_json_refresh=True, # allows one-time migration from legacy JSON if present
+ )
+ else:
+ self.oauth = None
+
+ # Track subscription ID for webhooks (note: change notifications might not be available for personal accounts)
+ self._subscription_id: Optional[str] = None
+
+ # Graph API defaults
+ self._graph_api_version = "v1.0"
+ self._default_params = {
+ "$select": "id,name,size,lastModifiedDateTime,createdDateTime,webUrl,file,folder,@microsoft.graph.downloadUrl"
}
- token = self.oauth.get_access_token()
- async with httpx.AsyncClient() as client:
- resp = await client.post(
- f"{self.base_url}/subscriptions",
- json=body,
- headers={"Authorization": f"Bearer {token}"},
- )
- resp.raise_for_status()
- data = resp.json()
+ @property
+ def _graph_base_url(self) -> str:
+ """Base URL for Microsoft Graph API calls."""
+ return f"https://graph.microsoft.com/{self._graph_api_version}"
- self.subscription_id = data["id"]
- return self.subscription_id
+ def emit(self, doc: ConnectorDocument) -> None:
+ """Emit a ConnectorDocument instance."""
+ logger.debug(f"Emitting OneDrive document: {doc.id} ({doc.filename})")
+
+ async def authenticate(self) -> bool:
+ """Test authentication - BaseConnector interface."""
+ logger.debug(f"OneDrive authenticate() called, oauth is None: {self.oauth is None}")
+ try:
+ if not self.oauth:
+ logger.debug("OneDrive authentication failed: OAuth not initialized")
+ self._authenticated = False
+ return False
+
+ logger.debug("Loading OneDrive credentials...")
+ load_result = await self.oauth.load_credentials()
+ logger.debug(f"Load credentials result: {load_result}")
+
+ logger.debug("Checking OneDrive authentication status...")
+ authenticated = await self.oauth.is_authenticated()
+ logger.debug(f"OneDrive is_authenticated result: {authenticated}")
+
+ self._authenticated = authenticated
+ return authenticated
+ except Exception as e:
+ logger.error(f"OneDrive authentication failed: {e}")
+ import traceback
+ traceback.print_exc()
+ self._authenticated = False
+ return False
+
+ def get_auth_url(self) -> str:
+ """Get OAuth authorization URL."""
+ if not self.oauth:
+ raise RuntimeError("OneDrive OAuth not initialized - missing credentials")
+ return self.oauth.create_authorization_url(self.redirect_uri)
+
+ async def handle_oauth_callback(self, auth_code: str) -> Dict[str, Any]:
+ """Handle OAuth callback."""
+ if not self.oauth:
+ raise RuntimeError("OneDrive OAuth not initialized - missing credentials")
+ try:
+ success = await self.oauth.handle_authorization_callback(auth_code, self.redirect_uri)
+ if success:
+ self._authenticated = True
+ return {"status": "success"}
+ else:
+ raise ValueError("OAuth callback failed")
+ except Exception as e:
+ logger.error(f"OAuth callback failed: {e}")
+ raise
+
+ def sync_once(self) -> None:
+ """
+ Perform a one-shot sync of OneDrive files and emit documents.
+ """
+ import asyncio
+
+ async def _async_sync():
+ try:
+ file_list = await self.list_files(max_files=1000)
+ files = file_list.get("files", [])
+ for file_info in files:
+ try:
+ file_id = file_info.get("id")
+ if not file_id:
+ continue
+ doc = await self.get_file_content(file_id)
+ self.emit(doc)
+ except Exception as e:
+ logger.error(f"Failed to sync OneDrive file {file_info.get('name', 'unknown')}: {e}")
+ continue
+ except Exception as e:
+ logger.error(f"OneDrive sync_once failed: {e}")
+ raise
+
+ if hasattr(asyncio, 'run'):
+ asyncio.run(_async_sync())
+ else:
+ loop = asyncio.get_event_loop()
+ loop.run_until_complete(_async_sync())
+
+ async def setup_subscription(self) -> str:
+ """
+ Set up real-time subscription for file changes.
+ NOTE: Change notifications may not be available for personal OneDrive accounts.
+ """
+ webhook_url = self.config.get('webhook_url')
+ if not webhook_url:
+ logger.warning("No webhook URL configured, skipping OneDrive subscription setup")
+ return "no-webhook-configured"
+
+ try:
+ if not await self.authenticate():
+ raise RuntimeError("OneDrive authentication failed during subscription setup")
+
+ token = self.oauth.get_access_token()
+
+ # For OneDrive personal we target the user's drive
+ resource = "/me/drive/root"
+
+ subscription_data = {
+ "changeType": "created,updated,deleted",
+ "notificationUrl": f"{webhook_url}/webhook/onedrive",
+ "resource": resource,
+ "expirationDateTime": self._get_subscription_expiry(),
+ "clientState": "onedrive_personal",
+ }
+
+ headers = {
+ "Authorization": f"Bearer {token}",
+ "Content-Type": "application/json",
+ }
+
+ url = f"{self._graph_base_url}/subscriptions"
+
+ async with httpx.AsyncClient() as client:
+ response = await client.post(url, json=subscription_data, headers=headers, timeout=30)
+ response.raise_for_status()
+
+ result = response.json()
+ subscription_id = result.get("id")
+
+ if subscription_id:
+ self._subscription_id = subscription_id
+ logger.info(f"OneDrive subscription created: {subscription_id}")
+ return subscription_id
+ else:
+ raise ValueError("No subscription ID returned from Microsoft Graph")
+
+ except Exception as e:
+ logger.error(f"Failed to setup OneDrive subscription: {e}")
+ raise
+
+ def _get_subscription_expiry(self) -> str:
+ """Get subscription expiry time (Graph caps duration; often <= 3 days)."""
+ from datetime import datetime, timedelta
+ expiry = datetime.utcnow() + timedelta(days=3)
+ return expiry.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
async def list_files(
- self, page_token: Optional[str] = None, limit: int = 100
+ self,
+ page_token: Optional[str] = None,
+ max_files: Optional[int] = None,
+ **kwargs
) -> Dict[str, Any]:
- if not self._authenticated:
- raise ValueError("Not authenticated")
+ """List files from OneDrive using Microsoft Graph."""
+ try:
+ if not await self.authenticate():
+ raise RuntimeError("OneDrive authentication failed during file listing")
- params = {"$top": str(limit)}
- if page_token:
- params["$skiptoken"] = page_token
+ files: List[Dict[str, Any]] = []
+ max_files_value = max_files if max_files is not None else 100
- token = self.oauth.get_access_token()
- async with httpx.AsyncClient() as client:
- resp = await client.get(
- f"{self.base_url}/me/drive/root/children",
- params=params,
- headers={"Authorization": f"Bearer {token}"},
- )
- resp.raise_for_status()
- data = resp.json()
+ base_url = f"{self._graph_base_url}/me/drive/root/children"
- files = []
- for item in data.get("value", []):
- if item.get("file"):
- files.append(
- {
- "id": item["id"],
- "name": item["name"],
- "mimeType": item.get("file", {}).get(
- "mimeType", "application/octet-stream"
- ),
- "webViewLink": item.get("webUrl"),
- "createdTime": item.get("createdDateTime"),
- "modifiedTime": item.get("lastModifiedDateTime"),
- }
- )
+ params = dict(self._default_params)
+ params["$top"] = str(max_files_value)
- next_token = None
- next_link = data.get("@odata.nextLink")
- if next_link:
- from urllib.parse import urlparse, parse_qs
+ if page_token:
+ params["$skiptoken"] = page_token
- parsed = urlparse(next_link)
- next_token = parse_qs(parsed.query).get("$skiptoken", [None])[0]
+ response = await self._make_graph_request(base_url, params=params)
+ data = response.json()
- return {"files": files, "nextPageToken": next_token}
+ items = data.get("value", [])
+ for item in items:
+ if item.get("file"): # include files only
+ files.append({
+ "id": item.get("id", ""),
+ "name": item.get("name", ""),
+ "path": f"/drive/items/{item.get('id')}",
+ "size": int(item.get("size", 0)),
+ "modified": item.get("lastModifiedDateTime"),
+ "created": item.get("createdDateTime"),
+ "mime_type": item.get("file", {}).get("mimeType", self._get_mime_type(item.get("name", ""))),
+ "url": item.get("webUrl", ""),
+ "download_url": item.get("@microsoft.graph.downloadUrl"),
+ })
+
+ # Next page
+ next_page_token = None
+ next_link = data.get("@odata.nextLink")
+ if next_link:
+ from urllib.parse import urlparse, parse_qs
+ parsed = urlparse(next_link)
+ query_params = parse_qs(parsed.query)
+ if "$skiptoken" in query_params:
+ next_page_token = query_params["$skiptoken"][0]
+
+ return {"files": files, "next_page_token": next_page_token}
+
+ except Exception as e:
+ logger.error(f"Failed to list OneDrive files: {e}")
+ return {"files": [], "next_page_token": None}
async def get_file_content(self, file_id: str) -> ConnectorDocument:
- if not self._authenticated:
- raise ValueError("Not authenticated")
+ """Get file content and metadata."""
+ try:
+ if not await self.authenticate():
+ raise RuntimeError("OneDrive authentication failed during file content retrieval")
+ file_metadata = await self._get_file_metadata_by_id(file_id)
+ if not file_metadata:
+ raise ValueError(f"File not found: {file_id}")
+
+ download_url = file_metadata.get("download_url")
+ if download_url:
+ content = await self._download_file_from_url(download_url)
+ else:
+ content = await self._download_file_content(file_id)
+
+ acl = DocumentACL(
+ owner="",
+ user_permissions={},
+ group_permissions={},
+ )
+
+ modified_time = self._parse_graph_date(file_metadata.get("modified"))
+ created_time = self._parse_graph_date(file_metadata.get("created"))
+
+ return ConnectorDocument(
+ id=file_id,
+ filename=file_metadata.get("name", ""),
+ mimetype=file_metadata.get("mime_type", "application/octet-stream"),
+ content=content,
+ source_url=file_metadata.get("url", ""),
+ acl=acl,
+ modified_time=modified_time,
+ created_time=created_time,
+ metadata={
+ "onedrive_path": file_metadata.get("path", ""),
+ "size": file_metadata.get("size", 0),
+ },
+ )
+
+ except Exception as e:
+ logger.error(f"Failed to get OneDrive file content {file_id}: {e}")
+ raise
+
+ async def _get_file_metadata_by_id(self, file_id: str) -> Optional[Dict[str, Any]]:
+ """Get file metadata by ID using Graph API."""
+ try:
+ url = f"{self._graph_base_url}/me/drive/items/{file_id}"
+ params = dict(self._default_params)
+
+ response = await self._make_graph_request(url, params=params)
+ item = response.json()
+
+ if item.get("file"):
+ return {
+ "id": file_id,
+ "name": item.get("name", ""),
+ "path": f"/drive/items/{file_id}",
+ "size": int(item.get("size", 0)),
+ "modified": item.get("lastModifiedDateTime"),
+ "created": item.get("createdDateTime"),
+ "mime_type": item.get("file", {}).get("mimeType", self._get_mime_type(item.get("name", ""))),
+ "url": item.get("webUrl", ""),
+ "download_url": item.get("@microsoft.graph.downloadUrl"),
+ }
+
+ return None
+
+ except Exception as e:
+ logger.error(f"Failed to get file metadata for {file_id}: {e}")
+ return None
+
+ async def _download_file_content(self, file_id: str) -> bytes:
+ """Download file content by file ID using Graph API."""
+ try:
+ url = f"{self._graph_base_url}/me/drive/items/{file_id}/content"
+ token = self.oauth.get_access_token()
+ headers = {"Authorization": f"Bearer {token}"}
+
+ async with httpx.AsyncClient() as client:
+ response = await client.get(url, headers=headers, timeout=60)
+ response.raise_for_status()
+ return response.content
+
+ except Exception as e:
+ logger.error(f"Failed to download file content for {file_id}: {e}")
+ raise
+
+ async def _download_file_from_url(self, download_url: str) -> bytes:
+ """Download file content from direct download URL."""
+ try:
+ async with httpx.AsyncClient() as client:
+ response = await client.get(download_url, timeout=60)
+ response.raise_for_status()
+ return response.content
+ except Exception as e:
+ logger.error(f"Failed to download from URL {download_url}: {e}")
+ raise
+
+ def _parse_graph_date(self, date_str: Optional[str]) -> datetime:
+ """Parse Microsoft Graph date string to datetime."""
+ if not date_str:
+ return datetime.now()
+ try:
+ if date_str.endswith('Z'):
+ return datetime.fromisoformat(date_str[:-1]).replace(tzinfo=None)
+ else:
+ return datetime.fromisoformat(date_str.replace('T', ' '))
+ except (ValueError, AttributeError):
+ return datetime.now()
+
+ async def _make_graph_request(self, url: str, method: str = "GET",
+ data: Optional[Dict] = None, params: Optional[Dict] = None) -> httpx.Response:
+ """Make authenticated API request to Microsoft Graph."""
token = self.oauth.get_access_token()
- headers = {"Authorization": f"Bearer {token}"}
+ headers = {
+ "Authorization": f"Bearer {token}",
+ "Content-Type": "application/json",
+ }
+
async with httpx.AsyncClient() as client:
- meta_resp = await client.get(
- f"{self.base_url}/me/drive/items/{file_id}", headers=headers
- )
- meta_resp.raise_for_status()
- metadata = meta_resp.json()
+ if method.upper() == "GET":
+ response = await client.get(url, headers=headers, params=params, timeout=30)
+ elif method.upper() == "POST":
+ response = await client.post(url, headers=headers, json=data, timeout=30)
+ elif method.upper() == "DELETE":
+ response = await client.delete(url, headers=headers, timeout=30)
+ else:
+ raise ValueError(f"Unsupported HTTP method: {method}")
- content_resp = await client.get(
- f"{self.base_url}/me/drive/items/{file_id}/content", headers=headers
- )
- content_resp.raise_for_status()
- content = content_resp.content
+ response.raise_for_status()
+ return response
- perm_resp = await client.get(
- f"{self.base_url}/me/drive/items/{file_id}/permissions", headers=headers
- )
- perm_resp.raise_for_status()
- permissions = perm_resp.json()
+ def _get_mime_type(self, filename: str) -> str:
+ """Get MIME type based on file extension."""
+ import mimetypes
+ mime_type, _ = mimetypes.guess_type(filename)
+ return mime_type or "application/octet-stream"
- acl = self._parse_permissions(metadata, permissions)
- modified = datetime.fromisoformat(
- metadata["lastModifiedDateTime"].replace("Z", "+00:00")
- ).replace(tzinfo=None)
- created = datetime.fromisoformat(
- metadata["createdDateTime"].replace("Z", "+00:00")
- ).replace(tzinfo=None)
-
- document = ConnectorDocument(
- id=metadata["id"],
- filename=metadata["name"],
- mimetype=metadata.get("file", {}).get(
- "mimeType", "application/octet-stream"
- ),
- content=content,
- source_url=metadata.get("webUrl"),
- acl=acl,
- modified_time=modified,
- created_time=created,
- metadata={"size": metadata.get("size")},
- )
- return document
-
- def _parse_permissions(
- self, metadata: Dict[str, Any], permissions: Dict[str, Any]
- ) -> DocumentACL:
- acl = DocumentACL()
- owner = metadata.get("createdBy", {}).get("user", {}).get("email")
- if owner:
- acl.owner = owner
- for perm in permissions.get("value", []):
- role = perm.get("roles", ["read"])[0]
- grantee = perm.get("grantedToV2") or perm.get("grantedTo")
- if not grantee:
- continue
- user = grantee.get("user")
- if user and user.get("email"):
- acl.user_permissions[user["email"]] = role
- group = grantee.get("group")
- if group and group.get("email"):
- acl.group_permissions[group["email"]] = role
- return acl
-
- def handle_webhook_validation(
- self, request_method: str, headers: Dict[str, str], query_params: Dict[str, str]
- ) -> Optional[str]:
- """Handle Microsoft Graph webhook validation"""
- if request_method == "GET":
- validation_token = query_params.get("validationtoken") or query_params.get(
- "validationToken"
- )
- if validation_token:
- return validation_token
+ # Webhook methods - BaseConnector interface
+ def handle_webhook_validation(self, request_method: str,
+ headers: Dict[str, str],
+ query_params: Dict[str, str]) -> Optional[str]:
+ """Handle webhook validation (Graph API specific)."""
+ if request_method == "POST" and "validationToken" in query_params:
+ return query_params["validationToken"]
return None
- def extract_webhook_channel_id(
- self, payload: Dict[str, Any], headers: Dict[str, str]
- ) -> Optional[str]:
- """Extract SharePoint subscription ID from webhook payload"""
- values = payload.get("value", [])
- return values[0].get("subscriptionId") if values else None
+ def extract_webhook_channel_id(self, payload: Dict[str, Any],
+ headers: Dict[str, str]) -> Optional[str]:
+ """Extract channel/subscription ID from webhook payload."""
+ notifications = payload.get("value", [])
+ if notifications:
+ return notifications[0].get("subscriptionId")
+ return None
async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]:
- values = payload.get("value", [])
- file_ids = []
- for item in values:
- resource_data = item.get("resourceData", {})
- file_id = resource_data.get("id")
- if file_id:
- file_ids.append(file_id)
- return file_ids
+ """Handle webhook notification and return affected file IDs."""
+ affected_files: List[str] = []
+ notifications = payload.get("value", [])
+ for notification in notifications:
+ resource = notification.get("resource")
+ if resource and "/drive/items/" in resource:
+ file_id = resource.split("/drive/items/")[-1]
+ affected_files.append(file_id)
+ return affected_files
- async def cleanup_subscription(
- self, subscription_id: str, resource_id: str = None
- ) -> bool:
- if not self._authenticated:
+ async def cleanup_subscription(self, subscription_id: str) -> bool:
+ """Clean up subscription - BaseConnector interface."""
+ if subscription_id == "no-webhook-configured":
+ logger.info("No subscription to cleanup (webhook was not configured)")
+ return True
+
+ try:
+ if not await self.authenticate():
+ logger.error("OneDrive authentication failed during subscription cleanup")
+ return False
+
+ token = self.oauth.get_access_token()
+ headers = {"Authorization": f"Bearer {token}"}
+
+ url = f"{self._graph_base_url}/subscriptions/{subscription_id}"
+
+ async with httpx.AsyncClient() as client:
+ response = await client.delete(url, headers=headers, timeout=30)
+
+ if response.status_code in [200, 204, 404]:
+ logger.info(f"OneDrive subscription {subscription_id} cleaned up successfully")
+ return True
+ else:
+ logger.warning(f"Unexpected response cleaning up subscription: {response.status_code}")
+ return False
+
+ except Exception as e:
+ logger.error(f"Failed to cleanup OneDrive subscription {subscription_id}: {e}")
return False
- token = self.oauth.get_access_token()
- async with httpx.AsyncClient() as client:
- resp = await client.delete(
- f"{self.base_url}/subscriptions/{subscription_id}",
- headers={"Authorization": f"Bearer {token}"},
- )
- return resp.status_code in (200, 204)
diff --git a/src/connectors/onedrive/oauth.py b/src/connectors/onedrive/oauth.py
index a81124e6..a2c94d15 100644
--- a/src/connectors/onedrive/oauth.py
+++ b/src/connectors/onedrive/oauth.py
@@ -1,17 +1,28 @@
import os
+import json
+import logging
+from typing import Optional, Dict, Any
+
import aiofiles
-from typing import Optional
import msal
+logger = logging.getLogger(__name__)
+
class OneDriveOAuth:
- """Handles Microsoft Graph OAuth authentication flow"""
+ """Handles Microsoft Graph OAuth for OneDrive (personal Microsoft accounts by default)."""
- SCOPES = [
- "offline_access",
- "Files.Read.All",
- ]
+ # Reserved scopes that must NOT be sent on token or silent calls
+ RESERVED_SCOPES = {"openid", "profile", "offline_access"}
+ # For PERSONAL Microsoft Accounts (OneDrive consumer):
+ # - Use AUTH_SCOPES for interactive auth (consent + refresh token issuance)
+ # - Use RESOURCE_SCOPES for acquire_token_silent / refresh paths
+ AUTH_SCOPES = ["User.Read", "Files.Read.All", "offline_access"]
+ RESOURCE_SCOPES = ["User.Read", "Files.Read.All"]
+ SCOPES = AUTH_SCOPES # Backward-compat alias if something references .SCOPES
+
+ # Kept for reference; MSAL derives endpoints from `authority`
AUTH_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
TOKEN_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/token"
@@ -21,18 +32,29 @@ class OneDriveOAuth:
client_secret: str,
token_file: str = "onedrive_token.json",
authority: str = "https://login.microsoftonline.com/common",
+ allow_json_refresh: bool = True,
):
+ """
+ Initialize OneDriveOAuth.
+
+ Args:
+ client_id: Azure AD application (client) ID.
+ client_secret: Azure AD application client secret.
+ token_file: Path to persisted token cache file (MSAL cache format).
+ authority: Usually "https://login.microsoftonline.com/common" for MSA + org,
+ or tenant-specific for work/school.
+ allow_json_refresh: If True, permit one-time migration from legacy flat JSON
+ {"access_token","refresh_token",...}. Otherwise refuse it.
+ """
self.client_id = client_id
self.client_secret = client_secret
self.token_file = token_file
self.authority = authority
+ self.allow_json_refresh = allow_json_refresh
self.token_cache = msal.SerializableTokenCache()
+ self._current_account = None
- # Load existing cache if available
- if os.path.exists(self.token_file):
- with open(self.token_file, "r") as f:
- self.token_cache.deserialize(f.read())
-
+ # Initialize MSAL Confidential Client
self.app = msal.ConfidentialClientApplication(
client_id=self.client_id,
client_credential=self.client_secret,
@@ -40,56 +62,261 @@ class OneDriveOAuth:
token_cache=self.token_cache,
)
- async def save_cache(self):
- """Persist the token cache to file"""
- async with aiofiles.open(self.token_file, "w") as f:
- await f.write(self.token_cache.serialize())
+ async def load_credentials(self) -> bool:
+ """Load existing credentials from token file (async)."""
+ try:
+ logger.debug(f"OneDrive OAuth loading credentials from: {self.token_file}")
+ if os.path.exists(self.token_file):
+ logger.debug(f"Token file exists, reading: {self.token_file}")
- def create_authorization_url(self, redirect_uri: str) -> str:
- """Create authorization URL for OAuth flow"""
- return self.app.get_authorization_request_url(
- self.SCOPES, redirect_uri=redirect_uri
- )
+ # Read the token file
+ async with aiofiles.open(self.token_file, "r") as f:
+ cache_data = await f.read()
+ logger.debug(f"Read {len(cache_data)} chars from token file")
+
+ if cache_data.strip():
+ # 1) Try legacy flat JSON first
+ try:
+ json_data = json.loads(cache_data)
+ if isinstance(json_data, dict) and "refresh_token" in json_data:
+ if self.allow_json_refresh:
+ logger.debug(
+ "Found legacy JSON refresh_token and allow_json_refresh=True; attempting migration refresh"
+ )
+ return await self._refresh_from_json_token(json_data)
+ else:
+ logger.warning(
+ "Token file contains a legacy JSON refresh_token, but allow_json_refresh=False. "
+ "Delete the file and re-auth."
+ )
+ return False
+ except json.JSONDecodeError:
+ logger.debug("Token file is not flat JSON; attempting MSAL cache format")
+
+ # 2) Try MSAL cache format
+ logger.debug("Attempting MSAL cache deserialization")
+ self.token_cache.deserialize(cache_data)
+
+ # Get accounts from loaded cache
+ accounts = self.app.get_accounts()
+ logger.debug(f"Found {len(accounts)} accounts in MSAL cache")
+ if accounts:
+ self._current_account = accounts[0]
+ logger.debug(f"Set current account: {self._current_account.get('username', 'no username')}")
+
+ # Use RESOURCE_SCOPES (no reserved scopes) for silent acquisition
+ result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
+ logger.debug(f"Silent token acquisition result keys: {list(result.keys()) if result else 'None'}")
+ if result and "access_token" in result:
+ logger.debug("Silent token acquisition successful")
+ await self.save_cache()
+ return True
+ else:
+ error_msg = (result or {}).get("error") or "No result"
+ logger.warning(f"Silent token acquisition failed: {error_msg}")
+ else:
+ logger.debug(f"Token file {self.token_file} is empty")
+ else:
+ logger.debug(f"Token file does not exist: {self.token_file}")
+
+ return False
+
+ except Exception as e:
+ logger.error(f"Failed to load OneDrive credentials: {e}")
+ import traceback
+ traceback.print_exc()
+ return False
+
+ async def _refresh_from_json_token(self, token_data: dict) -> bool:
+ """
+ Use refresh token from a legacy JSON file to get new tokens (one-time migration path).
+ Prefer using an MSAL cache file and acquire_token_silent(); this path is only for migrating older files.
+ """
+ try:
+ refresh_token = token_data.get("refresh_token")
+ if not refresh_token:
+ logger.error("No refresh_token found in JSON file - cannot refresh")
+ logger.error("You must re-authenticate interactively to obtain a valid token")
+ return False
+
+ # Use only RESOURCE_SCOPES when refreshing (no reserved scopes)
+ refresh_scopes = [s for s in self.RESOURCE_SCOPES if s not in self.RESERVED_SCOPES]
+ logger.debug(f"Using refresh token; refresh scopes = {refresh_scopes}")
+
+ result = self.app.acquire_token_by_refresh_token(
+ refresh_token=refresh_token,
+ scopes=refresh_scopes,
+ )
+
+ if result and "access_token" in result:
+ logger.debug("Successfully refreshed token via legacy JSON path")
+ await self.save_cache()
+
+ accounts = self.app.get_accounts()
+ logger.debug(f"After refresh, found {len(accounts)} accounts")
+ if accounts:
+ self._current_account = accounts[0]
+ logger.debug(f"Set current account after refresh: {self._current_account.get('username', 'no username')}")
+ return True
+
+ # Error handling
+ err = (result or {}).get("error_description") or (result or {}).get("error") or "Unknown error"
+ logger.error(f"Refresh token failed: {err}")
+
+ if any(code in err for code in ("AADSTS70000", "invalid_grant", "interaction_required")):
+ logger.warning(
+ "Refresh denied due to unauthorized/expired scopes or invalid grant. "
+ "Delete the token file and perform interactive sign-in with correct scopes."
+ )
+
+ return False
+
+ except Exception as e:
+ logger.error(f"Exception during refresh from JSON token: {e}")
+ import traceback
+ traceback.print_exc()
+ return False
+
+ async def save_cache(self):
+ """Persist the token cache to file."""
+ try:
+ # Ensure parent directory exists
+ parent = os.path.dirname(os.path.abspath(self.token_file))
+ if parent and not os.path.exists(parent):
+ os.makedirs(parent, exist_ok=True)
+
+ cache_data = self.token_cache.serialize()
+ if cache_data:
+ async with aiofiles.open(self.token_file, "w") as f:
+ await f.write(cache_data)
+ logger.debug(f"Token cache saved to {self.token_file}")
+ except Exception as e:
+ logger.error(f"Failed to save token cache: {e}")
+
+ def create_authorization_url(self, redirect_uri: str, state: Optional[str] = None) -> str:
+ """Create authorization URL for OAuth flow."""
+ # Store redirect URI for later use in callback
+ self._redirect_uri = redirect_uri
+
+ kwargs: Dict[str, Any] = {
+ # Interactive auth includes offline_access
+ "scopes": self.AUTH_SCOPES,
+ "redirect_uri": redirect_uri,
+ "prompt": "consent", # ensure refresh token on first run
+ }
+ if state:
+ kwargs["state"] = state # Optional CSRF protection
+
+ auth_url = self.app.get_authorization_request_url(**kwargs)
+
+ logger.debug(f"Generated auth URL: {auth_url}")
+ logger.debug(f"Auth scopes: {self.AUTH_SCOPES}")
+
+ return auth_url
async def handle_authorization_callback(
self, authorization_code: str, redirect_uri: str
) -> bool:
- """Handle OAuth callback and exchange code for tokens"""
- result = self.app.acquire_token_by_authorization_code(
- authorization_code,
- scopes=self.SCOPES,
- redirect_uri=redirect_uri,
- )
- if "access_token" in result:
- await self.save_cache()
- return True
- raise ValueError(result.get("error_description") or "Authorization failed")
+ """Handle OAuth callback and exchange code for tokens."""
+ try:
+ result = self.app.acquire_token_by_authorization_code(
+ authorization_code,
+ scopes=self.AUTH_SCOPES, # same as authorize step
+ redirect_uri=redirect_uri,
+ )
+
+ if result and "access_token" in result:
+ accounts = self.app.get_accounts()
+ if accounts:
+ self._current_account = accounts[0]
+
+ await self.save_cache()
+ logger.info("OneDrive OAuth authorization successful")
+ return True
+
+ error_msg = (result or {}).get("error_description") or (result or {}).get("error") or "Unknown error"
+ logger.error(f"OneDrive OAuth authorization failed: {error_msg}")
+ return False
+
+ except Exception as e:
+ logger.error(f"Exception during OneDrive OAuth authorization: {e}")
+ return False
async def is_authenticated(self) -> bool:
- """Check if we have valid credentials"""
- accounts = self.app.get_accounts()
- if not accounts:
+ """Check if we have valid credentials."""
+ try:
+ # First try to load credentials if we haven't already
+ if not self._current_account:
+ await self.load_credentials()
+
+ # Try to get a token (MSAL will refresh if needed)
+ if self._current_account:
+ result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
+ if result and "access_token" in result:
+ return True
+ else:
+ error_msg = (result or {}).get("error") or "No result returned"
+ logger.debug(f"Token acquisition failed for current account: {error_msg}")
+
+ # Fallback: try without specific account
+ result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=None)
+ if result and "access_token" in result:
+ accounts = self.app.get_accounts()
+ if accounts:
+ self._current_account = accounts[0]
+ return True
+
+ return False
+
+ except Exception as e:
+ logger.error(f"Authentication check failed: {e}")
return False
- result = self.app.acquire_token_silent(self.SCOPES, account=accounts[0])
- if "access_token" in result:
- await self.save_cache()
- return True
- return False
def get_access_token(self) -> str:
- """Get an access token for Microsoft Graph"""
- accounts = self.app.get_accounts()
- if not accounts:
- raise ValueError("Not authenticated")
- result = self.app.acquire_token_silent(self.SCOPES, account=accounts[0])
- if "access_token" not in result:
- raise ValueError(
- result.get("error_description") or "Failed to acquire access token"
- )
- return result["access_token"]
+ """Get an access token for Microsoft Graph."""
+ try:
+ # Try with current account first
+ if self._current_account:
+ result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
+ if result and "access_token" in result:
+ return result["access_token"]
+
+ # Fallback: try without specific account
+ result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=None)
+ if result and "access_token" in result:
+ return result["access_token"]
+
+ # If we get here, authentication has failed
+ error_msg = (result or {}).get("error_description") or (result or {}).get("error") or "No valid authentication"
+ raise ValueError(f"Failed to acquire access token: {error_msg}")
+
+ except Exception as e:
+ logger.error(f"Failed to get access token: {e}")
+ raise
async def revoke_credentials(self):
- """Clear token cache and remove token file"""
- self.token_cache.clear()
- if os.path.exists(self.token_file):
- os.remove(self.token_file)
+ """Clear token cache and remove token file."""
+ try:
+ # Clear in-memory state
+ self._current_account = None
+ self.token_cache = msal.SerializableTokenCache()
+
+ # Recreate MSAL app with fresh cache
+ self.app = msal.ConfidentialClientApplication(
+ client_id=self.client_id,
+ client_credential=self.client_secret,
+ authority=self.authority,
+ token_cache=self.token_cache,
+ )
+
+ # Remove token file
+ if os.path.exists(self.token_file):
+ os.remove(self.token_file)
+ logger.info(f"Removed OneDrive token file: {self.token_file}")
+
+ except Exception as e:
+ logger.error(f"Failed to revoke OneDrive credentials: {e}")
+
+ def get_service(self) -> str:
+ """Return an access token (Graph client is just the bearer)."""
+ return self.get_access_token()
diff --git a/src/connectors/sharepoint/connector.py b/src/connectors/sharepoint/connector.py
index 7135cc8e..f8283062 100644
--- a/src/connectors/sharepoint/connector.py
+++ b/src/connectors/sharepoint/connector.py
@@ -1,229 +1,567 @@
+import logging
+from pathlib import Path
+from typing import List, Dict, Any, Optional
+from urllib.parse import urlparse
+from datetime import datetime
import httpx
-import uuid
-from datetime import datetime, timedelta
-from typing import Dict, List, Any, Optional
from ..base import BaseConnector, ConnectorDocument, DocumentACL
from .oauth import SharePointOAuth
+logger = logging.getLogger(__name__)
+
class SharePointConnector(BaseConnector):
- """SharePoint Sites connector using Microsoft Graph API"""
+ """SharePoint connector using MSAL-based OAuth for authentication"""
+ # Required BaseConnector class attributes
CLIENT_ID_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_ID"
CLIENT_SECRET_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET"
-
+
# Connector metadata
CONNECTOR_NAME = "SharePoint"
- CONNECTOR_DESCRIPTION = "Connect to SharePoint sites to sync team documents"
+ CONNECTOR_DESCRIPTION = "Connect to SharePoint to sync documents and files"
CONNECTOR_ICON = "sharepoint"
-
+
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
- self.oauth = SharePointOAuth(
- client_id=self.get_client_id(),
- client_secret=self.get_client_secret(),
- token_file=config.get("token_file", "sharepoint_token.json"),
- )
- self.subscription_id = config.get("subscription_id") or config.get(
- "webhook_channel_id"
- )
- self.base_url = "https://graph.microsoft.com/v1.0"
- # SharePoint site configuration
- self.site_id = config.get("site_id") # Required for SharePoint
+ logger.debug(f"SharePoint connector __init__ called with config type: {type(config)}")
+ logger.debug(f"SharePoint connector __init__ config value: {config}")
+
+ # Ensure we always pass a valid config to the base class
+ if config is None:
+ logger.debug("Config was None, using empty dict")
+ config = {}
+
+ try:
+ logger.debug("Calling super().__init__")
+ super().__init__(config) # Now safe to call with empty dict instead of None
+ logger.debug("super().__init__ completed successfully")
+ except Exception as e:
+ logger.error(f"super().__init__ failed: {e}")
+ raise
+
+ # Initialize with defaults that allow the connector to be listed
+ self.client_id = None
+ self.client_secret = None
+ self.tenant_id = config.get("tenant_id", "common")
+ self.sharepoint_url = config.get("sharepoint_url")
+ self.redirect_uri = config.get("redirect_uri", "http://localhost")
+
+ # Try to get credentials, but don't fail if they're missing
+ try:
+ logger.debug("Attempting to get client_id")
+ self.client_id = self.get_client_id()
+ logger.debug(f"Got client_id: {self.client_id is not None}")
+ except Exception as e:
+ logger.debug(f"Failed to get client_id: {e}")
+ pass # Credentials not available, that's OK for listing
+
+ try:
+ logger.debug("Attempting to get client_secret")
+ self.client_secret = self.get_client_secret()
+ logger.debug(f"Got client_secret: {self.client_secret is not None}")
+ except Exception as e:
+ logger.debug(f"Failed to get client_secret: {e}")
+ pass # Credentials not available, that's OK for listing
- async def authenticate(self) -> bool:
- if await self.oauth.is_authenticated():
- self._authenticated = True
- return True
- return False
-
- async def setup_subscription(self) -> str:
- if not self._authenticated:
- raise ValueError("Not authenticated")
-
- webhook_url = self.config.get("webhook_url")
- if not webhook_url:
- raise ValueError("webhook_url required in config for subscriptions")
-
- expiration = (datetime.utcnow() + timedelta(days=2)).isoformat() + "Z"
- body = {
- "changeType": "created,updated,deleted",
- "notificationUrl": webhook_url,
- "resource": f"/sites/{self.site_id}/drive/root",
- "expirationDateTime": expiration,
- "clientState": str(uuid.uuid4()),
+ # Token file setup
+ project_root = Path(__file__).resolve().parent.parent.parent.parent
+ token_file = config.get("token_file") or str(project_root / "sharepoint_token.json")
+ Path(token_file).parent.mkdir(parents=True, exist_ok=True)
+
+ # Only initialize OAuth if we have credentials
+ if self.client_id and self.client_secret:
+ connection_id = config.get("connection_id", "default")
+
+ # Use token_file from config if provided, otherwise generate one
+ if config.get("token_file"):
+ oauth_token_file = config["token_file"]
+ else:
+ oauth_token_file = f"sharepoint_token_{connection_id}.json"
+
+ authority = f"https://login.microsoftonline.com/{self.tenant_id}" if self.tenant_id != "common" else "https://login.microsoftonline.com/common"
+
+ self.oauth = SharePointOAuth(
+ client_id=self.client_id,
+ client_secret=self.client_secret,
+ token_file=oauth_token_file,
+ authority=authority
+ )
+ else:
+ self.oauth = None
+
+ # Track subscription ID for webhooks
+ self._subscription_id: Optional[str] = None
+
+ # Add Graph API defaults similar to Google Drive flags
+ self._graph_api_version = "v1.0"
+ self._default_params = {
+ "$select": "id,name,size,lastModifiedDateTime,createdDateTime,webUrl,file,folder,@microsoft.graph.downloadUrl"
}
-
- token = self.oauth.get_access_token()
- async with httpx.AsyncClient() as client:
- resp = await client.post(
- f"{self.base_url}/subscriptions",
- json=body,
- headers={"Authorization": f"Bearer {token}"},
- )
- resp.raise_for_status()
- data = resp.json()
-
- self.subscription_id = data["id"]
- return self.subscription_id
-
- async def list_files(
- self, page_token: Optional[str] = None, limit: int = 100
- ) -> Dict[str, Any]:
- if not self._authenticated:
- raise ValueError("Not authenticated")
-
- params = {"$top": str(limit)}
- if page_token:
- params["$skiptoken"] = page_token
-
- token = self.oauth.get_access_token()
- async with httpx.AsyncClient() as client:
- resp = await client.get(
- f"{self.base_url}/sites/{self.site_id}/drive/root/children",
- params=params,
- headers={"Authorization": f"Bearer {token}"},
- )
- resp.raise_for_status()
- data = resp.json()
-
- files = []
- for item in data.get("value", []):
- if item.get("file"):
- files.append(
- {
- "id": item["id"],
- "name": item["name"],
- "mimeType": item.get("file", {}).get(
- "mimeType", "application/octet-stream"
- ),
- "webViewLink": item.get("webUrl"),
- "createdTime": item.get("createdDateTime"),
- "modifiedTime": item.get("lastModifiedDateTime"),
- }
- )
-
- next_token = None
- next_link = data.get("@odata.nextLink")
- if next_link:
- from urllib.parse import urlparse, parse_qs
-
- parsed = urlparse(next_link)
- next_token = parse_qs(parsed.query).get("$skiptoken", [None])[0]
-
- return {"files": files, "nextPageToken": next_token}
-
- async def get_file_content(self, file_id: str) -> ConnectorDocument:
- if not self._authenticated:
- raise ValueError("Not authenticated")
-
- token = self.oauth.get_access_token()
- headers = {"Authorization": f"Bearer {token}"}
- async with httpx.AsyncClient() as client:
- meta_resp = await client.get(
- f"{self.base_url}/sites/{self.site_id}/drive/items/{file_id}",
- headers=headers,
- )
- meta_resp.raise_for_status()
- metadata = meta_resp.json()
-
- content_resp = await client.get(
- f"{self.base_url}/sites/{self.site_id}/drive/items/{file_id}/content",
- headers=headers,
- )
- content_resp.raise_for_status()
- content = content_resp.content
-
- perm_resp = await client.get(
- f"{self.base_url}/sites/{self.site_id}/drive/items/{file_id}/permissions",
- headers=headers,
- )
- perm_resp.raise_for_status()
- permissions = perm_resp.json()
-
- acl = self._parse_permissions(metadata, permissions)
- modified = datetime.fromisoformat(
- metadata["lastModifiedDateTime"].replace("Z", "+00:00")
- ).replace(tzinfo=None)
- created = datetime.fromisoformat(
- metadata["createdDateTime"].replace("Z", "+00:00")
- ).replace(tzinfo=None)
-
- document = ConnectorDocument(
- id=metadata["id"],
- filename=metadata["name"],
- mimetype=metadata.get("file", {}).get(
- "mimeType", "application/octet-stream"
- ),
- content=content,
- source_url=metadata.get("webUrl"),
- acl=acl,
- modified_time=modified,
- created_time=created,
- metadata={"size": metadata.get("size")},
- )
- return document
-
- def _parse_permissions(
- self, metadata: Dict[str, Any], permissions: Dict[str, Any]
- ) -> DocumentACL:
- acl = DocumentACL()
- owner = metadata.get("createdBy", {}).get("user", {}).get("email")
- if owner:
- acl.owner = owner
- for perm in permissions.get("value", []):
- role = perm.get("roles", ["read"])[0]
- grantee = perm.get("grantedToV2") or perm.get("grantedTo")
- if not grantee:
- continue
- user = grantee.get("user")
- if user and user.get("email"):
- acl.user_permissions[user["email"]] = role
- group = grantee.get("group")
- if group and group.get("email"):
- acl.group_permissions[group["email"]] = role
- return acl
-
- def handle_webhook_validation(
- self, request_method: str, headers: Dict[str, str], query_params: Dict[str, str]
- ) -> Optional[str]:
- """Handle Microsoft Graph webhook validation"""
- if request_method == "GET":
- validation_token = query_params.get("validationtoken") or query_params.get(
- "validationToken"
- )
- if validation_token:
- return validation_token
- return None
-
- def extract_webhook_channel_id(
- self, payload: Dict[str, Any], headers: Dict[str, str]
- ) -> Optional[str]:
- """Extract SharePoint subscription ID from webhook payload"""
- values = payload.get("value", [])
- return values[0].get("subscriptionId") if values else None
-
- async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]:
- values = payload.get("value", [])
- file_ids = []
- for item in values:
- resource_data = item.get("resourceData", {})
- file_id = resource_data.get("id")
- if file_id:
- file_ids.append(file_id)
- return file_ids
-
- async def cleanup_subscription(
- self, subscription_id: str, resource_id: str = None
- ) -> bool:
- if not self._authenticated:
+
+ @property
+ def _graph_base_url(self) -> str:
+ """Base URL for Microsoft Graph API calls"""
+ return f"https://graph.microsoft.com/{self._graph_api_version}"
+
+ def emit(self, doc: ConnectorDocument) -> None:
+ """
+ Emit a ConnectorDocument instance.
+ """
+ logger.debug(f"Emitting SharePoint document: {doc.id} ({doc.filename})")
+
+ async def authenticate(self) -> bool:
+ """Test authentication - BaseConnector interface"""
+ logger.debug(f"SharePoint authenticate() called, oauth is None: {self.oauth is None}")
+ try:
+ if not self.oauth:
+ logger.debug("SharePoint authentication failed: OAuth not initialized")
+ self._authenticated = False
+ return False
+
+ logger.debug("Loading SharePoint credentials...")
+ # Try to load existing credentials first
+ load_result = await self.oauth.load_credentials()
+ logger.debug(f"Load credentials result: {load_result}")
+
+ logger.debug("Checking SharePoint authentication status...")
+ authenticated = await self.oauth.is_authenticated()
+ logger.debug(f"SharePoint is_authenticated result: {authenticated}")
+
+ self._authenticated = authenticated
+ return authenticated
+ except Exception as e:
+ logger.error(f"SharePoint authentication failed: {e}")
+ import traceback
+ traceback.print_exc()
+ self._authenticated = False
return False
- token = self.oauth.get_access_token()
- async with httpx.AsyncClient() as client:
- resp = await client.delete(
- f"{self.base_url}/subscriptions/{subscription_id}",
- headers={"Authorization": f"Bearer {token}"},
+
+ def get_auth_url(self) -> str:
+ """Get OAuth authorization URL"""
+ if not self.oauth:
+ raise RuntimeError("SharePoint OAuth not initialized - missing credentials")
+ return self.oauth.create_authorization_url(self.redirect_uri)
+
+ async def handle_oauth_callback(self, auth_code: str) -> Dict[str, Any]:
+ """Handle OAuth callback"""
+ if not self.oauth:
+ raise RuntimeError("SharePoint OAuth not initialized - missing credentials")
+ try:
+ success = await self.oauth.handle_authorization_callback(auth_code, self.redirect_uri)
+ if success:
+ self._authenticated = True
+ return {"status": "success"}
+ else:
+ raise ValueError("OAuth callback failed")
+ except Exception as e:
+ logger.error(f"OAuth callback failed: {e}")
+ raise
+
+ def sync_once(self) -> None:
+ """
+ Perform a one-shot sync of SharePoint files and emit documents.
+ This method mirrors the Google Drive connector's sync_once functionality.
+ """
+ import asyncio
+
+ async def _async_sync():
+ try:
+ # Get list of files
+ file_list = await self.list_files(max_files=1000) # Adjust as needed
+ files = file_list.get("files", [])
+
+ for file_info in files:
+ try:
+ file_id = file_info.get("id")
+ if not file_id:
+ continue
+
+ # Get full document content
+ doc = await self.get_file_content(file_id)
+ self.emit(doc)
+
+ except Exception as e:
+ logger.error(f"Failed to sync SharePoint file {file_info.get('name', 'unknown')}: {e}")
+ continue
+
+ except Exception as e:
+ logger.error(f"SharePoint sync_once failed: {e}")
+ raise
+
+ # Run the async sync
+ if hasattr(asyncio, 'run'):
+ asyncio.run(_async_sync())
+ else:
+ # Python < 3.7 compatibility
+ loop = asyncio.get_event_loop()
+ loop.run_until_complete(_async_sync())
+
+ async def setup_subscription(self) -> str:
+ """Set up real-time subscription for file changes - BaseConnector interface"""
+ webhook_url = self.config.get('webhook_url')
+ if not webhook_url:
+ logger.warning("No webhook URL configured, skipping SharePoint subscription setup")
+ return "no-webhook-configured"
+
+ try:
+ # Ensure we're authenticated
+ if not await self.authenticate():
+ raise RuntimeError("SharePoint authentication failed during subscription setup")
+
+ token = self.oauth.get_access_token()
+
+ # Microsoft Graph subscription for SharePoint site
+ site_info = self._parse_sharepoint_url()
+ if site_info:
+ resource = f"sites/{site_info['host_name']}:/sites/{site_info['site_name']}:/drive/root"
+ else:
+ resource = "/me/drive/root"
+
+ subscription_data = {
+ "changeType": "created,updated,deleted",
+ "notificationUrl": f"{webhook_url}/webhook/sharepoint",
+ "resource": resource,
+ "expirationDateTime": self._get_subscription_expiry(),
+ "clientState": f"sharepoint_{self.tenant_id}"
+ }
+
+ headers = {
+ "Authorization": f"Bearer {token}",
+ "Content-Type": "application/json"
+ }
+
+ url = f"{self._graph_base_url}/subscriptions"
+
+ async with httpx.AsyncClient() as client:
+ response = await client.post(url, json=subscription_data, headers=headers, timeout=30)
+ response.raise_for_status()
+
+ result = response.json()
+ subscription_id = result.get("id")
+
+ if subscription_id:
+ self._subscription_id = subscription_id
+ logger.info(f"SharePoint subscription created: {subscription_id}")
+ return subscription_id
+ else:
+ raise ValueError("No subscription ID returned from Microsoft Graph")
+
+ except Exception as e:
+ logger.error(f"Failed to setup SharePoint subscription: {e}")
+ raise
+
+ def _get_subscription_expiry(self) -> str:
+ """Get subscription expiry time (max 3 days for Graph API)"""
+ from datetime import datetime, timedelta
+ expiry = datetime.utcnow() + timedelta(days=3) # 3 days max for Graph
+ return expiry.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
+
+ def _parse_sharepoint_url(self) -> Optional[Dict[str, str]]:
+ """Parse SharePoint URL to extract site information for Graph API"""
+ if not self.sharepoint_url:
+ return None
+
+ try:
+ parsed = urlparse(self.sharepoint_url)
+ # Extract hostname and site name from URL like: https://contoso.sharepoint.com/sites/teamsite
+ host_name = parsed.netloc
+ path_parts = parsed.path.strip('/').split('/')
+
+ if len(path_parts) >= 2 and path_parts[0] == 'sites':
+ site_name = path_parts[1]
+ return {
+ "host_name": host_name,
+ "site_name": site_name
+ }
+ except Exception as e:
+ logger.warning(f"Could not parse SharePoint URL {self.sharepoint_url}: {e}")
+
+ return None
+
+ async def list_files(
+ self,
+ page_token: Optional[str] = None,
+ max_files: Optional[int] = None,
+ **kwargs
+ ) -> Dict[str, Any]:
+ """List all files using Microsoft Graph API - BaseConnector interface"""
+ try:
+ # Ensure authentication
+ if not await self.authenticate():
+ raise RuntimeError("SharePoint authentication failed during file listing")
+
+ files = []
+ max_files_value = max_files if max_files is not None else 100
+
+ # Build Graph API URL for the site or fallback to user's OneDrive
+ site_info = self._parse_sharepoint_url()
+ if site_info:
+ base_url = f"{self._graph_base_url}/sites/{site_info['host_name']}:/sites/{site_info['site_name']}:/drive/root/children"
+ else:
+ base_url = f"{self._graph_base_url}/me/drive/root/children"
+
+ params = dict(self._default_params)
+ params["$top"] = str(max_files_value)
+
+ if page_token:
+ params["$skiptoken"] = page_token
+
+ response = await self._make_graph_request(base_url, params=params)
+ data = response.json()
+
+ items = data.get("value", [])
+ for item in items:
+ # Only include files, not folders
+ if item.get("file"):
+ files.append({
+ "id": item.get("id", ""),
+ "name": item.get("name", ""),
+ "path": f"/drive/items/{item.get('id')}",
+ "size": int(item.get("size", 0)),
+ "modified": item.get("lastModifiedDateTime"),
+ "created": item.get("createdDateTime"),
+ "mime_type": item.get("file", {}).get("mimeType", self._get_mime_type(item.get("name", ""))),
+ "url": item.get("webUrl", ""),
+ "download_url": item.get("@microsoft.graph.downloadUrl")
+ })
+
+ # Check for next page
+ next_page_token = None
+ next_link = data.get("@odata.nextLink")
+ if next_link:
+ from urllib.parse import urlparse, parse_qs
+ parsed = urlparse(next_link)
+ query_params = parse_qs(parsed.query)
+ if "$skiptoken" in query_params:
+ next_page_token = query_params["$skiptoken"][0]
+
+ return {
+ "files": files,
+ "next_page_token": next_page_token
+ }
+
+ except Exception as e:
+ logger.error(f"Failed to list SharePoint files: {e}")
+ return {"files": [], "next_page_token": None} # Return empty result instead of raising
+
+ async def get_file_content(self, file_id: str) -> ConnectorDocument:
+ """Get file content and metadata - BaseConnector interface"""
+ try:
+ # Ensure authentication
+ if not await self.authenticate():
+ raise RuntimeError("SharePoint authentication failed during file content retrieval")
+
+ # First get file metadata using Graph API
+ file_metadata = await self._get_file_metadata_by_id(file_id)
+
+ if not file_metadata:
+ raise ValueError(f"File not found: {file_id}")
+
+ # Download file content
+ download_url = file_metadata.get("download_url")
+ if download_url:
+ content = await self._download_file_from_url(download_url)
+ else:
+ content = await self._download_file_content(file_id)
+
+ # Create ACL from metadata
+ acl = DocumentACL(
+ owner="", # Graph API requires additional calls for detailed permissions
+ user_permissions={},
+ group_permissions={}
)
- return resp.status_code in (200, 204)
+
+ # Parse dates
+ modified_time = self._parse_graph_date(file_metadata.get("modified"))
+ created_time = self._parse_graph_date(file_metadata.get("created"))
+
+ return ConnectorDocument(
+ id=file_id,
+ filename=file_metadata.get("name", ""),
+ mimetype=file_metadata.get("mime_type", "application/octet-stream"),
+ content=content,
+ source_url=file_metadata.get("url", ""),
+ acl=acl,
+ modified_time=modified_time,
+ created_time=created_time,
+ metadata={
+ "sharepoint_path": file_metadata.get("path", ""),
+ "sharepoint_url": self.sharepoint_url,
+ "size": file_metadata.get("size", 0)
+ }
+ )
+
+ except Exception as e:
+ logger.error(f"Failed to get SharePoint file content {file_id}: {e}")
+ raise
+
+ async def _get_file_metadata_by_id(self, file_id: str) -> Optional[Dict[str, Any]]:
+ """Get file metadata by ID using Graph API"""
+ try:
+ # Try site-specific path first, then fallback to user drive
+ site_info = self._parse_sharepoint_url()
+ if site_info:
+ url = f"{self._graph_base_url}/sites/{site_info['host_name']}:/sites/{site_info['site_name']}:/drive/items/{file_id}"
+ else:
+ url = f"{self._graph_base_url}/me/drive/items/{file_id}"
+
+ params = dict(self._default_params)
+
+ response = await self._make_graph_request(url, params=params)
+ item = response.json()
+
+ if item.get("file"):
+ return {
+ "id": file_id,
+ "name": item.get("name", ""),
+ "path": f"/drive/items/{file_id}",
+ "size": int(item.get("size", 0)),
+ "modified": item.get("lastModifiedDateTime"),
+ "created": item.get("createdDateTime"),
+ "mime_type": item.get("file", {}).get("mimeType", self._get_mime_type(item.get("name", ""))),
+ "url": item.get("webUrl", ""),
+ "download_url": item.get("@microsoft.graph.downloadUrl")
+ }
+
+ return None
+
+ except Exception as e:
+ logger.error(f"Failed to get file metadata for {file_id}: {e}")
+ return None
+
+ async def _download_file_content(self, file_id: str) -> bytes:
+ """Download file content by file ID using Graph API"""
+ try:
+ site_info = self._parse_sharepoint_url()
+ if site_info:
+ url = f"{self._graph_base_url}/sites/{site_info['host_name']}:/sites/{site_info['site_name']}:/drive/items/{file_id}/content"
+ else:
+ url = f"{self._graph_base_url}/me/drive/items/{file_id}/content"
+
+ token = self.oauth.get_access_token()
+ headers = {"Authorization": f"Bearer {token}"}
+
+ async with httpx.AsyncClient() as client:
+ response = await client.get(url, headers=headers, timeout=60)
+ response.raise_for_status()
+ return response.content
+
+ except Exception as e:
+ logger.error(f"Failed to download file content for {file_id}: {e}")
+ raise
+
+ async def _download_file_from_url(self, download_url: str) -> bytes:
+ """Download file content from direct download URL"""
+ try:
+ async with httpx.AsyncClient() as client:
+ response = await client.get(download_url, timeout=60)
+ response.raise_for_status()
+ return response.content
+ except Exception as e:
+ logger.error(f"Failed to download from URL {download_url}: {e}")
+ raise
+
+ def _parse_graph_date(self, date_str: Optional[str]) -> datetime:
+ """Parse Microsoft Graph date string to datetime"""
+ if not date_str:
+ return datetime.now()
+
+ try:
+ if date_str.endswith('Z'):
+ return datetime.fromisoformat(date_str[:-1]).replace(tzinfo=None)
+ else:
+ return datetime.fromisoformat(date_str.replace('T', ' '))
+ except (ValueError, AttributeError):
+ return datetime.now()
+
+ async def _make_graph_request(self, url: str, method: str = "GET",
+ data: Optional[Dict] = None, params: Optional[Dict] = None) -> httpx.Response:
+ """Make authenticated API request to Microsoft Graph"""
+ token = self.oauth.get_access_token()
+ headers = {
+ "Authorization": f"Bearer {token}",
+ "Content-Type": "application/json"
+ }
+
+ async with httpx.AsyncClient() as client:
+ if method.upper() == "GET":
+ response = await client.get(url, headers=headers, params=params, timeout=30)
+ elif method.upper() == "POST":
+ response = await client.post(url, headers=headers, json=data, timeout=30)
+ elif method.upper() == "DELETE":
+ response = await client.delete(url, headers=headers, timeout=30)
+ else:
+ raise ValueError(f"Unsupported HTTP method: {method}")
+
+ response.raise_for_status()
+ return response
+
+ def _get_mime_type(self, filename: str) -> str:
+ """Get MIME type based on file extension"""
+ import mimetypes
+ mime_type, _ = mimetypes.guess_type(filename)
+ return mime_type or "application/octet-stream"
+
+ # Webhook methods - BaseConnector interface
+ def handle_webhook_validation(self, request_method: str, headers: Dict[str, str],
+ query_params: Dict[str, str]) -> Optional[str]:
+ """Handle webhook validation (Graph API specific)"""
+ if request_method == "POST" and "validationToken" in query_params:
+ return query_params["validationToken"]
+ return None
+
+ def extract_webhook_channel_id(self, payload: Dict[str, Any],
+ headers: Dict[str, str]) -> Optional[str]:
+ """Extract channel/subscription ID from webhook payload"""
+ notifications = payload.get("value", [])
+ if notifications:
+ return notifications[0].get("subscriptionId")
+ return None
+
+ async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]:
+ """Handle webhook notification and return affected file IDs"""
+ affected_files = []
+
+ # Process Microsoft Graph webhook payload
+ notifications = payload.get("value", [])
+ for notification in notifications:
+ resource = notification.get("resource")
+ if resource and "/drive/items/" in resource:
+ file_id = resource.split("/drive/items/")[-1]
+ affected_files.append(file_id)
+
+ return affected_files
+
+ async def cleanup_subscription(self, subscription_id: str) -> bool:
+ """Clean up subscription - BaseConnector interface"""
+ if subscription_id == "no-webhook-configured":
+ logger.info("No subscription to cleanup (webhook was not configured)")
+ return True
+
+ try:
+ # Ensure authentication
+ if not await self.authenticate():
+ logger.error("SharePoint authentication failed during subscription cleanup")
+ return False
+
+ token = self.oauth.get_access_token()
+ headers = {"Authorization": f"Bearer {token}"}
+
+ url = f"{self._graph_base_url}/subscriptions/{subscription_id}"
+
+ async with httpx.AsyncClient() as client:
+ response = await client.delete(url, headers=headers, timeout=30)
+
+ if response.status_code in [200, 204, 404]:
+ logger.info(f"SharePoint subscription {subscription_id} cleaned up successfully")
+ return True
+ else:
+ logger.warning(f"Unexpected response cleaning up subscription: {response.status_code}")
+ return False
+
+ except Exception as e:
+ logger.error(f"Failed to cleanup SharePoint subscription {subscription_id}: {e}")
+ return False
diff --git a/src/connectors/sharepoint/oauth.py b/src/connectors/sharepoint/oauth.py
index fa7424e9..4a96581f 100644
--- a/src/connectors/sharepoint/oauth.py
+++ b/src/connectors/sharepoint/oauth.py
@@ -1,18 +1,28 @@
import os
+import json
+import logging
+from typing import Optional, Dict, Any
+
import aiofiles
-from typing import Optional
import msal
+logger = logging.getLogger(__name__)
+
class SharePointOAuth:
- """Handles Microsoft Graph OAuth authentication flow"""
+ """Handles Microsoft Graph OAuth authentication flow following Google Drive pattern."""
- SCOPES = [
- "offline_access",
- "Files.Read.All",
- "Sites.Read.All",
- ]
+ # Reserved scopes that must NOT be sent on token or silent calls
+ RESERVED_SCOPES = {"openid", "profile", "offline_access"}
+ # For PERSONAL Microsoft Accounts (OneDrive consumer):
+ # - Use AUTH_SCOPES for interactive auth (consent + refresh token issuance)
+ # - Use RESOURCE_SCOPES for acquire_token_silent / refresh paths
+ AUTH_SCOPES = ["User.Read", "Files.Read.All", "offline_access"]
+ RESOURCE_SCOPES = ["User.Read", "Files.Read.All"]
+ SCOPES = AUTH_SCOPES # Backward compatibility alias
+
+ # Kept for reference; MSAL derives endpoints from `authority`
AUTH_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
TOKEN_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/token"
@@ -22,18 +32,29 @@ class SharePointOAuth:
client_secret: str,
token_file: str = "sharepoint_token.json",
authority: str = "https://login.microsoftonline.com/common",
+ allow_json_refresh: bool = True,
):
+ """
+ Initialize SharePointOAuth.
+
+ Args:
+ client_id: Azure AD application (client) ID.
+ client_secret: Azure AD application client secret.
+ token_file: Path to persisted token cache file (MSAL cache format).
+ authority: Usually "https://login.microsoftonline.com/common" for MSA + org,
+ or tenant-specific for work/school.
+ allow_json_refresh: If True, permit one-time migration from legacy flat JSON
+ {"access_token","refresh_token",...}. Otherwise refuse it.
+ """
self.client_id = client_id
self.client_secret = client_secret
self.token_file = token_file
self.authority = authority
+ self.allow_json_refresh = allow_json_refresh
self.token_cache = msal.SerializableTokenCache()
+ self._current_account = None
- # Load existing cache if available
- if os.path.exists(self.token_file):
- with open(self.token_file, "r") as f:
- self.token_cache.deserialize(f.read())
-
+ # Initialize MSAL Confidential Client
self.app = msal.ConfidentialClientApplication(
client_id=self.client_id,
client_credential=self.client_secret,
@@ -41,56 +62,268 @@ class SharePointOAuth:
token_cache=self.token_cache,
)
- async def save_cache(self):
- """Persist the token cache to file"""
- async with aiofiles.open(self.token_file, "w") as f:
- await f.write(self.token_cache.serialize())
+ async def load_credentials(self) -> bool:
+ """Load existing credentials from token file (async)."""
+ try:
+ logger.debug(f"SharePoint OAuth loading credentials from: {self.token_file}")
+ if os.path.exists(self.token_file):
+ logger.debug(f"Token file exists, reading: {self.token_file}")
- def create_authorization_url(self, redirect_uri: str) -> str:
- """Create authorization URL for OAuth flow"""
- return self.app.get_authorization_request_url(
- self.SCOPES, redirect_uri=redirect_uri
- )
+ # Read the token file
+ async with aiofiles.open(self.token_file, "r") as f:
+ cache_data = await f.read()
+ logger.debug(f"Read {len(cache_data)} chars from token file")
+
+ if cache_data.strip():
+ # 1) Try legacy flat JSON first
+ try:
+ json_data = json.loads(cache_data)
+ if isinstance(json_data, dict) and "refresh_token" in json_data:
+ if self.allow_json_refresh:
+ logger.debug(
+ "Found legacy JSON refresh_token and allow_json_refresh=True; attempting migration refresh"
+ )
+ return await self._refresh_from_json_token(json_data)
+ else:
+ logger.warning(
+ "Token file contains a legacy JSON refresh_token, but allow_json_refresh=False. "
+ "Delete the file and re-auth."
+ )
+ return False
+ except json.JSONDecodeError:
+ logger.debug("Token file is not flat JSON; attempting MSAL cache format")
+
+ # 2) Try MSAL cache format
+ logger.debug("Attempting MSAL cache deserialization")
+ self.token_cache.deserialize(cache_data)
+
+ # Get accounts from loaded cache
+ accounts = self.app.get_accounts()
+ logger.debug(f"Found {len(accounts)} accounts in MSAL cache")
+ if accounts:
+ self._current_account = accounts[0]
+ logger.debug(f"Set current account: {self._current_account.get('username', 'no username')}")
+
+ # IMPORTANT: Use RESOURCE_SCOPES (no reserved scopes) for silent acquisition
+ result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
+ logger.debug(f"Silent token acquisition result keys: {list(result.keys()) if result else 'None'}")
+ if result and "access_token" in result:
+ logger.debug("Silent token acquisition successful")
+ await self.save_cache()
+ return True
+ else:
+ error_msg = (result or {}).get("error") or "No result"
+ logger.warning(f"Silent token acquisition failed: {error_msg}")
+ else:
+ logger.debug(f"Token file {self.token_file} is empty")
+ else:
+ logger.debug(f"Token file does not exist: {self.token_file}")
+
+ return False
+
+ except Exception as e:
+ logger.error(f"Failed to load SharePoint credentials: {e}")
+ import traceback
+ traceback.print_exc()
+ return False
+
+ async def _refresh_from_json_token(self, token_data: dict) -> bool:
+ """
+ Use refresh token from a legacy JSON file to get new tokens (one-time migration path).
+
+ Notes:
+ - Prefer using an MSAL cache file and acquire_token_silent().
+ - This path is only for migrating older refresh_token JSON files.
+ """
+ try:
+ refresh_token = token_data.get("refresh_token")
+ if not refresh_token:
+ logger.error("No refresh_token found in JSON file - cannot refresh")
+ logger.error("You must re-authenticate interactively to obtain a valid token")
+ return False
+
+ # Use only RESOURCE_SCOPES when refreshing (no reserved scopes)
+ refresh_scopes = [s for s in self.RESOURCE_SCOPES if s not in self.RESERVED_SCOPES]
+ logger.debug(f"Using refresh token; refresh scopes = {refresh_scopes}")
+
+ result = self.app.acquire_token_by_refresh_token(
+ refresh_token=refresh_token,
+ scopes=refresh_scopes,
+ )
+
+ if result and "access_token" in result:
+ logger.debug("Successfully refreshed token via legacy JSON path")
+ await self.save_cache()
+
+ accounts = self.app.get_accounts()
+ logger.debug(f"After refresh, found {len(accounts)} accounts")
+ if accounts:
+ self._current_account = accounts[0]
+ logger.debug(f"Set current account after refresh: {self._current_account.get('username', 'no username')}")
+ return True
+
+ # Error handling
+ err = (result or {}).get("error_description") or (result or {}).get("error") or "Unknown error"
+ logger.error(f"Refresh token failed: {err}")
+
+ if any(code in err for code in ("AADSTS70000", "invalid_grant", "interaction_required")):
+ logger.warning(
+ "Refresh denied due to unauthorized/expired scopes or invalid grant. "
+ "Delete the token file and perform interactive sign-in with correct scopes."
+ )
+
+ return False
+
+ except Exception as e:
+ logger.error(f"Exception during refresh from JSON token: {e}")
+ import traceback
+ traceback.print_exc()
+ return False
+
+ async def save_cache(self):
+ """Persist the token cache to file."""
+ try:
+ # Ensure parent directory exists
+ parent = os.path.dirname(os.path.abspath(self.token_file))
+ if parent and not os.path.exists(parent):
+ os.makedirs(parent, exist_ok=True)
+
+ cache_data = self.token_cache.serialize()
+ if cache_data:
+ async with aiofiles.open(self.token_file, "w") as f:
+ await f.write(cache_data)
+ logger.debug(f"Token cache saved to {self.token_file}")
+ except Exception as e:
+ logger.error(f"Failed to save token cache: {e}")
+
+ def create_authorization_url(self, redirect_uri: str, state: Optional[str] = None) -> str:
+ """Create authorization URL for OAuth flow."""
+ # Store redirect URI for later use in callback
+ self._redirect_uri = redirect_uri
+
+ kwargs: Dict[str, Any] = {
+ # IMPORTANT: interactive auth includes offline_access
+ "scopes": self.AUTH_SCOPES,
+ "redirect_uri": redirect_uri,
+ "prompt": "consent", # ensure refresh token on first run
+ }
+ if state:
+ kwargs["state"] = state # Optional CSRF protection
+
+ auth_url = self.app.get_authorization_request_url(**kwargs)
+
+ logger.debug(f"Generated auth URL: {auth_url}")
+ logger.debug(f"Auth scopes: {self.AUTH_SCOPES}")
+
+ return auth_url
async def handle_authorization_callback(
self, authorization_code: str, redirect_uri: str
) -> bool:
- """Handle OAuth callback and exchange code for tokens"""
- result = self.app.acquire_token_by_authorization_code(
- authorization_code,
- scopes=self.SCOPES,
- redirect_uri=redirect_uri,
- )
- if "access_token" in result:
- await self.save_cache()
- return True
- raise ValueError(result.get("error_description") or "Authorization failed")
+ """Handle OAuth callback and exchange code for tokens."""
+ try:
+ # For code exchange, we pass the same auth scopes as used in the authorize step
+ result = self.app.acquire_token_by_authorization_code(
+ authorization_code,
+ scopes=self.AUTH_SCOPES,
+ redirect_uri=redirect_uri,
+ )
+
+ if result and "access_token" in result:
+ # Store the account for future use
+ accounts = self.app.get_accounts()
+ if accounts:
+ self._current_account = accounts[0]
+
+ await self.save_cache()
+ logger.info("SharePoint OAuth authorization successful")
+ return True
+
+ error_msg = (result or {}).get("error_description") or (result or {}).get("error") or "Unknown error"
+ logger.error(f"SharePoint OAuth authorization failed: {error_msg}")
+ return False
+
+ except Exception as e:
+ logger.error(f"Exception during SharePoint OAuth authorization: {e}")
+ return False
async def is_authenticated(self) -> bool:
- """Check if we have valid credentials"""
- accounts = self.app.get_accounts()
- if not accounts:
+ """Check if we have valid credentials (simplified like Google Drive)."""
+ try:
+ # First try to load credentials if we haven't already
+ if not self._current_account:
+ await self.load_credentials()
+
+ # If we have an account, try to get a token (MSAL will refresh if needed)
+ if self._current_account:
+ # IMPORTANT: use RESOURCE_SCOPES here
+ result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
+ if result and "access_token" in result:
+ return True
+ else:
+ error_msg = (result or {}).get("error") or "No result returned"
+ logger.debug(f"Token acquisition failed for current account: {error_msg}")
+
+ # Fallback: try without specific account
+ result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=None)
+ if result and "access_token" in result:
+ # Update current account if this worked
+ accounts = self.app.get_accounts()
+ if accounts:
+ self._current_account = accounts[0]
+ return True
+
+ return False
+
+ except Exception as e:
+ logger.error(f"Authentication check failed: {e}")
return False
- result = self.app.acquire_token_silent(self.SCOPES, account=accounts[0])
- if "access_token" in result:
- await self.save_cache()
- return True
- return False
def get_access_token(self) -> str:
- """Get an access token for Microsoft Graph"""
- accounts = self.app.get_accounts()
- if not accounts:
- raise ValueError("Not authenticated")
- result = self.app.acquire_token_silent(self.SCOPES, account=accounts[0])
- if "access_token" not in result:
- raise ValueError(
- result.get("error_description") or "Failed to acquire access token"
- )
- return result["access_token"]
+ """Get an access token for Microsoft Graph (simplified like Google Drive)."""
+ try:
+ # Try with current account first
+ if self._current_account:
+ result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
+ if result and "access_token" in result:
+ return result["access_token"]
+
+ # Fallback: try without specific account
+ result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=None)
+ if result and "access_token" in result:
+ return result["access_token"]
+
+ # If we get here, authentication has failed
+ error_msg = (result or {}).get("error_description") or (result or {}).get("error") or "No valid authentication"
+ raise ValueError(f"Failed to acquire access token: {error_msg}")
+
+ except Exception as e:
+ logger.error(f"Failed to get access token: {e}")
+ raise
async def revoke_credentials(self):
- """Clear token cache and remove token file"""
- self.token_cache.clear()
- if os.path.exists(self.token_file):
- os.remove(self.token_file)
+ """Clear token cache and remove token file (like Google Drive)."""
+ try:
+ # Clear in-memory state
+ self._current_account = None
+ self.token_cache = msal.SerializableTokenCache()
+
+ # Recreate MSAL app with fresh cache
+ self.app = msal.ConfidentialClientApplication(
+ client_id=self.client_id,
+ client_credential=self.client_secret,
+ authority=self.authority,
+ token_cache=self.token_cache,
+ )
+
+ # Remove token file
+ if os.path.exists(self.token_file):
+ os.remove(self.token_file)
+ logger.info(f"Removed SharePoint token file: {self.token_file}")
+
+ except Exception as e:
+ logger.error(f"Failed to revoke SharePoint credentials: {e}")
+
+ def get_service(self) -> str:
+ """Return an access token (Graph doesn't need a generated client like Google Drive)."""
+ return self.get_access_token()
diff --git a/src/main.py b/src/main.py
index 90add401..69f2ad9f 100644
--- a/src/main.py
+++ b/src/main.py
@@ -2,6 +2,7 @@
from connectors.langflow_connector_service import LangflowConnectorService
from connectors.service import ConnectorService
from services.flows_service import FlowsService
+from utils.embeddings import create_dynamic_index_body
from utils.logging_config import configure_from_env, get_logger
configure_from_env()
@@ -52,11 +53,11 @@ from auth_middleware import optional_auth, require_auth
from config.settings import (
DISABLE_INGEST_WITH_LANGFLOW,
EMBED_MODEL,
- INDEX_BODY,
INDEX_NAME,
SESSION_SECRET,
clients,
is_no_auth_mode,
+ get_openrag_config,
)
from services.auth_service import AuthService
from services.langflow_mcp_service import LangflowMCPService
@@ -81,7 +82,6 @@ logger.info(
cuda_version=torch.version.cuda,
)
-
async def wait_for_opensearch():
"""Wait for OpenSearch to be ready with retries"""
max_retries = 30
@@ -132,12 +132,19 @@ async def init_index():
"""Initialize OpenSearch index and security roles"""
await wait_for_opensearch()
+ # Get the configured embedding model from user configuration
+ config = get_openrag_config()
+ embedding_model = config.knowledge.embedding_model
+
+ # Create dynamic index body based on the configured embedding model
+ dynamic_index_body = create_dynamic_index_body(embedding_model)
+
# Create documents index
if not await clients.opensearch.indices.exists(index=INDEX_NAME):
- await clients.opensearch.indices.create(index=INDEX_NAME, body=INDEX_BODY)
- logger.info("Created OpenSearch index", index_name=INDEX_NAME)
+ await clients.opensearch.indices.create(index=INDEX_NAME, body=dynamic_index_body)
+ logger.info("Created OpenSearch index", index_name=INDEX_NAME, embedding_model=embedding_model)
else:
- logger.info("Index already exists, skipping creation", index_name=INDEX_NAME)
+ logger.info("Index already exists, skipping creation", index_name=INDEX_NAME, embedding_model=embedding_model)
# Create knowledge filters index
knowledge_filter_index_name = "knowledge_filters"
@@ -391,7 +398,12 @@ async def _ingest_default_documents_openrag(services, file_paths):
async def startup_tasks(services):
"""Startup tasks"""
logger.info("Starting startup tasks")
- await init_index()
+ # Only initialize basic OpenSearch connection, not the index
+ # Index will be created after onboarding when we know the embedding model
+ await wait_for_opensearch()
+
+ # Configure alerting security
+ await configure_alerting_security()
async def initialize_services():
diff --git a/src/services/flows_service.py b/src/services/flows_service.py
index 0d7a7bc8..7397cf6b 100644
--- a/src/services/flows_service.py
+++ b/src/services/flows_service.py
@@ -1,3 +1,4 @@
+import asyncio
from config.settings import (
NUDGES_FLOW_ID,
LANGFLOW_URL,
@@ -19,6 +20,7 @@ from config.settings import (
WATSONX_LLM_COMPONENT_ID,
OLLAMA_EMBEDDING_COMPONENT_ID,
OLLAMA_LLM_COMPONENT_ID,
+ get_openrag_config,
)
import json
import os
@@ -29,6 +31,74 @@ logger = get_logger(__name__)
class FlowsService:
+ def __init__(self):
+ # Cache for flow file mappings to avoid repeated filesystem scans
+ self._flow_file_cache = {}
+
+ def _get_flows_directory(self):
+ """Get the flows directory path"""
+ current_file_dir = os.path.dirname(os.path.abspath(__file__)) # src/services/
+ src_dir = os.path.dirname(current_file_dir) # src/
+ project_root = os.path.dirname(src_dir) # project root
+ return os.path.join(project_root, "flows")
+
+ def _find_flow_file_by_id(self, flow_id: str):
+ """
+ Scan the flows directory and find the JSON file that contains the specified flow ID.
+
+ Args:
+ flow_id: The flow ID to search for
+
+ Returns:
+ str: The path to the flow file, or None if not found
+ """
+ if not flow_id:
+ raise ValueError("flow_id is required")
+
+ # Check cache first
+ if flow_id in self._flow_file_cache:
+ cached_path = self._flow_file_cache[flow_id]
+ if os.path.exists(cached_path):
+ return cached_path
+ else:
+ # Remove stale cache entry
+ del self._flow_file_cache[flow_id]
+
+ flows_dir = self._get_flows_directory()
+
+ if not os.path.exists(flows_dir):
+ logger.warning(f"Flows directory not found: {flows_dir}")
+ return None
+
+ # Scan all JSON files in the flows directory
+ try:
+ for filename in os.listdir(flows_dir):
+ if not filename.endswith('.json'):
+ continue
+
+ file_path = os.path.join(flows_dir, filename)
+
+ try:
+ with open(file_path, 'r') as f:
+ flow_data = json.load(f)
+
+ # Check if this file contains the flow we're looking for
+ if flow_data.get('id') == flow_id:
+ # Cache the result
+ self._flow_file_cache[flow_id] = file_path
+ logger.info(f"Found flow {flow_id} in file: {filename}")
+ return file_path
+
+ except (json.JSONDecodeError, FileNotFoundError) as e:
+ logger.warning(f"Error reading flow file {filename}: {e}")
+ continue
+
+ except Exception as e:
+ logger.error(f"Error scanning flows directory: {e}")
+ return None
+
+ logger.warning(f"Flow with ID {flow_id} not found in flows directory")
+ return None
async def reset_langflow_flow(self, flow_type: str):
"""Reset a Langflow flow by uploading the corresponding JSON file
@@ -41,59 +111,35 @@ class FlowsService:
if not LANGFLOW_URL:
raise ValueError("LANGFLOW_URL environment variable is required")
- # Determine flow file and ID based on type
+ # Determine flow ID based on type
if flow_type == "nudges":
- flow_file = "flows/openrag_nudges.json"
flow_id = NUDGES_FLOW_ID
elif flow_type == "retrieval":
- flow_file = "flows/openrag_agent.json"
flow_id = LANGFLOW_CHAT_FLOW_ID
elif flow_type == "ingest":
- flow_file = "flows/ingestion_flow.json"
flow_id = LANGFLOW_INGEST_FLOW_ID
else:
raise ValueError(
"flow_type must be either 'nudges', 'retrieval', or 'ingest'"
)
+ if not flow_id:
+ raise ValueError(f"Flow ID not configured for flow_type '{flow_type}'")
+
+ # Dynamically find the flow file by ID
+ flow_path = self._find_flow_file_by_id(flow_id)
+ if not flow_path:
+ raise FileNotFoundError(f"Flow file not found for flow ID: {flow_id}")
+
# Load flow JSON file
try:
- # Get the project root directory (go up from src/services/ to project root)
- # __file__ is src/services/chat_service.py
- # os.path.dirname(__file__) is src/services/
- # os.path.dirname(os.path.dirname(__file__)) is src/
- # os.path.dirname(os.path.dirname(os.path.dirname(__file__))) is project root
- current_file_dir = os.path.dirname(
- os.path.abspath(__file__)
- ) # src/services/
- src_dir = os.path.dirname(current_file_dir) # src/
- project_root = os.path.dirname(src_dir) # project root
- flow_path = os.path.join(project_root, flow_file)
-
- if not os.path.exists(flow_path):
- # List contents of project root to help debug
- try:
- contents = os.listdir(project_root)
- logger.info(f"Project root contents: {contents}")
-
- flows_dir = os.path.join(project_root, "flows")
- if os.path.exists(flows_dir):
- flows_contents = os.listdir(flows_dir)
- logger.info(f"Flows directory contents: {flows_contents}")
- else:
- logger.info("Flows directory does not exist")
- except Exception as e:
- logger.error(f"Error listing directory contents: {e}")
-
- raise FileNotFoundError(f"Flow file not found at: {flow_path}")
-
with open(flow_path, "r") as f:
flow_data = json.load(f)
- logger.info(f"Successfully loaded flow data from {flow_file}")
+ logger.info(f"Successfully loaded flow data for {flow_type} from {os.path.basename(flow_path)}")
+ except json.JSONDecodeError as e:
+ raise ValueError(f"Invalid JSON in flow file {flow_path}: {e}")
except FileNotFoundError:
raise ValueError(f"Flow file not found: {flow_path}")
- except json.JSONDecodeError as e:
- raise ValueError(f"Invalid JSON in flow file {flow_file}: {e}")
# Make PATCH request to Langflow API to update the flow using shared client
try:
@@ -106,8 +152,54 @@ class FlowsService:
logger.info(
f"Successfully reset {flow_type} flow",
flow_id=flow_id,
- flow_file=flow_file,
+ flow_file=os.path.basename(flow_path),
)
+
+ # Now update the flow with current configuration settings
+ try:
+ config = get_openrag_config()
+
+ # Check if configuration has been edited (onboarding completed)
+ if config.edited:
+ logger.info(f"Updating {flow_type} flow with current configuration settings")
+
+ provider = config.provider.model_provider.lower()
+
+ # Step 1: Assign model provider (replace components) if not OpenAI
+ if provider != "openai":
+ logger.info(f"Assigning {provider} components to {flow_type} flow")
+ provider_result = await self.assign_model_provider(provider)
+
+ if not provider_result.get("success"):
+ logger.warning(f"Failed to assign {provider} components: {provider_result.get('error', 'Unknown error')}")
+ # Continue anyway, maybe just value updates will work
+
+ # Step 2: Update model values for the specific flow being reset
+ single_flow_config = [{
+ "name": flow_type,
+ "flow_id": flow_id,
+ }]
+
+ logger.info(f"Updating {flow_type} flow model values")
+ update_result = await self.change_langflow_model_value(
+ provider=provider,
+ embedding_model=config.knowledge.embedding_model,
+ llm_model=config.agent.llm_model,
+ endpoint=config.provider.endpoint if config.provider.endpoint else None,
+ flow_configs=single_flow_config
+ )
+
+ if update_result.get("success"):
+ logger.info(f"Successfully updated {flow_type} flow with current configuration")
+ else:
+ logger.warning(f"Failed to update {flow_type} flow with current configuration: {update_result.get('error', 'Unknown error')}")
+ else:
+ logger.info(f"Configuration not yet edited (onboarding not completed), skipping model updates for {flow_type} flow")
+
+ except Exception as e:
+ logger.error(f"Error updating {flow_type} flow with current configuration", error=str(e))
+ # Don't fail the entire reset operation if configuration update fails
+
return {
"success": True,
"message": f"Successfully reset {flow_type} flow",
@@ -155,11 +247,10 @@ class FlowsService:
logger.info(f"Assigning {provider} components")
- # Define flow configurations
+ # Define flow configurations (removed hardcoded file paths)
flow_configs = [
{
"name": "nudges",
- "file": "flows/openrag_nudges.json",
"flow_id": NUDGES_FLOW_ID,
"embedding_id": OPENAI_EMBEDDING_COMPONENT_ID,
"llm_id": OPENAI_LLM_COMPONENT_ID,
@@ -167,7 +258,6 @@ class FlowsService:
},
{
"name": "retrieval",
- "file": "flows/openrag_agent.json",
"flow_id": LANGFLOW_CHAT_FLOW_ID,
"embedding_id": OPENAI_EMBEDDING_COMPONENT_ID,
"llm_id": OPENAI_LLM_COMPONENT_ID,
@@ -175,7 +265,6 @@ class FlowsService:
},
{
"name": "ingest",
- "file": "flows/ingestion_flow.json",
"flow_id": LANGFLOW_INGEST_FLOW_ID,
"embedding_id": OPENAI_EMBEDDING_COMPONENT_ID,
"llm_id": None, # Ingestion flow might not have LLM
@@ -272,7 +361,6 @@ class FlowsService:
async def _update_flow_components(self, config, llm_template, embedding_template, llm_text_template):
"""Update components in a specific flow"""
flow_name = config["name"]
- flow_file = config["file"]
flow_id = config["flow_id"]
old_embedding_id = config["embedding_id"]
old_llm_id = config["llm_id"]
@@ -281,14 +369,11 @@ class FlowsService:
new_llm_id = llm_template["data"]["id"]
new_embedding_id = embedding_template["data"]["id"]
new_llm_text_id = llm_text_template["data"]["id"]
- # Get the project root directory
- current_file_dir = os.path.dirname(os.path.abspath(__file__))
- src_dir = os.path.dirname(current_file_dir)
- project_root = os.path.dirname(src_dir)
- flow_path = os.path.join(project_root, flow_file)
- if not os.path.exists(flow_path):
- raise FileNotFoundError(f"Flow file not found at: {flow_path}")
+ # Dynamically find the flow file by ID
+ flow_path = self._find_flow_file_by_id(flow_id)
+ if not flow_path:
+ raise FileNotFoundError(f"Flow file not found for flow ID: {flow_id}")
# Load flow JSON
with open(flow_path, "r") as f:
@@ -527,16 +612,17 @@ class FlowsService:
return False
async def change_langflow_model_value(
- self, provider: str, embedding_model: str, llm_model: str, endpoint: str = None
+ self, provider: str, embedding_model: str, llm_model: str, endpoint: str = None, flow_configs: list = None
):
"""
- Change dropdown values for provider-specific components across all flows
+ Change dropdown values for provider-specific components across flows
Args:
provider: The provider ("watsonx", "ollama", "openai")
embedding_model: The embedding model name to set
llm_model: The LLM model name to set
endpoint: The endpoint URL (required for watsonx/ibm provider)
+ flow_configs: Optional list of specific flow configs to update. If None, updates all flows.
Returns:
dict: Success/error response with details for each flow
@@ -552,24 +638,22 @@ class FlowsService:
f"Changing dropdown values for provider {provider}, embedding: {embedding_model}, llm: {llm_model}, endpoint: {endpoint}"
)
- # Define flow configurations with provider-specific component IDs
- flow_configs = [
- {
- "name": "nudges",
- "file": "flows/openrag_nudges.json",
- "flow_id": NUDGES_FLOW_ID,
- },
- {
- "name": "retrieval",
- "file": "flows/openrag_agent.json",
- "flow_id": LANGFLOW_CHAT_FLOW_ID,
- },
- {
- "name": "ingest",
- "file": "flows/ingestion_flow.json",
- "flow_id": LANGFLOW_INGEST_FLOW_ID,
- },
- ]
+ # Use provided flow_configs or default to all flows
+ if flow_configs is None:
+ flow_configs = [
+ {
+ "name": "nudges",
+ "flow_id": NUDGES_FLOW_ID,
+ },
+ {
+ "name": "retrieval",
+ "flow_id": LANGFLOW_CHAT_FLOW_ID,
+ },
+ {
+ "name": "ingest",
+ "flow_id": LANGFLOW_INGEST_FLOW_ID,
+ },
+ ]
# Determine target component IDs based on provider
target_embedding_id, target_llm_id, target_llm_text_id = self._get_provider_component_ids(
diff --git a/src/utils/embeddings.py b/src/utils/embeddings.py
new file mode 100644
index 00000000..f3c902e7
--- /dev/null
+++ b/src/utils/embeddings.py
@@ -0,0 +1,64 @@
+from config.settings import OLLAMA_EMBEDDING_DIMENSIONS, OPENAI_EMBEDDING_DIMENSIONS, VECTOR_DIM, WATSONX_EMBEDDING_DIMENSIONS
+from utils.logging_config import get_logger
+
+
+logger = get_logger(__name__)
+
+def get_embedding_dimensions(model_name: str) -> int:
+ """Get the embedding dimensions for a given model name."""
+
+ # Check all model dictionaries
+ all_models = {**OPENAI_EMBEDDING_DIMENSIONS, **OLLAMA_EMBEDDING_DIMENSIONS, **WATSONX_EMBEDDING_DIMENSIONS}
+
+ if model_name in all_models:
+ dimensions = all_models[model_name]
+ logger.info(f"Found dimensions for model '{model_name}': {dimensions}")
+ return dimensions
+
+ logger.warning(
+ f"Unknown embedding model '{model_name}', using default dimensions: {VECTOR_DIM}"
+ )
+ return VECTOR_DIM
+
+
+def create_dynamic_index_body(embedding_model: str) -> dict:
+ """Create a dynamic index body configuration based on the embedding model."""
+ dimensions = get_embedding_dimensions(embedding_model)
+
+ return {
+ "settings": {
+ "index": {"knn": True},
+ "number_of_shards": 1,
+ "number_of_replicas": 1,
+ },
+ "mappings": {
+ "properties": {
+ "document_id": {"type": "keyword"},
+ "filename": {"type": "keyword"},
+ "mimetype": {"type": "keyword"},
+ "page": {"type": "integer"},
+ "text": {"type": "text"},
+ "chunk_embedding": {
+ "type": "knn_vector",
+ "dimension": dimensions,
+ "method": {
+ "name": "disk_ann",
+ "engine": "jvector",
+ "space_type": "l2",
+ "parameters": {"ef_construction": 100, "m": 16},
+ },
+ },
+ "source_url": {"type": "keyword"},
+ "connector_type": {"type": "keyword"},
+ "owner": {"type": "keyword"},
+ "allowed_users": {"type": "keyword"},
+ "allowed_groups": {"type": "keyword"},
+ "user_permissions": {"type": "object"},
+ "group_permissions": {"type": "object"},
+ "created_time": {"type": "date"},
+ "modified_time": {"type": "date"},
+ "indexed_time": {"type": "date"},
+ "metadata": {"type": "object"},
+ }
+ },
+ }
\ No newline at end of file