diff --git a/docs/docs/core-components/agents.mdx b/docs/docs/core-components/agents.mdx
new file mode 100644
index 00000000..d8c6471b
--- /dev/null
+++ b/docs/docs/core-components/agents.mdx
@@ -0,0 +1,52 @@
+---
+title: Agents powered by Langflow
+slug: /agents
+---
+
+import Icon from "@site/src/components/icon/icon";
+import Tabs from '@theme/Tabs';
+import TabItem from '@theme/TabItem';
+
+OpenRAG leverages Langflow's Agent component to power the OpenRAG Open Search Agent flow.
+
+This flow intelligently chats with your knowledge by embedding your query, comparing it the vector database embeddings, and generating a response with the LLM.
+
+The Agent component shines here in its ability to make decisions on not only what query should be sent, but when a query is necessary to solve the problem at hand.
+
+
+How do agents work?
+
+Agents extend Large Language Models (LLMs) by integrating tools, which are functions that provide additional context and enable autonomous task execution. These integrations make agents more specialized and powerful than standalone LLMs.
+
+Whereas an LLM might generate acceptable, inert responses to general queries and tasks, an agent can leverage the integrated context and tools to provide more relevant responses and even take action. For example, you might create an agent that can access your company's documentation, repositories, and other resources to help your team with tasks that require knowledge of your specific products, customers, and code.
+
+Agents use LLMs as a reasoning engine to process input, determine which actions to take to address the query, and then generate a response. The response could be a typical text-based LLM response, or it could involve an action, like editing a file, running a script, or calling an external API.
+
+In an agentic context, tools are functions that the agent can run to perform tasks or access external resources. A function is wrapped as a Tool object with a common interface that the agent understands. Agents become aware of tools through tool registration, which is when the agent is provided a list of available tools typically at agent initialization. The Tool object's description tells the agent what the tool can do so that it can decide whether the tool is appropriate for a given request.
+
+
+
+## Use the OpenRAG Open Search Agent flow
+
+If you've chatted with your knowledge in OpenRAG, you've already experienced the OpenRAG Open Search Agent chat flow.
+To view the flow, click **Settings**, and then click **Edit in Langflow**.
+This flow contains seven components:
+
+* The Agent component orchestrates the entire flow by deciding when to search the knowledge base, how to formulate search queries, and how to combine retrieved information with the user's question to generate a comprehensive response.
+The Agent behaves according to the prompt in the **Agent Instructions** field.
+* The Chat Input component is connected to the Agent component's Input port. This allows to flow to be triggered by an incoming prompt from a user or application.
+* The OpenSearch component is connected to the Agent component's Tools port. The agent may not use this database for every request; the agent only uses this connection if it decides the knowledge can help respond to the prompt.
+* The Language Model component is connected to the Agent component's Language Model port. The agent uses the connected LLM to reason through the request sent through Chat Input.
+* The Embedding Model component is connected to the Open Search component's Embedding port. This component converts text queries into vector representations that are compared with document embeddings stored in OpenSearch for semantic similarity matching. This gives your Agent's queries context.
+* The Text Input component is populated with the global variable `OPENRAG-QUERY-FILTER`.
+This filter is the Knowledge filter, and filters which knowledge sources to search through.
+* The Agent component's Output port is connected to the Chat Output component, which returns the final response to the user or application.
+
+All flows included with OpenRAG are designed to be modular, performant, and provider-agnostic.
+To modify a flow, click **Settings**, and click **Edit in Langflow**.
+Flows are edited in the same way as in the [Langflow visual editor](https://docs.langflow.org/concepts-overview).
+
+For an example of changing out the agent's LLM in OpenRAG, see the [Quickstart](/quickstart#change-components).
+
+To restore the flow to its initial state, in OpenRAG, click **Settings**, and then click **Restore Flow**.
+OpenRAG warns you that this discards all custom settings. Click **Restore** to restore the flow.
\ No newline at end of file
diff --git a/docs/docs/get-started/quickstart.mdx b/docs/docs/get-started/quickstart.mdx
index 993d9739..68d15aef 100644
--- a/docs/docs/get-started/quickstart.mdx
+++ b/docs/docs/get-started/quickstart.mdx
@@ -16,6 +16,8 @@ Get started with OpenRAG by loading your knowledge, swapping out your language m
## Find your way around
1. In OpenRAG, click **Chat**.
+ The chat is powered by the OpenRAG Open Search Agent.
+ For more information, see [Langflow Agents](/agents).
2. Ask `What documents are available to you?`
The agent responds with a message summarizing the documents that OpenRAG loads by default, which are PDFs about evaluating data quality when using LLMs in health care.
3. To confirm the agent is correct, click **Knowledge**.
@@ -33,7 +35,7 @@ Get started with OpenRAG by loading your knowledge, swapping out your language m
These events log the agent's request to the tool and the tool's response, so you have direct visibility into your agent's functionality.
If you aren't getting the results you need, you can further tune the knowledge ingestion and agent behavior in the next section.
-## Swap out the language model to modify agent behavior
+## Swap out the language model to modify agent behavior {#change-components}
To modify the knowledge ingestion or Agent behavior, click **Settings**.
diff --git a/docs/sidebars.js b/docs/sidebars.js
index 588e7665..f76fdcda 100644
--- a/docs/sidebars.js
+++ b/docs/sidebars.js
@@ -47,6 +47,17 @@ const sidebars = {
},
],
},
+ {
+ type: "category",
+ label: "Core components",
+ items: [
+ {
+ type: "doc",
+ id: "core-components/agents",
+ label: "Langflow Agents"
+ },
+ ],
+ },
{
type: "category",
label: "Configuration",
diff --git a/flows/openrag_ingest_docling.json b/flows/openrag_ingest_docling.json
index 889f8425..cce73398 100644
--- a/flows/openrag_ingest_docling.json
+++ b/flows/openrag_ingest_docling.json
@@ -95,7 +95,7 @@
"data": {
"sourceHandle": {
"dataType": "EmbeddingModel",
- "id": "EmbeddingModel-cxG9r",
+ "id": "EmbeddingModel-eZ6bT",
"name": "embeddings",
"output_types": [
"Embeddings"
@@ -110,10 +110,10 @@
"type": "other"
}
},
- "id": "xy-edge__EmbeddingModel-cxG9r{œdataTypeœ:œEmbeddingModelœ,œidœ:œEmbeddingModel-cxG9rœ,œnameœ:œembeddingsœ,œoutput_typesœ:[œEmbeddingsœ]}-OpenSearchHybrid-XtKoA{œfieldNameœ:œembeddingœ,œidœ:œOpenSearchHybrid-XtKoAœ,œinputTypesœ:[œEmbeddingsœ],œtypeœ:œotherœ}",
+ "id": "xy-edge__EmbeddingModel-eZ6bT{œdataTypeœ:œEmbeddingModelœ,œidœ:œEmbeddingModel-eZ6bTœ,œnameœ:œembeddingsœ,œoutput_typesœ:[œEmbeddingsœ]}-OpenSearchHybrid-XtKoA{œfieldNameœ:œembeddingœ,œidœ:œOpenSearchHybrid-XtKoAœ,œinputTypesœ:[œEmbeddingsœ],œtypeœ:œotherœ}",
"selected": false,
- "source": "EmbeddingModel-cxG9r",
- "sourceHandle": "{œdataTypeœ:œEmbeddingModelœ,œidœ:œEmbeddingModel-cxG9rœ,œnameœ:œembeddingsœ,œoutput_typesœ:[œEmbeddingsœ]}",
+ "source": "EmbeddingModel-eZ6bT",
+ "sourceHandle": "{œdataTypeœ:œEmbeddingModelœ,œidœ:œEmbeddingModel-eZ6bTœ,œnameœ:œembeddingsœ,œoutput_typesœ:[œEmbeddingsœ]}",
"target": "OpenSearchHybrid-XtKoA",
"targetHandle": "{œfieldNameœ:œembeddingœ,œidœ:œOpenSearchHybrid-XtKoAœ,œinputTypesœ:[œEmbeddingsœ],œtypeœ:œotherœ}"
}
@@ -1631,7 +1631,7 @@
},
{
"data": {
- "id": "EmbeddingModel-cxG9r",
+ "id": "EmbeddingModel-eZ6bT",
"node": {
"base_classes": [
"Embeddings"
@@ -1657,7 +1657,7 @@
],
"frozen": false,
"icon": "binary",
- "last_updated": "2025-09-24T16:02:07.998Z",
+ "last_updated": "2025-09-22T15:54:52.885Z",
"legacy": false,
"metadata": {
"code_hash": "93faf11517da",
@@ -1738,7 +1738,7 @@
"show": true,
"title_case": false,
"type": "str",
- "value": ""
+ "value": "OPENAI_API_KEY"
},
"chunk_size": {
"_input_type": "IntInput",
@@ -1926,16 +1926,16 @@
"type": "EmbeddingModel"
},
"dragging": false,
- "id": "EmbeddingModel-cxG9r",
+ "id": "EmbeddingModel-eZ6bT",
"measured": {
- "height": 366,
+ "height": 369,
"width": 320
},
"position": {
- "x": 1743.8608432729177,
- "y": 1808.780792406514
+ "x": 1726.6943524438122,
+ "y": 1800.5330404375484
},
- "selected": false,
+ "selected": true,
"type": "genericNode"
}
],
diff --git a/frontend/components/label-wrapper.tsx b/frontend/components/label-wrapper.tsx
index ab785c5c..691b7726 100644
--- a/frontend/components/label-wrapper.tsx
+++ b/frontend/components/label-wrapper.tsx
@@ -10,18 +10,25 @@ export function LabelWrapper({
id,
required,
flex,
+ start,
children,
}: {
label: string;
description?: string;
- helperText?: string;
+ helperText?: string | React.ReactNode;
id: string;
required?: boolean;
flex?: boolean;
+ start?: boolean;
children: React.ReactNode;
}) {
return (
-
+
{label}
{required && * }
@@ -39,7 +46,7 @@ export function LabelWrapper({
- {helperText}
+ {helperText}
)}
@@ -48,7 +55,7 @@ export function LabelWrapper({
{description}
)}
- {flex &&
{children}
}
+ {flex &&
{children}
}
);
}
diff --git a/frontend/components/ui/tooltip.tsx b/frontend/components/ui/tooltip.tsx
index d7ca7125..3d86b3b3 100644
--- a/frontend/components/ui/tooltip.tsx
+++ b/frontend/components/ui/tooltip.tsx
@@ -19,7 +19,7 @@ const TooltipContent = React.forwardRef<
ref={ref}
sideOffset={sideOffset}
className={cn(
- "z-50 overflow-hidden rounded-md border bg-popover px-3 py-1.5 text-sm text-popover-foreground shadow-md animate-in fade-in-0 zoom-in-95 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2 origin-[--radix-tooltip-content-transform-origin]",
+ "z-50 overflow-hidden rounded-md border bg-primary py-1 px-1.5 text-xs font-normal text-primary-foreground shadow-md animate-in fade-in-0 zoom-in-95 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2 origin-[--radix-tooltip-content-transform-origin]",
className,
)}
{...props}
diff --git a/frontend/src/app/connectors/page.tsx b/frontend/src/app/connectors/page.tsx
index 26e1a88a..ad70ec90 100644
--- a/frontend/src/app/connectors/page.tsx
+++ b/frontend/src/app/connectors/page.tsx
@@ -1,20 +1,14 @@
-"use client"
+"use client";
import React, { useState } from "react";
-import { GoogleDrivePicker } from "@/components/google-drive-picker"
-import { useTask } from "@/contexts/task-context"
+import { UnifiedCloudPicker, CloudFile } from "@/components/cloud-picker";
+import { useTask } from "@/contexts/task-context";
-interface GoogleDriveFile {
- id: string;
- name: string;
- mimeType: string;
- webViewLink?: string;
- iconLink?: string;
-}
+// CloudFile interface is now imported from the unified cloud picker
export default function ConnectorsPage() {
- const { addTask } = useTask()
- const [selectedFiles, setSelectedFiles] = useState
([]);
+ const { addTask } = useTask();
+ const [selectedFiles, setSelectedFiles] = useState([]);
const [isSyncing, setIsSyncing] = useState(false);
const [syncResult, setSyncResult] = useState<{
processed?: number;
@@ -25,16 +19,19 @@ export default function ConnectorsPage() {
errors?: number;
} | null>(null);
- const handleFileSelection = (files: GoogleDriveFile[]) => {
+ const handleFileSelection = (files: CloudFile[]) => {
setSelectedFiles(files);
};
- const handleSync = async (connector: { connectionId: string, type: string }) => {
- if (!connector.connectionId || selectedFiles.length === 0) return
-
- setIsSyncing(true)
- setSyncResult(null)
-
+ const handleSync = async (connector: {
+ connectionId: string;
+ type: string;
+ }) => {
+ if (!connector.connectionId || selectedFiles.length === 0) return;
+
+ setIsSyncing(true);
+ setSyncResult(null);
+
try {
const syncBody: {
connection_id: string;
@@ -42,54 +39,55 @@ export default function ConnectorsPage() {
selected_files?: string[];
} = {
connection_id: connector.connectionId,
- selected_files: selectedFiles.map(file => file.id)
- }
-
+ selected_files: selectedFiles.map(file => file.id),
+ };
+
const response = await fetch(`/api/connectors/${connector.type}/sync`, {
- method: 'POST',
+ method: "POST",
headers: {
- 'Content-Type': 'application/json',
+ "Content-Type": "application/json",
},
body: JSON.stringify(syncBody),
- })
-
- const result = await response.json()
-
+ });
+
+ const result = await response.json();
+
if (response.status === 201) {
- const taskId = result.task_id
+ const taskId = result.task_id;
if (taskId) {
- addTask(taskId)
- setSyncResult({
- processed: 0,
+ addTask(taskId);
+ setSyncResult({
+ processed: 0,
total: selectedFiles.length,
- status: 'started'
- })
+ status: "started",
+ });
}
} else if (response.ok) {
- setSyncResult(result)
+ setSyncResult(result);
} else {
- console.error('Sync failed:', result.error)
- setSyncResult({ error: result.error || 'Sync failed' })
+ console.error("Sync failed:", result.error);
+ setSyncResult({ error: result.error || "Sync failed" });
}
} catch (error) {
- console.error('Sync error:', error)
- setSyncResult({ error: 'Network error occurred' })
+ console.error("Sync error:", error);
+ setSyncResult({ error: "Network error occurred" });
} finally {
- setIsSyncing(false)
+ setIsSyncing(false);
}
};
return (
Connectors
-
+
- This is a demo page for the Google Drive picker component.
- For full connector functionality, visit the Settings page.
+ This is a demo page for the Google Drive picker component. For full
+ connector functionality, visit the Settings page.
-
-
0 && (
-
handleSync({ connectionId: "google-drive-connection-id", type: "google-drive" })}
+
+ handleSync({
+ connectionId: "google-drive-connection-id",
+ type: "google-drive",
+ })
+ }
disabled={isSyncing}
className="px-4 py-2 bg-blue-500 text-white rounded hover:bg-blue-600 disabled:opacity-50 disabled:cursor-not-allowed"
>
@@ -110,14 +113,15 @@ export default function ConnectorsPage() {
<>Sync {selectedFiles.length} Selected Items>
)}
-
+
{syncResult && (
{syncResult.error ? (
Error: {syncResult.error}
- ) : syncResult.status === 'started' ? (
+ ) : syncResult.status === "started" ? (
- Sync started for {syncResult.total} files. Check the task notification for progress.
+ Sync started for {syncResult.total} files. Check the task
+ notification for progress.
) : (
diff --git a/frontend/src/app/onboarding/components/openai-onboarding.tsx b/frontend/src/app/onboarding/components/openai-onboarding.tsx
index 236097a4..b202756d 100644
--- a/frontend/src/app/onboarding/components/openai-onboarding.tsx
+++ b/frontend/src/app/onboarding/components/openai-onboarding.tsx
@@ -2,7 +2,7 @@ import { useState } from "react";
import { LabelInput } from "@/components/label-input";
import { LabelWrapper } from "@/components/label-wrapper";
import OpenAILogo from "@/components/logo/openai-logo";
-import { Switch } from "@/components/ui/switch";
+import { Checkbox } from "@/components/ui/checkbox";
import { useDebouncedValue } from "@/lib/debounce";
import type { OnboardingVariables } from "../../api/mutations/useOnboardingMutation";
import { useGetOpenAIModelsQuery } from "../../api/queries/useGetModelsQuery";
@@ -72,11 +72,19 @@ export function OpenAIOnboarding({
<>
+ Reuse the key from your environment config.
+
+ Uncheck to enter a different key.
+ >
+ }
flex
+ start
>
-
@@ -86,6 +94,7 @@ export function OpenAIOnboarding({
)}
{modelsError && (
-
+
Invalid OpenAI API key. Verify or replace the key.
)}
diff --git a/frontend/src/app/upload/[provider]/page.tsx b/frontend/src/app/upload/[provider]/page.tsx
index ea2d319e..522a45b7 100644
--- a/frontend/src/app/upload/[provider]/page.tsx
+++ b/frontend/src/app/upload/[provider]/page.tsx
@@ -4,29 +4,12 @@ import { useState, useEffect } from "react";
import { useParams, useRouter } from "next/navigation";
import { Button } from "@/components/ui/button";
import { ArrowLeft, AlertCircle } from "lucide-react";
-import { GoogleDrivePicker } from "@/components/google-drive-picker";
-import { OneDrivePicker } from "@/components/onedrive-picker";
+import { UnifiedCloudPicker, CloudFile } from "@/components/cloud-picker";
+import type { IngestSettings } from "@/components/cloud-picker/types";
import { useTask } from "@/contexts/task-context";
import { Toast } from "@/components/ui/toast";
-interface GoogleDriveFile {
- id: string;
- name: string;
- mimeType: string;
- webViewLink?: string;
- iconLink?: string;
-}
-
-interface OneDriveFile {
- id: string;
- name: string;
- mimeType?: string;
- webUrl?: string;
- driveItem?: {
- file?: { mimeType: string };
- folder?: unknown;
- };
-}
+// CloudFile interface is now imported from the unified cloud picker
interface CloudConnector {
id: string;
@@ -35,6 +18,7 @@ interface CloudConnector {
status: "not_connected" | "connecting" | "connected" | "error";
type: string;
connectionId?: string;
+ clientId: string;
hasAccessToken: boolean;
accessTokenError?: string;
}
@@ -49,14 +33,19 @@ export default function UploadProviderPage() {
const [isLoading, setIsLoading] = useState(true);
const [error, setError] = useState(null);
const [accessToken, setAccessToken] = useState(null);
- const [selectedFiles, setSelectedFiles] = useState<
- GoogleDriveFile[] | OneDriveFile[]
- >([]);
+ const [selectedFiles, setSelectedFiles] = useState([]);
const [isIngesting, setIsIngesting] = useState(false);
const [currentSyncTaskId, setCurrentSyncTaskId] = useState(
null
);
const [showSuccessToast, setShowSuccessToast] = useState(false);
+ const [ingestSettings, setIngestSettings] = useState({
+ chunkSize: 1000,
+ chunkOverlap: 200,
+ ocr: false,
+ pictureDescriptions: false,
+ embeddingModel: "text-embedding-3-small",
+ });
useEffect(() => {
const fetchConnectorInfo = async () => {
@@ -129,6 +118,7 @@ export default function UploadProviderPage() {
status: isConnected ? "connected" : "not_connected",
type: provider,
connectionId: activeConnection?.connection_id,
+ clientId: activeConnection?.client_id,
hasAccessToken,
accessTokenError,
});
@@ -159,13 +149,6 @@ export default function UploadProviderPage() {
// Task completed successfully, show toast and redirect
setIsIngesting(false);
setShowSuccessToast(true);
-
- // Dispatch knowledge updated event to refresh the knowledge table
- console.log(
- "Cloud provider task completed, dispatching knowledgeUpdated event"
- );
- window.dispatchEvent(new CustomEvent("knowledgeUpdated"));
-
setTimeout(() => {
router.push("/knowledge");
}, 2000); // 2 second delay to let user see toast
@@ -176,20 +159,12 @@ export default function UploadProviderPage() {
}
}, [tasks, currentSyncTaskId, router]);
- const handleFileSelected = (files: GoogleDriveFile[] | OneDriveFile[]) => {
+ const handleFileSelected = (files: CloudFile[]) => {
setSelectedFiles(files);
console.log(`Selected ${files.length} files from ${provider}:`, files);
// You can add additional handling here like triggering sync, etc.
};
- const handleGoogleDriveFileSelected = (files: GoogleDriveFile[]) => {
- handleFileSelected(files);
- };
-
- const handleOneDriveFileSelected = (files: OneDriveFile[]) => {
- handleFileSelected(files);
- };
-
const handleSync = async (connector: CloudConnector) => {
if (!connector.connectionId || selectedFiles.length === 0) return;
@@ -200,9 +175,11 @@ export default function UploadProviderPage() {
connection_id: string;
max_files?: number;
selected_files?: string[];
+ settings?: IngestSettings;
} = {
connection_id: connector.connectionId,
selected_files: selectedFiles.map(file => file.id),
+ settings: ingestSettings,
};
const response = await fetch(`/api/connectors/${connector.type}/sync`, {
@@ -353,48 +330,49 @@ export default function UploadProviderPage() {
router.back()}>
-
+
-
Add Cloud Knowledge
+
+ Add from {getProviderDisplayName()}
+
- {connector.type === "google_drive" && (
-
- )}
-
- {(connector.type === "onedrive" || connector.type === "sharepoint") && (
-
- )}
+
- {selectedFiles.length > 0 && (
-
-
- handleSync(connector)}
- disabled={selectedFiles.length === 0 || isIngesting}
- >
- {isIngesting ? (
- <>Ingesting {selectedFiles.length} Files...>
- ) : (
- <>Ingest Files ({selectedFiles.length})>
- )}
-
-
+
+
+ router.back()}
+ >
+ Back
+
+ handleSync(connector)}
+ disabled={selectedFiles.length === 0 || isIngesting}
+ >
+ {isIngesting ? (
+ <>Ingesting {selectedFiles.length} Files...>
+ ) : (
+ <>Start ingest>
+ )}
+
- )}
+
{/* Success toast notification */}
void
- onFileSelected?: (files: GoogleDriveFile[] | OneDriveFile[], connectorType: string) => void
+ isOpen: boolean;
+ onOpenChange: (open: boolean) => void;
+ onFileSelected?: (files: CloudFile[], connectorType: string) => void;
}
-export function CloudConnectorsDialog({
- isOpen,
+export function CloudConnectorsDialog({
+ isOpen,
onOpenChange,
- onFileSelected
+ onFileSelected,
}: CloudConnectorsDialogProps) {
- const [connectors, setConnectors] = useState([])
- const [isLoading, setIsLoading] = useState(true)
- const [selectedFiles, setSelectedFiles] = useState<{[connectorId: string]: GoogleDriveFile[] | OneDriveFile[]}>({})
- const [connectorAccessTokens, setConnectorAccessTokens] = useState<{[connectorType: string]: string}>({})
- const [activePickerType, setActivePickerType] = useState(null)
+ const [connectors, setConnectors] = useState([]);
+ const [isLoading, setIsLoading] = useState(true);
+ const [selectedFiles, setSelectedFiles] = useState<{
+ [connectorId: string]: CloudFile[];
+ }>({});
+ const [connectorAccessTokens, setConnectorAccessTokens] = useState<{
+ [connectorType: string]: string;
+ }>({});
+ const [activePickerType, setActivePickerType] = useState(null);
const getConnectorIcon = (iconName: string) => {
const iconMap: { [key: string]: React.ReactElement } = {
- 'google-drive': (
+ "google-drive": (
G
),
- 'sharepoint': (
+ sharepoint: (
SP
),
- 'onedrive': (
+ onedrive: (
OD
),
- }
- return iconMap[iconName] || (
-
- ?
-
- )
- }
+ };
+ return (
+ iconMap[iconName] || (
+
+ ?
+
+ )
+ );
+ };
const fetchConnectorStatuses = useCallback(async () => {
- if (!isOpen) return
-
- setIsLoading(true)
+ if (!isOpen) return;
+
+ setIsLoading(true);
try {
// Fetch available connectors from backend
- const connectorsResponse = await fetch('/api/connectors')
+ const connectorsResponse = await fetch("/api/connectors");
if (!connectorsResponse.ok) {
- throw new Error('Failed to load connectors')
+ throw new Error("Failed to load connectors");
}
-
- const connectorsResult = await connectorsResponse.json()
- const connectorTypes = Object.keys(connectorsResult.connectors)
-
+
+ const connectorsResult = await connectorsResponse.json();
+ const connectorTypes = Object.keys(connectorsResult.connectors);
+
// Filter to only cloud connectors
- const cloudConnectorTypes = connectorTypes.filter(type =>
- ['google_drive', 'onedrive', 'sharepoint'].includes(type) &&
- connectorsResult.connectors[type].available
- )
-
+ const cloudConnectorTypes = connectorTypes.filter(
+ type =>
+ ["google_drive", "onedrive", "sharepoint"].includes(type) &&
+ connectorsResult.connectors[type].available
+ );
+
// Initialize connectors list
const initialConnectors = cloudConnectorTypes.map(type => ({
id: type,
@@ -115,82 +105,95 @@ export function CloudConnectorsDialog({
status: "not_connected" as const,
type: type,
hasAccessToken: false,
- accessTokenError: undefined
- }))
-
- setConnectors(initialConnectors)
+ accessTokenError: undefined,
+ clientId: "",
+ }));
+
+ setConnectors(initialConnectors);
// Check status for each cloud connector type
for (const connectorType of cloudConnectorTypes) {
try {
- const response = await fetch(`/api/connectors/${connectorType}/status`)
+ const response = await fetch(
+ `/api/connectors/${connectorType}/status`
+ );
if (response.ok) {
- const data = await response.json()
- const connections = data.connections || []
- const activeConnection = connections.find((conn: { connection_id: string; is_active: boolean }) => conn.is_active)
- const isConnected = activeConnection !== undefined
-
- let hasAccessToken = false
- let accessTokenError: string | undefined = undefined
+ const data = await response.json();
+ const connections = data.connections || [];
+ const activeConnection = connections.find(
+ (conn: { connection_id: string; is_active: boolean }) =>
+ conn.is_active
+ );
+ const isConnected = activeConnection !== undefined;
+
+ let hasAccessToken = false;
+ let accessTokenError: string | undefined = undefined;
// Try to get access token for connected connectors
if (isConnected && activeConnection) {
try {
- const tokenResponse = await fetch(`/api/connectors/${connectorType}/token?connection_id=${activeConnection.connection_id}`)
+ const tokenResponse = await fetch(
+ `/api/connectors/${connectorType}/token?connection_id=${activeConnection.connection_id}`
+ );
if (tokenResponse.ok) {
- const tokenData = await tokenResponse.json()
+ const tokenData = await tokenResponse.json();
if (tokenData.access_token) {
- hasAccessToken = true
+ hasAccessToken = true;
setConnectorAccessTokens(prev => ({
...prev,
- [connectorType]: tokenData.access_token
- }))
+ [connectorType]: tokenData.access_token,
+ }));
}
} else {
- const errorData = await tokenResponse.json().catch(() => ({ error: 'Token unavailable' }))
- accessTokenError = errorData.error || 'Access token unavailable'
+ const errorData = await tokenResponse
+ .json()
+ .catch(() => ({ error: "Token unavailable" }));
+ accessTokenError =
+ errorData.error || "Access token unavailable";
}
} catch {
- accessTokenError = 'Failed to fetch access token'
+ accessTokenError = "Failed to fetch access token";
}
}
-
- setConnectors(prev => prev.map(c =>
- c.type === connectorType
- ? {
- ...c,
- status: isConnected ? "connected" : "not_connected",
- connectionId: activeConnection?.connection_id,
- hasAccessToken,
- accessTokenError
- }
- : c
- ))
+
+ setConnectors(prev =>
+ prev.map(c =>
+ c.type === connectorType
+ ? {
+ ...c,
+ status: isConnected ? "connected" : "not_connected",
+ connectionId: activeConnection?.connection_id,
+ clientId: activeConnection?.client_id,
+ hasAccessToken,
+ accessTokenError,
+ }
+ : c
+ )
+ );
}
} catch (error) {
- console.error(`Failed to check status for ${connectorType}:`, error)
+ console.error(`Failed to check status for ${connectorType}:`, error);
}
}
} catch (error) {
- console.error('Failed to load cloud connectors:', error)
+ console.error("Failed to load cloud connectors:", error);
} finally {
- setIsLoading(false)
+ setIsLoading(false);
}
- }, [isOpen])
+ }, [isOpen]);
- const handleFileSelection = (connectorId: string, files: GoogleDriveFile[] | OneDriveFile[]) => {
+ const handleFileSelection = (connectorId: string, files: CloudFile[]) => {
setSelectedFiles(prev => ({
...prev,
- [connectorId]: files
- }))
-
- onFileSelected?.(files, connectorId)
- }
+ [connectorId]: files,
+ }));
+
+ onFileSelected?.(files, connectorId);
+ };
useEffect(() => {
- fetchConnectorStatuses()
- }, [fetchConnectorStatuses])
-
+ fetchConnectorStatuses();
+ }, [fetchConnectorStatuses]);
return (
@@ -218,19 +221,24 @@ export function CloudConnectorsDialog({
{connectors
.filter(connector => connector.status === "connected")
- .map((connector) => (
+ .map(connector => (
{
- e.preventDefault()
- e.stopPropagation()
+ title={
+ !connector.hasAccessToken
+ ? connector.accessTokenError ||
+ "Access token required - try reconnecting your account"
+ : `Select files from ${connector.name}`
+ }
+ onClick={e => {
+ e.preventDefault();
+ e.stopPropagation();
if (connector.hasAccessToken) {
- setActivePickerType(connector.id)
+ setActivePickerType(connector.id);
}
}}
className="min-w-[120px]"
@@ -243,54 +251,46 @@ export function CloudConnectorsDialog({
{connectors.every(c => c.status !== "connected") && (
No connected cloud providers found.
-
Go to Settings to connect your cloud storage accounts.
+
+ Go to Settings to connect your cloud storage accounts.
+
)}
-
- {/* Render pickers inside dialog */}
- {activePickerType && connectors.find(c => c.id === activePickerType) && (() => {
- const connector = connectors.find(c => c.id === activePickerType)!
-
- if (connector.type === "google_drive") {
+
+ {/* Render unified picker inside dialog */}
+ {activePickerType &&
+ connectors.find(c => c.id === activePickerType) &&
+ (() => {
+ const connector = connectors.find(
+ c => c.id === activePickerType
+ )!;
+
return (
- {
- handleFileSelection(connector.id, files)
- setActivePickerType(null)
+ {
+ handleFileSelection(connector.id, files);
+ setActivePickerType(null);
}}
- selectedFiles={selectedFiles[connector.id] as GoogleDriveFile[] || []}
+ selectedFiles={selectedFiles[connector.id] || []}
isAuthenticated={connector.status === "connected"}
accessToken={connectorAccessTokens[connector.type]}
onPickerStateChange={() => {}}
+ clientId={connector.clientId}
/>
- )
- }
-
- if (connector.type === "onedrive" || connector.type === "sharepoint") {
- return (
-
- {
- handleFileSelection(connector.id, files)
- setActivePickerType(null)
- }}
- selectedFiles={selectedFiles[connector.id] as OneDriveFile[] || []}
- isAuthenticated={connector.status === "connected"}
- accessToken={connectorAccessTokens[connector.type]}
- connectorType={connector.type as "onedrive" | "sharepoint"}
- />
-
- )
- }
-
- return null
- })()}
+ );
+ })()}
)}
- )
-}
\ No newline at end of file
+ );
+}
diff --git a/frontend/src/components/cloud-picker/file-item.tsx b/frontend/src/components/cloud-picker/file-item.tsx
new file mode 100644
index 00000000..3f6b5ab5
--- /dev/null
+++ b/frontend/src/components/cloud-picker/file-item.tsx
@@ -0,0 +1,67 @@
+"use client";
+
+import { Badge } from "@/components/ui/badge";
+import { FileText, Folder, Trash } from "lucide-react";
+import { CloudFile } from "./types";
+
+interface FileItemProps {
+ file: CloudFile;
+ onRemove: (fileId: string) => void;
+}
+
+const getFileIcon = (mimeType: string) => {
+ if (mimeType.includes("folder")) {
+ return
;
+ }
+ return
;
+};
+
+const getMimeTypeLabel = (mimeType: string) => {
+ const typeMap: { [key: string]: string } = {
+ "application/vnd.google-apps.document": "Google Doc",
+ "application/vnd.google-apps.spreadsheet": "Google Sheet",
+ "application/vnd.google-apps.presentation": "Google Slides",
+ "application/vnd.google-apps.folder": "Folder",
+ "application/pdf": "PDF",
+ "text/plain": "Text",
+ "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
+ "Word Doc",
+ "application/vnd.openxmlformats-officedocument.presentationml.presentation":
+ "PowerPoint",
+ };
+
+ return typeMap[mimeType] || mimeType?.split("/").pop() || "Document";
+};
+
+const formatFileSize = (bytes?: number) => {
+ if (!bytes) return "";
+ const sizes = ["B", "KB", "MB", "GB", "TB"];
+ if (bytes === 0) return "0 B";
+ const i = Math.floor(Math.log(bytes) / Math.log(1024));
+ return `${(bytes / Math.pow(1024, i)).toFixed(1)} ${sizes[i]}`;
+};
+
+export const FileItem = ({ file, onRemove }: FileItemProps) => (
+
+
+ {getFileIcon(file.mimeType)}
+ {file.name}
+
+ {getMimeTypeLabel(file.mimeType)}
+
+
+
+
+ {formatFileSize(file.size) || "—"}
+
+
+ onRemove(file.id)}
+ />
+
+
+);
diff --git a/frontend/src/components/cloud-picker/file-list.tsx b/frontend/src/components/cloud-picker/file-list.tsx
new file mode 100644
index 00000000..775d78c4
--- /dev/null
+++ b/frontend/src/components/cloud-picker/file-list.tsx
@@ -0,0 +1,42 @@
+"use client";
+
+import { Button } from "@/components/ui/button";
+import { CloudFile } from "./types";
+import { FileItem } from "./file-item";
+
+interface FileListProps {
+ files: CloudFile[];
+ onClearAll: () => void;
+ onRemoveFile: (fileId: string) => void;
+}
+
+export const FileList = ({
+ files,
+ onClearAll,
+ onRemoveFile,
+}: FileListProps) => {
+ if (files.length === 0) {
+ return null;
+ }
+
+ return (
+
+
+
Added files
+
+ Clear all
+
+
+
+ {files.map(file => (
+
+ ))}
+
+
+ );
+};
diff --git a/frontend/src/components/cloud-picker/index.ts b/frontend/src/components/cloud-picker/index.ts
new file mode 100644
index 00000000..ef7aa74b
--- /dev/null
+++ b/frontend/src/components/cloud-picker/index.ts
@@ -0,0 +1,7 @@
+export { UnifiedCloudPicker } from "./unified-cloud-picker";
+export { PickerHeader } from "./picker-header";
+export { FileList } from "./file-list";
+export { FileItem } from "./file-item";
+export { IngestSettings } from "./ingest-settings";
+export * from "./types";
+export * from "./provider-handlers";
diff --git a/frontend/src/components/cloud-picker/ingest-settings.tsx b/frontend/src/components/cloud-picker/ingest-settings.tsx
new file mode 100644
index 00000000..d5843a2a
--- /dev/null
+++ b/frontend/src/components/cloud-picker/ingest-settings.tsx
@@ -0,0 +1,139 @@
+"use client";
+
+import { Input } from "@/components/ui/input";
+import { Switch } from "@/components/ui/switch";
+import {
+ Collapsible,
+ CollapsibleContent,
+ CollapsibleTrigger,
+} from "@/components/ui/collapsible";
+import { ChevronRight, Info } from "lucide-react";
+import { IngestSettings as IngestSettingsType } from "./types";
+
+interface IngestSettingsProps {
+ isOpen: boolean;
+ onOpenChange: (open: boolean) => void;
+ settings?: IngestSettingsType;
+ onSettingsChange?: (settings: IngestSettingsType) => void;
+}
+
+export const IngestSettings = ({
+ isOpen,
+ onOpenChange,
+ settings,
+ onSettingsChange,
+}: IngestSettingsProps) => {
+ // Default settings
+ const defaultSettings: IngestSettingsType = {
+ chunkSize: 1000,
+ chunkOverlap: 200,
+ ocr: false,
+ pictureDescriptions: false,
+ embeddingModel: "text-embedding-3-small",
+ };
+
+ // Use provided settings or defaults
+ const currentSettings = settings || defaultSettings;
+
+ const handleSettingsChange = (newSettings: Partial
) => {
+ const updatedSettings = { ...currentSettings, ...newSettings };
+ onSettingsChange?.(updatedSettings);
+ };
+
+ return (
+
+
+
+
+ Ingest settings
+
+
+
+
+
+
+
+
Chunk size
+
+ handleSettingsChange({
+ chunkSize: parseInt(e.target.value) || 0,
+ })
+ }
+ />
+
+
+
Chunk overlap
+
+ handleSettingsChange({
+ chunkOverlap: parseInt(e.target.value) || 0,
+ })
+ }
+ />
+
+
+
+
+
+
OCR
+
+ Extracts text from images/PDFs. Ingest is slower when enabled.
+
+
+
+ handleSettingsChange({ ocr: checked })
+ }
+ />
+
+
+
+
+
+ Picture descriptions
+
+
+ Adds captions for images. Ingest is more expensive when enabled.
+
+
+
+ handleSettingsChange({ pictureDescriptions: checked })
+ }
+ />
+
+
+
+
+ Embedding model
+
+
+
+ handleSettingsChange({ embeddingModel: e.target.value })
+ }
+ placeholder="text-embedding-3-small"
+ />
+
+
+
+
+ );
+};
diff --git a/frontend/src/components/cloud-picker/picker-header.tsx b/frontend/src/components/cloud-picker/picker-header.tsx
new file mode 100644
index 00000000..05dcaebd
--- /dev/null
+++ b/frontend/src/components/cloud-picker/picker-header.tsx
@@ -0,0 +1,70 @@
+"use client";
+
+import { Button } from "@/components/ui/button";
+import { Card, CardContent } from "@/components/ui/card";
+import { Plus } from "lucide-react";
+import { CloudProvider } from "./types";
+
+interface PickerHeaderProps {
+ provider: CloudProvider;
+ onAddFiles: () => void;
+ isPickerLoaded: boolean;
+ isPickerOpen: boolean;
+ accessToken?: string;
+ isAuthenticated: boolean;
+}
+
+const getProviderName = (provider: CloudProvider): string => {
+ switch (provider) {
+ case "google_drive":
+ return "Google Drive";
+ case "onedrive":
+ return "OneDrive";
+ case "sharepoint":
+ return "SharePoint";
+ default:
+ return "Cloud Storage";
+ }
+};
+
+export const PickerHeader = ({
+ provider,
+ onAddFiles,
+ isPickerLoaded,
+ isPickerOpen,
+ accessToken,
+ isAuthenticated,
+}: PickerHeaderProps) => {
+ if (!isAuthenticated) {
+ return (
+
+ Please connect to {getProviderName(provider)} first to select specific
+ files.
+
+ );
+ }
+
+ return (
+
+
+
+ Select files from {getProviderName(provider)} to ingest.
+
+
+
+ {isPickerOpen ? "Opening Picker..." : "Add Files"}
+
+
+ csv, json, pdf,{" "}
+
+16 more {" "}
+
150 MB max
+
+
+
+ );
+};
diff --git a/frontend/src/components/cloud-picker/provider-handlers.ts b/frontend/src/components/cloud-picker/provider-handlers.ts
new file mode 100644
index 00000000..4a39312f
--- /dev/null
+++ b/frontend/src/components/cloud-picker/provider-handlers.ts
@@ -0,0 +1,245 @@
+"use client";
+
+import {
+ CloudFile,
+ CloudProvider,
+ GooglePickerData,
+ GooglePickerDocument,
+} from "./types";
+
+export class GoogleDriveHandler {
+ private accessToken: string;
+ private onPickerStateChange?: (isOpen: boolean) => void;
+
+ constructor(
+ accessToken: string,
+ onPickerStateChange?: (isOpen: boolean) => void
+ ) {
+ this.accessToken = accessToken;
+ this.onPickerStateChange = onPickerStateChange;
+ }
+
+ async loadPickerApi(): Promise {
+ return new Promise(resolve => {
+ if (typeof window !== "undefined" && window.gapi) {
+ window.gapi.load("picker", {
+ callback: () => resolve(true),
+ onerror: () => resolve(false),
+ });
+ } else {
+ // Load Google API script
+ const script = document.createElement("script");
+ script.src = "https://apis.google.com/js/api.js";
+ script.async = true;
+ script.defer = true;
+ script.onload = () => {
+ window.gapi.load("picker", {
+ callback: () => resolve(true),
+ onerror: () => resolve(false),
+ });
+ };
+ script.onerror = () => resolve(false);
+ document.head.appendChild(script);
+ }
+ });
+ }
+
+ openPicker(onFileSelected: (files: CloudFile[]) => void): void {
+ if (!window.google?.picker) {
+ return;
+ }
+
+ try {
+ this.onPickerStateChange?.(true);
+
+ const picker = new window.google.picker.PickerBuilder()
+ .addView(window.google.picker.ViewId.DOCS)
+ .addView(window.google.picker.ViewId.FOLDERS)
+ .setOAuthToken(this.accessToken)
+ .enableFeature(window.google.picker.Feature.MULTISELECT_ENABLED)
+ .setTitle("Select files from Google Drive")
+ .setCallback(data => this.pickerCallback(data, onFileSelected))
+ .build();
+
+ picker.setVisible(true);
+
+ // Apply z-index fix
+ setTimeout(() => {
+ const pickerElements = document.querySelectorAll(
+ ".picker-dialog, .goog-modalpopup"
+ );
+ pickerElements.forEach(el => {
+ (el as HTMLElement).style.zIndex = "10000";
+ });
+ const bgElements = document.querySelectorAll(
+ ".picker-dialog-bg, .goog-modalpopup-bg"
+ );
+ bgElements.forEach(el => {
+ (el as HTMLElement).style.zIndex = "9999";
+ });
+ }, 100);
+ } catch (error) {
+ console.error("Error creating picker:", error);
+ this.onPickerStateChange?.(false);
+ }
+ }
+
+ private async pickerCallback(
+ data: GooglePickerData,
+ onFileSelected: (files: CloudFile[]) => void
+ ): Promise {
+ if (data.action === window.google.picker.Action.PICKED) {
+ const files: CloudFile[] = data.docs.map((doc: GooglePickerDocument) => ({
+ id: doc[window.google.picker.Document.ID],
+ name: doc[window.google.picker.Document.NAME],
+ mimeType: doc[window.google.picker.Document.MIME_TYPE],
+ webViewLink: doc[window.google.picker.Document.URL],
+ iconLink: doc[window.google.picker.Document.ICON_URL],
+ size: doc["sizeBytes"] ? parseInt(doc["sizeBytes"]) : undefined,
+ modifiedTime: doc["lastEditedUtc"],
+ isFolder:
+ doc[window.google.picker.Document.MIME_TYPE] ===
+ "application/vnd.google-apps.folder",
+ }));
+
+ // Enrich with additional file data if needed
+ if (files.some(f => !f.size && !f.isFolder)) {
+ try {
+ const enrichedFiles = await Promise.all(
+ files.map(async file => {
+ if (!file.size && !file.isFolder) {
+ try {
+ const response = await fetch(
+ `https://www.googleapis.com/drive/v3/files/${file.id}?fields=size,modifiedTime`,
+ {
+ headers: {
+ Authorization: `Bearer ${this.accessToken}`,
+ },
+ }
+ );
+ if (response.ok) {
+ const fileDetails = await response.json();
+ return {
+ ...file,
+ size: fileDetails.size
+ ? parseInt(fileDetails.size)
+ : undefined,
+ modifiedTime:
+ fileDetails.modifiedTime || file.modifiedTime,
+ };
+ }
+ } catch (error) {
+ console.warn("Failed to fetch file details:", error);
+ }
+ }
+ return file;
+ })
+ );
+ onFileSelected(enrichedFiles);
+ } catch (error) {
+ console.warn("Failed to enrich file data:", error);
+ onFileSelected(files);
+ }
+ } else {
+ onFileSelected(files);
+ }
+ }
+
+ this.onPickerStateChange?.(false);
+ }
+}
+
+export class OneDriveHandler {
+ private accessToken: string;
+ private clientId: string;
+ private provider: CloudProvider;
+ private baseUrl?: string;
+
+ constructor(
+ accessToken: string,
+ clientId: string,
+ provider: CloudProvider = "onedrive",
+ baseUrl?: string
+ ) {
+ this.accessToken = accessToken;
+ this.clientId = clientId;
+ this.provider = provider;
+ this.baseUrl = baseUrl;
+ }
+
+ async loadPickerApi(): Promise {
+ return new Promise(resolve => {
+ const script = document.createElement("script");
+ script.src = "https://js.live.net/v7.2/OneDrive.js";
+ script.onload = () => resolve(true);
+ script.onerror = () => resolve(false);
+ document.head.appendChild(script);
+ });
+ }
+
+ openPicker(onFileSelected: (files: CloudFile[]) => void): void {
+ if (!window.OneDrive) {
+ return;
+ }
+
+ window.OneDrive.open({
+ clientId: this.clientId,
+ action: "query",
+ multiSelect: true,
+ advanced: {
+ endpointHint: "api.onedrive.com",
+ accessToken: this.accessToken,
+ },
+ success: (response: any) => {
+ const newFiles: CloudFile[] =
+ response.value?.map((item: any, index: number) => ({
+ id: item.id,
+ name:
+ item.name ||
+ `${this.getProviderName()} File ${index + 1} (${item.id.slice(
+ -8
+ )})`,
+ mimeType: item.file?.mimeType || "application/octet-stream",
+ webUrl: item.webUrl || "",
+ downloadUrl: item["@microsoft.graph.downloadUrl"] || "",
+ size: item.size,
+ modifiedTime: item.lastModifiedDateTime,
+ isFolder: !!item.folder,
+ })) || [];
+
+ onFileSelected(newFiles);
+ },
+ cancel: () => {
+ console.log("Picker cancelled");
+ },
+ error: (error: any) => {
+ console.error("Picker error:", error);
+ },
+ });
+ }
+
+ private getProviderName(): string {
+ return this.provider === "sharepoint" ? "SharePoint" : "OneDrive";
+ }
+}
+
+export const createProviderHandler = (
+ provider: CloudProvider,
+ accessToken: string,
+ onPickerStateChange?: (isOpen: boolean) => void,
+ clientId?: string,
+ baseUrl?: string
+) => {
+ switch (provider) {
+ case "google_drive":
+ return new GoogleDriveHandler(accessToken, onPickerStateChange);
+ case "onedrive":
+ case "sharepoint":
+ if (!clientId) {
+ throw new Error("Client ID required for OneDrive/SharePoint");
+ }
+ return new OneDriveHandler(accessToken, clientId, provider, baseUrl);
+ default:
+ throw new Error(`Unsupported provider: ${provider}`);
+ }
+};
diff --git a/frontend/src/components/cloud-picker/types.ts b/frontend/src/components/cloud-picker/types.ts
new file mode 100644
index 00000000..ca346bf0
--- /dev/null
+++ b/frontend/src/components/cloud-picker/types.ts
@@ -0,0 +1,106 @@
+export interface CloudFile {
+ id: string;
+ name: string;
+ mimeType: string;
+ webViewLink?: string;
+ iconLink?: string;
+ size?: number;
+ modifiedTime?: string;
+ isFolder?: boolean;
+ webUrl?: string;
+ downloadUrl?: string;
+}
+
+export type CloudProvider = "google_drive" | "onedrive" | "sharepoint";
+
+export interface UnifiedCloudPickerProps {
+ provider: CloudProvider;
+ onFileSelected: (files: CloudFile[]) => void;
+ selectedFiles?: CloudFile[];
+ isAuthenticated: boolean;
+ accessToken?: string;
+ onPickerStateChange?: (isOpen: boolean) => void;
+ // OneDrive/SharePoint specific props
+ clientId?: string;
+ baseUrl?: string;
+ // Ingest settings
+ onSettingsChange?: (settings: IngestSettings) => void;
+}
+
+export interface GoogleAPI {
+ load: (
+ api: string,
+ options: { callback: () => void; onerror?: () => void }
+ ) => void;
+}
+
+export interface GooglePickerData {
+ action: string;
+ docs: GooglePickerDocument[];
+}
+
+export interface GooglePickerDocument {
+ [key: string]: string;
+}
+
+declare global {
+ interface Window {
+ gapi: GoogleAPI;
+ google: {
+ picker: {
+ api: {
+ load: (callback: () => void) => void;
+ };
+ PickerBuilder: new () => GooglePickerBuilder;
+ ViewId: {
+ DOCS: string;
+ FOLDERS: string;
+ DOCS_IMAGES_AND_VIDEOS: string;
+ DOCUMENTS: string;
+ PRESENTATIONS: string;
+ SPREADSHEETS: string;
+ };
+ Feature: {
+ MULTISELECT_ENABLED: string;
+ NAV_HIDDEN: string;
+ SIMPLE_UPLOAD_ENABLED: string;
+ };
+ Action: {
+ PICKED: string;
+ CANCEL: string;
+ };
+ Document: {
+ ID: string;
+ NAME: string;
+ MIME_TYPE: string;
+ URL: string;
+ ICON_URL: string;
+ };
+ };
+ };
+ OneDrive?: any;
+ }
+}
+
+export interface GooglePickerBuilder {
+ addView: (view: string) => GooglePickerBuilder;
+ setOAuthToken: (token: string) => GooglePickerBuilder;
+ setCallback: (
+ callback: (data: GooglePickerData) => void
+ ) => GooglePickerBuilder;
+ enableFeature: (feature: string) => GooglePickerBuilder;
+ setTitle: (title: string) => GooglePickerBuilder;
+ build: () => GooglePicker;
+}
+
+export interface GooglePicker {
+ setVisible: (visible: boolean) => void;
+}
+
+export interface IngestSettings {
+ chunkSize: number;
+ chunkOverlap: number;
+ ocr: boolean;
+ pictureDescriptions: boolean;
+ embeddingModel: string;
+}
diff --git a/frontend/src/components/cloud-picker/unified-cloud-picker.tsx b/frontend/src/components/cloud-picker/unified-cloud-picker.tsx
new file mode 100644
index 00000000..fd77698f
--- /dev/null
+++ b/frontend/src/components/cloud-picker/unified-cloud-picker.tsx
@@ -0,0 +1,195 @@
+"use client";
+
+import { useState, useEffect } from "react";
+import {
+ UnifiedCloudPickerProps,
+ CloudFile,
+ IngestSettings as IngestSettingsType,
+} from "./types";
+import { PickerHeader } from "./picker-header";
+import { FileList } from "./file-list";
+import { IngestSettings } from "./ingest-settings";
+import { createProviderHandler } from "./provider-handlers";
+
+export const UnifiedCloudPicker = ({
+ provider,
+ onFileSelected,
+ selectedFiles = [],
+ isAuthenticated,
+ accessToken,
+ onPickerStateChange,
+ clientId,
+ baseUrl,
+ onSettingsChange,
+}: UnifiedCloudPickerProps) => {
+ const [isPickerLoaded, setIsPickerLoaded] = useState(false);
+ const [isPickerOpen, setIsPickerOpen] = useState(false);
+ const [isIngestSettingsOpen, setIsIngestSettingsOpen] = useState(false);
+ const [isLoadingBaseUrl, setIsLoadingBaseUrl] = useState(false);
+ const [autoBaseUrl, setAutoBaseUrl] = useState(undefined);
+
+ // Settings state with defaults
+ const [ingestSettings, setIngestSettings] = useState({
+ chunkSize: 1000,
+ chunkOverlap: 200,
+ ocr: false,
+ pictureDescriptions: false,
+ embeddingModel: "text-embedding-3-small",
+ });
+
+ // Handle settings changes and notify parent
+ const handleSettingsChange = (newSettings: IngestSettingsType) => {
+ setIngestSettings(newSettings);
+ onSettingsChange?.(newSettings);
+ };
+
+ const effectiveBaseUrl = baseUrl || autoBaseUrl;
+
+ // Auto-detect base URL for OneDrive personal accounts
+ useEffect(() => {
+ if (
+ (provider === "onedrive" || provider === "sharepoint") &&
+ !baseUrl &&
+ accessToken &&
+ !autoBaseUrl
+ ) {
+ const getBaseUrl = async () => {
+ setIsLoadingBaseUrl(true);
+ try {
+ setAutoBaseUrl("https://onedrive.live.com/picker");
+ } catch (error) {
+ console.error("Auto-detect baseUrl failed:", error);
+ } finally {
+ setIsLoadingBaseUrl(false);
+ }
+ };
+
+ getBaseUrl();
+ }
+ }, [accessToken, baseUrl, autoBaseUrl, provider]);
+
+ // Load picker API
+ useEffect(() => {
+ if (!accessToken || !isAuthenticated) return;
+
+ const loadApi = async () => {
+ try {
+ const handler = createProviderHandler(
+ provider,
+ accessToken,
+ onPickerStateChange,
+ clientId,
+ effectiveBaseUrl
+ );
+ const loaded = await handler.loadPickerApi();
+ setIsPickerLoaded(loaded);
+ } catch (error) {
+ console.error("Failed to create provider handler:", error);
+ setIsPickerLoaded(false);
+ }
+ };
+
+ loadApi();
+ }, [
+ accessToken,
+ isAuthenticated,
+ provider,
+ clientId,
+ effectiveBaseUrl,
+ onPickerStateChange,
+ ]);
+
+ const handleAddFiles = () => {
+ if (!isPickerLoaded || !accessToken) {
+ return;
+ }
+
+ if ((provider === "onedrive" || provider === "sharepoint") && !clientId) {
+ console.error("Client ID required for OneDrive/SharePoint");
+ return;
+ }
+
+ try {
+ setIsPickerOpen(true);
+ onPickerStateChange?.(true);
+
+ const handler = createProviderHandler(
+ provider,
+ accessToken,
+ isOpen => {
+ setIsPickerOpen(isOpen);
+ onPickerStateChange?.(isOpen);
+ },
+ clientId,
+ effectiveBaseUrl
+ );
+
+ handler.openPicker((files: CloudFile[]) => {
+ // Merge new files with existing ones, avoiding duplicates
+ const existingIds = new Set(selectedFiles.map(f => f.id));
+ const newFiles = files.filter(f => !existingIds.has(f.id));
+ onFileSelected([...selectedFiles, ...newFiles]);
+ });
+ } catch (error) {
+ console.error("Error opening picker:", error);
+ setIsPickerOpen(false);
+ onPickerStateChange?.(false);
+ }
+ };
+
+ const handleRemoveFile = (fileId: string) => {
+ const updatedFiles = selectedFiles.filter(file => file.id !== fileId);
+ onFileSelected(updatedFiles);
+ };
+
+ const handleClearAll = () => {
+ onFileSelected([]);
+ };
+
+ if (isLoadingBaseUrl) {
+ return (
+
+ Loading...
+
+ );
+ }
+
+ if (
+ (provider === "onedrive" || provider === "sharepoint") &&
+ !clientId &&
+ isAuthenticated
+ ) {
+ return (
+
+ Configuration required: Client ID missing for{" "}
+ {provider === "sharepoint" ? "SharePoint" : "OneDrive"}.
+
+ );
+ }
+
+ return (
+
+ );
+};
diff --git a/frontend/src/components/google-drive-picker.tsx b/frontend/src/components/google-drive-picker.tsx
deleted file mode 100644
index c9dee19a..00000000
--- a/frontend/src/components/google-drive-picker.tsx
+++ /dev/null
@@ -1,341 +0,0 @@
-"use client"
-
-import { useState, useEffect } from "react"
-import { Button } from "@/components/ui/button"
-import { Badge } from "@/components/ui/badge"
-import { FileText, Folder, Plus, Trash2 } from "lucide-react"
-import { Card, CardContent } from "@/components/ui/card"
-
-interface GoogleDrivePickerProps {
- onFileSelected: (files: GoogleDriveFile[]) => void
- selectedFiles?: GoogleDriveFile[]
- isAuthenticated: boolean
- accessToken?: string
- onPickerStateChange?: (isOpen: boolean) => void
-}
-
-interface GoogleDriveFile {
- id: string
- name: string
- mimeType: string
- webViewLink?: string
- iconLink?: string
- size?: number
- modifiedTime?: string
- isFolder?: boolean
-}
-
-interface GoogleAPI {
- load: (api: string, options: { callback: () => void; onerror?: () => void }) => void
-}
-
-interface GooglePickerData {
- action: string
- docs: GooglePickerDocument[]
-}
-
-interface GooglePickerDocument {
- [key: string]: string
-}
-
-declare global {
- interface Window {
- gapi: GoogleAPI
- google: {
- picker: {
- api: {
- load: (callback: () => void) => void
- }
- PickerBuilder: new () => GooglePickerBuilder
- ViewId: {
- DOCS: string
- FOLDERS: string
- DOCS_IMAGES_AND_VIDEOS: string
- DOCUMENTS: string
- PRESENTATIONS: string
- SPREADSHEETS: string
- }
- Feature: {
- MULTISELECT_ENABLED: string
- NAV_HIDDEN: string
- SIMPLE_UPLOAD_ENABLED: string
- }
- Action: {
- PICKED: string
- CANCEL: string
- }
- Document: {
- ID: string
- NAME: string
- MIME_TYPE: string
- URL: string
- ICON_URL: string
- }
- }
- }
- }
-}
-
-interface GooglePickerBuilder {
- addView: (view: string) => GooglePickerBuilder
- setOAuthToken: (token: string) => GooglePickerBuilder
- setCallback: (callback: (data: GooglePickerData) => void) => GooglePickerBuilder
- enableFeature: (feature: string) => GooglePickerBuilder
- setTitle: (title: string) => GooglePickerBuilder
- build: () => GooglePicker
-}
-
-interface GooglePicker {
- setVisible: (visible: boolean) => void
-}
-
-export function GoogleDrivePicker({
- onFileSelected,
- selectedFiles = [],
- isAuthenticated,
- accessToken,
- onPickerStateChange
-}: GoogleDrivePickerProps) {
- const [isPickerLoaded, setIsPickerLoaded] = useState(false)
- const [isPickerOpen, setIsPickerOpen] = useState(false)
-
- useEffect(() => {
- const loadPickerApi = () => {
- if (typeof window !== 'undefined' && window.gapi) {
- window.gapi.load('picker', {
- callback: () => {
- setIsPickerLoaded(true)
- },
- onerror: () => {
- console.error('Failed to load Google Picker API')
- }
- })
- }
- }
-
- // Load Google API script if not already loaded
- if (typeof window !== 'undefined') {
- if (!window.gapi) {
- const script = document.createElement('script')
- script.src = 'https://apis.google.com/js/api.js'
- script.async = true
- script.defer = true
- script.onload = loadPickerApi
- script.onerror = () => {
- console.error('Failed to load Google API script')
- }
- document.head.appendChild(script)
-
- return () => {
- if (document.head.contains(script)) {
- document.head.removeChild(script)
- }
- }
- } else {
- loadPickerApi()
- }
- }
- }, [])
-
-
- const openPicker = () => {
- if (!isPickerLoaded || !accessToken || !window.google?.picker) {
- return
- }
-
- try {
- setIsPickerOpen(true)
- onPickerStateChange?.(true)
-
- // Create picker with higher z-index and focus handling
- const picker = new window.google.picker.PickerBuilder()
- .addView(window.google.picker.ViewId.DOCS)
- .addView(window.google.picker.ViewId.FOLDERS)
- .setOAuthToken(accessToken)
- .enableFeature(window.google.picker.Feature.MULTISELECT_ENABLED)
- .setTitle('Select files from Google Drive')
- .setCallback(pickerCallback)
- .build()
-
- picker.setVisible(true)
-
- // Apply z-index fix after a short delay to ensure picker is rendered
- setTimeout(() => {
- const pickerElements = document.querySelectorAll('.picker-dialog, .goog-modalpopup')
- pickerElements.forEach(el => {
- (el as HTMLElement).style.zIndex = '10000'
- })
- const bgElements = document.querySelectorAll('.picker-dialog-bg, .goog-modalpopup-bg')
- bgElements.forEach(el => {
- (el as HTMLElement).style.zIndex = '9999'
- })
- }, 100)
-
- } catch (error) {
- console.error('Error creating picker:', error)
- setIsPickerOpen(false)
- onPickerStateChange?.(false)
- }
- }
-
- const pickerCallback = async (data: GooglePickerData) => {
- if (data.action === window.google.picker.Action.PICKED) {
- const files: GoogleDriveFile[] = data.docs.map((doc: GooglePickerDocument) => ({
- id: doc[window.google.picker.Document.ID],
- name: doc[window.google.picker.Document.NAME],
- mimeType: doc[window.google.picker.Document.MIME_TYPE],
- webViewLink: doc[window.google.picker.Document.URL],
- iconLink: doc[window.google.picker.Document.ICON_URL],
- size: doc['sizeBytes'] ? parseInt(doc['sizeBytes']) : undefined,
- modifiedTime: doc['lastEditedUtc'],
- isFolder: doc[window.google.picker.Document.MIME_TYPE] === 'application/vnd.google-apps.folder'
- }))
-
- // If size is still missing, try to fetch it via Google Drive API
- if (accessToken && files.some(f => !f.size && !f.isFolder)) {
- try {
- const enrichedFiles = await Promise.all(files.map(async (file) => {
- if (!file.size && !file.isFolder) {
- try {
- const response = await fetch(`https://www.googleapis.com/drive/v3/files/${file.id}?fields=size,modifiedTime`, {
- headers: {
- 'Authorization': `Bearer ${accessToken}`
- }
- })
- if (response.ok) {
- const fileDetails = await response.json()
- return {
- ...file,
- size: fileDetails.size ? parseInt(fileDetails.size) : undefined,
- modifiedTime: fileDetails.modifiedTime || file.modifiedTime
- }
- }
- } catch (error) {
- console.warn('Failed to fetch file details:', error)
- }
- }
- return file
- }))
- onFileSelected(enrichedFiles)
- } catch (error) {
- console.warn('Failed to enrich file data:', error)
- onFileSelected(files)
- }
- } else {
- onFileSelected(files)
- }
- }
-
- setIsPickerOpen(false)
- onPickerStateChange?.(false)
- }
-
- const removeFile = (fileId: string) => {
- const updatedFiles = selectedFiles.filter(file => file.id !== fileId)
- onFileSelected(updatedFiles)
- }
-
- const getFileIcon = (mimeType: string) => {
- if (mimeType.includes('folder')) {
- return
- }
- return
- }
-
- const getMimeTypeLabel = (mimeType: string) => {
- const typeMap: { [key: string]: string } = {
- 'application/vnd.google-apps.document': 'Google Doc',
- 'application/vnd.google-apps.spreadsheet': 'Google Sheet',
- 'application/vnd.google-apps.presentation': 'Google Slides',
- 'application/vnd.google-apps.folder': 'Folder',
- 'application/pdf': 'PDF',
- 'text/plain': 'Text',
- 'application/vnd.openxmlformats-officedocument.wordprocessingml.document': 'Word Doc',
- 'application/vnd.openxmlformats-officedocument.presentationml.presentation': 'PowerPoint'
- }
-
- return typeMap[mimeType] || 'Document'
- }
-
- const formatFileSize = (bytes?: number) => {
- if (!bytes) return ''
- const sizes = ['B', 'KB', 'MB', 'GB', 'TB']
- if (bytes === 0) return '0 B'
- const i = Math.floor(Math.log(bytes) / Math.log(1024))
- return `${(bytes / Math.pow(1024, i)).toFixed(1)} ${sizes[i]}`
- }
-
- if (!isAuthenticated) {
- return (
-
- Please connect to Google Drive first to select specific files.
-
- )
- }
-
- return (
-
-
-
-
- Select files from Google Drive to ingest.
-
-
-
- {isPickerOpen ? 'Opening Picker...' : 'Add Files'}
-
-
-
-
- {selectedFiles.length > 0 && (
-
-
-
- Added files
-
-
onFileSelected([])}
- size="sm"
- variant="ghost"
- className="text-xs h-6"
- >
- Clear all
-
-
-
- {selectedFiles.map((file) => (
-
-
- {getFileIcon(file.mimeType)}
- {file.name}
-
- {getMimeTypeLabel(file.mimeType)}
-
-
-
- {formatFileSize(file.size)}
- removeFile(file.id)}
- size="sm"
- variant="ghost"
- className="h-6 w-6 p-0"
- >
-
-
-
-
- ))}
-
-
-
- )}
-
- )
-}
diff --git a/frontend/src/components/onedrive-picker.tsx b/frontend/src/components/onedrive-picker.tsx
deleted file mode 100644
index 739491c1..00000000
--- a/frontend/src/components/onedrive-picker.tsx
+++ /dev/null
@@ -1,320 +0,0 @@
-"use client"
-
-import { useState, useEffect } from "react"
-import { Button } from "@/components/ui/button"
-import { Badge } from "@/components/ui/badge"
-import { FileText, Folder, Trash2, X } from "lucide-react"
-
-interface OneDrivePickerProps {
- onFileSelected: (files: OneDriveFile[]) => void
- selectedFiles?: OneDriveFile[]
- isAuthenticated: boolean
- accessToken?: string
- connectorType?: "onedrive" | "sharepoint"
- onPickerStateChange?: (isOpen: boolean) => void
-}
-
-interface OneDriveFile {
- id: string
- name: string
- mimeType?: string
- webUrl?: string
- driveItem?: {
- file?: { mimeType: string }
- folder?: unknown
- }
-}
-
-interface GraphResponse {
- value: OneDriveFile[]
-}
-
-declare global {
- interface Window {
- mgt?: {
- Providers?: {
- globalProvider?: unknown
- }
- }
- }
-}
-
-export function OneDrivePicker({
- onFileSelected,
- selectedFiles = [],
- isAuthenticated,
- accessToken,
- connectorType = "onedrive",
- onPickerStateChange
-}: OneDrivePickerProps) {
- const [isLoading, setIsLoading] = useState(false)
- const [files, setFiles] = useState([])
- const [isPickerOpen, setIsPickerOpen] = useState(false)
- const [currentPath, setCurrentPath] = useState(
- connectorType === "sharepoint" ? 'sites?search=' : 'me/drive/root/children'
- )
- const [breadcrumbs, setBreadcrumbs] = useState<{id: string, name: string}[]>([
- {id: 'root', name: connectorType === "sharepoint" ? 'SharePoint' : 'OneDrive'}
- ])
-
- useEffect(() => {
- const loadMGT = async () => {
- if (typeof window !== 'undefined' && !window.mgt) {
- try {
- await import('@microsoft/mgt-components')
- await import('@microsoft/mgt-msal2-provider')
-
- // For simplicity, we'll use direct Graph API calls instead of MGT components
- // MGT provider initialization would go here if needed
- } catch {
- console.warn('MGT not available, falling back to direct API calls')
- }
- }
- }
-
- loadMGT()
- }, [accessToken])
-
-
- const fetchFiles = async (path: string = currentPath) => {
- if (!accessToken) return
-
- setIsLoading(true)
- try {
- const response = await fetch(`https://graph.microsoft.com/v1.0/${path}`, {
- headers: {
- 'Authorization': `Bearer ${accessToken}`,
- 'Content-Type': 'application/json'
- }
- })
-
- if (response.ok) {
- const data: GraphResponse = await response.json()
- setFiles(data.value || [])
- } else {
- console.error('Failed to fetch OneDrive files:', response.statusText)
- }
- } catch (error) {
- console.error('Error fetching OneDrive files:', error)
- } finally {
- setIsLoading(false)
- }
- }
-
- const openPicker = () => {
- if (!accessToken) return
-
- setIsPickerOpen(true)
- onPickerStateChange?.(true)
- fetchFiles()
- }
-
- const closePicker = () => {
- setIsPickerOpen(false)
- onPickerStateChange?.(false)
- setFiles([])
- setCurrentPath(
- connectorType === "sharepoint" ? 'sites?search=' : 'me/drive/root/children'
- )
- setBreadcrumbs([
- {id: 'root', name: connectorType === "sharepoint" ? 'SharePoint' : 'OneDrive'}
- ])
- }
-
- const handleFileClick = (file: OneDriveFile) => {
- if (file.driveItem?.folder) {
- // Navigate to folder
- const newPath = `me/drive/items/${file.id}/children`
- setCurrentPath(newPath)
- setBreadcrumbs([...breadcrumbs, {id: file.id, name: file.name}])
- fetchFiles(newPath)
- } else {
- // Select file
- const isAlreadySelected = selectedFiles.some(f => f.id === file.id)
- if (!isAlreadySelected) {
- onFileSelected([...selectedFiles, file])
- }
- }
- }
-
- const navigateToBreadcrumb = (index: number) => {
- if (index === 0) {
- setCurrentPath('me/drive/root/children')
- setBreadcrumbs([{id: 'root', name: 'OneDrive'}])
- fetchFiles('me/drive/root/children')
- } else {
- const targetCrumb = breadcrumbs[index]
- const newPath = `me/drive/items/${targetCrumb.id}/children`
- setCurrentPath(newPath)
- setBreadcrumbs(breadcrumbs.slice(0, index + 1))
- fetchFiles(newPath)
- }
- }
-
- const removeFile = (fileId: string) => {
- const updatedFiles = selectedFiles.filter(file => file.id !== fileId)
- onFileSelected(updatedFiles)
- }
-
- const getFileIcon = (file: OneDriveFile) => {
- if (file.driveItem?.folder) {
- return
- }
- return
- }
-
- const getMimeTypeLabel = (file: OneDriveFile) => {
- const mimeType = file.driveItem?.file?.mimeType || file.mimeType || ''
- const typeMap: { [key: string]: string } = {
- 'application/vnd.openxmlformats-officedocument.wordprocessingml.document': 'Word Doc',
- 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': 'Excel',
- 'application/vnd.openxmlformats-officedocument.presentationml.presentation': 'PowerPoint',
- 'application/pdf': 'PDF',
- 'text/plain': 'Text',
- 'image/jpeg': 'Image',
- 'image/png': 'Image',
- }
-
- if (file.driveItem?.folder) return 'Folder'
- return typeMap[mimeType] || 'Document'
- }
-
- const serviceName = connectorType === "sharepoint" ? "SharePoint" : "OneDrive"
-
- if (!isAuthenticated) {
- return (
-
- Please connect to {serviceName} first to select specific files.
-
- )
- }
-
- return (
-
-
-
-
{serviceName} File Selection
-
- Choose specific files to sync instead of syncing everything
-
-
-
- {!accessToken ? "No Access Token" : "Select Files"}
-
-
-
- {/* Status message when access token is missing */}
- {isAuthenticated && !accessToken && (
-
-
Access token unavailable
-
The file picker requires an access token. Try disconnecting and reconnecting your {serviceName} account.
-
- )}
-
- {/* File Picker Modal */}
- {isPickerOpen && (
-
-
-
-
Select Files from {serviceName}
-
-
-
-
-
- {/* Breadcrumbs */}
-
- {breadcrumbs.map((crumb, index) => (
-
- {index > 0 && / }
- navigateToBreadcrumb(index)}
- className="text-blue-600 hover:underline"
- >
- {crumb.name}
-
-
- ))}
-
-
- {/* File List */}
-
- {isLoading ? (
-
Loading...
- ) : files.length === 0 ? (
-
No files found
- ) : (
-
- {files.map((file) => (
-
handleFileClick(file)}
- >
-
- {getFileIcon(file)}
- {file.name}
-
- {getMimeTypeLabel(file)}
-
-
- {selectedFiles.some(f => f.id === file.id) && (
-
Selected
- )}
-
- ))}
-
- )}
-
-
-
- )}
-
- {selectedFiles.length > 0 && (
-
-
- Selected files ({selectedFiles.length}):
-
-
- {selectedFiles.map((file) => (
-
-
- {getFileIcon(file)}
- {file.name}
-
- {getMimeTypeLabel(file)}
-
-
-
removeFile(file.id)}
- size="sm"
- variant="ghost"
- className="h-6 w-6 p-0"
- >
-
-
-
- ))}
-
-
onFileSelected([])}
- size="sm"
- variant="ghost"
- className="text-xs h-6"
- >
- Clear all
-
-
- )}
-
- )
-}
\ No newline at end of file
diff --git a/src/api/connectors.py b/src/api/connectors.py
index 4c9eab49..5fbcec86 100644
--- a/src/api/connectors.py
+++ b/src/api/connectors.py
@@ -127,6 +127,23 @@ async def connector_status(request: Request, connector_service, session_manager)
user_id=user.user_id, connector_type=connector_type
)
+ # Get the connector for each connection
+ connection_client_ids = {}
+ for connection in connections:
+ try:
+ connector = await connector_service._get_connector(connection.connection_id)
+ if connector is not None:
+ connection_client_ids[connection.connection_id] = connector.get_client_id()
+ else:
+ connection_client_ids[connection.connection_id] = None
+ except Exception as e:
+ logger.warning(
+ "Could not get connector for connection",
+ connection_id=connection.connection_id,
+ error=str(e),
+ )
+ connection.connector = None
+
# Check if there are any active connections
active_connections = [conn for conn in connections if conn.is_active]
has_authenticated_connection = len(active_connections) > 0
@@ -140,6 +157,7 @@ async def connector_status(request: Request, connector_service, session_manager)
{
"connection_id": conn.connection_id,
"name": conn.name,
+ "client_id": connection_client_ids.get(conn.connection_id),
"is_active": conn.is_active,
"created_at": conn.created_at.isoformat(),
"last_sync": conn.last_sync.isoformat() if conn.last_sync else None,
@@ -323,8 +341,8 @@ async def connector_webhook(request: Request, connector_service, session_manager
)
async def connector_token(request: Request, connector_service, session_manager):
- """Get access token for connector API calls (e.g., Google Picker)"""
- connector_type = request.path_params.get("connector_type")
+ """Get access token for connector API calls (e.g., Pickers)."""
+ url_connector_type = request.path_params.get("connector_type")
connection_id = request.query_params.get("connection_id")
if not connection_id:
@@ -333,37 +351,81 @@ async def connector_token(request: Request, connector_service, session_manager):
user = request.state.user
try:
- # Get the connection and verify it belongs to the user
+ # 1) Load the connection and verify ownership
connection = await connector_service.connection_manager.get_connection(connection_id)
if not connection or connection.user_id != user.user_id:
return JSONResponse({"error": "Connection not found"}, status_code=404)
- # Get the connector instance
+ # 2) Get the ACTUAL connector instance/type for this connection_id
connector = await connector_service._get_connector(connection_id)
if not connector:
- return JSONResponse({"error": f"Connector not available - authentication may have failed for {connector_type}"}, status_code=404)
+ return JSONResponse(
+ {"error": f"Connector not available - authentication may have failed for {url_connector_type}"},
+ status_code=404,
+ )
- # For Google Drive, get the access token
- if connector_type == "google_drive" and hasattr(connector, 'oauth'):
+ real_type = getattr(connector, "type", None) or getattr(connection, "connector_type", None)
+ if real_type is None:
+ return JSONResponse({"error": "Unable to determine connector type"}, status_code=500)
+
+ # Optional: warn if URL path type disagrees with real type
+ if url_connector_type and url_connector_type != real_type:
+ # You can downgrade this to debug if you expect cross-routing.
+ return JSONResponse(
+ {
+ "error": "Connector type mismatch",
+ "detail": {
+ "requested_type": url_connector_type,
+ "actual_type": real_type,
+ "hint": "Call the token endpoint using the correct connector_type for this connection_id.",
+ },
+ },
+ status_code=400,
+ )
+
+ # 3) Branch by the actual connector type
+ # GOOGLE DRIVE (google-auth)
+ if real_type == "google_drive" and hasattr(connector, "oauth"):
await connector.oauth.load_credentials()
if connector.oauth.creds and connector.oauth.creds.valid:
- return JSONResponse({
- "access_token": connector.oauth.creds.token,
- "expires_in": (connector.oauth.creds.expiry.timestamp() -
- __import__('time').time()) if connector.oauth.creds.expiry else None
- })
- else:
- return JSONResponse({"error": "Invalid or expired credentials"}, status_code=401)
-
- # For OneDrive and SharePoint, get the access token
- elif connector_type in ["onedrive", "sharepoint"] and hasattr(connector, 'oauth'):
+ expires_in = None
+ try:
+ if connector.oauth.creds.expiry:
+ import time
+ expires_in = max(0, int(connector.oauth.creds.expiry.timestamp() - time.time()))
+ except Exception:
+ expires_in = None
+
+ return JSONResponse(
+ {
+ "access_token": connector.oauth.creds.token,
+ "expires_in": expires_in,
+ }
+ )
+ return JSONResponse({"error": "Invalid or expired credentials"}, status_code=401)
+
+ # ONEDRIVE / SHAREPOINT (MSAL or custom)
+ if real_type in ("onedrive", "sharepoint") and hasattr(connector, "oauth"):
+ # Ensure cache/credentials are loaded before trying to use them
try:
+ # Prefer a dedicated is_authenticated() that loads cache internally
+ if hasattr(connector.oauth, "is_authenticated"):
+ ok = await connector.oauth.is_authenticated()
+ else:
+ # Fallback: try to load credentials explicitly if available
+ ok = True
+ if hasattr(connector.oauth, "load_credentials"):
+ ok = await connector.oauth.load_credentials()
+
+ if not ok:
+ return JSONResponse({"error": "Not authenticated"}, status_code=401)
+
+ # Now safe to fetch access token
access_token = connector.oauth.get_access_token()
- return JSONResponse({
- "access_token": access_token,
- "expires_in": None # MSAL handles token expiry internally
- })
+ # MSAL result has expiry, but we’re returning a raw token; keep expires_in None for simplicity
+ return JSONResponse({"access_token": access_token, "expires_in": None})
except ValueError as e:
+ # Typical when acquire_token_silent fails (e.g., needs re-auth)
return JSONResponse({"error": f"Failed to get access token: {str(e)}"}, status_code=401)
except Exception as e:
return JSONResponse({"error": f"Authentication error: {str(e)}"}, status_code=500)
@@ -371,7 +433,5 @@ async def connector_token(request: Request, connector_service, session_manager):
return JSONResponse({"error": "Token not available for this connector type"}, status_code=400)
except Exception as e:
- logger.error("Error getting connector token", error=str(e))
+ logger.error("Error getting connector token", exc_info=True)
return JSONResponse({"error": str(e)}, status_code=500)
-
-
diff --git a/src/api/settings.py b/src/api/settings.py
index 3e242c4b..c2c7cbd0 100644
--- a/src/api/settings.py
+++ b/src/api/settings.py
@@ -556,6 +556,19 @@ async def onboarding(request, flows_service):
)
# Continue even if setting global variables fails
+ # Initialize the OpenSearch index now that we have the embedding model configured
+ try:
+ # Import here to avoid circular imports
+ from main import init_index
+
+ logger.info("Initializing OpenSearch index after onboarding configuration")
+ await init_index()
+ logger.info("OpenSearch index initialization completed successfully")
+ except Exception as e:
+ logger.error("Failed to initialize OpenSearch index after onboarding", error=str(e))
+ # Don't fail the entire onboarding process if index creation fails
+ # The application can still work, but document operations may fail
+
# Handle sample data ingestion if requested
if should_ingest_sample_data:
try:
diff --git a/src/config/config_manager.py b/src/config/config_manager.py
index 055d48a7..0b814470 100644
--- a/src/config/config_manager.py
+++ b/src/config/config_manager.py
@@ -16,6 +16,8 @@ class ProviderConfig:
model_provider: str = "openai" # openai, anthropic, etc.
api_key: str = ""
+ endpoint: str = "" # For providers like Watson/IBM that need custom endpoints
+ project_id: str = "" # For providers like Watson/IBM that need project IDs
@dataclass
@@ -129,6 +131,10 @@ class ConfigManager:
config_data["provider"]["model_provider"] = os.getenv("MODEL_PROVIDER")
if os.getenv("PROVIDER_API_KEY"):
config_data["provider"]["api_key"] = os.getenv("PROVIDER_API_KEY")
+ if os.getenv("PROVIDER_ENDPOINT"):
+ config_data["provider"]["endpoint"] = os.getenv("PROVIDER_ENDPOINT")
+ if os.getenv("PROVIDER_PROJECT_ID"):
+ config_data["provider"]["project_id"] = os.getenv("PROVIDER_PROJECT_ID")
# Backward compatibility for OpenAI
if os.getenv("OPENAI_API_KEY"):
config_data["provider"]["api_key"] = os.getenv("OPENAI_API_KEY")
diff --git a/src/config/settings.py b/src/config/settings.py
index 5f9b189d..3bf1e6cf 100644
--- a/src/config/settings.py
+++ b/src/config/settings.py
@@ -78,6 +78,31 @@ INDEX_NAME = "documents"
VECTOR_DIM = 1536
EMBED_MODEL = "text-embedding-3-small"
+OPENAI_EMBEDDING_DIMENSIONS = {
+ "text-embedding-3-small": 1536,
+ "text-embedding-3-large": 3072,
+ "text-embedding-ada-002": 1536,
+ }
+
+OLLAMA_EMBEDDING_DIMENSIONS = {
+ "nomic-embed-text": 768,
+ "all-minilm": 384,
+ "mxbai-embed-large": 1024,
+}
+
+WATSONX_EMBEDDING_DIMENSIONS = {
+# IBM Models
+"ibm/granite-embedding-107m-multilingual": 384,
+"ibm/granite-embedding-278m-multilingual": 1024,
+"ibm/slate-125m-english-rtrvr": 768,
+"ibm/slate-125m-english-rtrvr-v2": 768,
+"ibm/slate-30m-english-rtrvr": 384,
+"ibm/slate-30m-english-rtrvr-v2": 384,
+# Third Party Models
+"intfloat/multilingual-e5-large": 1024,
+"sentence-transformers/all-minilm-l6-v2": 384,
+}
+
INDEX_BODY = {
"settings": {
"index": {"knn": True},
diff --git a/src/connectors/connection_manager.py b/src/connectors/connection_manager.py
index 2e70ee1f..07ebd5ee 100644
--- a/src/connectors/connection_manager.py
+++ b/src/connectors/connection_manager.py
@@ -294,32 +294,39 @@ class ConnectionManager:
async def get_connector(self, connection_id: str) -> Optional[BaseConnector]:
"""Get an active connector instance"""
+ logger.debug(f"Getting connector for connection_id: {connection_id}")
+
# Return cached connector if available
if connection_id in self.active_connectors:
connector = self.active_connectors[connection_id]
if connector.is_authenticated:
+ logger.debug(f"Returning cached authenticated connector for {connection_id}")
return connector
else:
# Remove unauthenticated connector from cache
+ logger.debug(f"Removing unauthenticated connector from cache for {connection_id}")
del self.active_connectors[connection_id]
# Try to create and authenticate connector
connection_config = self.connections.get(connection_id)
if not connection_config or not connection_config.is_active:
+ logger.debug(f"No active connection config found for {connection_id}")
return None
+ logger.debug(f"Creating connector for {connection_config.connector_type}")
connector = self._create_connector(connection_config)
- if await connector.authenticate():
+
+ logger.debug(f"Attempting authentication for {connection_id}")
+ auth_result = await connector.authenticate()
+ logger.debug(f"Authentication result for {connection_id}: {auth_result}")
+
+ if auth_result:
self.active_connectors[connection_id] = connector
-
- # Setup webhook subscription if not already set up
- await self._setup_webhook_if_needed(
- connection_id, connection_config, connector
- )
-
+ # ... rest of the method
return connector
-
- return None
+ else:
+ logger.warning(f"Authentication failed for {connection_id}")
+ return None
def get_available_connector_types(self) -> Dict[str, Dict[str, Any]]:
"""Get available connector types with their metadata"""
@@ -363,20 +370,23 @@ class ConnectionManager:
def _create_connector(self, config: ConnectionConfig) -> BaseConnector:
"""Factory method to create connector instances"""
- if config.connector_type == "google_drive":
- return GoogleDriveConnector(config.config)
- elif config.connector_type == "sharepoint":
- return SharePointConnector(config.config)
- elif config.connector_type == "onedrive":
- return OneDriveConnector(config.config)
- elif config.connector_type == "box":
- # Future: BoxConnector(config.config)
- raise NotImplementedError("Box connector not implemented yet")
- elif config.connector_type == "dropbox":
- # Future: DropboxConnector(config.config)
- raise NotImplementedError("Dropbox connector not implemented yet")
- else:
- raise ValueError(f"Unknown connector type: {config.connector_type}")
+ try:
+ if config.connector_type == "google_drive":
+ return GoogleDriveConnector(config.config)
+ elif config.connector_type == "sharepoint":
+ return SharePointConnector(config.config)
+ elif config.connector_type == "onedrive":
+ return OneDriveConnector(config.config)
+ elif config.connector_type == "box":
+ raise NotImplementedError("Box connector not implemented yet")
+ elif config.connector_type == "dropbox":
+ raise NotImplementedError("Dropbox connector not implemented yet")
+ else:
+ raise ValueError(f"Unknown connector type: {config.connector_type}")
+ except Exception as e:
+ logger.error(f"Failed to create {config.connector_type} connector: {e}")
+ # Re-raise the exception so caller can handle appropriately
+ raise
async def update_last_sync(self, connection_id: str):
"""Update the last sync timestamp for a connection"""
diff --git a/src/connectors/google_drive/connector.py b/src/connectors/google_drive/connector.py
index 0aa4234a..37ebdc8a 100644
--- a/src/connectors/google_drive/connector.py
+++ b/src/connectors/google_drive/connector.py
@@ -477,7 +477,7 @@ class GoogleDriveConnector(BaseConnector):
"next_page_token": None, # no more pages
}
except Exception as e:
- # Optionally log error with your base class logger
+ # Log the error
try:
logger.error(f"GoogleDriveConnector.list_files failed: {e}")
except Exception:
@@ -495,7 +495,6 @@ class GoogleDriveConnector(BaseConnector):
try:
blob = self._download_file_bytes(meta)
except Exception as e:
- # Use your base class logger if available
try:
logger.error(f"Download failed for {file_id}: {e}")
except Exception:
@@ -562,7 +561,6 @@ class GoogleDriveConnector(BaseConnector):
if not self.cfg.changes_page_token:
self.cfg.changes_page_token = self.get_start_page_token()
except Exception as e:
- # Optional: use your base logger
try:
logger.error(f"Failed to get start page token: {e}")
except Exception:
@@ -593,7 +591,6 @@ class GoogleDriveConnector(BaseConnector):
expiration = result.get("expiration")
# Persist in-memory so cleanup can stop this channel later.
- # If your project has a persistence layer, save these values there.
self._active_channel = {
"channel_id": channel_id,
"resource_id": resource_id,
@@ -803,7 +800,7 @@ class GoogleDriveConnector(BaseConnector):
"""
Perform a one-shot sync of the currently selected scope and emit documents.
- Emits ConnectorDocument instances (adapt to your BaseConnector ingestion).
+ Emits ConnectorDocument instances
"""
items = self._iter_selected_items()
for meta in items:
diff --git a/src/connectors/onedrive/connector.py b/src/connectors/onedrive/connector.py
index 8b800b3d..3ef4bdaf 100644
--- a/src/connectors/onedrive/connector.py
+++ b/src/connectors/onedrive/connector.py
@@ -1,223 +1,494 @@
+import logging
+from pathlib import Path
+from typing import List, Dict, Any, Optional
+from datetime import datetime
import httpx
-import uuid
-from datetime import datetime, timedelta
-from typing import Dict, List, Any, Optional
from ..base import BaseConnector, ConnectorDocument, DocumentACL
from .oauth import OneDriveOAuth
+logger = logging.getLogger(__name__)
+
class OneDriveConnector(BaseConnector):
- """OneDrive connector using Microsoft Graph API"""
+ """OneDrive connector using MSAL-based OAuth for authentication."""
+ # Required BaseConnector class attributes
CLIENT_ID_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_ID"
CLIENT_SECRET_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET"
# Connector metadata
CONNECTOR_NAME = "OneDrive"
- CONNECTOR_DESCRIPTION = "Connect your personal OneDrive to sync documents"
+ CONNECTOR_DESCRIPTION = "Connect to OneDrive (personal) to sync documents and files"
CONNECTOR_ICON = "onedrive"
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
- self.oauth = OneDriveOAuth(
- client_id=self.get_client_id(),
- client_secret=self.get_client_secret(),
- token_file=config.get("token_file", "onedrive_token.json"),
- )
- self.subscription_id = config.get("subscription_id") or config.get(
- "webhook_channel_id"
- )
- self.base_url = "https://graph.microsoft.com/v1.0"
- async def authenticate(self) -> bool:
- if await self.oauth.is_authenticated():
- self._authenticated = True
- return True
- return False
+ logger.debug(f"OneDrive connector __init__ called with config type: {type(config)}")
+ logger.debug(f"OneDrive connector __init__ config value: {config}")
- async def setup_subscription(self) -> str:
- if not self._authenticated:
- raise ValueError("Not authenticated")
+ if config is None:
+ logger.debug("Config was None, using empty dict")
+ config = {}
- webhook_url = self.config.get("webhook_url")
- if not webhook_url:
- raise ValueError("webhook_url required in config for subscriptions")
+ try:
+ logger.debug("Calling super().__init__")
+ super().__init__(config)
+ logger.debug("super().__init__ completed successfully")
+ except Exception as e:
+ logger.error(f"super().__init__ failed: {e}")
+ raise
- expiration = (datetime.utcnow() + timedelta(days=2)).isoformat() + "Z"
- body = {
- "changeType": "created,updated,deleted",
- "notificationUrl": webhook_url,
- "resource": "/me/drive/root",
- "expirationDateTime": expiration,
- "clientState": str(uuid.uuid4()),
+ # Initialize with defaults that allow the connector to be listed
+ self.client_id = None
+ self.client_secret = None
+ self.redirect_uri = config.get("redirect_uri", "http://localhost")
+
+ # Try to get credentials, but don't fail if they're missing
+ try:
+ self.client_id = self.get_client_id()
+ logger.debug(f"Got client_id: {self.client_id is not None}")
+ except Exception as e:
+ logger.debug(f"Failed to get client_id: {e}")
+
+ try:
+ self.client_secret = self.get_client_secret()
+ logger.debug(f"Got client_secret: {self.client_secret is not None}")
+ except Exception as e:
+ logger.debug(f"Failed to get client_secret: {e}")
+
+ # Token file setup
+ project_root = Path(__file__).resolve().parent.parent.parent.parent
+ token_file = config.get("token_file") or str(project_root / "onedrive_token.json")
+ Path(token_file).parent.mkdir(parents=True, exist_ok=True)
+
+ # Only initialize OAuth if we have credentials
+ if self.client_id and self.client_secret:
+ connection_id = config.get("connection_id", "default")
+
+ # Use token_file from config if provided, otherwise generate one
+ if config.get("token_file"):
+ oauth_token_file = config["token_file"]
+ else:
+ # Use a per-connection cache file to avoid collisions with other connectors
+ oauth_token_file = f"onedrive_token_{connection_id}.json"
+
+ # MSA & org both work via /common for OneDrive personal testing
+ authority = "https://login.microsoftonline.com/common"
+
+ self.oauth = OneDriveOAuth(
+ client_id=self.client_id,
+ client_secret=self.client_secret,
+ token_file=oauth_token_file,
+ authority=authority,
+ allow_json_refresh=True, # allows one-time migration from legacy JSON if present
+ )
+ else:
+ self.oauth = None
+
+ # Track subscription ID for webhooks (note: change notifications might not be available for personal accounts)
+ self._subscription_id: Optional[str] = None
+
+ # Graph API defaults
+ self._graph_api_version = "v1.0"
+ self._default_params = {
+ "$select": "id,name,size,lastModifiedDateTime,createdDateTime,webUrl,file,folder,@microsoft.graph.downloadUrl"
}
- token = self.oauth.get_access_token()
- async with httpx.AsyncClient() as client:
- resp = await client.post(
- f"{self.base_url}/subscriptions",
- json=body,
- headers={"Authorization": f"Bearer {token}"},
- )
- resp.raise_for_status()
- data = resp.json()
+ @property
+ def _graph_base_url(self) -> str:
+ """Base URL for Microsoft Graph API calls."""
+ return f"https://graph.microsoft.com/{self._graph_api_version}"
- self.subscription_id = data["id"]
- return self.subscription_id
+ def emit(self, doc: ConnectorDocument) -> None:
+ """Emit a ConnectorDocument instance."""
+ logger.debug(f"Emitting OneDrive document: {doc.id} ({doc.filename})")
+
+ async def authenticate(self) -> bool:
+ """Test authentication - BaseConnector interface."""
+ logger.debug(f"OneDrive authenticate() called, oauth is None: {self.oauth is None}")
+ try:
+ if not self.oauth:
+ logger.debug("OneDrive authentication failed: OAuth not initialized")
+ self._authenticated = False
+ return False
+
+ logger.debug("Loading OneDrive credentials...")
+ load_result = await self.oauth.load_credentials()
+ logger.debug(f"Load credentials result: {load_result}")
+
+ logger.debug("Checking OneDrive authentication status...")
+ authenticated = await self.oauth.is_authenticated()
+ logger.debug(f"OneDrive is_authenticated result: {authenticated}")
+
+ self._authenticated = authenticated
+ return authenticated
+ except Exception as e:
+ logger.error(f"OneDrive authentication failed: {e}")
+ import traceback
+ traceback.print_exc()
+ self._authenticated = False
+ return False
+
+ def get_auth_url(self) -> str:
+ """Get OAuth authorization URL."""
+ if not self.oauth:
+ raise RuntimeError("OneDrive OAuth not initialized - missing credentials")
+ return self.oauth.create_authorization_url(self.redirect_uri)
+
+ async def handle_oauth_callback(self, auth_code: str) -> Dict[str, Any]:
+ """Handle OAuth callback."""
+ if not self.oauth:
+ raise RuntimeError("OneDrive OAuth not initialized - missing credentials")
+ try:
+ success = await self.oauth.handle_authorization_callback(auth_code, self.redirect_uri)
+ if success:
+ self._authenticated = True
+ return {"status": "success"}
+ else:
+ raise ValueError("OAuth callback failed")
+ except Exception as e:
+ logger.error(f"OAuth callback failed: {e}")
+ raise
+
+ def sync_once(self) -> None:
+ """
+ Perform a one-shot sync of OneDrive files and emit documents.
+ """
+ import asyncio
+
+ async def _async_sync():
+ try:
+ file_list = await self.list_files(max_files=1000)
+ files = file_list.get("files", [])
+ for file_info in files:
+ try:
+ file_id = file_info.get("id")
+ if not file_id:
+ continue
+ doc = await self.get_file_content(file_id)
+ self.emit(doc)
+ except Exception as e:
+ logger.error(f"Failed to sync OneDrive file {file_info.get('name', 'unknown')}: {e}")
+ continue
+ except Exception as e:
+ logger.error(f"OneDrive sync_once failed: {e}")
+ raise
+
+ if hasattr(asyncio, 'run'):
+ asyncio.run(_async_sync())
+ else:
+ loop = asyncio.get_event_loop()
+ loop.run_until_complete(_async_sync())
+
+ async def setup_subscription(self) -> str:
+ """
+ Set up real-time subscription for file changes.
+ NOTE: Change notifications may not be available for personal OneDrive accounts.
+ """
+ webhook_url = self.config.get('webhook_url')
+ if not webhook_url:
+ logger.warning("No webhook URL configured, skipping OneDrive subscription setup")
+ return "no-webhook-configured"
+
+ try:
+ if not await self.authenticate():
+ raise RuntimeError("OneDrive authentication failed during subscription setup")
+
+ token = self.oauth.get_access_token()
+
+ # For OneDrive personal we target the user's drive
+ resource = "/me/drive/root"
+
+ subscription_data = {
+ "changeType": "created,updated,deleted",
+ "notificationUrl": f"{webhook_url}/webhook/onedrive",
+ "resource": resource,
+ "expirationDateTime": self._get_subscription_expiry(),
+ "clientState": "onedrive_personal",
+ }
+
+ headers = {
+ "Authorization": f"Bearer {token}",
+ "Content-Type": "application/json",
+ }
+
+ url = f"{self._graph_base_url}/subscriptions"
+
+ async with httpx.AsyncClient() as client:
+ response = await client.post(url, json=subscription_data, headers=headers, timeout=30)
+ response.raise_for_status()
+
+ result = response.json()
+ subscription_id = result.get("id")
+
+ if subscription_id:
+ self._subscription_id = subscription_id
+ logger.info(f"OneDrive subscription created: {subscription_id}")
+ return subscription_id
+ else:
+ raise ValueError("No subscription ID returned from Microsoft Graph")
+
+ except Exception as e:
+ logger.error(f"Failed to setup OneDrive subscription: {e}")
+ raise
+
+ def _get_subscription_expiry(self) -> str:
+ """Get subscription expiry time (Graph caps duration; often <= 3 days)."""
+ from datetime import datetime, timedelta
+ expiry = datetime.utcnow() + timedelta(days=3)
+ return expiry.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
async def list_files(
- self, page_token: Optional[str] = None, limit: int = 100
+ self,
+ page_token: Optional[str] = None,
+ max_files: Optional[int] = None,
+ **kwargs
) -> Dict[str, Any]:
- if not self._authenticated:
- raise ValueError("Not authenticated")
+ """List files from OneDrive using Microsoft Graph."""
+ try:
+ if not await self.authenticate():
+ raise RuntimeError("OneDrive authentication failed during file listing")
- params = {"$top": str(limit)}
- if page_token:
- params["$skiptoken"] = page_token
+ files: List[Dict[str, Any]] = []
+ max_files_value = max_files if max_files is not None else 100
- token = self.oauth.get_access_token()
- async with httpx.AsyncClient() as client:
- resp = await client.get(
- f"{self.base_url}/me/drive/root/children",
- params=params,
- headers={"Authorization": f"Bearer {token}"},
- )
- resp.raise_for_status()
- data = resp.json()
+ base_url = f"{self._graph_base_url}/me/drive/root/children"
- files = []
- for item in data.get("value", []):
- if item.get("file"):
- files.append(
- {
- "id": item["id"],
- "name": item["name"],
- "mimeType": item.get("file", {}).get(
- "mimeType", "application/octet-stream"
- ),
- "webViewLink": item.get("webUrl"),
- "createdTime": item.get("createdDateTime"),
- "modifiedTime": item.get("lastModifiedDateTime"),
- }
- )
+ params = dict(self._default_params)
+ params["$top"] = str(max_files_value)
- next_token = None
- next_link = data.get("@odata.nextLink")
- if next_link:
- from urllib.parse import urlparse, parse_qs
+ if page_token:
+ params["$skiptoken"] = page_token
- parsed = urlparse(next_link)
- next_token = parse_qs(parsed.query).get("$skiptoken", [None])[0]
+ response = await self._make_graph_request(base_url, params=params)
+ data = response.json()
- return {"files": files, "nextPageToken": next_token}
+ items = data.get("value", [])
+ for item in items:
+ if item.get("file"): # include files only
+ files.append({
+ "id": item.get("id", ""),
+ "name": item.get("name", ""),
+ "path": f"/drive/items/{item.get('id')}",
+ "size": int(item.get("size", 0)),
+ "modified": item.get("lastModifiedDateTime"),
+ "created": item.get("createdDateTime"),
+ "mime_type": item.get("file", {}).get("mimeType", self._get_mime_type(item.get("name", ""))),
+ "url": item.get("webUrl", ""),
+ "download_url": item.get("@microsoft.graph.downloadUrl"),
+ })
+
+ # Next page
+ next_page_token = None
+ next_link = data.get("@odata.nextLink")
+ if next_link:
+ from urllib.parse import urlparse, parse_qs
+ parsed = urlparse(next_link)
+ query_params = parse_qs(parsed.query)
+ if "$skiptoken" in query_params:
+ next_page_token = query_params["$skiptoken"][0]
+
+ return {"files": files, "next_page_token": next_page_token}
+
+ except Exception as e:
+ logger.error(f"Failed to list OneDrive files: {e}")
+ return {"files": [], "next_page_token": None}
async def get_file_content(self, file_id: str) -> ConnectorDocument:
- if not self._authenticated:
- raise ValueError("Not authenticated")
+ """Get file content and metadata."""
+ try:
+ if not await self.authenticate():
+ raise RuntimeError("OneDrive authentication failed during file content retrieval")
+ file_metadata = await self._get_file_metadata_by_id(file_id)
+ if not file_metadata:
+ raise ValueError(f"File not found: {file_id}")
+
+ download_url = file_metadata.get("download_url")
+ if download_url:
+ content = await self._download_file_from_url(download_url)
+ else:
+ content = await self._download_file_content(file_id)
+
+ acl = DocumentACL(
+ owner="",
+ user_permissions={},
+ group_permissions={},
+ )
+
+ modified_time = self._parse_graph_date(file_metadata.get("modified"))
+ created_time = self._parse_graph_date(file_metadata.get("created"))
+
+ return ConnectorDocument(
+ id=file_id,
+ filename=file_metadata.get("name", ""),
+ mimetype=file_metadata.get("mime_type", "application/octet-stream"),
+ content=content,
+ source_url=file_metadata.get("url", ""),
+ acl=acl,
+ modified_time=modified_time,
+ created_time=created_time,
+ metadata={
+ "onedrive_path": file_metadata.get("path", ""),
+ "size": file_metadata.get("size", 0),
+ },
+ )
+
+ except Exception as e:
+ logger.error(f"Failed to get OneDrive file content {file_id}: {e}")
+ raise
+
+ async def _get_file_metadata_by_id(self, file_id: str) -> Optional[Dict[str, Any]]:
+ """Get file metadata by ID using Graph API."""
+ try:
+ url = f"{self._graph_base_url}/me/drive/items/{file_id}"
+ params = dict(self._default_params)
+
+ response = await self._make_graph_request(url, params=params)
+ item = response.json()
+
+ if item.get("file"):
+ return {
+ "id": file_id,
+ "name": item.get("name", ""),
+ "path": f"/drive/items/{file_id}",
+ "size": int(item.get("size", 0)),
+ "modified": item.get("lastModifiedDateTime"),
+ "created": item.get("createdDateTime"),
+ "mime_type": item.get("file", {}).get("mimeType", self._get_mime_type(item.get("name", ""))),
+ "url": item.get("webUrl", ""),
+ "download_url": item.get("@microsoft.graph.downloadUrl"),
+ }
+
+ return None
+
+ except Exception as e:
+ logger.error(f"Failed to get file metadata for {file_id}: {e}")
+ return None
+
+ async def _download_file_content(self, file_id: str) -> bytes:
+ """Download file content by file ID using Graph API."""
+ try:
+ url = f"{self._graph_base_url}/me/drive/items/{file_id}/content"
+ token = self.oauth.get_access_token()
+ headers = {"Authorization": f"Bearer {token}"}
+
+ async with httpx.AsyncClient() as client:
+ response = await client.get(url, headers=headers, timeout=60)
+ response.raise_for_status()
+ return response.content
+
+ except Exception as e:
+ logger.error(f"Failed to download file content for {file_id}: {e}")
+ raise
+
+ async def _download_file_from_url(self, download_url: str) -> bytes:
+ """Download file content from direct download URL."""
+ try:
+ async with httpx.AsyncClient() as client:
+ response = await client.get(download_url, timeout=60)
+ response.raise_for_status()
+ return response.content
+ except Exception as e:
+ logger.error(f"Failed to download from URL {download_url}: {e}")
+ raise
+
+ def _parse_graph_date(self, date_str: Optional[str]) -> datetime:
+ """Parse Microsoft Graph date string to datetime."""
+ if not date_str:
+ return datetime.now()
+ try:
+ if date_str.endswith('Z'):
+ return datetime.fromisoformat(date_str[:-1]).replace(tzinfo=None)
+ else:
+ return datetime.fromisoformat(date_str.replace('T', ' '))
+ except (ValueError, AttributeError):
+ return datetime.now()
+
+ async def _make_graph_request(self, url: str, method: str = "GET",
+ data: Optional[Dict] = None, params: Optional[Dict] = None) -> httpx.Response:
+ """Make authenticated API request to Microsoft Graph."""
token = self.oauth.get_access_token()
- headers = {"Authorization": f"Bearer {token}"}
+ headers = {
+ "Authorization": f"Bearer {token}",
+ "Content-Type": "application/json",
+ }
+
async with httpx.AsyncClient() as client:
- meta_resp = await client.get(
- f"{self.base_url}/me/drive/items/{file_id}", headers=headers
- )
- meta_resp.raise_for_status()
- metadata = meta_resp.json()
+ if method.upper() == "GET":
+ response = await client.get(url, headers=headers, params=params, timeout=30)
+ elif method.upper() == "POST":
+ response = await client.post(url, headers=headers, json=data, timeout=30)
+ elif method.upper() == "DELETE":
+ response = await client.delete(url, headers=headers, timeout=30)
+ else:
+ raise ValueError(f"Unsupported HTTP method: {method}")
- content_resp = await client.get(
- f"{self.base_url}/me/drive/items/{file_id}/content", headers=headers
- )
- content_resp.raise_for_status()
- content = content_resp.content
+ response.raise_for_status()
+ return response
- perm_resp = await client.get(
- f"{self.base_url}/me/drive/items/{file_id}/permissions", headers=headers
- )
- perm_resp.raise_for_status()
- permissions = perm_resp.json()
+ def _get_mime_type(self, filename: str) -> str:
+ """Get MIME type based on file extension."""
+ import mimetypes
+ mime_type, _ = mimetypes.guess_type(filename)
+ return mime_type or "application/octet-stream"
- acl = self._parse_permissions(metadata, permissions)
- modified = datetime.fromisoformat(
- metadata["lastModifiedDateTime"].replace("Z", "+00:00")
- ).replace(tzinfo=None)
- created = datetime.fromisoformat(
- metadata["createdDateTime"].replace("Z", "+00:00")
- ).replace(tzinfo=None)
-
- document = ConnectorDocument(
- id=metadata["id"],
- filename=metadata["name"],
- mimetype=metadata.get("file", {}).get(
- "mimeType", "application/octet-stream"
- ),
- content=content,
- source_url=metadata.get("webUrl"),
- acl=acl,
- modified_time=modified,
- created_time=created,
- metadata={"size": metadata.get("size")},
- )
- return document
-
- def _parse_permissions(
- self, metadata: Dict[str, Any], permissions: Dict[str, Any]
- ) -> DocumentACL:
- acl = DocumentACL()
- owner = metadata.get("createdBy", {}).get("user", {}).get("email")
- if owner:
- acl.owner = owner
- for perm in permissions.get("value", []):
- role = perm.get("roles", ["read"])[0]
- grantee = perm.get("grantedToV2") or perm.get("grantedTo")
- if not grantee:
- continue
- user = grantee.get("user")
- if user and user.get("email"):
- acl.user_permissions[user["email"]] = role
- group = grantee.get("group")
- if group and group.get("email"):
- acl.group_permissions[group["email"]] = role
- return acl
-
- def handle_webhook_validation(
- self, request_method: str, headers: Dict[str, str], query_params: Dict[str, str]
- ) -> Optional[str]:
- """Handle Microsoft Graph webhook validation"""
- if request_method == "GET":
- validation_token = query_params.get("validationtoken") or query_params.get(
- "validationToken"
- )
- if validation_token:
- return validation_token
+ # Webhook methods - BaseConnector interface
+ def handle_webhook_validation(self, request_method: str,
+ headers: Dict[str, str],
+ query_params: Dict[str, str]) -> Optional[str]:
+ """Handle webhook validation (Graph API specific)."""
+ if request_method == "POST" and "validationToken" in query_params:
+ return query_params["validationToken"]
return None
- def extract_webhook_channel_id(
- self, payload: Dict[str, Any], headers: Dict[str, str]
- ) -> Optional[str]:
- """Extract SharePoint subscription ID from webhook payload"""
- values = payload.get("value", [])
- return values[0].get("subscriptionId") if values else None
+ def extract_webhook_channel_id(self, payload: Dict[str, Any],
+ headers: Dict[str, str]) -> Optional[str]:
+ """Extract channel/subscription ID from webhook payload."""
+ notifications = payload.get("value", [])
+ if notifications:
+ return notifications[0].get("subscriptionId")
+ return None
async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]:
- values = payload.get("value", [])
- file_ids = []
- for item in values:
- resource_data = item.get("resourceData", {})
- file_id = resource_data.get("id")
- if file_id:
- file_ids.append(file_id)
- return file_ids
+ """Handle webhook notification and return affected file IDs."""
+ affected_files: List[str] = []
+ notifications = payload.get("value", [])
+ for notification in notifications:
+ resource = notification.get("resource")
+ if resource and "/drive/items/" in resource:
+ file_id = resource.split("/drive/items/")[-1]
+ affected_files.append(file_id)
+ return affected_files
- async def cleanup_subscription(
- self, subscription_id: str, resource_id: str = None
- ) -> bool:
- if not self._authenticated:
+ async def cleanup_subscription(self, subscription_id: str) -> bool:
+ """Clean up subscription - BaseConnector interface."""
+ if subscription_id == "no-webhook-configured":
+ logger.info("No subscription to cleanup (webhook was not configured)")
+ return True
+
+ try:
+ if not await self.authenticate():
+ logger.error("OneDrive authentication failed during subscription cleanup")
+ return False
+
+ token = self.oauth.get_access_token()
+ headers = {"Authorization": f"Bearer {token}"}
+
+ url = f"{self._graph_base_url}/subscriptions/{subscription_id}"
+
+ async with httpx.AsyncClient() as client:
+ response = await client.delete(url, headers=headers, timeout=30)
+
+ if response.status_code in [200, 204, 404]:
+ logger.info(f"OneDrive subscription {subscription_id} cleaned up successfully")
+ return True
+ else:
+ logger.warning(f"Unexpected response cleaning up subscription: {response.status_code}")
+ return False
+
+ except Exception as e:
+ logger.error(f"Failed to cleanup OneDrive subscription {subscription_id}: {e}")
return False
- token = self.oauth.get_access_token()
- async with httpx.AsyncClient() as client:
- resp = await client.delete(
- f"{self.base_url}/subscriptions/{subscription_id}",
- headers={"Authorization": f"Bearer {token}"},
- )
- return resp.status_code in (200, 204)
diff --git a/src/connectors/onedrive/oauth.py b/src/connectors/onedrive/oauth.py
index a81124e6..a2c94d15 100644
--- a/src/connectors/onedrive/oauth.py
+++ b/src/connectors/onedrive/oauth.py
@@ -1,17 +1,28 @@
import os
+import json
+import logging
+from typing import Optional, Dict, Any
+
import aiofiles
-from typing import Optional
import msal
+logger = logging.getLogger(__name__)
+
class OneDriveOAuth:
- """Handles Microsoft Graph OAuth authentication flow"""
+ """Handles Microsoft Graph OAuth for OneDrive (personal Microsoft accounts by default)."""
- SCOPES = [
- "offline_access",
- "Files.Read.All",
- ]
+ # Reserved scopes that must NOT be sent on token or silent calls
+ RESERVED_SCOPES = {"openid", "profile", "offline_access"}
+ # For PERSONAL Microsoft Accounts (OneDrive consumer):
+ # - Use AUTH_SCOPES for interactive auth (consent + refresh token issuance)
+ # - Use RESOURCE_SCOPES for acquire_token_silent / refresh paths
+ AUTH_SCOPES = ["User.Read", "Files.Read.All", "offline_access"]
+ RESOURCE_SCOPES = ["User.Read", "Files.Read.All"]
+ SCOPES = AUTH_SCOPES # Backward-compat alias if something references .SCOPES
+
+ # Kept for reference; MSAL derives endpoints from `authority`
AUTH_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
TOKEN_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/token"
@@ -21,18 +32,29 @@ class OneDriveOAuth:
client_secret: str,
token_file: str = "onedrive_token.json",
authority: str = "https://login.microsoftonline.com/common",
+ allow_json_refresh: bool = True,
):
+ """
+ Initialize OneDriveOAuth.
+
+ Args:
+ client_id: Azure AD application (client) ID.
+ client_secret: Azure AD application client secret.
+ token_file: Path to persisted token cache file (MSAL cache format).
+ authority: Usually "https://login.microsoftonline.com/common" for MSA + org,
+ or tenant-specific for work/school.
+ allow_json_refresh: If True, permit one-time migration from legacy flat JSON
+ {"access_token","refresh_token",...}. Otherwise refuse it.
+ """
self.client_id = client_id
self.client_secret = client_secret
self.token_file = token_file
self.authority = authority
+ self.allow_json_refresh = allow_json_refresh
self.token_cache = msal.SerializableTokenCache()
+ self._current_account = None
- # Load existing cache if available
- if os.path.exists(self.token_file):
- with open(self.token_file, "r") as f:
- self.token_cache.deserialize(f.read())
-
+ # Initialize MSAL Confidential Client
self.app = msal.ConfidentialClientApplication(
client_id=self.client_id,
client_credential=self.client_secret,
@@ -40,56 +62,261 @@ class OneDriveOAuth:
token_cache=self.token_cache,
)
- async def save_cache(self):
- """Persist the token cache to file"""
- async with aiofiles.open(self.token_file, "w") as f:
- await f.write(self.token_cache.serialize())
+ async def load_credentials(self) -> bool:
+ """Load existing credentials from token file (async)."""
+ try:
+ logger.debug(f"OneDrive OAuth loading credentials from: {self.token_file}")
+ if os.path.exists(self.token_file):
+ logger.debug(f"Token file exists, reading: {self.token_file}")
- def create_authorization_url(self, redirect_uri: str) -> str:
- """Create authorization URL for OAuth flow"""
- return self.app.get_authorization_request_url(
- self.SCOPES, redirect_uri=redirect_uri
- )
+ # Read the token file
+ async with aiofiles.open(self.token_file, "r") as f:
+ cache_data = await f.read()
+ logger.debug(f"Read {len(cache_data)} chars from token file")
+
+ if cache_data.strip():
+ # 1) Try legacy flat JSON first
+ try:
+ json_data = json.loads(cache_data)
+ if isinstance(json_data, dict) and "refresh_token" in json_data:
+ if self.allow_json_refresh:
+ logger.debug(
+ "Found legacy JSON refresh_token and allow_json_refresh=True; attempting migration refresh"
+ )
+ return await self._refresh_from_json_token(json_data)
+ else:
+ logger.warning(
+ "Token file contains a legacy JSON refresh_token, but allow_json_refresh=False. "
+ "Delete the file and re-auth."
+ )
+ return False
+ except json.JSONDecodeError:
+ logger.debug("Token file is not flat JSON; attempting MSAL cache format")
+
+ # 2) Try MSAL cache format
+ logger.debug("Attempting MSAL cache deserialization")
+ self.token_cache.deserialize(cache_data)
+
+ # Get accounts from loaded cache
+ accounts = self.app.get_accounts()
+ logger.debug(f"Found {len(accounts)} accounts in MSAL cache")
+ if accounts:
+ self._current_account = accounts[0]
+ logger.debug(f"Set current account: {self._current_account.get('username', 'no username')}")
+
+ # Use RESOURCE_SCOPES (no reserved scopes) for silent acquisition
+ result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
+ logger.debug(f"Silent token acquisition result keys: {list(result.keys()) if result else 'None'}")
+ if result and "access_token" in result:
+ logger.debug("Silent token acquisition successful")
+ await self.save_cache()
+ return True
+ else:
+ error_msg = (result or {}).get("error") or "No result"
+ logger.warning(f"Silent token acquisition failed: {error_msg}")
+ else:
+ logger.debug(f"Token file {self.token_file} is empty")
+ else:
+ logger.debug(f"Token file does not exist: {self.token_file}")
+
+ return False
+
+ except Exception as e:
+ logger.error(f"Failed to load OneDrive credentials: {e}")
+ import traceback
+ traceback.print_exc()
+ return False
+
+ async def _refresh_from_json_token(self, token_data: dict) -> bool:
+ """
+ Use refresh token from a legacy JSON file to get new tokens (one-time migration path).
+ Prefer using an MSAL cache file and acquire_token_silent(); this path is only for migrating older files.
+ """
+ try:
+ refresh_token = token_data.get("refresh_token")
+ if not refresh_token:
+ logger.error("No refresh_token found in JSON file - cannot refresh")
+ logger.error("You must re-authenticate interactively to obtain a valid token")
+ return False
+
+ # Use only RESOURCE_SCOPES when refreshing (no reserved scopes)
+ refresh_scopes = [s for s in self.RESOURCE_SCOPES if s not in self.RESERVED_SCOPES]
+ logger.debug(f"Using refresh token; refresh scopes = {refresh_scopes}")
+
+ result = self.app.acquire_token_by_refresh_token(
+ refresh_token=refresh_token,
+ scopes=refresh_scopes,
+ )
+
+ if result and "access_token" in result:
+ logger.debug("Successfully refreshed token via legacy JSON path")
+ await self.save_cache()
+
+ accounts = self.app.get_accounts()
+ logger.debug(f"After refresh, found {len(accounts)} accounts")
+ if accounts:
+ self._current_account = accounts[0]
+ logger.debug(f"Set current account after refresh: {self._current_account.get('username', 'no username')}")
+ return True
+
+ # Error handling
+ err = (result or {}).get("error_description") or (result or {}).get("error") or "Unknown error"
+ logger.error(f"Refresh token failed: {err}")
+
+ if any(code in err for code in ("AADSTS70000", "invalid_grant", "interaction_required")):
+ logger.warning(
+ "Refresh denied due to unauthorized/expired scopes or invalid grant. "
+ "Delete the token file and perform interactive sign-in with correct scopes."
+ )
+
+ return False
+
+ except Exception as e:
+ logger.error(f"Exception during refresh from JSON token: {e}")
+ import traceback
+ traceback.print_exc()
+ return False
+
+ async def save_cache(self):
+ """Persist the token cache to file."""
+ try:
+ # Ensure parent directory exists
+ parent = os.path.dirname(os.path.abspath(self.token_file))
+ if parent and not os.path.exists(parent):
+ os.makedirs(parent, exist_ok=True)
+
+ cache_data = self.token_cache.serialize()
+ if cache_data:
+ async with aiofiles.open(self.token_file, "w") as f:
+ await f.write(cache_data)
+ logger.debug(f"Token cache saved to {self.token_file}")
+ except Exception as e:
+ logger.error(f"Failed to save token cache: {e}")
+
+ def create_authorization_url(self, redirect_uri: str, state: Optional[str] = None) -> str:
+ """Create authorization URL for OAuth flow."""
+ # Store redirect URI for later use in callback
+ self._redirect_uri = redirect_uri
+
+ kwargs: Dict[str, Any] = {
+ # Interactive auth includes offline_access
+ "scopes": self.AUTH_SCOPES,
+ "redirect_uri": redirect_uri,
+ "prompt": "consent", # ensure refresh token on first run
+ }
+ if state:
+ kwargs["state"] = state # Optional CSRF protection
+
+ auth_url = self.app.get_authorization_request_url(**kwargs)
+
+ logger.debug(f"Generated auth URL: {auth_url}")
+ logger.debug(f"Auth scopes: {self.AUTH_SCOPES}")
+
+ return auth_url
async def handle_authorization_callback(
self, authorization_code: str, redirect_uri: str
) -> bool:
- """Handle OAuth callback and exchange code for tokens"""
- result = self.app.acquire_token_by_authorization_code(
- authorization_code,
- scopes=self.SCOPES,
- redirect_uri=redirect_uri,
- )
- if "access_token" in result:
- await self.save_cache()
- return True
- raise ValueError(result.get("error_description") or "Authorization failed")
+ """Handle OAuth callback and exchange code for tokens."""
+ try:
+ result = self.app.acquire_token_by_authorization_code(
+ authorization_code,
+ scopes=self.AUTH_SCOPES, # same as authorize step
+ redirect_uri=redirect_uri,
+ )
+
+ if result and "access_token" in result:
+ accounts = self.app.get_accounts()
+ if accounts:
+ self._current_account = accounts[0]
+
+ await self.save_cache()
+ logger.info("OneDrive OAuth authorization successful")
+ return True
+
+ error_msg = (result or {}).get("error_description") or (result or {}).get("error") or "Unknown error"
+ logger.error(f"OneDrive OAuth authorization failed: {error_msg}")
+ return False
+
+ except Exception as e:
+ logger.error(f"Exception during OneDrive OAuth authorization: {e}")
+ return False
async def is_authenticated(self) -> bool:
- """Check if we have valid credentials"""
- accounts = self.app.get_accounts()
- if not accounts:
+ """Check if we have valid credentials."""
+ try:
+ # First try to load credentials if we haven't already
+ if not self._current_account:
+ await self.load_credentials()
+
+ # Try to get a token (MSAL will refresh if needed)
+ if self._current_account:
+ result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
+ if result and "access_token" in result:
+ return True
+ else:
+ error_msg = (result or {}).get("error") or "No result returned"
+ logger.debug(f"Token acquisition failed for current account: {error_msg}")
+
+ # Fallback: try without specific account
+ result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=None)
+ if result and "access_token" in result:
+ accounts = self.app.get_accounts()
+ if accounts:
+ self._current_account = accounts[0]
+ return True
+
+ return False
+
+ except Exception as e:
+ logger.error(f"Authentication check failed: {e}")
return False
- result = self.app.acquire_token_silent(self.SCOPES, account=accounts[0])
- if "access_token" in result:
- await self.save_cache()
- return True
- return False
def get_access_token(self) -> str:
- """Get an access token for Microsoft Graph"""
- accounts = self.app.get_accounts()
- if not accounts:
- raise ValueError("Not authenticated")
- result = self.app.acquire_token_silent(self.SCOPES, account=accounts[0])
- if "access_token" not in result:
- raise ValueError(
- result.get("error_description") or "Failed to acquire access token"
- )
- return result["access_token"]
+ """Get an access token for Microsoft Graph."""
+ try:
+ # Try with current account first
+ if self._current_account:
+ result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
+ if result and "access_token" in result:
+ return result["access_token"]
+
+ # Fallback: try without specific account
+ result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=None)
+ if result and "access_token" in result:
+ return result["access_token"]
+
+ # If we get here, authentication has failed
+ error_msg = (result or {}).get("error_description") or (result or {}).get("error") or "No valid authentication"
+ raise ValueError(f"Failed to acquire access token: {error_msg}")
+
+ except Exception as e:
+ logger.error(f"Failed to get access token: {e}")
+ raise
async def revoke_credentials(self):
- """Clear token cache and remove token file"""
- self.token_cache.clear()
- if os.path.exists(self.token_file):
- os.remove(self.token_file)
+ """Clear token cache and remove token file."""
+ try:
+ # Clear in-memory state
+ self._current_account = None
+ self.token_cache = msal.SerializableTokenCache()
+
+ # Recreate MSAL app with fresh cache
+ self.app = msal.ConfidentialClientApplication(
+ client_id=self.client_id,
+ client_credential=self.client_secret,
+ authority=self.authority,
+ token_cache=self.token_cache,
+ )
+
+ # Remove token file
+ if os.path.exists(self.token_file):
+ os.remove(self.token_file)
+ logger.info(f"Removed OneDrive token file: {self.token_file}")
+
+ except Exception as e:
+ logger.error(f"Failed to revoke OneDrive credentials: {e}")
+
+ def get_service(self) -> str:
+ """Return an access token (Graph client is just the bearer)."""
+ return self.get_access_token()
diff --git a/src/connectors/sharepoint/connector.py b/src/connectors/sharepoint/connector.py
index 7135cc8e..f8283062 100644
--- a/src/connectors/sharepoint/connector.py
+++ b/src/connectors/sharepoint/connector.py
@@ -1,229 +1,567 @@
+import logging
+from pathlib import Path
+from typing import List, Dict, Any, Optional
+from urllib.parse import urlparse
+from datetime import datetime
import httpx
-import uuid
-from datetime import datetime, timedelta
-from typing import Dict, List, Any, Optional
from ..base import BaseConnector, ConnectorDocument, DocumentACL
from .oauth import SharePointOAuth
+logger = logging.getLogger(__name__)
+
class SharePointConnector(BaseConnector):
- """SharePoint Sites connector using Microsoft Graph API"""
+ """SharePoint connector using MSAL-based OAuth for authentication"""
+ # Required BaseConnector class attributes
CLIENT_ID_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_ID"
CLIENT_SECRET_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET"
-
+
# Connector metadata
CONNECTOR_NAME = "SharePoint"
- CONNECTOR_DESCRIPTION = "Connect to SharePoint sites to sync team documents"
+ CONNECTOR_DESCRIPTION = "Connect to SharePoint to sync documents and files"
CONNECTOR_ICON = "sharepoint"
-
+
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
- self.oauth = SharePointOAuth(
- client_id=self.get_client_id(),
- client_secret=self.get_client_secret(),
- token_file=config.get("token_file", "sharepoint_token.json"),
- )
- self.subscription_id = config.get("subscription_id") or config.get(
- "webhook_channel_id"
- )
- self.base_url = "https://graph.microsoft.com/v1.0"
- # SharePoint site configuration
- self.site_id = config.get("site_id") # Required for SharePoint
+ logger.debug(f"SharePoint connector __init__ called with config type: {type(config)}")
+ logger.debug(f"SharePoint connector __init__ config value: {config}")
+
+ # Ensure we always pass a valid config to the base class
+ if config is None:
+ logger.debug("Config was None, using empty dict")
+ config = {}
+
+ try:
+ logger.debug("Calling super().__init__")
+ super().__init__(config) # Now safe to call with empty dict instead of None
+ logger.debug("super().__init__ completed successfully")
+ except Exception as e:
+ logger.error(f"super().__init__ failed: {e}")
+ raise
+
+ # Initialize with defaults that allow the connector to be listed
+ self.client_id = None
+ self.client_secret = None
+ self.tenant_id = config.get("tenant_id", "common")
+ self.sharepoint_url = config.get("sharepoint_url")
+ self.redirect_uri = config.get("redirect_uri", "http://localhost")
+
+ # Try to get credentials, but don't fail if they're missing
+ try:
+ logger.debug("Attempting to get client_id")
+ self.client_id = self.get_client_id()
+ logger.debug(f"Got client_id: {self.client_id is not None}")
+ except Exception as e:
+ logger.debug(f"Failed to get client_id: {e}")
+ pass # Credentials not available, that's OK for listing
+
+ try:
+ logger.debug("Attempting to get client_secret")
+ self.client_secret = self.get_client_secret()
+ logger.debug(f"Got client_secret: {self.client_secret is not None}")
+ except Exception as e:
+ logger.debug(f"Failed to get client_secret: {e}")
+ pass # Credentials not available, that's OK for listing
- async def authenticate(self) -> bool:
- if await self.oauth.is_authenticated():
- self._authenticated = True
- return True
- return False
-
- async def setup_subscription(self) -> str:
- if not self._authenticated:
- raise ValueError("Not authenticated")
-
- webhook_url = self.config.get("webhook_url")
- if not webhook_url:
- raise ValueError("webhook_url required in config for subscriptions")
-
- expiration = (datetime.utcnow() + timedelta(days=2)).isoformat() + "Z"
- body = {
- "changeType": "created,updated,deleted",
- "notificationUrl": webhook_url,
- "resource": f"/sites/{self.site_id}/drive/root",
- "expirationDateTime": expiration,
- "clientState": str(uuid.uuid4()),
+ # Token file setup
+ project_root = Path(__file__).resolve().parent.parent.parent.parent
+ token_file = config.get("token_file") or str(project_root / "sharepoint_token.json")
+ Path(token_file).parent.mkdir(parents=True, exist_ok=True)
+
+ # Only initialize OAuth if we have credentials
+ if self.client_id and self.client_secret:
+ connection_id = config.get("connection_id", "default")
+
+ # Use token_file from config if provided, otherwise generate one
+ if config.get("token_file"):
+ oauth_token_file = config["token_file"]
+ else:
+ oauth_token_file = f"sharepoint_token_{connection_id}.json"
+
+ authority = f"https://login.microsoftonline.com/{self.tenant_id}" if self.tenant_id != "common" else "https://login.microsoftonline.com/common"
+
+ self.oauth = SharePointOAuth(
+ client_id=self.client_id,
+ client_secret=self.client_secret,
+ token_file=oauth_token_file,
+ authority=authority
+ )
+ else:
+ self.oauth = None
+
+ # Track subscription ID for webhooks
+ self._subscription_id: Optional[str] = None
+
+ # Add Graph API defaults similar to Google Drive flags
+ self._graph_api_version = "v1.0"
+ self._default_params = {
+ "$select": "id,name,size,lastModifiedDateTime,createdDateTime,webUrl,file,folder,@microsoft.graph.downloadUrl"
}
-
- token = self.oauth.get_access_token()
- async with httpx.AsyncClient() as client:
- resp = await client.post(
- f"{self.base_url}/subscriptions",
- json=body,
- headers={"Authorization": f"Bearer {token}"},
- )
- resp.raise_for_status()
- data = resp.json()
-
- self.subscription_id = data["id"]
- return self.subscription_id
-
- async def list_files(
- self, page_token: Optional[str] = None, limit: int = 100
- ) -> Dict[str, Any]:
- if not self._authenticated:
- raise ValueError("Not authenticated")
-
- params = {"$top": str(limit)}
- if page_token:
- params["$skiptoken"] = page_token
-
- token = self.oauth.get_access_token()
- async with httpx.AsyncClient() as client:
- resp = await client.get(
- f"{self.base_url}/sites/{self.site_id}/drive/root/children",
- params=params,
- headers={"Authorization": f"Bearer {token}"},
- )
- resp.raise_for_status()
- data = resp.json()
-
- files = []
- for item in data.get("value", []):
- if item.get("file"):
- files.append(
- {
- "id": item["id"],
- "name": item["name"],
- "mimeType": item.get("file", {}).get(
- "mimeType", "application/octet-stream"
- ),
- "webViewLink": item.get("webUrl"),
- "createdTime": item.get("createdDateTime"),
- "modifiedTime": item.get("lastModifiedDateTime"),
- }
- )
-
- next_token = None
- next_link = data.get("@odata.nextLink")
- if next_link:
- from urllib.parse import urlparse, parse_qs
-
- parsed = urlparse(next_link)
- next_token = parse_qs(parsed.query).get("$skiptoken", [None])[0]
-
- return {"files": files, "nextPageToken": next_token}
-
- async def get_file_content(self, file_id: str) -> ConnectorDocument:
- if not self._authenticated:
- raise ValueError("Not authenticated")
-
- token = self.oauth.get_access_token()
- headers = {"Authorization": f"Bearer {token}"}
- async with httpx.AsyncClient() as client:
- meta_resp = await client.get(
- f"{self.base_url}/sites/{self.site_id}/drive/items/{file_id}",
- headers=headers,
- )
- meta_resp.raise_for_status()
- metadata = meta_resp.json()
-
- content_resp = await client.get(
- f"{self.base_url}/sites/{self.site_id}/drive/items/{file_id}/content",
- headers=headers,
- )
- content_resp.raise_for_status()
- content = content_resp.content
-
- perm_resp = await client.get(
- f"{self.base_url}/sites/{self.site_id}/drive/items/{file_id}/permissions",
- headers=headers,
- )
- perm_resp.raise_for_status()
- permissions = perm_resp.json()
-
- acl = self._parse_permissions(metadata, permissions)
- modified = datetime.fromisoformat(
- metadata["lastModifiedDateTime"].replace("Z", "+00:00")
- ).replace(tzinfo=None)
- created = datetime.fromisoformat(
- metadata["createdDateTime"].replace("Z", "+00:00")
- ).replace(tzinfo=None)
-
- document = ConnectorDocument(
- id=metadata["id"],
- filename=metadata["name"],
- mimetype=metadata.get("file", {}).get(
- "mimeType", "application/octet-stream"
- ),
- content=content,
- source_url=metadata.get("webUrl"),
- acl=acl,
- modified_time=modified,
- created_time=created,
- metadata={"size": metadata.get("size")},
- )
- return document
-
- def _parse_permissions(
- self, metadata: Dict[str, Any], permissions: Dict[str, Any]
- ) -> DocumentACL:
- acl = DocumentACL()
- owner = metadata.get("createdBy", {}).get("user", {}).get("email")
- if owner:
- acl.owner = owner
- for perm in permissions.get("value", []):
- role = perm.get("roles", ["read"])[0]
- grantee = perm.get("grantedToV2") or perm.get("grantedTo")
- if not grantee:
- continue
- user = grantee.get("user")
- if user and user.get("email"):
- acl.user_permissions[user["email"]] = role
- group = grantee.get("group")
- if group and group.get("email"):
- acl.group_permissions[group["email"]] = role
- return acl
-
- def handle_webhook_validation(
- self, request_method: str, headers: Dict[str, str], query_params: Dict[str, str]
- ) -> Optional[str]:
- """Handle Microsoft Graph webhook validation"""
- if request_method == "GET":
- validation_token = query_params.get("validationtoken") or query_params.get(
- "validationToken"
- )
- if validation_token:
- return validation_token
- return None
-
- def extract_webhook_channel_id(
- self, payload: Dict[str, Any], headers: Dict[str, str]
- ) -> Optional[str]:
- """Extract SharePoint subscription ID from webhook payload"""
- values = payload.get("value", [])
- return values[0].get("subscriptionId") if values else None
-
- async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]:
- values = payload.get("value", [])
- file_ids = []
- for item in values:
- resource_data = item.get("resourceData", {})
- file_id = resource_data.get("id")
- if file_id:
- file_ids.append(file_id)
- return file_ids
-
- async def cleanup_subscription(
- self, subscription_id: str, resource_id: str = None
- ) -> bool:
- if not self._authenticated:
+
+ @property
+ def _graph_base_url(self) -> str:
+ """Base URL for Microsoft Graph API calls"""
+ return f"https://graph.microsoft.com/{self._graph_api_version}"
+
+ def emit(self, doc: ConnectorDocument) -> None:
+ """
+ Emit a ConnectorDocument instance.
+ """
+ logger.debug(f"Emitting SharePoint document: {doc.id} ({doc.filename})")
+
+ async def authenticate(self) -> bool:
+ """Test authentication - BaseConnector interface"""
+ logger.debug(f"SharePoint authenticate() called, oauth is None: {self.oauth is None}")
+ try:
+ if not self.oauth:
+ logger.debug("SharePoint authentication failed: OAuth not initialized")
+ self._authenticated = False
+ return False
+
+ logger.debug("Loading SharePoint credentials...")
+ # Try to load existing credentials first
+ load_result = await self.oauth.load_credentials()
+ logger.debug(f"Load credentials result: {load_result}")
+
+ logger.debug("Checking SharePoint authentication status...")
+ authenticated = await self.oauth.is_authenticated()
+ logger.debug(f"SharePoint is_authenticated result: {authenticated}")
+
+ self._authenticated = authenticated
+ return authenticated
+ except Exception as e:
+ logger.error(f"SharePoint authentication failed: {e}")
+ import traceback
+ traceback.print_exc()
+ self._authenticated = False
return False
- token = self.oauth.get_access_token()
- async with httpx.AsyncClient() as client:
- resp = await client.delete(
- f"{self.base_url}/subscriptions/{subscription_id}",
- headers={"Authorization": f"Bearer {token}"},
+
+ def get_auth_url(self) -> str:
+ """Get OAuth authorization URL"""
+ if not self.oauth:
+ raise RuntimeError("SharePoint OAuth not initialized - missing credentials")
+ return self.oauth.create_authorization_url(self.redirect_uri)
+
+ async def handle_oauth_callback(self, auth_code: str) -> Dict[str, Any]:
+ """Handle OAuth callback"""
+ if not self.oauth:
+ raise RuntimeError("SharePoint OAuth not initialized - missing credentials")
+ try:
+ success = await self.oauth.handle_authorization_callback(auth_code, self.redirect_uri)
+ if success:
+ self._authenticated = True
+ return {"status": "success"}
+ else:
+ raise ValueError("OAuth callback failed")
+ except Exception as e:
+ logger.error(f"OAuth callback failed: {e}")
+ raise
+
+ def sync_once(self) -> None:
+ """
+ Perform a one-shot sync of SharePoint files and emit documents.
+ This method mirrors the Google Drive connector's sync_once functionality.
+ """
+ import asyncio
+
+ async def _async_sync():
+ try:
+ # Get list of files
+ file_list = await self.list_files(max_files=1000) # Adjust as needed
+ files = file_list.get("files", [])
+
+ for file_info in files:
+ try:
+ file_id = file_info.get("id")
+ if not file_id:
+ continue
+
+ # Get full document content
+ doc = await self.get_file_content(file_id)
+ self.emit(doc)
+
+ except Exception as e:
+ logger.error(f"Failed to sync SharePoint file {file_info.get('name', 'unknown')}: {e}")
+ continue
+
+ except Exception as e:
+ logger.error(f"SharePoint sync_once failed: {e}")
+ raise
+
+ # Run the async sync
+ if hasattr(asyncio, 'run'):
+ asyncio.run(_async_sync())
+ else:
+ # Python < 3.7 compatibility
+ loop = asyncio.get_event_loop()
+ loop.run_until_complete(_async_sync())
+
+ async def setup_subscription(self) -> str:
+ """Set up real-time subscription for file changes - BaseConnector interface"""
+ webhook_url = self.config.get('webhook_url')
+ if not webhook_url:
+ logger.warning("No webhook URL configured, skipping SharePoint subscription setup")
+ return "no-webhook-configured"
+
+ try:
+ # Ensure we're authenticated
+ if not await self.authenticate():
+ raise RuntimeError("SharePoint authentication failed during subscription setup")
+
+ token = self.oauth.get_access_token()
+
+ # Microsoft Graph subscription for SharePoint site
+ site_info = self._parse_sharepoint_url()
+ if site_info:
+ resource = f"sites/{site_info['host_name']}:/sites/{site_info['site_name']}:/drive/root"
+ else:
+ resource = "/me/drive/root"
+
+ subscription_data = {
+ "changeType": "created,updated,deleted",
+ "notificationUrl": f"{webhook_url}/webhook/sharepoint",
+ "resource": resource,
+ "expirationDateTime": self._get_subscription_expiry(),
+ "clientState": f"sharepoint_{self.tenant_id}"
+ }
+
+ headers = {
+ "Authorization": f"Bearer {token}",
+ "Content-Type": "application/json"
+ }
+
+ url = f"{self._graph_base_url}/subscriptions"
+
+ async with httpx.AsyncClient() as client:
+ response = await client.post(url, json=subscription_data, headers=headers, timeout=30)
+ response.raise_for_status()
+
+ result = response.json()
+ subscription_id = result.get("id")
+
+ if subscription_id:
+ self._subscription_id = subscription_id
+ logger.info(f"SharePoint subscription created: {subscription_id}")
+ return subscription_id
+ else:
+ raise ValueError("No subscription ID returned from Microsoft Graph")
+
+ except Exception as e:
+ logger.error(f"Failed to setup SharePoint subscription: {e}")
+ raise
+
+ def _get_subscription_expiry(self) -> str:
+ """Get subscription expiry time (max 3 days for Graph API)"""
+ from datetime import datetime, timedelta
+ expiry = datetime.utcnow() + timedelta(days=3) # 3 days max for Graph
+ return expiry.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
+
+ def _parse_sharepoint_url(self) -> Optional[Dict[str, str]]:
+ """Parse SharePoint URL to extract site information for Graph API"""
+ if not self.sharepoint_url:
+ return None
+
+ try:
+ parsed = urlparse(self.sharepoint_url)
+ # Extract hostname and site name from URL like: https://contoso.sharepoint.com/sites/teamsite
+ host_name = parsed.netloc
+ path_parts = parsed.path.strip('/').split('/')
+
+ if len(path_parts) >= 2 and path_parts[0] == 'sites':
+ site_name = path_parts[1]
+ return {
+ "host_name": host_name,
+ "site_name": site_name
+ }
+ except Exception as e:
+ logger.warning(f"Could not parse SharePoint URL {self.sharepoint_url}: {e}")
+
+ return None
+
+ async def list_files(
+ self,
+ page_token: Optional[str] = None,
+ max_files: Optional[int] = None,
+ **kwargs
+ ) -> Dict[str, Any]:
+ """List all files using Microsoft Graph API - BaseConnector interface"""
+ try:
+ # Ensure authentication
+ if not await self.authenticate():
+ raise RuntimeError("SharePoint authentication failed during file listing")
+
+ files = []
+ max_files_value = max_files if max_files is not None else 100
+
+ # Build Graph API URL for the site or fallback to user's OneDrive
+ site_info = self._parse_sharepoint_url()
+ if site_info:
+ base_url = f"{self._graph_base_url}/sites/{site_info['host_name']}:/sites/{site_info['site_name']}:/drive/root/children"
+ else:
+ base_url = f"{self._graph_base_url}/me/drive/root/children"
+
+ params = dict(self._default_params)
+ params["$top"] = str(max_files_value)
+
+ if page_token:
+ params["$skiptoken"] = page_token
+
+ response = await self._make_graph_request(base_url, params=params)
+ data = response.json()
+
+ items = data.get("value", [])
+ for item in items:
+ # Only include files, not folders
+ if item.get("file"):
+ files.append({
+ "id": item.get("id", ""),
+ "name": item.get("name", ""),
+ "path": f"/drive/items/{item.get('id')}",
+ "size": int(item.get("size", 0)),
+ "modified": item.get("lastModifiedDateTime"),
+ "created": item.get("createdDateTime"),
+ "mime_type": item.get("file", {}).get("mimeType", self._get_mime_type(item.get("name", ""))),
+ "url": item.get("webUrl", ""),
+ "download_url": item.get("@microsoft.graph.downloadUrl")
+ })
+
+ # Check for next page
+ next_page_token = None
+ next_link = data.get("@odata.nextLink")
+ if next_link:
+ from urllib.parse import urlparse, parse_qs
+ parsed = urlparse(next_link)
+ query_params = parse_qs(parsed.query)
+ if "$skiptoken" in query_params:
+ next_page_token = query_params["$skiptoken"][0]
+
+ return {
+ "files": files,
+ "next_page_token": next_page_token
+ }
+
+ except Exception as e:
+ logger.error(f"Failed to list SharePoint files: {e}")
+ return {"files": [], "next_page_token": None} # Return empty result instead of raising
+
+ async def get_file_content(self, file_id: str) -> ConnectorDocument:
+ """Get file content and metadata - BaseConnector interface"""
+ try:
+ # Ensure authentication
+ if not await self.authenticate():
+ raise RuntimeError("SharePoint authentication failed during file content retrieval")
+
+ # First get file metadata using Graph API
+ file_metadata = await self._get_file_metadata_by_id(file_id)
+
+ if not file_metadata:
+ raise ValueError(f"File not found: {file_id}")
+
+ # Download file content
+ download_url = file_metadata.get("download_url")
+ if download_url:
+ content = await self._download_file_from_url(download_url)
+ else:
+ content = await self._download_file_content(file_id)
+
+ # Create ACL from metadata
+ acl = DocumentACL(
+ owner="", # Graph API requires additional calls for detailed permissions
+ user_permissions={},
+ group_permissions={}
)
- return resp.status_code in (200, 204)
+
+ # Parse dates
+ modified_time = self._parse_graph_date(file_metadata.get("modified"))
+ created_time = self._parse_graph_date(file_metadata.get("created"))
+
+ return ConnectorDocument(
+ id=file_id,
+ filename=file_metadata.get("name", ""),
+ mimetype=file_metadata.get("mime_type", "application/octet-stream"),
+ content=content,
+ source_url=file_metadata.get("url", ""),
+ acl=acl,
+ modified_time=modified_time,
+ created_time=created_time,
+ metadata={
+ "sharepoint_path": file_metadata.get("path", ""),
+ "sharepoint_url": self.sharepoint_url,
+ "size": file_metadata.get("size", 0)
+ }
+ )
+
+ except Exception as e:
+ logger.error(f"Failed to get SharePoint file content {file_id}: {e}")
+ raise
+
+ async def _get_file_metadata_by_id(self, file_id: str) -> Optional[Dict[str, Any]]:
+ """Get file metadata by ID using Graph API"""
+ try:
+ # Try site-specific path first, then fallback to user drive
+ site_info = self._parse_sharepoint_url()
+ if site_info:
+ url = f"{self._graph_base_url}/sites/{site_info['host_name']}:/sites/{site_info['site_name']}:/drive/items/{file_id}"
+ else:
+ url = f"{self._graph_base_url}/me/drive/items/{file_id}"
+
+ params = dict(self._default_params)
+
+ response = await self._make_graph_request(url, params=params)
+ item = response.json()
+
+ if item.get("file"):
+ return {
+ "id": file_id,
+ "name": item.get("name", ""),
+ "path": f"/drive/items/{file_id}",
+ "size": int(item.get("size", 0)),
+ "modified": item.get("lastModifiedDateTime"),
+ "created": item.get("createdDateTime"),
+ "mime_type": item.get("file", {}).get("mimeType", self._get_mime_type(item.get("name", ""))),
+ "url": item.get("webUrl", ""),
+ "download_url": item.get("@microsoft.graph.downloadUrl")
+ }
+
+ return None
+
+ except Exception as e:
+ logger.error(f"Failed to get file metadata for {file_id}: {e}")
+ return None
+
+ async def _download_file_content(self, file_id: str) -> bytes:
+ """Download file content by file ID using Graph API"""
+ try:
+ site_info = self._parse_sharepoint_url()
+ if site_info:
+ url = f"{self._graph_base_url}/sites/{site_info['host_name']}:/sites/{site_info['site_name']}:/drive/items/{file_id}/content"
+ else:
+ url = f"{self._graph_base_url}/me/drive/items/{file_id}/content"
+
+ token = self.oauth.get_access_token()
+ headers = {"Authorization": f"Bearer {token}"}
+
+ async with httpx.AsyncClient() as client:
+ response = await client.get(url, headers=headers, timeout=60)
+ response.raise_for_status()
+ return response.content
+
+ except Exception as e:
+ logger.error(f"Failed to download file content for {file_id}: {e}")
+ raise
+
+ async def _download_file_from_url(self, download_url: str) -> bytes:
+ """Download file content from direct download URL"""
+ try:
+ async with httpx.AsyncClient() as client:
+ response = await client.get(download_url, timeout=60)
+ response.raise_for_status()
+ return response.content
+ except Exception as e:
+ logger.error(f"Failed to download from URL {download_url}: {e}")
+ raise
+
+ def _parse_graph_date(self, date_str: Optional[str]) -> datetime:
+ """Parse Microsoft Graph date string to datetime"""
+ if not date_str:
+ return datetime.now()
+
+ try:
+ if date_str.endswith('Z'):
+ return datetime.fromisoformat(date_str[:-1]).replace(tzinfo=None)
+ else:
+ return datetime.fromisoformat(date_str.replace('T', ' '))
+ except (ValueError, AttributeError):
+ return datetime.now()
+
+ async def _make_graph_request(self, url: str, method: str = "GET",
+ data: Optional[Dict] = None, params: Optional[Dict] = None) -> httpx.Response:
+ """Make authenticated API request to Microsoft Graph"""
+ token = self.oauth.get_access_token()
+ headers = {
+ "Authorization": f"Bearer {token}",
+ "Content-Type": "application/json"
+ }
+
+ async with httpx.AsyncClient() as client:
+ if method.upper() == "GET":
+ response = await client.get(url, headers=headers, params=params, timeout=30)
+ elif method.upper() == "POST":
+ response = await client.post(url, headers=headers, json=data, timeout=30)
+ elif method.upper() == "DELETE":
+ response = await client.delete(url, headers=headers, timeout=30)
+ else:
+ raise ValueError(f"Unsupported HTTP method: {method}")
+
+ response.raise_for_status()
+ return response
+
+ def _get_mime_type(self, filename: str) -> str:
+ """Get MIME type based on file extension"""
+ import mimetypes
+ mime_type, _ = mimetypes.guess_type(filename)
+ return mime_type or "application/octet-stream"
+
+ # Webhook methods - BaseConnector interface
+ def handle_webhook_validation(self, request_method: str, headers: Dict[str, str],
+ query_params: Dict[str, str]) -> Optional[str]:
+ """Handle webhook validation (Graph API specific)"""
+ if request_method == "POST" and "validationToken" in query_params:
+ return query_params["validationToken"]
+ return None
+
+ def extract_webhook_channel_id(self, payload: Dict[str, Any],
+ headers: Dict[str, str]) -> Optional[str]:
+ """Extract channel/subscription ID from webhook payload"""
+ notifications = payload.get("value", [])
+ if notifications:
+ return notifications[0].get("subscriptionId")
+ return None
+
+ async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]:
+ """Handle webhook notification and return affected file IDs"""
+ affected_files = []
+
+ # Process Microsoft Graph webhook payload
+ notifications = payload.get("value", [])
+ for notification in notifications:
+ resource = notification.get("resource")
+ if resource and "/drive/items/" in resource:
+ file_id = resource.split("/drive/items/")[-1]
+ affected_files.append(file_id)
+
+ return affected_files
+
+ async def cleanup_subscription(self, subscription_id: str) -> bool:
+ """Clean up subscription - BaseConnector interface"""
+ if subscription_id == "no-webhook-configured":
+ logger.info("No subscription to cleanup (webhook was not configured)")
+ return True
+
+ try:
+ # Ensure authentication
+ if not await self.authenticate():
+ logger.error("SharePoint authentication failed during subscription cleanup")
+ return False
+
+ token = self.oauth.get_access_token()
+ headers = {"Authorization": f"Bearer {token}"}
+
+ url = f"{self._graph_base_url}/subscriptions/{subscription_id}"
+
+ async with httpx.AsyncClient() as client:
+ response = await client.delete(url, headers=headers, timeout=30)
+
+ if response.status_code in [200, 204, 404]:
+ logger.info(f"SharePoint subscription {subscription_id} cleaned up successfully")
+ return True
+ else:
+ logger.warning(f"Unexpected response cleaning up subscription: {response.status_code}")
+ return False
+
+ except Exception as e:
+ logger.error(f"Failed to cleanup SharePoint subscription {subscription_id}: {e}")
+ return False
diff --git a/src/connectors/sharepoint/oauth.py b/src/connectors/sharepoint/oauth.py
index fa7424e9..4a96581f 100644
--- a/src/connectors/sharepoint/oauth.py
+++ b/src/connectors/sharepoint/oauth.py
@@ -1,18 +1,28 @@
import os
+import json
+import logging
+from typing import Optional, Dict, Any
+
import aiofiles
-from typing import Optional
import msal
+logger = logging.getLogger(__name__)
+
class SharePointOAuth:
- """Handles Microsoft Graph OAuth authentication flow"""
+ """Handles Microsoft Graph OAuth authentication flow following Google Drive pattern."""
- SCOPES = [
- "offline_access",
- "Files.Read.All",
- "Sites.Read.All",
- ]
+ # Reserved scopes that must NOT be sent on token or silent calls
+ RESERVED_SCOPES = {"openid", "profile", "offline_access"}
+ # For PERSONAL Microsoft Accounts (OneDrive consumer):
+ # - Use AUTH_SCOPES for interactive auth (consent + refresh token issuance)
+ # - Use RESOURCE_SCOPES for acquire_token_silent / refresh paths
+ AUTH_SCOPES = ["User.Read", "Files.Read.All", "offline_access"]
+ RESOURCE_SCOPES = ["User.Read", "Files.Read.All"]
+ SCOPES = AUTH_SCOPES # Backward compatibility alias
+
+ # Kept for reference; MSAL derives endpoints from `authority`
AUTH_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
TOKEN_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/token"
@@ -22,18 +32,29 @@ class SharePointOAuth:
client_secret: str,
token_file: str = "sharepoint_token.json",
authority: str = "https://login.microsoftonline.com/common",
+ allow_json_refresh: bool = True,
):
+ """
+ Initialize SharePointOAuth.
+
+ Args:
+ client_id: Azure AD application (client) ID.
+ client_secret: Azure AD application client secret.
+ token_file: Path to persisted token cache file (MSAL cache format).
+ authority: Usually "https://login.microsoftonline.com/common" for MSA + org,
+ or tenant-specific for work/school.
+ allow_json_refresh: If True, permit one-time migration from legacy flat JSON
+ {"access_token","refresh_token",...}. Otherwise refuse it.
+ """
self.client_id = client_id
self.client_secret = client_secret
self.token_file = token_file
self.authority = authority
+ self.allow_json_refresh = allow_json_refresh
self.token_cache = msal.SerializableTokenCache()
+ self._current_account = None
- # Load existing cache if available
- if os.path.exists(self.token_file):
- with open(self.token_file, "r") as f:
- self.token_cache.deserialize(f.read())
-
+ # Initialize MSAL Confidential Client
self.app = msal.ConfidentialClientApplication(
client_id=self.client_id,
client_credential=self.client_secret,
@@ -41,56 +62,268 @@ class SharePointOAuth:
token_cache=self.token_cache,
)
- async def save_cache(self):
- """Persist the token cache to file"""
- async with aiofiles.open(self.token_file, "w") as f:
- await f.write(self.token_cache.serialize())
+ async def load_credentials(self) -> bool:
+ """Load existing credentials from token file (async)."""
+ try:
+ logger.debug(f"SharePoint OAuth loading credentials from: {self.token_file}")
+ if os.path.exists(self.token_file):
+ logger.debug(f"Token file exists, reading: {self.token_file}")
- def create_authorization_url(self, redirect_uri: str) -> str:
- """Create authorization URL for OAuth flow"""
- return self.app.get_authorization_request_url(
- self.SCOPES, redirect_uri=redirect_uri
- )
+ # Read the token file
+ async with aiofiles.open(self.token_file, "r") as f:
+ cache_data = await f.read()
+ logger.debug(f"Read {len(cache_data)} chars from token file")
+
+ if cache_data.strip():
+ # 1) Try legacy flat JSON first
+ try:
+ json_data = json.loads(cache_data)
+ if isinstance(json_data, dict) and "refresh_token" in json_data:
+ if self.allow_json_refresh:
+ logger.debug(
+ "Found legacy JSON refresh_token and allow_json_refresh=True; attempting migration refresh"
+ )
+ return await self._refresh_from_json_token(json_data)
+ else:
+ logger.warning(
+ "Token file contains a legacy JSON refresh_token, but allow_json_refresh=False. "
+ "Delete the file and re-auth."
+ )
+ return False
+ except json.JSONDecodeError:
+ logger.debug("Token file is not flat JSON; attempting MSAL cache format")
+
+ # 2) Try MSAL cache format
+ logger.debug("Attempting MSAL cache deserialization")
+ self.token_cache.deserialize(cache_data)
+
+ # Get accounts from loaded cache
+ accounts = self.app.get_accounts()
+ logger.debug(f"Found {len(accounts)} accounts in MSAL cache")
+ if accounts:
+ self._current_account = accounts[0]
+ logger.debug(f"Set current account: {self._current_account.get('username', 'no username')}")
+
+ # IMPORTANT: Use RESOURCE_SCOPES (no reserved scopes) for silent acquisition
+ result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
+ logger.debug(f"Silent token acquisition result keys: {list(result.keys()) if result else 'None'}")
+ if result and "access_token" in result:
+ logger.debug("Silent token acquisition successful")
+ await self.save_cache()
+ return True
+ else:
+ error_msg = (result or {}).get("error") or "No result"
+ logger.warning(f"Silent token acquisition failed: {error_msg}")
+ else:
+ logger.debug(f"Token file {self.token_file} is empty")
+ else:
+ logger.debug(f"Token file does not exist: {self.token_file}")
+
+ return False
+
+ except Exception as e:
+ logger.error(f"Failed to load SharePoint credentials: {e}")
+ import traceback
+ traceback.print_exc()
+ return False
+
+ async def _refresh_from_json_token(self, token_data: dict) -> bool:
+ """
+ Use refresh token from a legacy JSON file to get new tokens (one-time migration path).
+
+ Notes:
+ - Prefer using an MSAL cache file and acquire_token_silent().
+ - This path is only for migrating older refresh_token JSON files.
+ """
+ try:
+ refresh_token = token_data.get("refresh_token")
+ if not refresh_token:
+ logger.error("No refresh_token found in JSON file - cannot refresh")
+ logger.error("You must re-authenticate interactively to obtain a valid token")
+ return False
+
+ # Use only RESOURCE_SCOPES when refreshing (no reserved scopes)
+ refresh_scopes = [s for s in self.RESOURCE_SCOPES if s not in self.RESERVED_SCOPES]
+ logger.debug(f"Using refresh token; refresh scopes = {refresh_scopes}")
+
+ result = self.app.acquire_token_by_refresh_token(
+ refresh_token=refresh_token,
+ scopes=refresh_scopes,
+ )
+
+ if result and "access_token" in result:
+ logger.debug("Successfully refreshed token via legacy JSON path")
+ await self.save_cache()
+
+ accounts = self.app.get_accounts()
+ logger.debug(f"After refresh, found {len(accounts)} accounts")
+ if accounts:
+ self._current_account = accounts[0]
+ logger.debug(f"Set current account after refresh: {self._current_account.get('username', 'no username')}")
+ return True
+
+ # Error handling
+ err = (result or {}).get("error_description") or (result or {}).get("error") or "Unknown error"
+ logger.error(f"Refresh token failed: {err}")
+
+ if any(code in err for code in ("AADSTS70000", "invalid_grant", "interaction_required")):
+ logger.warning(
+ "Refresh denied due to unauthorized/expired scopes or invalid grant. "
+ "Delete the token file and perform interactive sign-in with correct scopes."
+ )
+
+ return False
+
+ except Exception as e:
+ logger.error(f"Exception during refresh from JSON token: {e}")
+ import traceback
+ traceback.print_exc()
+ return False
+
+ async def save_cache(self):
+ """Persist the token cache to file."""
+ try:
+ # Ensure parent directory exists
+ parent = os.path.dirname(os.path.abspath(self.token_file))
+ if parent and not os.path.exists(parent):
+ os.makedirs(parent, exist_ok=True)
+
+ cache_data = self.token_cache.serialize()
+ if cache_data:
+ async with aiofiles.open(self.token_file, "w") as f:
+ await f.write(cache_data)
+ logger.debug(f"Token cache saved to {self.token_file}")
+ except Exception as e:
+ logger.error(f"Failed to save token cache: {e}")
+
+ def create_authorization_url(self, redirect_uri: str, state: Optional[str] = None) -> str:
+ """Create authorization URL for OAuth flow."""
+ # Store redirect URI for later use in callback
+ self._redirect_uri = redirect_uri
+
+ kwargs: Dict[str, Any] = {
+ # IMPORTANT: interactive auth includes offline_access
+ "scopes": self.AUTH_SCOPES,
+ "redirect_uri": redirect_uri,
+ "prompt": "consent", # ensure refresh token on first run
+ }
+ if state:
+ kwargs["state"] = state # Optional CSRF protection
+
+ auth_url = self.app.get_authorization_request_url(**kwargs)
+
+ logger.debug(f"Generated auth URL: {auth_url}")
+ logger.debug(f"Auth scopes: {self.AUTH_SCOPES}")
+
+ return auth_url
async def handle_authorization_callback(
self, authorization_code: str, redirect_uri: str
) -> bool:
- """Handle OAuth callback and exchange code for tokens"""
- result = self.app.acquire_token_by_authorization_code(
- authorization_code,
- scopes=self.SCOPES,
- redirect_uri=redirect_uri,
- )
- if "access_token" in result:
- await self.save_cache()
- return True
- raise ValueError(result.get("error_description") or "Authorization failed")
+ """Handle OAuth callback and exchange code for tokens."""
+ try:
+ # For code exchange, we pass the same auth scopes as used in the authorize step
+ result = self.app.acquire_token_by_authorization_code(
+ authorization_code,
+ scopes=self.AUTH_SCOPES,
+ redirect_uri=redirect_uri,
+ )
+
+ if result and "access_token" in result:
+ # Store the account for future use
+ accounts = self.app.get_accounts()
+ if accounts:
+ self._current_account = accounts[0]
+
+ await self.save_cache()
+ logger.info("SharePoint OAuth authorization successful")
+ return True
+
+ error_msg = (result or {}).get("error_description") or (result or {}).get("error") or "Unknown error"
+ logger.error(f"SharePoint OAuth authorization failed: {error_msg}")
+ return False
+
+ except Exception as e:
+ logger.error(f"Exception during SharePoint OAuth authorization: {e}")
+ return False
async def is_authenticated(self) -> bool:
- """Check if we have valid credentials"""
- accounts = self.app.get_accounts()
- if not accounts:
+ """Check if we have valid credentials (simplified like Google Drive)."""
+ try:
+ # First try to load credentials if we haven't already
+ if not self._current_account:
+ await self.load_credentials()
+
+ # If we have an account, try to get a token (MSAL will refresh if needed)
+ if self._current_account:
+ # IMPORTANT: use RESOURCE_SCOPES here
+ result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
+ if result and "access_token" in result:
+ return True
+ else:
+ error_msg = (result or {}).get("error") or "No result returned"
+ logger.debug(f"Token acquisition failed for current account: {error_msg}")
+
+ # Fallback: try without specific account
+ result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=None)
+ if result and "access_token" in result:
+ # Update current account if this worked
+ accounts = self.app.get_accounts()
+ if accounts:
+ self._current_account = accounts[0]
+ return True
+
+ return False
+
+ except Exception as e:
+ logger.error(f"Authentication check failed: {e}")
return False
- result = self.app.acquire_token_silent(self.SCOPES, account=accounts[0])
- if "access_token" in result:
- await self.save_cache()
- return True
- return False
def get_access_token(self) -> str:
- """Get an access token for Microsoft Graph"""
- accounts = self.app.get_accounts()
- if not accounts:
- raise ValueError("Not authenticated")
- result = self.app.acquire_token_silent(self.SCOPES, account=accounts[0])
- if "access_token" not in result:
- raise ValueError(
- result.get("error_description") or "Failed to acquire access token"
- )
- return result["access_token"]
+ """Get an access token for Microsoft Graph (simplified like Google Drive)."""
+ try:
+ # Try with current account first
+ if self._current_account:
+ result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
+ if result and "access_token" in result:
+ return result["access_token"]
+
+ # Fallback: try without specific account
+ result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=None)
+ if result and "access_token" in result:
+ return result["access_token"]
+
+ # If we get here, authentication has failed
+ error_msg = (result or {}).get("error_description") or (result or {}).get("error") or "No valid authentication"
+ raise ValueError(f"Failed to acquire access token: {error_msg}")
+
+ except Exception as e:
+ logger.error(f"Failed to get access token: {e}")
+ raise
async def revoke_credentials(self):
- """Clear token cache and remove token file"""
- self.token_cache.clear()
- if os.path.exists(self.token_file):
- os.remove(self.token_file)
+ """Clear token cache and remove token file (like Google Drive)."""
+ try:
+ # Clear in-memory state
+ self._current_account = None
+ self.token_cache = msal.SerializableTokenCache()
+
+ # Recreate MSAL app with fresh cache
+ self.app = msal.ConfidentialClientApplication(
+ client_id=self.client_id,
+ client_credential=self.client_secret,
+ authority=self.authority,
+ token_cache=self.token_cache,
+ )
+
+ # Remove token file
+ if os.path.exists(self.token_file):
+ os.remove(self.token_file)
+ logger.info(f"Removed SharePoint token file: {self.token_file}")
+
+ except Exception as e:
+ logger.error(f"Failed to revoke SharePoint credentials: {e}")
+
+ def get_service(self) -> str:
+ """Return an access token (Graph doesn't need a generated client like Google Drive)."""
+ return self.get_access_token()
diff --git a/src/main.py b/src/main.py
index 10592e47..db4f6f4d 100644
--- a/src/main.py
+++ b/src/main.py
@@ -2,6 +2,7 @@
from connectors.langflow_connector_service import LangflowConnectorService
from connectors.service import ConnectorService
from services.flows_service import FlowsService
+from utils.embeddings import create_dynamic_index_body
from utils.logging_config import configure_from_env, get_logger
configure_from_env()
@@ -52,11 +53,11 @@ from auth_middleware import optional_auth, require_auth
from config.settings import (
DISABLE_INGEST_WITH_LANGFLOW,
EMBED_MODEL,
- INDEX_BODY,
INDEX_NAME,
SESSION_SECRET,
clients,
is_no_auth_mode,
+ get_openrag_config,
)
from services.auth_service import AuthService
from services.langflow_mcp_service import LangflowMCPService
@@ -81,7 +82,6 @@ logger.info(
cuda_version=torch.version.cuda,
)
-
async def wait_for_opensearch():
"""Wait for OpenSearch to be ready with retries"""
max_retries = 30
@@ -132,12 +132,19 @@ async def init_index():
"""Initialize OpenSearch index and security roles"""
await wait_for_opensearch()
+ # Get the configured embedding model from user configuration
+ config = get_openrag_config()
+ embedding_model = config.knowledge.embedding_model
+
+ # Create dynamic index body based on the configured embedding model
+ dynamic_index_body = create_dynamic_index_body(embedding_model)
+
# Create documents index
if not await clients.opensearch.indices.exists(index=INDEX_NAME):
- await clients.opensearch.indices.create(index=INDEX_NAME, body=INDEX_BODY)
- logger.info("Created OpenSearch index", index_name=INDEX_NAME)
+ await clients.opensearch.indices.create(index=INDEX_NAME, body=dynamic_index_body)
+ logger.info("Created OpenSearch index", index_name=INDEX_NAME, embedding_model=embedding_model)
else:
- logger.info("Index already exists, skipping creation", index_name=INDEX_NAME)
+ logger.info("Index already exists, skipping creation", index_name=INDEX_NAME, embedding_model=embedding_model)
# Create knowledge filters index
knowledge_filter_index_name = "knowledge_filters"
@@ -395,7 +402,12 @@ async def _ingest_default_documents_openrag(services, file_paths):
async def startup_tasks(services):
"""Startup tasks"""
logger.info("Starting startup tasks")
- await init_index()
+ # Only initialize basic OpenSearch connection, not the index
+ # Index will be created after onboarding when we know the embedding model
+ await wait_for_opensearch()
+
+ # Configure alerting security
+ await configure_alerting_security()
async def initialize_services():
diff --git a/src/services/flows_service.py b/src/services/flows_service.py
index 0d7a7bc8..7397cf6b 100644
--- a/src/services/flows_service.py
+++ b/src/services/flows_service.py
@@ -1,3 +1,4 @@
+import asyncio
from config.settings import (
NUDGES_FLOW_ID,
LANGFLOW_URL,
@@ -19,6 +20,7 @@ from config.settings import (
WATSONX_LLM_COMPONENT_ID,
OLLAMA_EMBEDDING_COMPONENT_ID,
OLLAMA_LLM_COMPONENT_ID,
+ get_openrag_config,
)
import json
import os
@@ -29,6 +31,74 @@ logger = get_logger(__name__)
class FlowsService:
+ def __init__(self):
+ # Cache for flow file mappings to avoid repeated filesystem scans
+ self._flow_file_cache = {}
+
+ def _get_flows_directory(self):
+ """Get the flows directory path"""
+ current_file_dir = os.path.dirname(os.path.abspath(__file__)) # src/services/
+ src_dir = os.path.dirname(current_file_dir) # src/
+ project_root = os.path.dirname(src_dir) # project root
+ return os.path.join(project_root, "flows")
+
+ def _find_flow_file_by_id(self, flow_id: str):
+ """
+ Scan the flows directory and find the JSON file that contains the specified flow ID.
+
+ Args:
+ flow_id: The flow ID to search for
+
+ Returns:
+ str: The path to the flow file, or None if not found
+ """
+ if not flow_id:
+ raise ValueError("flow_id is required")
+
+ # Check cache first
+ if flow_id in self._flow_file_cache:
+ cached_path = self._flow_file_cache[flow_id]
+ if os.path.exists(cached_path):
+ return cached_path
+ else:
+ # Remove stale cache entry
+ del self._flow_file_cache[flow_id]
+
+ flows_dir = self._get_flows_directory()
+
+ if not os.path.exists(flows_dir):
+ logger.warning(f"Flows directory not found: {flows_dir}")
+ return None
+
+ # Scan all JSON files in the flows directory
+ try:
+ for filename in os.listdir(flows_dir):
+ if not filename.endswith('.json'):
+ continue
+
+ file_path = os.path.join(flows_dir, filename)
+
+ try:
+ with open(file_path, 'r') as f:
+ flow_data = json.load(f)
+
+ # Check if this file contains the flow we're looking for
+ if flow_data.get('id') == flow_id:
+ # Cache the result
+ self._flow_file_cache[flow_id] = file_path
+ logger.info(f"Found flow {flow_id} in file: {filename}")
+ return file_path
+
+ except (json.JSONDecodeError, FileNotFoundError) as e:
+ logger.warning(f"Error reading flow file {filename}: {e}")
+ continue
+
+ except Exception as e:
+ logger.error(f"Error scanning flows directory: {e}")
+ return None
+
+ logger.warning(f"Flow with ID {flow_id} not found in flows directory")
+ return None
async def reset_langflow_flow(self, flow_type: str):
"""Reset a Langflow flow by uploading the corresponding JSON file
@@ -41,59 +111,35 @@ class FlowsService:
if not LANGFLOW_URL:
raise ValueError("LANGFLOW_URL environment variable is required")
- # Determine flow file and ID based on type
+ # Determine flow ID based on type
if flow_type == "nudges":
- flow_file = "flows/openrag_nudges.json"
flow_id = NUDGES_FLOW_ID
elif flow_type == "retrieval":
- flow_file = "flows/openrag_agent.json"
flow_id = LANGFLOW_CHAT_FLOW_ID
elif flow_type == "ingest":
- flow_file = "flows/ingestion_flow.json"
flow_id = LANGFLOW_INGEST_FLOW_ID
else:
raise ValueError(
"flow_type must be either 'nudges', 'retrieval', or 'ingest'"
)
+ if not flow_id:
+ raise ValueError(f"Flow ID not configured for flow_type '{flow_type}'")
+
+ # Dynamically find the flow file by ID
+ flow_path = self._find_flow_file_by_id(flow_id)
+ if not flow_path:
+ raise FileNotFoundError(f"Flow file not found for flow ID: {flow_id}")
+
# Load flow JSON file
try:
- # Get the project root directory (go up from src/services/ to project root)
- # __file__ is src/services/chat_service.py
- # os.path.dirname(__file__) is src/services/
- # os.path.dirname(os.path.dirname(__file__)) is src/
- # os.path.dirname(os.path.dirname(os.path.dirname(__file__))) is project root
- current_file_dir = os.path.dirname(
- os.path.abspath(__file__)
- ) # src/services/
- src_dir = os.path.dirname(current_file_dir) # src/
- project_root = os.path.dirname(src_dir) # project root
- flow_path = os.path.join(project_root, flow_file)
-
- if not os.path.exists(flow_path):
- # List contents of project root to help debug
- try:
- contents = os.listdir(project_root)
- logger.info(f"Project root contents: {contents}")
-
- flows_dir = os.path.join(project_root, "flows")
- if os.path.exists(flows_dir):
- flows_contents = os.listdir(flows_dir)
- logger.info(f"Flows directory contents: {flows_contents}")
- else:
- logger.info("Flows directory does not exist")
- except Exception as e:
- logger.error(f"Error listing directory contents: {e}")
-
- raise FileNotFoundError(f"Flow file not found at: {flow_path}")
-
with open(flow_path, "r") as f:
flow_data = json.load(f)
- logger.info(f"Successfully loaded flow data from {flow_file}")
+ logger.info(f"Successfully loaded flow data for {flow_type} from {os.path.basename(flow_path)}")
+ except json.JSONDecodeError as e:
+ raise ValueError(f"Invalid JSON in flow file {flow_path}: {e}")
except FileNotFoundError:
raise ValueError(f"Flow file not found: {flow_path}")
- except json.JSONDecodeError as e:
- raise ValueError(f"Invalid JSON in flow file {flow_file}: {e}")
# Make PATCH request to Langflow API to update the flow using shared client
try:
@@ -106,8 +152,54 @@ class FlowsService:
logger.info(
f"Successfully reset {flow_type} flow",
flow_id=flow_id,
- flow_file=flow_file,
+ flow_file=os.path.basename(flow_path),
)
+
+ # Now update the flow with current configuration settings
+ try:
+ config = get_openrag_config()
+
+ # Check if configuration has been edited (onboarding completed)
+ if config.edited:
+ logger.info(f"Updating {flow_type} flow with current configuration settings")
+
+ provider = config.provider.model_provider.lower()
+
+ # Step 1: Assign model provider (replace components) if not OpenAI
+ if provider != "openai":
+ logger.info(f"Assigning {provider} components to {flow_type} flow")
+ provider_result = await self.assign_model_provider(provider)
+
+ if not provider_result.get("success"):
+ logger.warning(f"Failed to assign {provider} components: {provider_result.get('error', 'Unknown error')}")
+ # Continue anyway, maybe just value updates will work
+
+ # Step 2: Update model values for the specific flow being reset
+ single_flow_config = [{
+ "name": flow_type,
+ "flow_id": flow_id,
+ }]
+
+ logger.info(f"Updating {flow_type} flow model values")
+ update_result = await self.change_langflow_model_value(
+ provider=provider,
+ embedding_model=config.knowledge.embedding_model,
+ llm_model=config.agent.llm_model,
+ endpoint=config.provider.endpoint if config.provider.endpoint else None,
+ flow_configs=single_flow_config
+ )
+
+ if update_result.get("success"):
+ logger.info(f"Successfully updated {flow_type} flow with current configuration")
+ else:
+ logger.warning(f"Failed to update {flow_type} flow with current configuration: {update_result.get('error', 'Unknown error')}")
+ else:
+ logger.info(f"Configuration not yet edited (onboarding not completed), skipping model updates for {flow_type} flow")
+
+ except Exception as e:
+ logger.error(f"Error updating {flow_type} flow with current configuration", error=str(e))
+ # Don't fail the entire reset operation if configuration update fails
+
return {
"success": True,
"message": f"Successfully reset {flow_type} flow",
@@ -155,11 +247,10 @@ class FlowsService:
logger.info(f"Assigning {provider} components")
- # Define flow configurations
+ # Define flow configurations (removed hardcoded file paths)
flow_configs = [
{
"name": "nudges",
- "file": "flows/openrag_nudges.json",
"flow_id": NUDGES_FLOW_ID,
"embedding_id": OPENAI_EMBEDDING_COMPONENT_ID,
"llm_id": OPENAI_LLM_COMPONENT_ID,
@@ -167,7 +258,6 @@ class FlowsService:
},
{
"name": "retrieval",
- "file": "flows/openrag_agent.json",
"flow_id": LANGFLOW_CHAT_FLOW_ID,
"embedding_id": OPENAI_EMBEDDING_COMPONENT_ID,
"llm_id": OPENAI_LLM_COMPONENT_ID,
@@ -175,7 +265,6 @@ class FlowsService:
},
{
"name": "ingest",
- "file": "flows/ingestion_flow.json",
"flow_id": LANGFLOW_INGEST_FLOW_ID,
"embedding_id": OPENAI_EMBEDDING_COMPONENT_ID,
"llm_id": None, # Ingestion flow might not have LLM
@@ -272,7 +361,6 @@ class FlowsService:
async def _update_flow_components(self, config, llm_template, embedding_template, llm_text_template):
"""Update components in a specific flow"""
flow_name = config["name"]
- flow_file = config["file"]
flow_id = config["flow_id"]
old_embedding_id = config["embedding_id"]
old_llm_id = config["llm_id"]
@@ -281,14 +369,11 @@ class FlowsService:
new_llm_id = llm_template["data"]["id"]
new_embedding_id = embedding_template["data"]["id"]
new_llm_text_id = llm_text_template["data"]["id"]
- # Get the project root directory
- current_file_dir = os.path.dirname(os.path.abspath(__file__))
- src_dir = os.path.dirname(current_file_dir)
- project_root = os.path.dirname(src_dir)
- flow_path = os.path.join(project_root, flow_file)
- if not os.path.exists(flow_path):
- raise FileNotFoundError(f"Flow file not found at: {flow_path}")
+ # Dynamically find the flow file by ID
+ flow_path = self._find_flow_file_by_id(flow_id)
+ if not flow_path:
+ raise FileNotFoundError(f"Flow file not found for flow ID: {flow_id}")
# Load flow JSON
with open(flow_path, "r") as f:
@@ -527,16 +612,17 @@ class FlowsService:
return False
async def change_langflow_model_value(
- self, provider: str, embedding_model: str, llm_model: str, endpoint: str = None
+ self, provider: str, embedding_model: str, llm_model: str, endpoint: str = None, flow_configs: list = None
):
"""
- Change dropdown values for provider-specific components across all flows
+ Change dropdown values for provider-specific components across flows
Args:
provider: The provider ("watsonx", "ollama", "openai")
embedding_model: The embedding model name to set
llm_model: The LLM model name to set
endpoint: The endpoint URL (required for watsonx/ibm provider)
+ flow_configs: Optional list of specific flow configs to update. If None, updates all flows.
Returns:
dict: Success/error response with details for each flow
@@ -552,24 +638,22 @@ class FlowsService:
f"Changing dropdown values for provider {provider}, embedding: {embedding_model}, llm: {llm_model}, endpoint: {endpoint}"
)
- # Define flow configurations with provider-specific component IDs
- flow_configs = [
- {
- "name": "nudges",
- "file": "flows/openrag_nudges.json",
- "flow_id": NUDGES_FLOW_ID,
- },
- {
- "name": "retrieval",
- "file": "flows/openrag_agent.json",
- "flow_id": LANGFLOW_CHAT_FLOW_ID,
- },
- {
- "name": "ingest",
- "file": "flows/ingestion_flow.json",
- "flow_id": LANGFLOW_INGEST_FLOW_ID,
- },
- ]
+ # Use provided flow_configs or default to all flows
+ if flow_configs is None:
+ flow_configs = [
+ {
+ "name": "nudges",
+ "flow_id": NUDGES_FLOW_ID,
+ },
+ {
+ "name": "retrieval",
+ "flow_id": LANGFLOW_CHAT_FLOW_ID,
+ },
+ {
+ "name": "ingest",
+ "flow_id": LANGFLOW_INGEST_FLOW_ID,
+ },
+ ]
# Determine target component IDs based on provider
target_embedding_id, target_llm_id, target_llm_text_id = self._get_provider_component_ids(
diff --git a/src/utils/embeddings.py b/src/utils/embeddings.py
new file mode 100644
index 00000000..f3c902e7
--- /dev/null
+++ b/src/utils/embeddings.py
@@ -0,0 +1,64 @@
+from config.settings import OLLAMA_EMBEDDING_DIMENSIONS, OPENAI_EMBEDDING_DIMENSIONS, VECTOR_DIM, WATSONX_EMBEDDING_DIMENSIONS
+from utils.logging_config import get_logger
+
+
+logger = get_logger(__name__)
+
+def get_embedding_dimensions(model_name: str) -> int:
+ """Get the embedding dimensions for a given model name."""
+
+ # Check all model dictionaries
+ all_models = {**OPENAI_EMBEDDING_DIMENSIONS, **OLLAMA_EMBEDDING_DIMENSIONS, **WATSONX_EMBEDDING_DIMENSIONS}
+
+ if model_name in all_models:
+ dimensions = all_models[model_name]
+ logger.info(f"Found dimensions for model '{model_name}': {dimensions}")
+ return dimensions
+
+ logger.warning(
+ f"Unknown embedding model '{model_name}', using default dimensions: {VECTOR_DIM}"
+ )
+ return VECTOR_DIM
+
+
+def create_dynamic_index_body(embedding_model: str) -> dict:
+ """Create a dynamic index body configuration based on the embedding model."""
+ dimensions = get_embedding_dimensions(embedding_model)
+
+ return {
+ "settings": {
+ "index": {"knn": True},
+ "number_of_shards": 1,
+ "number_of_replicas": 1,
+ },
+ "mappings": {
+ "properties": {
+ "document_id": {"type": "keyword"},
+ "filename": {"type": "keyword"},
+ "mimetype": {"type": "keyword"},
+ "page": {"type": "integer"},
+ "text": {"type": "text"},
+ "chunk_embedding": {
+ "type": "knn_vector",
+ "dimension": dimensions,
+ "method": {
+ "name": "disk_ann",
+ "engine": "jvector",
+ "space_type": "l2",
+ "parameters": {"ef_construction": 100, "m": 16},
+ },
+ },
+ "source_url": {"type": "keyword"},
+ "connector_type": {"type": "keyword"},
+ "owner": {"type": "keyword"},
+ "allowed_users": {"type": "keyword"},
+ "allowed_groups": {"type": "keyword"},
+ "user_permissions": {"type": "object"},
+ "group_permissions": {"type": "object"},
+ "created_time": {"type": "date"},
+ "modified_time": {"type": "date"},
+ "indexed_time": {"type": "date"},
+ "metadata": {"type": "object"},
+ }
+ },
+ }
\ No newline at end of file