Merge branch 'main' into lfx-openrag-update-flows

This commit is contained in:
Edwin Jose 2025-09-26 12:16:02 -05:00 committed by GitHub
commit b51efd0d5e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
33 changed files with 3197 additions and 1584 deletions

View file

@ -0,0 +1,52 @@
---
title: Agents powered by Langflow
slug: /agents
---
import Icon from "@site/src/components/icon/icon";
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
OpenRAG leverages Langflow's Agent component to power the OpenRAG Open Search Agent flow.
This flow intelligently chats with your knowledge by embedding your query, comparing it the vector database embeddings, and generating a response with the LLM.
The Agent component shines here in its ability to make decisions on not only what query should be sent, but when a query is necessary to solve the problem at hand.
<details closed>
<summary>How do agents work?</summary>
Agents extend Large Language Models (LLMs) by integrating tools, which are functions that provide additional context and enable autonomous task execution. These integrations make agents more specialized and powerful than standalone LLMs.
Whereas an LLM might generate acceptable, inert responses to general queries and tasks, an agent can leverage the integrated context and tools to provide more relevant responses and even take action. For example, you might create an agent that can access your company's documentation, repositories, and other resources to help your team with tasks that require knowledge of your specific products, customers, and code.
Agents use LLMs as a reasoning engine to process input, determine which actions to take to address the query, and then generate a response. The response could be a typical text-based LLM response, or it could involve an action, like editing a file, running a script, or calling an external API.
In an agentic context, tools are functions that the agent can run to perform tasks or access external resources. A function is wrapped as a Tool object with a common interface that the agent understands. Agents become aware of tools through tool registration, which is when the agent is provided a list of available tools typically at agent initialization. The Tool object's description tells the agent what the tool can do so that it can decide whether the tool is appropriate for a given request.
</details>
## Use the OpenRAG Open Search Agent flow
If you've chatted with your knowledge in OpenRAG, you've already experienced the OpenRAG Open Search Agent chat flow.
To view the flow, click <Icon name="Settings" aria-hidden="true"/> **Settings**, and then click **Edit in Langflow**.
This flow contains seven components:
* The Agent component orchestrates the entire flow by deciding when to search the knowledge base, how to formulate search queries, and how to combine retrieved information with the user's question to generate a comprehensive response.
The Agent behaves according to the prompt in the **Agent Instructions** field.
* The Chat Input component is connected to the Agent component's Input port. This allows to flow to be triggered by an incoming prompt from a user or application.
* The OpenSearch component is connected to the Agent component's Tools port. The agent may not use this database for every request; the agent only uses this connection if it decides the knowledge can help respond to the prompt.
* The Language Model component is connected to the Agent component's Language Model port. The agent uses the connected LLM to reason through the request sent through Chat Input.
* The Embedding Model component is connected to the Open Search component's Embedding port. This component converts text queries into vector representations that are compared with document embeddings stored in OpenSearch for semantic similarity matching. This gives your Agent's queries context.
* The Text Input component is populated with the global variable `OPENRAG-QUERY-FILTER`.
This filter is the Knowledge filter, and filters which knowledge sources to search through.
* The Agent component's Output port is connected to the Chat Output component, which returns the final response to the user or application.
All flows included with OpenRAG are designed to be modular, performant, and provider-agnostic.
To modify a flow, click <Icon name="Settings" aria-hidden="true"/> **Settings**, and click **Edit in Langflow**.
Flows are edited in the same way as in the [Langflow visual editor](https://docs.langflow.org/concepts-overview).
For an example of changing out the agent's LLM in OpenRAG, see the [Quickstart](/quickstart#change-components).
To restore the flow to its initial state, in OpenRAG, click <Icon name="Settings" aria-hidden="true"/> **Settings**, and then click **Restore Flow**.
OpenRAG warns you that this discards all custom settings. Click **Restore** to restore the flow.

View file

@ -16,6 +16,8 @@ Get started with OpenRAG by loading your knowledge, swapping out your language m
## Find your way around ## Find your way around
1. In OpenRAG, click <Icon name="MessageSquare" aria-hidden="true"/> **Chat**. 1. In OpenRAG, click <Icon name="MessageSquare" aria-hidden="true"/> **Chat**.
The chat is powered by the OpenRAG Open Search Agent.
For more information, see [Langflow Agents](/agents).
2. Ask `What documents are available to you?` 2. Ask `What documents are available to you?`
The agent responds with a message summarizing the documents that OpenRAG loads by default, which are PDFs about evaluating data quality when using LLMs in health care. The agent responds with a message summarizing the documents that OpenRAG loads by default, which are PDFs about evaluating data quality when using LLMs in health care.
3. To confirm the agent is correct, click <Icon name="Library" aria-hidden="true"/> **Knowledge**. 3. To confirm the agent is correct, click <Icon name="Library" aria-hidden="true"/> **Knowledge**.
@ -33,7 +35,7 @@ Get started with OpenRAG by loading your knowledge, swapping out your language m
These events log the agent's request to the tool and the tool's response, so you have direct visibility into your agent's functionality. These events log the agent's request to the tool and the tool's response, so you have direct visibility into your agent's functionality.
If you aren't getting the results you need, you can further tune the knowledge ingestion and agent behavior in the next section. If you aren't getting the results you need, you can further tune the knowledge ingestion and agent behavior in the next section.
## Swap out the language model to modify agent behavior ## Swap out the language model to modify agent behavior {#change-components}
To modify the knowledge ingestion or Agent behavior, click <Icon name="Settings" aria-hidden="true"/> **Settings**. To modify the knowledge ingestion or Agent behavior, click <Icon name="Settings" aria-hidden="true"/> **Settings**.

View file

@ -47,6 +47,17 @@ const sidebars = {
}, },
], ],
}, },
{
type: "category",
label: "Core components",
items: [
{
type: "doc",
id: "core-components/agents",
label: "Langflow Agents"
},
],
},
{ {
type: "category", type: "category",
label: "Configuration", label: "Configuration",

View file

@ -95,7 +95,7 @@
"data": { "data": {
"sourceHandle": { "sourceHandle": {
"dataType": "EmbeddingModel", "dataType": "EmbeddingModel",
"id": "EmbeddingModel-cxG9r", "id": "EmbeddingModel-eZ6bT",
"name": "embeddings", "name": "embeddings",
"output_types": [ "output_types": [
"Embeddings" "Embeddings"
@ -110,10 +110,10 @@
"type": "other" "type": "other"
} }
}, },
"id": "xy-edge__EmbeddingModel-cxG9r{œdataTypeœ:œEmbeddingModelœ,œidœ:œEmbeddingModel-cxG9rœ,œnameœ:œembeddingsœ,œoutput_typesœ:[œEmbeddingsœ]}-OpenSearchHybrid-XtKoA{œfieldNameœ:œembeddingœ,œidœ:œOpenSearchHybrid-XtKoAœ,œinputTypesœ:[œEmbeddingsœ],œtypeœ:œotherœ}", "id": "xy-edge__EmbeddingModel-eZ6bT{œdataTypeœ:œEmbeddingModelœ,œidœ:œEmbeddingModel-eZ6bTœ,œnameœ:œembeddingsœ,œoutput_typesœ:[œEmbeddingsœ]}-OpenSearchHybrid-XtKoA{œfieldNameœ:œembeddingœ,œidœ:œOpenSearchHybrid-XtKoAœ,œinputTypesœ:[œEmbeddingsœ],œtypeœ:œotherœ}",
"selected": false, "selected": false,
"source": "EmbeddingModel-cxG9r", "source": "EmbeddingModel-eZ6bT",
"sourceHandle": "{œdataTypeœ:œEmbeddingModelœ,œidœ:œEmbeddingModel-cxG9rœ,œnameœ:œembeddingsœ,œoutput_typesœ:[œEmbeddingsœ]}", "sourceHandle": "{œdataTypeœ:œEmbeddingModelœ,œidœ:œEmbeddingModel-eZ6bTœ,œnameœ:œembeddingsœ,œoutput_typesœ:[œEmbeddingsœ]}",
"target": "OpenSearchHybrid-XtKoA", "target": "OpenSearchHybrid-XtKoA",
"targetHandle": "{œfieldNameœ:œembeddingœ,œidœ:œOpenSearchHybrid-XtKoAœ,œinputTypesœ:[œEmbeddingsœ],œtypeœ:œotherœ}" "targetHandle": "{œfieldNameœ:œembeddingœ,œidœ:œOpenSearchHybrid-XtKoAœ,œinputTypesœ:[œEmbeddingsœ],œtypeœ:œotherœ}"
} }
@ -1631,7 +1631,7 @@
}, },
{ {
"data": { "data": {
"id": "EmbeddingModel-cxG9r", "id": "EmbeddingModel-eZ6bT",
"node": { "node": {
"base_classes": [ "base_classes": [
"Embeddings" "Embeddings"
@ -1657,7 +1657,7 @@
], ],
"frozen": false, "frozen": false,
"icon": "binary", "icon": "binary",
"last_updated": "2025-09-24T16:02:07.998Z", "last_updated": "2025-09-22T15:54:52.885Z",
"legacy": false, "legacy": false,
"metadata": { "metadata": {
"code_hash": "93faf11517da", "code_hash": "93faf11517da",
@ -1738,7 +1738,7 @@
"show": true, "show": true,
"title_case": false, "title_case": false,
"type": "str", "type": "str",
"value": "" "value": "OPENAI_API_KEY"
}, },
"chunk_size": { "chunk_size": {
"_input_type": "IntInput", "_input_type": "IntInput",
@ -1926,16 +1926,16 @@
"type": "EmbeddingModel" "type": "EmbeddingModel"
}, },
"dragging": false, "dragging": false,
"id": "EmbeddingModel-cxG9r", "id": "EmbeddingModel-eZ6bT",
"measured": { "measured": {
"height": 366, "height": 369,
"width": 320 "width": 320
}, },
"position": { "position": {
"x": 1743.8608432729177, "x": 1726.6943524438122,
"y": 1808.780792406514 "y": 1800.5330404375484
}, },
"selected": false, "selected": true,
"type": "genericNode" "type": "genericNode"
} }
], ],

View file

@ -10,18 +10,25 @@ export function LabelWrapper({
id, id,
required, required,
flex, flex,
start,
children, children,
}: { }: {
label: string; label: string;
description?: string; description?: string;
helperText?: string; helperText?: string | React.ReactNode;
id: string; id: string;
required?: boolean; required?: boolean;
flex?: boolean; flex?: boolean;
start?: boolean;
children: React.ReactNode; children: React.ReactNode;
}) { }) {
return ( return (
<div className="flex w-full items-center justify-between"> <div
className={cn(
"flex w-full items-center",
start ? "justify-start flex-row-reverse gap-3" : "justify-between",
)}
>
<div <div
className={cn( className={cn(
"flex flex-1 flex-col items-start", "flex flex-1 flex-col items-start",
@ -30,7 +37,7 @@ export function LabelWrapper({
> >
<Label <Label
htmlFor={id} htmlFor={id}
className="!text-mmd font-medium flex items-center gap-1" className="!text-mmd font-medium flex items-center gap-1.5"
> >
{label} {label}
{required && <span className="text-red-500">*</span>} {required && <span className="text-red-500">*</span>}
@ -39,7 +46,7 @@ export function LabelWrapper({
<TooltipTrigger> <TooltipTrigger>
<Info className="w-3.5 h-3.5 text-muted-foreground" /> <Info className="w-3.5 h-3.5 text-muted-foreground" />
</TooltipTrigger> </TooltipTrigger>
<TooltipContent>{helperText}</TooltipContent> <TooltipContent side="right">{helperText}</TooltipContent>
</Tooltip> </Tooltip>
)} )}
</Label> </Label>
@ -48,7 +55,7 @@ export function LabelWrapper({
<p className="text-mmd text-muted-foreground">{description}</p> <p className="text-mmd text-muted-foreground">{description}</p>
)} )}
</div> </div>
{flex && <div className="relative">{children}</div>} {flex && <div className="relative items-center flex">{children}</div>}
</div> </div>
); );
} }

View file

@ -19,7 +19,7 @@ const TooltipContent = React.forwardRef<
ref={ref} ref={ref}
sideOffset={sideOffset} sideOffset={sideOffset}
className={cn( className={cn(
"z-50 overflow-hidden rounded-md border bg-popover px-3 py-1.5 text-sm text-popover-foreground shadow-md animate-in fade-in-0 zoom-in-95 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2 origin-[--radix-tooltip-content-transform-origin]", "z-50 overflow-hidden rounded-md border bg-primary py-1 px-1.5 text-xs font-normal text-primary-foreground shadow-md animate-in fade-in-0 zoom-in-95 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2 origin-[--radix-tooltip-content-transform-origin]",
className, className,
)} )}
{...props} {...props}

View file

@ -1,20 +1,14 @@
"use client" "use client";
import React, { useState } from "react"; import React, { useState } from "react";
import { GoogleDrivePicker } from "@/components/google-drive-picker" import { UnifiedCloudPicker, CloudFile } from "@/components/cloud-picker";
import { useTask } from "@/contexts/task-context" import { useTask } from "@/contexts/task-context";
interface GoogleDriveFile { // CloudFile interface is now imported from the unified cloud picker
id: string;
name: string;
mimeType: string;
webViewLink?: string;
iconLink?: string;
}
export default function ConnectorsPage() { export default function ConnectorsPage() {
const { addTask } = useTask() const { addTask } = useTask();
const [selectedFiles, setSelectedFiles] = useState<GoogleDriveFile[]>([]); const [selectedFiles, setSelectedFiles] = useState<CloudFile[]>([]);
const [isSyncing, setIsSyncing] = useState<boolean>(false); const [isSyncing, setIsSyncing] = useState<boolean>(false);
const [syncResult, setSyncResult] = useState<{ const [syncResult, setSyncResult] = useState<{
processed?: number; processed?: number;
@ -25,15 +19,18 @@ export default function ConnectorsPage() {
errors?: number; errors?: number;
} | null>(null); } | null>(null);
const handleFileSelection = (files: GoogleDriveFile[]) => { const handleFileSelection = (files: CloudFile[]) => {
setSelectedFiles(files); setSelectedFiles(files);
}; };
const handleSync = async (connector: { connectionId: string, type: string }) => { const handleSync = async (connector: {
if (!connector.connectionId || selectedFiles.length === 0) return connectionId: string;
type: string;
}) => {
if (!connector.connectionId || selectedFiles.length === 0) return;
setIsSyncing(true) setIsSyncing(true);
setSyncResult(null) setSyncResult(null);
try { try {
const syncBody: { const syncBody: {
@ -42,40 +39,40 @@ export default function ConnectorsPage() {
selected_files?: string[]; selected_files?: string[];
} = { } = {
connection_id: connector.connectionId, connection_id: connector.connectionId,
selected_files: selectedFiles.map(file => file.id) selected_files: selectedFiles.map(file => file.id),
} };
const response = await fetch(`/api/connectors/${connector.type}/sync`, { const response = await fetch(`/api/connectors/${connector.type}/sync`, {
method: 'POST', method: "POST",
headers: { headers: {
'Content-Type': 'application/json', "Content-Type": "application/json",
}, },
body: JSON.stringify(syncBody), body: JSON.stringify(syncBody),
}) });
const result = await response.json() const result = await response.json();
if (response.status === 201) { if (response.status === 201) {
const taskId = result.task_id const taskId = result.task_id;
if (taskId) { if (taskId) {
addTask(taskId) addTask(taskId);
setSyncResult({ setSyncResult({
processed: 0, processed: 0,
total: selectedFiles.length, total: selectedFiles.length,
status: 'started' status: "started",
}) });
} }
} else if (response.ok) { } else if (response.ok) {
setSyncResult(result) setSyncResult(result);
} else { } else {
console.error('Sync failed:', result.error) console.error("Sync failed:", result.error);
setSyncResult({ error: result.error || 'Sync failed' }) setSyncResult({ error: result.error || "Sync failed" });
} }
} catch (error) { } catch (error) {
console.error('Sync error:', error) console.error("Sync error:", error);
setSyncResult({ error: 'Network error occurred' }) setSyncResult({ error: "Network error occurred" });
} finally { } finally {
setIsSyncing(false) setIsSyncing(false);
} }
}; };
@ -85,11 +82,12 @@ export default function ConnectorsPage() {
<div className="mb-6"> <div className="mb-6">
<p className="text-sm text-gray-600 mb-4"> <p className="text-sm text-gray-600 mb-4">
This is a demo page for the Google Drive picker component. This is a demo page for the Google Drive picker component. For full
For full connector functionality, visit the Settings page. connector functionality, visit the Settings page.
</p> </p>
<GoogleDrivePicker <UnifiedCloudPicker
provider="google_drive"
onFileSelected={handleFileSelection} onFileSelected={handleFileSelection}
selectedFiles={selectedFiles} selectedFiles={selectedFiles}
isAuthenticated={false} // This would come from auth context in real usage isAuthenticated={false} // This would come from auth context in real usage
@ -100,7 +98,12 @@ export default function ConnectorsPage() {
{selectedFiles.length > 0 && ( {selectedFiles.length > 0 && (
<div className="space-y-4"> <div className="space-y-4">
<button <button
onClick={() => handleSync({ connectionId: "google-drive-connection-id", type: "google-drive" })} onClick={() =>
handleSync({
connectionId: "google-drive-connection-id",
type: "google-drive",
})
}
disabled={isSyncing} disabled={isSyncing}
className="px-4 py-2 bg-blue-500 text-white rounded hover:bg-blue-600 disabled:opacity-50 disabled:cursor-not-allowed" className="px-4 py-2 bg-blue-500 text-white rounded hover:bg-blue-600 disabled:opacity-50 disabled:cursor-not-allowed"
> >
@ -115,9 +118,10 @@ export default function ConnectorsPage() {
<div className="p-3 bg-gray-100 rounded text-sm"> <div className="p-3 bg-gray-100 rounded text-sm">
{syncResult.error ? ( {syncResult.error ? (
<div className="text-red-600">Error: {syncResult.error}</div> <div className="text-red-600">Error: {syncResult.error}</div>
) : syncResult.status === 'started' ? ( ) : syncResult.status === "started" ? (
<div className="text-blue-600"> <div className="text-blue-600">
Sync started for {syncResult.total} files. Check the task notification for progress. Sync started for {syncResult.total} files. Check the task
notification for progress.
</div> </div>
) : ( ) : (
<div className="text-green-600"> <div className="text-green-600">

View file

@ -2,7 +2,7 @@ import { useState } from "react";
import { LabelInput } from "@/components/label-input"; import { LabelInput } from "@/components/label-input";
import { LabelWrapper } from "@/components/label-wrapper"; import { LabelWrapper } from "@/components/label-wrapper";
import OpenAILogo from "@/components/logo/openai-logo"; import OpenAILogo from "@/components/logo/openai-logo";
import { Switch } from "@/components/ui/switch"; import { Checkbox } from "@/components/ui/checkbox";
import { useDebouncedValue } from "@/lib/debounce"; import { useDebouncedValue } from "@/lib/debounce";
import type { OnboardingVariables } from "../../api/mutations/useOnboardingMutation"; import type { OnboardingVariables } from "../../api/mutations/useOnboardingMutation";
import { useGetOpenAIModelsQuery } from "../../api/queries/useGetModelsQuery"; import { useGetOpenAIModelsQuery } from "../../api/queries/useGetModelsQuery";
@ -72,11 +72,19 @@ export function OpenAIOnboarding({
<> <>
<div className="space-y-5"> <div className="space-y-5">
<LabelWrapper <LabelWrapper
label="Get API key from environment variable" label="Use environment OpenAI API key"
id="get-api-key" id="get-api-key"
helperText={
<>
Reuse the key from your environment config.
<br />
Uncheck to enter a different key.
</>
}
flex flex
start
> >
<Switch <Checkbox
checked={getFromEnv} checked={getFromEnv}
onCheckedChange={handleGetFromEnvChange} onCheckedChange={handleGetFromEnvChange}
/> />
@ -86,6 +94,7 @@ export function OpenAIOnboarding({
<LabelInput <LabelInput
label="OpenAI API key" label="OpenAI API key"
helperText="The API key for your OpenAI account." helperText="The API key for your OpenAI account."
className={modelsError ? "!border-destructive" : ""}
id="api-key" id="api-key"
type="password" type="password"
required required
@ -99,7 +108,7 @@ export function OpenAIOnboarding({
</p> </p>
)} )}
{modelsError && ( {modelsError && (
<p className="text-mmd text-accent-amber-foreground"> <p className="text-mmd text-destructive">
Invalid OpenAI API key. Verify or replace the key. Invalid OpenAI API key. Verify or replace the key.
</p> </p>
)} )}

View file

@ -4,29 +4,12 @@ import { useState, useEffect } from "react";
import { useParams, useRouter } from "next/navigation"; import { useParams, useRouter } from "next/navigation";
import { Button } from "@/components/ui/button"; import { Button } from "@/components/ui/button";
import { ArrowLeft, AlertCircle } from "lucide-react"; import { ArrowLeft, AlertCircle } from "lucide-react";
import { GoogleDrivePicker } from "@/components/google-drive-picker"; import { UnifiedCloudPicker, CloudFile } from "@/components/cloud-picker";
import { OneDrivePicker } from "@/components/onedrive-picker"; import type { IngestSettings } from "@/components/cloud-picker/types";
import { useTask } from "@/contexts/task-context"; import { useTask } from "@/contexts/task-context";
import { Toast } from "@/components/ui/toast"; import { Toast } from "@/components/ui/toast";
interface GoogleDriveFile { // CloudFile interface is now imported from the unified cloud picker
id: string;
name: string;
mimeType: string;
webViewLink?: string;
iconLink?: string;
}
interface OneDriveFile {
id: string;
name: string;
mimeType?: string;
webUrl?: string;
driveItem?: {
file?: { mimeType: string };
folder?: unknown;
};
}
interface CloudConnector { interface CloudConnector {
id: string; id: string;
@ -35,6 +18,7 @@ interface CloudConnector {
status: "not_connected" | "connecting" | "connected" | "error"; status: "not_connected" | "connecting" | "connected" | "error";
type: string; type: string;
connectionId?: string; connectionId?: string;
clientId: string;
hasAccessToken: boolean; hasAccessToken: boolean;
accessTokenError?: string; accessTokenError?: string;
} }
@ -49,14 +33,19 @@ export default function UploadProviderPage() {
const [isLoading, setIsLoading] = useState(true); const [isLoading, setIsLoading] = useState(true);
const [error, setError] = useState<string | null>(null); const [error, setError] = useState<string | null>(null);
const [accessToken, setAccessToken] = useState<string | null>(null); const [accessToken, setAccessToken] = useState<string | null>(null);
const [selectedFiles, setSelectedFiles] = useState< const [selectedFiles, setSelectedFiles] = useState<CloudFile[]>([]);
GoogleDriveFile[] | OneDriveFile[]
>([]);
const [isIngesting, setIsIngesting] = useState<boolean>(false); const [isIngesting, setIsIngesting] = useState<boolean>(false);
const [currentSyncTaskId, setCurrentSyncTaskId] = useState<string | null>( const [currentSyncTaskId, setCurrentSyncTaskId] = useState<string | null>(
null null
); );
const [showSuccessToast, setShowSuccessToast] = useState(false); const [showSuccessToast, setShowSuccessToast] = useState(false);
const [ingestSettings, setIngestSettings] = useState<IngestSettings>({
chunkSize: 1000,
chunkOverlap: 200,
ocr: false,
pictureDescriptions: false,
embeddingModel: "text-embedding-3-small",
});
useEffect(() => { useEffect(() => {
const fetchConnectorInfo = async () => { const fetchConnectorInfo = async () => {
@ -129,6 +118,7 @@ export default function UploadProviderPage() {
status: isConnected ? "connected" : "not_connected", status: isConnected ? "connected" : "not_connected",
type: provider, type: provider,
connectionId: activeConnection?.connection_id, connectionId: activeConnection?.connection_id,
clientId: activeConnection?.client_id,
hasAccessToken, hasAccessToken,
accessTokenError, accessTokenError,
}); });
@ -159,13 +149,6 @@ export default function UploadProviderPage() {
// Task completed successfully, show toast and redirect // Task completed successfully, show toast and redirect
setIsIngesting(false); setIsIngesting(false);
setShowSuccessToast(true); setShowSuccessToast(true);
// Dispatch knowledge updated event to refresh the knowledge table
console.log(
"Cloud provider task completed, dispatching knowledgeUpdated event"
);
window.dispatchEvent(new CustomEvent("knowledgeUpdated"));
setTimeout(() => { setTimeout(() => {
router.push("/knowledge"); router.push("/knowledge");
}, 2000); // 2 second delay to let user see toast }, 2000); // 2 second delay to let user see toast
@ -176,20 +159,12 @@ export default function UploadProviderPage() {
} }
}, [tasks, currentSyncTaskId, router]); }, [tasks, currentSyncTaskId, router]);
const handleFileSelected = (files: GoogleDriveFile[] | OneDriveFile[]) => { const handleFileSelected = (files: CloudFile[]) => {
setSelectedFiles(files); setSelectedFiles(files);
console.log(`Selected ${files.length} files from ${provider}:`, files); console.log(`Selected ${files.length} files from ${provider}:`, files);
// You can add additional handling here like triggering sync, etc. // You can add additional handling here like triggering sync, etc.
}; };
const handleGoogleDriveFileSelected = (files: GoogleDriveFile[]) => {
handleFileSelected(files);
};
const handleOneDriveFileSelected = (files: OneDriveFile[]) => {
handleFileSelected(files);
};
const handleSync = async (connector: CloudConnector) => { const handleSync = async (connector: CloudConnector) => {
if (!connector.connectionId || selectedFiles.length === 0) return; if (!connector.connectionId || selectedFiles.length === 0) return;
@ -200,9 +175,11 @@ export default function UploadProviderPage() {
connection_id: string; connection_id: string;
max_files?: number; max_files?: number;
selected_files?: string[]; selected_files?: string[];
settings?: IngestSettings;
} = { } = {
connection_id: connector.connectionId, connection_id: connector.connectionId,
selected_files: selectedFiles.map(file => file.id), selected_files: selectedFiles.map(file => file.id),
settings: ingestSettings,
}; };
const response = await fetch(`/api/connectors/${connector.type}/sync`, { const response = await fetch(`/api/connectors/${connector.type}/sync`, {
@ -353,48 +330,49 @@ export default function UploadProviderPage() {
<div className="container mx-auto max-w-3xl p-6"> <div className="container mx-auto max-w-3xl p-6">
<div className="mb-6 flex gap-2 items-center"> <div className="mb-6 flex gap-2 items-center">
<Button variant="ghost" onClick={() => router.back()}> <Button variant="ghost" onClick={() => router.back()}>
<ArrowLeft className="h-4 w-4 scale-125 mr-2" /> <ArrowLeft className="h-4 w-4 scale-125" />
</Button> </Button>
<h2 className="text-2xl font-bold">Add Cloud Knowledge</h2> <h2 className="text-2xl font-bold">
Add from {getProviderDisplayName()}
</h2>
</div> </div>
<div className="max-w-3xl mx-auto"> <div className="max-w-3xl mx-auto">
{connector.type === "google_drive" && ( <UnifiedCloudPicker
<GoogleDrivePicker provider={
onFileSelected={handleGoogleDriveFileSelected} connector.type as "google_drive" | "onedrive" | "sharepoint"
selectedFiles={selectedFiles as GoogleDriveFile[]} }
onFileSelected={handleFileSelected}
selectedFiles={selectedFiles}
isAuthenticated={true} isAuthenticated={true}
accessToken={accessToken || undefined} accessToken={accessToken || undefined}
clientId={connector.clientId}
onSettingsChange={setIngestSettings}
/> />
)}
{(connector.type === "onedrive" || connector.type === "sharepoint") && (
<OneDrivePicker
onFileSelected={handleOneDriveFileSelected}
selectedFiles={selectedFiles as OneDriveFile[]}
isAuthenticated={true}
accessToken={accessToken || undefined}
connectorType={connector.type as "onedrive" | "sharepoint"}
/>
)}
</div> </div>
{selectedFiles.length > 0 && ( <div className="max-w-3xl mx-auto mt-6">
<div className="max-w-3xl mx-auto mt-8"> <div className="flex justify-between gap-3 mb-4">
<div className="flex justify-end gap-3 mb-4">
<Button <Button
variant="ghost"
className=" border bg-transparent border-border rounded-lg text-secondary-foreground"
onClick={() => router.back()}
>
Back
</Button>
<Button
variant="secondary"
onClick={() => handleSync(connector)} onClick={() => handleSync(connector)}
disabled={selectedFiles.length === 0 || isIngesting} disabled={selectedFiles.length === 0 || isIngesting}
> >
{isIngesting ? ( {isIngesting ? (
<>Ingesting {selectedFiles.length} Files...</> <>Ingesting {selectedFiles.length} Files...</>
) : ( ) : (
<>Ingest Files ({selectedFiles.length})</> <>Start ingest</>
)} )}
</Button> </Button>
</div> </div>
</div> </div>
)}
{/* Success toast notification */} {/* Success toast notification */}
<Toast <Toast

View file

@ -1,110 +1,100 @@
"use client" "use client";
import { useState, useEffect, useCallback } from "react" import { useState, useEffect, useCallback } from "react";
import { Button } from "@/components/ui/button" import { Button } from "@/components/ui/button";
import { import {
Dialog, Dialog,
DialogContent, DialogContent,
DialogDescription, DialogDescription,
DialogHeader, DialogHeader,
DialogTitle, DialogTitle,
} from "@/components/ui/dialog" } from "@/components/ui/dialog";
import { GoogleDrivePicker } from "@/components/google-drive-picker" import { UnifiedCloudPicker, CloudFile } from "@/components/cloud-picker";
import { OneDrivePicker } from "@/components/onedrive-picker" import { Loader2 } from "lucide-react";
import { Loader2 } from "lucide-react"
interface GoogleDriveFile { // CloudFile interface is now imported from the unified cloud picker
id: string
name: string
mimeType: string
webViewLink?: string
iconLink?: string
}
interface OneDriveFile {
id: string
name: string
mimeType?: string
webUrl?: string
driveItem?: {
file?: { mimeType: string }
folder?: unknown
}
}
interface CloudConnector { interface CloudConnector {
id: string id: string;
name: string name: string;
description: string description: string;
icon: React.ReactNode icon: React.ReactNode;
status: "not_connected" | "connecting" | "connected" | "error" status: "not_connected" | "connecting" | "connected" | "error";
type: string type: string;
connectionId?: string connectionId?: string;
hasAccessToken: boolean clientId: string;
accessTokenError?: string hasAccessToken: boolean;
accessTokenError?: string;
} }
interface CloudConnectorsDialogProps { interface CloudConnectorsDialogProps {
isOpen: boolean isOpen: boolean;
onOpenChange: (open: boolean) => void onOpenChange: (open: boolean) => void;
onFileSelected?: (files: GoogleDriveFile[] | OneDriveFile[], connectorType: string) => void onFileSelected?: (files: CloudFile[], connectorType: string) => void;
} }
export function CloudConnectorsDialog({ export function CloudConnectorsDialog({
isOpen, isOpen,
onOpenChange, onOpenChange,
onFileSelected onFileSelected,
}: CloudConnectorsDialogProps) { }: CloudConnectorsDialogProps) {
const [connectors, setConnectors] = useState<CloudConnector[]>([]) const [connectors, setConnectors] = useState<CloudConnector[]>([]);
const [isLoading, setIsLoading] = useState(true) const [isLoading, setIsLoading] = useState(true);
const [selectedFiles, setSelectedFiles] = useState<{[connectorId: string]: GoogleDriveFile[] | OneDriveFile[]}>({}) const [selectedFiles, setSelectedFiles] = useState<{
const [connectorAccessTokens, setConnectorAccessTokens] = useState<{[connectorType: string]: string}>({}) [connectorId: string]: CloudFile[];
const [activePickerType, setActivePickerType] = useState<string | null>(null) }>({});
const [connectorAccessTokens, setConnectorAccessTokens] = useState<{
[connectorType: string]: string;
}>({});
const [activePickerType, setActivePickerType] = useState<string | null>(null);
const getConnectorIcon = (iconName: string) => { const getConnectorIcon = (iconName: string) => {
const iconMap: { [key: string]: React.ReactElement } = { const iconMap: { [key: string]: React.ReactElement } = {
'google-drive': ( "google-drive": (
<div className="w-8 h-8 bg-blue-600 rounded flex items-center justify-center text-white font-bold leading-none shrink-0"> <div className="w-8 h-8 bg-blue-600 rounded flex items-center justify-center text-white font-bold leading-none shrink-0">
G G
</div> </div>
), ),
'sharepoint': ( sharepoint: (
<div className="w-8 h-8 bg-blue-700 rounded flex items-center justify-center text-white font-bold leading-none shrink-0"> <div className="w-8 h-8 bg-blue-700 rounded flex items-center justify-center text-white font-bold leading-none shrink-0">
SP SP
</div> </div>
), ),
'onedrive': ( onedrive: (
<div className="w-8 h-8 bg-blue-400 rounded flex items-center justify-center text-white font-bold leading-none shrink-0"> <div className="w-8 h-8 bg-blue-400 rounded flex items-center justify-center text-white font-bold leading-none shrink-0">
OD OD
</div> </div>
), ),
} };
return iconMap[iconName] || ( return (
iconMap[iconName] || (
<div className="w-8 h-8 bg-gray-500 rounded flex items-center justify-center text-white font-bold leading-none shrink-0"> <div className="w-8 h-8 bg-gray-500 rounded flex items-center justify-center text-white font-bold leading-none shrink-0">
? ?
</div> </div>
) )
} );
};
const fetchConnectorStatuses = useCallback(async () => { const fetchConnectorStatuses = useCallback(async () => {
if (!isOpen) return if (!isOpen) return;
setIsLoading(true) setIsLoading(true);
try { try {
// Fetch available connectors from backend // Fetch available connectors from backend
const connectorsResponse = await fetch('/api/connectors') const connectorsResponse = await fetch("/api/connectors");
if (!connectorsResponse.ok) { if (!connectorsResponse.ok) {
throw new Error('Failed to load connectors') throw new Error("Failed to load connectors");
} }
const connectorsResult = await connectorsResponse.json() const connectorsResult = await connectorsResponse.json();
const connectorTypes = Object.keys(connectorsResult.connectors) const connectorTypes = Object.keys(connectorsResult.connectors);
// Filter to only cloud connectors // Filter to only cloud connectors
const cloudConnectorTypes = connectorTypes.filter(type => const cloudConnectorTypes = connectorTypes.filter(
['google_drive', 'onedrive', 'sharepoint'].includes(type) && type =>
["google_drive", "onedrive", "sharepoint"].includes(type) &&
connectorsResult.connectors[type].available connectorsResult.connectors[type].available
) );
// Initialize connectors list // Initialize connectors list
const initialConnectors = cloudConnectorTypes.map(type => ({ const initialConnectors = cloudConnectorTypes.map(type => ({
@ -115,82 +105,95 @@ export function CloudConnectorsDialog({
status: "not_connected" as const, status: "not_connected" as const,
type: type, type: type,
hasAccessToken: false, hasAccessToken: false,
accessTokenError: undefined accessTokenError: undefined,
})) clientId: "",
}));
setConnectors(initialConnectors) setConnectors(initialConnectors);
// Check status for each cloud connector type // Check status for each cloud connector type
for (const connectorType of cloudConnectorTypes) { for (const connectorType of cloudConnectorTypes) {
try { try {
const response = await fetch(`/api/connectors/${connectorType}/status`) const response = await fetch(
`/api/connectors/${connectorType}/status`
);
if (response.ok) { if (response.ok) {
const data = await response.json() const data = await response.json();
const connections = data.connections || [] const connections = data.connections || [];
const activeConnection = connections.find((conn: { connection_id: string; is_active: boolean }) => conn.is_active) const activeConnection = connections.find(
const isConnected = activeConnection !== undefined (conn: { connection_id: string; is_active: boolean }) =>
conn.is_active
);
const isConnected = activeConnection !== undefined;
let hasAccessToken = false let hasAccessToken = false;
let accessTokenError: string | undefined = undefined let accessTokenError: string | undefined = undefined;
// Try to get access token for connected connectors // Try to get access token for connected connectors
if (isConnected && activeConnection) { if (isConnected && activeConnection) {
try { try {
const tokenResponse = await fetch(`/api/connectors/${connectorType}/token?connection_id=${activeConnection.connection_id}`) const tokenResponse = await fetch(
`/api/connectors/${connectorType}/token?connection_id=${activeConnection.connection_id}`
);
if (tokenResponse.ok) { if (tokenResponse.ok) {
const tokenData = await tokenResponse.json() const tokenData = await tokenResponse.json();
if (tokenData.access_token) { if (tokenData.access_token) {
hasAccessToken = true hasAccessToken = true;
setConnectorAccessTokens(prev => ({ setConnectorAccessTokens(prev => ({
...prev, ...prev,
[connectorType]: tokenData.access_token [connectorType]: tokenData.access_token,
})) }));
} }
} else { } else {
const errorData = await tokenResponse.json().catch(() => ({ error: 'Token unavailable' })) const errorData = await tokenResponse
accessTokenError = errorData.error || 'Access token unavailable' .json()
.catch(() => ({ error: "Token unavailable" }));
accessTokenError =
errorData.error || "Access token unavailable";
} }
} catch { } catch {
accessTokenError = 'Failed to fetch access token' accessTokenError = "Failed to fetch access token";
} }
} }
setConnectors(prev => prev.map(c => setConnectors(prev =>
prev.map(c =>
c.type === connectorType c.type === connectorType
? { ? {
...c, ...c,
status: isConnected ? "connected" : "not_connected", status: isConnected ? "connected" : "not_connected",
connectionId: activeConnection?.connection_id, connectionId: activeConnection?.connection_id,
clientId: activeConnection?.client_id,
hasAccessToken, hasAccessToken,
accessTokenError accessTokenError,
} }
: c : c
)) )
);
} }
} catch (error) { } catch (error) {
console.error(`Failed to check status for ${connectorType}:`, error) console.error(`Failed to check status for ${connectorType}:`, error);
} }
} }
} catch (error) { } catch (error) {
console.error('Failed to load cloud connectors:', error) console.error("Failed to load cloud connectors:", error);
} finally { } finally {
setIsLoading(false) setIsLoading(false);
} }
}, [isOpen]) }, [isOpen]);
const handleFileSelection = (connectorId: string, files: GoogleDriveFile[] | OneDriveFile[]) => { const handleFileSelection = (connectorId: string, files: CloudFile[]) => {
setSelectedFiles(prev => ({ setSelectedFiles(prev => ({
...prev, ...prev,
[connectorId]: files [connectorId]: files,
})) }));
onFileSelected?.(files, connectorId) onFileSelected?.(files, connectorId);
} };
useEffect(() => { useEffect(() => {
fetchConnectorStatuses() fetchConnectorStatuses();
}, [fetchConnectorStatuses]) }, [fetchConnectorStatuses]);
return ( return (
<Dialog open={isOpen} onOpenChange={onOpenChange}> <Dialog open={isOpen} onOpenChange={onOpenChange}>
@ -218,19 +221,24 @@ export function CloudConnectorsDialog({
<div className="flex flex-wrap gap-3 justify-center"> <div className="flex flex-wrap gap-3 justify-center">
{connectors {connectors
.filter(connector => connector.status === "connected") .filter(connector => connector.status === "connected")
.map((connector) => ( .map(connector => (
<Button <Button
key={connector.id} key={connector.id}
variant={connector.hasAccessToken ? "default" : "secondary"} variant={
connector.hasAccessToken ? "default" : "secondary"
}
disabled={!connector.hasAccessToken} disabled={!connector.hasAccessToken}
title={!connector.hasAccessToken ? title={
(connector.accessTokenError || "Access token required - try reconnecting your account") !connector.hasAccessToken
: `Select files from ${connector.name}`} ? connector.accessTokenError ||
onClick={(e) => { "Access token required - try reconnecting your account"
e.preventDefault() : `Select files from ${connector.name}`
e.stopPropagation() }
onClick={e => {
e.preventDefault();
e.stopPropagation();
if (connector.hasAccessToken) { if (connector.hasAccessToken) {
setActivePickerType(connector.id) setActivePickerType(connector.id);
} }
}} }}
className="min-w-[120px]" className="min-w-[120px]"
@ -243,54 +251,46 @@ export function CloudConnectorsDialog({
{connectors.every(c => c.status !== "connected") && ( {connectors.every(c => c.status !== "connected") && (
<div className="text-center py-8 text-muted-foreground"> <div className="text-center py-8 text-muted-foreground">
<p>No connected cloud providers found.</p> <p>No connected cloud providers found.</p>
<p className="text-sm mt-1">Go to Settings to connect your cloud storage accounts.</p> <p className="text-sm mt-1">
Go to Settings to connect your cloud storage accounts.
</p>
</div> </div>
)} )}
{/* Render pickers inside dialog */} {/* Render unified picker inside dialog */}
{activePickerType && connectors.find(c => c.id === activePickerType) && (() => { {activePickerType &&
const connector = connectors.find(c => c.id === activePickerType)! connectors.find(c => c.id === activePickerType) &&
(() => {
const connector = connectors.find(
c => c.id === activePickerType
)!;
if (connector.type === "google_drive") {
return ( return (
<div className="mt-6"> <div className="mt-6">
<GoogleDrivePicker <UnifiedCloudPicker
onFileSelected={(files) => { provider={
handleFileSelection(connector.id, files) connector.type as
setActivePickerType(null) | "google_drive"
| "onedrive"
| "sharepoint"
}
onFileSelected={files => {
handleFileSelection(connector.id, files);
setActivePickerType(null);
}} }}
selectedFiles={selectedFiles[connector.id] as GoogleDriveFile[] || []} selectedFiles={selectedFiles[connector.id] || []}
isAuthenticated={connector.status === "connected"} isAuthenticated={connector.status === "connected"}
accessToken={connectorAccessTokens[connector.type]} accessToken={connectorAccessTokens[connector.type]}
onPickerStateChange={() => {}} onPickerStateChange={() => {}}
clientId={connector.clientId}
/> />
</div> </div>
) );
}
if (connector.type === "onedrive" || connector.type === "sharepoint") {
return (
<div className="mt-6">
<OneDrivePicker
onFileSelected={(files) => {
handleFileSelection(connector.id, files)
setActivePickerType(null)
}}
selectedFiles={selectedFiles[connector.id] as OneDriveFile[] || []}
isAuthenticated={connector.status === "connected"}
accessToken={connectorAccessTokens[connector.type]}
connectorType={connector.type as "onedrive" | "sharepoint"}
/>
</div>
)
}
return null
})()} })()}
</div> </div>
)} )}
</div> </div>
</DialogContent> </DialogContent>
</Dialog> </Dialog>
) );
} }

View file

@ -0,0 +1,67 @@
"use client";
import { Badge } from "@/components/ui/badge";
import { FileText, Folder, Trash } from "lucide-react";
import { CloudFile } from "./types";
interface FileItemProps {
file: CloudFile;
onRemove: (fileId: string) => void;
}
const getFileIcon = (mimeType: string) => {
if (mimeType.includes("folder")) {
return <Folder className="h-6 w-6" />;
}
return <FileText className="h-6 w-6" />;
};
const getMimeTypeLabel = (mimeType: string) => {
const typeMap: { [key: string]: string } = {
"application/vnd.google-apps.document": "Google Doc",
"application/vnd.google-apps.spreadsheet": "Google Sheet",
"application/vnd.google-apps.presentation": "Google Slides",
"application/vnd.google-apps.folder": "Folder",
"application/pdf": "PDF",
"text/plain": "Text",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document":
"Word Doc",
"application/vnd.openxmlformats-officedocument.presentationml.presentation":
"PowerPoint",
};
return typeMap[mimeType] || mimeType?.split("/").pop() || "Document";
};
const formatFileSize = (bytes?: number) => {
if (!bytes) return "";
const sizes = ["B", "KB", "MB", "GB", "TB"];
if (bytes === 0) return "0 B";
const i = Math.floor(Math.log(bytes) / Math.log(1024));
return `${(bytes / Math.pow(1024, i)).toFixed(1)} ${sizes[i]}`;
};
export const FileItem = ({ file, onRemove }: FileItemProps) => (
<div
key={file.id}
className="flex items-center justify-between p-2 rounded-md text-xs"
>
<div className="flex items-center gap-2 flex-1 min-w-0">
{getFileIcon(file.mimeType)}
<span className="truncate font-medium text-sm mr-2">{file.name}</span>
<Badge variant="secondary" className="text-xs px-1 py-0.5 h-auto">
{getMimeTypeLabel(file.mimeType)}
</Badge>
</div>
<div className="flex items-center gap-2">
<span className="text-xs text-muted-foreground mr-4" title="file size">
{formatFileSize(file.size) || "—"}
</span>
<Trash
className="text-muted-foreground w-5 h-5 cursor-pointer hover:text-destructive"
onClick={() => onRemove(file.id)}
/>
</div>
</div>
);

View file

@ -0,0 +1,42 @@
"use client";
import { Button } from "@/components/ui/button";
import { CloudFile } from "./types";
import { FileItem } from "./file-item";
interface FileListProps {
files: CloudFile[];
onClearAll: () => void;
onRemoveFile: (fileId: string) => void;
}
export const FileList = ({
files,
onClearAll,
onRemoveFile,
}: FileListProps) => {
if (files.length === 0) {
return null;
}
return (
<div className="space-y-2">
<div className="flex items-center justify-between">
<p className="text-sm font-medium">Added files</p>
<Button
onClick={onClearAll}
size="sm"
variant="ghost"
className="text-sm text-muted-foreground"
>
Clear all
</Button>
</div>
<div className="max-h-64 overflow-y-auto space-y-1">
{files.map(file => (
<FileItem key={file.id} file={file} onRemove={onRemoveFile} />
))}
</div>
</div>
);
};

View file

@ -0,0 +1,7 @@
export { UnifiedCloudPicker } from "./unified-cloud-picker";
export { PickerHeader } from "./picker-header";
export { FileList } from "./file-list";
export { FileItem } from "./file-item";
export { IngestSettings } from "./ingest-settings";
export * from "./types";
export * from "./provider-handlers";

View file

@ -0,0 +1,139 @@
"use client";
import { Input } from "@/components/ui/input";
import { Switch } from "@/components/ui/switch";
import {
Collapsible,
CollapsibleContent,
CollapsibleTrigger,
} from "@/components/ui/collapsible";
import { ChevronRight, Info } from "lucide-react";
import { IngestSettings as IngestSettingsType } from "./types";
interface IngestSettingsProps {
isOpen: boolean;
onOpenChange: (open: boolean) => void;
settings?: IngestSettingsType;
onSettingsChange?: (settings: IngestSettingsType) => void;
}
export const IngestSettings = ({
isOpen,
onOpenChange,
settings,
onSettingsChange,
}: IngestSettingsProps) => {
// Default settings
const defaultSettings: IngestSettingsType = {
chunkSize: 1000,
chunkOverlap: 200,
ocr: false,
pictureDescriptions: false,
embeddingModel: "text-embedding-3-small",
};
// Use provided settings or defaults
const currentSettings = settings || defaultSettings;
const handleSettingsChange = (newSettings: Partial<IngestSettingsType>) => {
const updatedSettings = { ...currentSettings, ...newSettings };
onSettingsChange?.(updatedSettings);
};
return (
<Collapsible
open={isOpen}
onOpenChange={onOpenChange}
className="border rounded-md p-4 border-muted-foreground/20"
>
<CollapsibleTrigger className="flex items-center gap-2 justify-between w-full -m-4 p-4 rounded-md transition-colors">
<div className="flex items-center gap-2">
<ChevronRight
className={`h-4 w-4 text-muted-foreground transition-transform duration-200 ${
isOpen ? "rotate-90" : ""
}`}
/>
<span className="text-sm font-medium">Ingest settings</span>
</div>
</CollapsibleTrigger>
<CollapsibleContent className="data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:slide-up-2 data-[state=open]:slide-down-2">
<div className="pt-5 space-y-5">
<div className="flex items-center gap-4 w-full">
<div className="w-full">
<div className="text-sm mb-2 font-semibold">Chunk size</div>
<Input
type="number"
value={currentSettings.chunkSize}
onChange={e =>
handleSettingsChange({
chunkSize: parseInt(e.target.value) || 0,
})
}
/>
</div>
<div className="w-full">
<div className="text-sm mb-2 font-semibold">Chunk overlap</div>
<Input
type="number"
value={currentSettings.chunkOverlap}
onChange={e =>
handleSettingsChange({
chunkOverlap: parseInt(e.target.value) || 0,
})
}
/>
</div>
</div>
<div className="flex gap-2 items-center justify-between">
<div>
<div className="text-sm font-semibold pb-2">OCR</div>
<div className="text-sm text-muted-foreground">
Extracts text from images/PDFs. Ingest is slower when enabled.
</div>
</div>
<Switch
checked={currentSettings.ocr}
onCheckedChange={checked =>
handleSettingsChange({ ocr: checked })
}
/>
</div>
<div className="flex gap-2 items-center justify-between">
<div>
<div className="text-sm pb-2 font-semibold">
Picture descriptions
</div>
<div className="text-sm text-muted-foreground">
Adds captions for images. Ingest is more expensive when enabled.
</div>
</div>
<Switch
checked={currentSettings.pictureDescriptions}
onCheckedChange={checked =>
handleSettingsChange({ pictureDescriptions: checked })
}
/>
</div>
<div>
<div className="text-sm font-semibold pb-2 flex items-center">
Embedding model
<Info className="w-3.5 h-3.5 text-muted-foreground ml-2" />
</div>
<Input
disabled
value={currentSettings.embeddingModel}
onChange={e =>
handleSettingsChange({ embeddingModel: e.target.value })
}
placeholder="text-embedding-3-small"
/>
</div>
</div>
</CollapsibleContent>
</Collapsible>
);
};

View file

@ -0,0 +1,70 @@
"use client";
import { Button } from "@/components/ui/button";
import { Card, CardContent } from "@/components/ui/card";
import { Plus } from "lucide-react";
import { CloudProvider } from "./types";
interface PickerHeaderProps {
provider: CloudProvider;
onAddFiles: () => void;
isPickerLoaded: boolean;
isPickerOpen: boolean;
accessToken?: string;
isAuthenticated: boolean;
}
const getProviderName = (provider: CloudProvider): string => {
switch (provider) {
case "google_drive":
return "Google Drive";
case "onedrive":
return "OneDrive";
case "sharepoint":
return "SharePoint";
default:
return "Cloud Storage";
}
};
export const PickerHeader = ({
provider,
onAddFiles,
isPickerLoaded,
isPickerOpen,
accessToken,
isAuthenticated,
}: PickerHeaderProps) => {
if (!isAuthenticated) {
return (
<div className="text-sm text-muted-foreground p-4 bg-muted/20 rounded-md">
Please connect to {getProviderName(provider)} first to select specific
files.
</div>
);
}
return (
<Card>
<CardContent className="flex flex-col items-center text-center py-8">
<p className="text-sm text-primary mb-4">
Select files from {getProviderName(provider)} to ingest.
</p>
<Button
size="sm"
onClick={onAddFiles}
disabled={!isPickerLoaded || isPickerOpen || !accessToken}
className="bg-foreground text-background hover:bg-foreground/90 font-semibold"
>
<Plus className="h-4 w-4" />
{isPickerOpen ? "Opening Picker..." : "Add Files"}
</Button>
<div className="text-xs text-muted-foreground pt-4">
csv, json, pdf,{" "}
<a className="underline dark:text-pink-400 text-pink-600">+16 more</a>{" "}
<b>150 MB</b> max
</div>
</CardContent>
</Card>
);
};

View file

@ -0,0 +1,245 @@
"use client";
import {
CloudFile,
CloudProvider,
GooglePickerData,
GooglePickerDocument,
} from "./types";
export class GoogleDriveHandler {
private accessToken: string;
private onPickerStateChange?: (isOpen: boolean) => void;
constructor(
accessToken: string,
onPickerStateChange?: (isOpen: boolean) => void
) {
this.accessToken = accessToken;
this.onPickerStateChange = onPickerStateChange;
}
async loadPickerApi(): Promise<boolean> {
return new Promise(resolve => {
if (typeof window !== "undefined" && window.gapi) {
window.gapi.load("picker", {
callback: () => resolve(true),
onerror: () => resolve(false),
});
} else {
// Load Google API script
const script = document.createElement("script");
script.src = "https://apis.google.com/js/api.js";
script.async = true;
script.defer = true;
script.onload = () => {
window.gapi.load("picker", {
callback: () => resolve(true),
onerror: () => resolve(false),
});
};
script.onerror = () => resolve(false);
document.head.appendChild(script);
}
});
}
openPicker(onFileSelected: (files: CloudFile[]) => void): void {
if (!window.google?.picker) {
return;
}
try {
this.onPickerStateChange?.(true);
const picker = new window.google.picker.PickerBuilder()
.addView(window.google.picker.ViewId.DOCS)
.addView(window.google.picker.ViewId.FOLDERS)
.setOAuthToken(this.accessToken)
.enableFeature(window.google.picker.Feature.MULTISELECT_ENABLED)
.setTitle("Select files from Google Drive")
.setCallback(data => this.pickerCallback(data, onFileSelected))
.build();
picker.setVisible(true);
// Apply z-index fix
setTimeout(() => {
const pickerElements = document.querySelectorAll(
".picker-dialog, .goog-modalpopup"
);
pickerElements.forEach(el => {
(el as HTMLElement).style.zIndex = "10000";
});
const bgElements = document.querySelectorAll(
".picker-dialog-bg, .goog-modalpopup-bg"
);
bgElements.forEach(el => {
(el as HTMLElement).style.zIndex = "9999";
});
}, 100);
} catch (error) {
console.error("Error creating picker:", error);
this.onPickerStateChange?.(false);
}
}
private async pickerCallback(
data: GooglePickerData,
onFileSelected: (files: CloudFile[]) => void
): Promise<void> {
if (data.action === window.google.picker.Action.PICKED) {
const files: CloudFile[] = data.docs.map((doc: GooglePickerDocument) => ({
id: doc[window.google.picker.Document.ID],
name: doc[window.google.picker.Document.NAME],
mimeType: doc[window.google.picker.Document.MIME_TYPE],
webViewLink: doc[window.google.picker.Document.URL],
iconLink: doc[window.google.picker.Document.ICON_URL],
size: doc["sizeBytes"] ? parseInt(doc["sizeBytes"]) : undefined,
modifiedTime: doc["lastEditedUtc"],
isFolder:
doc[window.google.picker.Document.MIME_TYPE] ===
"application/vnd.google-apps.folder",
}));
// Enrich with additional file data if needed
if (files.some(f => !f.size && !f.isFolder)) {
try {
const enrichedFiles = await Promise.all(
files.map(async file => {
if (!file.size && !file.isFolder) {
try {
const response = await fetch(
`https://www.googleapis.com/drive/v3/files/${file.id}?fields=size,modifiedTime`,
{
headers: {
Authorization: `Bearer ${this.accessToken}`,
},
}
);
if (response.ok) {
const fileDetails = await response.json();
return {
...file,
size: fileDetails.size
? parseInt(fileDetails.size)
: undefined,
modifiedTime:
fileDetails.modifiedTime || file.modifiedTime,
};
}
} catch (error) {
console.warn("Failed to fetch file details:", error);
}
}
return file;
})
);
onFileSelected(enrichedFiles);
} catch (error) {
console.warn("Failed to enrich file data:", error);
onFileSelected(files);
}
} else {
onFileSelected(files);
}
}
this.onPickerStateChange?.(false);
}
}
export class OneDriveHandler {
private accessToken: string;
private clientId: string;
private provider: CloudProvider;
private baseUrl?: string;
constructor(
accessToken: string,
clientId: string,
provider: CloudProvider = "onedrive",
baseUrl?: string
) {
this.accessToken = accessToken;
this.clientId = clientId;
this.provider = provider;
this.baseUrl = baseUrl;
}
async loadPickerApi(): Promise<boolean> {
return new Promise(resolve => {
const script = document.createElement("script");
script.src = "https://js.live.net/v7.2/OneDrive.js";
script.onload = () => resolve(true);
script.onerror = () => resolve(false);
document.head.appendChild(script);
});
}
openPicker(onFileSelected: (files: CloudFile[]) => void): void {
if (!window.OneDrive) {
return;
}
window.OneDrive.open({
clientId: this.clientId,
action: "query",
multiSelect: true,
advanced: {
endpointHint: "api.onedrive.com",
accessToken: this.accessToken,
},
success: (response: any) => {
const newFiles: CloudFile[] =
response.value?.map((item: any, index: number) => ({
id: item.id,
name:
item.name ||
`${this.getProviderName()} File ${index + 1} (${item.id.slice(
-8
)})`,
mimeType: item.file?.mimeType || "application/octet-stream",
webUrl: item.webUrl || "",
downloadUrl: item["@microsoft.graph.downloadUrl"] || "",
size: item.size,
modifiedTime: item.lastModifiedDateTime,
isFolder: !!item.folder,
})) || [];
onFileSelected(newFiles);
},
cancel: () => {
console.log("Picker cancelled");
},
error: (error: any) => {
console.error("Picker error:", error);
},
});
}
private getProviderName(): string {
return this.provider === "sharepoint" ? "SharePoint" : "OneDrive";
}
}
export const createProviderHandler = (
provider: CloudProvider,
accessToken: string,
onPickerStateChange?: (isOpen: boolean) => void,
clientId?: string,
baseUrl?: string
) => {
switch (provider) {
case "google_drive":
return new GoogleDriveHandler(accessToken, onPickerStateChange);
case "onedrive":
case "sharepoint":
if (!clientId) {
throw new Error("Client ID required for OneDrive/SharePoint");
}
return new OneDriveHandler(accessToken, clientId, provider, baseUrl);
default:
throw new Error(`Unsupported provider: ${provider}`);
}
};

View file

@ -0,0 +1,106 @@
export interface CloudFile {
id: string;
name: string;
mimeType: string;
webViewLink?: string;
iconLink?: string;
size?: number;
modifiedTime?: string;
isFolder?: boolean;
webUrl?: string;
downloadUrl?: string;
}
export type CloudProvider = "google_drive" | "onedrive" | "sharepoint";
export interface UnifiedCloudPickerProps {
provider: CloudProvider;
onFileSelected: (files: CloudFile[]) => void;
selectedFiles?: CloudFile[];
isAuthenticated: boolean;
accessToken?: string;
onPickerStateChange?: (isOpen: boolean) => void;
// OneDrive/SharePoint specific props
clientId?: string;
baseUrl?: string;
// Ingest settings
onSettingsChange?: (settings: IngestSettings) => void;
}
export interface GoogleAPI {
load: (
api: string,
options: { callback: () => void; onerror?: () => void }
) => void;
}
export interface GooglePickerData {
action: string;
docs: GooglePickerDocument[];
}
export interface GooglePickerDocument {
[key: string]: string;
}
declare global {
interface Window {
gapi: GoogleAPI;
google: {
picker: {
api: {
load: (callback: () => void) => void;
};
PickerBuilder: new () => GooglePickerBuilder;
ViewId: {
DOCS: string;
FOLDERS: string;
DOCS_IMAGES_AND_VIDEOS: string;
DOCUMENTS: string;
PRESENTATIONS: string;
SPREADSHEETS: string;
};
Feature: {
MULTISELECT_ENABLED: string;
NAV_HIDDEN: string;
SIMPLE_UPLOAD_ENABLED: string;
};
Action: {
PICKED: string;
CANCEL: string;
};
Document: {
ID: string;
NAME: string;
MIME_TYPE: string;
URL: string;
ICON_URL: string;
};
};
};
OneDrive?: any;
}
}
export interface GooglePickerBuilder {
addView: (view: string) => GooglePickerBuilder;
setOAuthToken: (token: string) => GooglePickerBuilder;
setCallback: (
callback: (data: GooglePickerData) => void
) => GooglePickerBuilder;
enableFeature: (feature: string) => GooglePickerBuilder;
setTitle: (title: string) => GooglePickerBuilder;
build: () => GooglePicker;
}
export interface GooglePicker {
setVisible: (visible: boolean) => void;
}
export interface IngestSettings {
chunkSize: number;
chunkOverlap: number;
ocr: boolean;
pictureDescriptions: boolean;
embeddingModel: string;
}

View file

@ -0,0 +1,195 @@
"use client";
import { useState, useEffect } from "react";
import {
UnifiedCloudPickerProps,
CloudFile,
IngestSettings as IngestSettingsType,
} from "./types";
import { PickerHeader } from "./picker-header";
import { FileList } from "./file-list";
import { IngestSettings } from "./ingest-settings";
import { createProviderHandler } from "./provider-handlers";
export const UnifiedCloudPicker = ({
provider,
onFileSelected,
selectedFiles = [],
isAuthenticated,
accessToken,
onPickerStateChange,
clientId,
baseUrl,
onSettingsChange,
}: UnifiedCloudPickerProps) => {
const [isPickerLoaded, setIsPickerLoaded] = useState(false);
const [isPickerOpen, setIsPickerOpen] = useState(false);
const [isIngestSettingsOpen, setIsIngestSettingsOpen] = useState(false);
const [isLoadingBaseUrl, setIsLoadingBaseUrl] = useState(false);
const [autoBaseUrl, setAutoBaseUrl] = useState<string | undefined>(undefined);
// Settings state with defaults
const [ingestSettings, setIngestSettings] = useState<IngestSettingsType>({
chunkSize: 1000,
chunkOverlap: 200,
ocr: false,
pictureDescriptions: false,
embeddingModel: "text-embedding-3-small",
});
// Handle settings changes and notify parent
const handleSettingsChange = (newSettings: IngestSettingsType) => {
setIngestSettings(newSettings);
onSettingsChange?.(newSettings);
};
const effectiveBaseUrl = baseUrl || autoBaseUrl;
// Auto-detect base URL for OneDrive personal accounts
useEffect(() => {
if (
(provider === "onedrive" || provider === "sharepoint") &&
!baseUrl &&
accessToken &&
!autoBaseUrl
) {
const getBaseUrl = async () => {
setIsLoadingBaseUrl(true);
try {
setAutoBaseUrl("https://onedrive.live.com/picker");
} catch (error) {
console.error("Auto-detect baseUrl failed:", error);
} finally {
setIsLoadingBaseUrl(false);
}
};
getBaseUrl();
}
}, [accessToken, baseUrl, autoBaseUrl, provider]);
// Load picker API
useEffect(() => {
if (!accessToken || !isAuthenticated) return;
const loadApi = async () => {
try {
const handler = createProviderHandler(
provider,
accessToken,
onPickerStateChange,
clientId,
effectiveBaseUrl
);
const loaded = await handler.loadPickerApi();
setIsPickerLoaded(loaded);
} catch (error) {
console.error("Failed to create provider handler:", error);
setIsPickerLoaded(false);
}
};
loadApi();
}, [
accessToken,
isAuthenticated,
provider,
clientId,
effectiveBaseUrl,
onPickerStateChange,
]);
const handleAddFiles = () => {
if (!isPickerLoaded || !accessToken) {
return;
}
if ((provider === "onedrive" || provider === "sharepoint") && !clientId) {
console.error("Client ID required for OneDrive/SharePoint");
return;
}
try {
setIsPickerOpen(true);
onPickerStateChange?.(true);
const handler = createProviderHandler(
provider,
accessToken,
isOpen => {
setIsPickerOpen(isOpen);
onPickerStateChange?.(isOpen);
},
clientId,
effectiveBaseUrl
);
handler.openPicker((files: CloudFile[]) => {
// Merge new files with existing ones, avoiding duplicates
const existingIds = new Set(selectedFiles.map(f => f.id));
const newFiles = files.filter(f => !existingIds.has(f.id));
onFileSelected([...selectedFiles, ...newFiles]);
});
} catch (error) {
console.error("Error opening picker:", error);
setIsPickerOpen(false);
onPickerStateChange?.(false);
}
};
const handleRemoveFile = (fileId: string) => {
const updatedFiles = selectedFiles.filter(file => file.id !== fileId);
onFileSelected(updatedFiles);
};
const handleClearAll = () => {
onFileSelected([]);
};
if (isLoadingBaseUrl) {
return (
<div className="text-sm text-muted-foreground p-4 bg-muted/20 rounded-md">
Loading...
</div>
);
}
if (
(provider === "onedrive" || provider === "sharepoint") &&
!clientId &&
isAuthenticated
) {
return (
<div className="text-sm text-muted-foreground p-4 bg-muted/20 rounded-md">
Configuration required: Client ID missing for{" "}
{provider === "sharepoint" ? "SharePoint" : "OneDrive"}.
</div>
);
}
return (
<div className="space-y-6">
<PickerHeader
provider={provider}
onAddFiles={handleAddFiles}
isPickerLoaded={isPickerLoaded}
isPickerOpen={isPickerOpen}
accessToken={accessToken}
isAuthenticated={isAuthenticated}
/>
<FileList
files={selectedFiles}
onClearAll={handleClearAll}
onRemoveFile={handleRemoveFile}
/>
<IngestSettings
isOpen={isIngestSettingsOpen}
onOpenChange={setIsIngestSettingsOpen}
settings={ingestSettings}
onSettingsChange={handleSettingsChange}
/>
</div>
);
};

View file

@ -1,341 +0,0 @@
"use client"
import { useState, useEffect } from "react"
import { Button } from "@/components/ui/button"
import { Badge } from "@/components/ui/badge"
import { FileText, Folder, Plus, Trash2 } from "lucide-react"
import { Card, CardContent } from "@/components/ui/card"
interface GoogleDrivePickerProps {
onFileSelected: (files: GoogleDriveFile[]) => void
selectedFiles?: GoogleDriveFile[]
isAuthenticated: boolean
accessToken?: string
onPickerStateChange?: (isOpen: boolean) => void
}
interface GoogleDriveFile {
id: string
name: string
mimeType: string
webViewLink?: string
iconLink?: string
size?: number
modifiedTime?: string
isFolder?: boolean
}
interface GoogleAPI {
load: (api: string, options: { callback: () => void; onerror?: () => void }) => void
}
interface GooglePickerData {
action: string
docs: GooglePickerDocument[]
}
interface GooglePickerDocument {
[key: string]: string
}
declare global {
interface Window {
gapi: GoogleAPI
google: {
picker: {
api: {
load: (callback: () => void) => void
}
PickerBuilder: new () => GooglePickerBuilder
ViewId: {
DOCS: string
FOLDERS: string
DOCS_IMAGES_AND_VIDEOS: string
DOCUMENTS: string
PRESENTATIONS: string
SPREADSHEETS: string
}
Feature: {
MULTISELECT_ENABLED: string
NAV_HIDDEN: string
SIMPLE_UPLOAD_ENABLED: string
}
Action: {
PICKED: string
CANCEL: string
}
Document: {
ID: string
NAME: string
MIME_TYPE: string
URL: string
ICON_URL: string
}
}
}
}
}
interface GooglePickerBuilder {
addView: (view: string) => GooglePickerBuilder
setOAuthToken: (token: string) => GooglePickerBuilder
setCallback: (callback: (data: GooglePickerData) => void) => GooglePickerBuilder
enableFeature: (feature: string) => GooglePickerBuilder
setTitle: (title: string) => GooglePickerBuilder
build: () => GooglePicker
}
interface GooglePicker {
setVisible: (visible: boolean) => void
}
export function GoogleDrivePicker({
onFileSelected,
selectedFiles = [],
isAuthenticated,
accessToken,
onPickerStateChange
}: GoogleDrivePickerProps) {
const [isPickerLoaded, setIsPickerLoaded] = useState(false)
const [isPickerOpen, setIsPickerOpen] = useState(false)
useEffect(() => {
const loadPickerApi = () => {
if (typeof window !== 'undefined' && window.gapi) {
window.gapi.load('picker', {
callback: () => {
setIsPickerLoaded(true)
},
onerror: () => {
console.error('Failed to load Google Picker API')
}
})
}
}
// Load Google API script if not already loaded
if (typeof window !== 'undefined') {
if (!window.gapi) {
const script = document.createElement('script')
script.src = 'https://apis.google.com/js/api.js'
script.async = true
script.defer = true
script.onload = loadPickerApi
script.onerror = () => {
console.error('Failed to load Google API script')
}
document.head.appendChild(script)
return () => {
if (document.head.contains(script)) {
document.head.removeChild(script)
}
}
} else {
loadPickerApi()
}
}
}, [])
const openPicker = () => {
if (!isPickerLoaded || !accessToken || !window.google?.picker) {
return
}
try {
setIsPickerOpen(true)
onPickerStateChange?.(true)
// Create picker with higher z-index and focus handling
const picker = new window.google.picker.PickerBuilder()
.addView(window.google.picker.ViewId.DOCS)
.addView(window.google.picker.ViewId.FOLDERS)
.setOAuthToken(accessToken)
.enableFeature(window.google.picker.Feature.MULTISELECT_ENABLED)
.setTitle('Select files from Google Drive')
.setCallback(pickerCallback)
.build()
picker.setVisible(true)
// Apply z-index fix after a short delay to ensure picker is rendered
setTimeout(() => {
const pickerElements = document.querySelectorAll('.picker-dialog, .goog-modalpopup')
pickerElements.forEach(el => {
(el as HTMLElement).style.zIndex = '10000'
})
const bgElements = document.querySelectorAll('.picker-dialog-bg, .goog-modalpopup-bg')
bgElements.forEach(el => {
(el as HTMLElement).style.zIndex = '9999'
})
}, 100)
} catch (error) {
console.error('Error creating picker:', error)
setIsPickerOpen(false)
onPickerStateChange?.(false)
}
}
const pickerCallback = async (data: GooglePickerData) => {
if (data.action === window.google.picker.Action.PICKED) {
const files: GoogleDriveFile[] = data.docs.map((doc: GooglePickerDocument) => ({
id: doc[window.google.picker.Document.ID],
name: doc[window.google.picker.Document.NAME],
mimeType: doc[window.google.picker.Document.MIME_TYPE],
webViewLink: doc[window.google.picker.Document.URL],
iconLink: doc[window.google.picker.Document.ICON_URL],
size: doc['sizeBytes'] ? parseInt(doc['sizeBytes']) : undefined,
modifiedTime: doc['lastEditedUtc'],
isFolder: doc[window.google.picker.Document.MIME_TYPE] === 'application/vnd.google-apps.folder'
}))
// If size is still missing, try to fetch it via Google Drive API
if (accessToken && files.some(f => !f.size && !f.isFolder)) {
try {
const enrichedFiles = await Promise.all(files.map(async (file) => {
if (!file.size && !file.isFolder) {
try {
const response = await fetch(`https://www.googleapis.com/drive/v3/files/${file.id}?fields=size,modifiedTime`, {
headers: {
'Authorization': `Bearer ${accessToken}`
}
})
if (response.ok) {
const fileDetails = await response.json()
return {
...file,
size: fileDetails.size ? parseInt(fileDetails.size) : undefined,
modifiedTime: fileDetails.modifiedTime || file.modifiedTime
}
}
} catch (error) {
console.warn('Failed to fetch file details:', error)
}
}
return file
}))
onFileSelected(enrichedFiles)
} catch (error) {
console.warn('Failed to enrich file data:', error)
onFileSelected(files)
}
} else {
onFileSelected(files)
}
}
setIsPickerOpen(false)
onPickerStateChange?.(false)
}
const removeFile = (fileId: string) => {
const updatedFiles = selectedFiles.filter(file => file.id !== fileId)
onFileSelected(updatedFiles)
}
const getFileIcon = (mimeType: string) => {
if (mimeType.includes('folder')) {
return <Folder className="h-4 w-4" />
}
return <FileText className="h-4 w-4" />
}
const getMimeTypeLabel = (mimeType: string) => {
const typeMap: { [key: string]: string } = {
'application/vnd.google-apps.document': 'Google Doc',
'application/vnd.google-apps.spreadsheet': 'Google Sheet',
'application/vnd.google-apps.presentation': 'Google Slides',
'application/vnd.google-apps.folder': 'Folder',
'application/pdf': 'PDF',
'text/plain': 'Text',
'application/vnd.openxmlformats-officedocument.wordprocessingml.document': 'Word Doc',
'application/vnd.openxmlformats-officedocument.presentationml.presentation': 'PowerPoint'
}
return typeMap[mimeType] || 'Document'
}
const formatFileSize = (bytes?: number) => {
if (!bytes) return ''
const sizes = ['B', 'KB', 'MB', 'GB', 'TB']
if (bytes === 0) return '0 B'
const i = Math.floor(Math.log(bytes) / Math.log(1024))
return `${(bytes / Math.pow(1024, i)).toFixed(1)} ${sizes[i]}`
}
if (!isAuthenticated) {
return (
<div className="text-sm text-muted-foreground p-4 bg-muted/20 rounded-md">
Please connect to Google Drive first to select specific files.
</div>
)
}
return (
<div className="space-y-4">
<Card>
<CardContent className="flex flex-col items-center text-center p-6">
<p className="text-sm text-muted-foreground mb-4">
Select files from Google Drive to ingest.
</p>
<Button
onClick={openPicker}
disabled={!isPickerLoaded || isPickerOpen || !accessToken}
className="bg-foreground text-background hover:bg-foreground/90"
>
<Plus className="h-4 w-4" />
{isPickerOpen ? 'Opening Picker...' : 'Add Files'}
</Button>
</CardContent>
</Card>
{selectedFiles.length > 0 && (
<div className="space-y-2">
<div className="flex items-center justify-between">
<p className="text-xs text-muted-foreground">
Added files
</p>
<Button
onClick={() => onFileSelected([])}
size="sm"
variant="ghost"
className="text-xs h-6"
>
Clear all
</Button>
</div>
<div className="max-h-64 overflow-y-auto space-y-1">
{selectedFiles.map((file) => (
<div
key={file.id}
className="flex items-center justify-between p-2 bg-muted/30 rounded-md text-xs"
>
<div className="flex items-center gap-2 flex-1 min-w-0">
{getFileIcon(file.mimeType)}
<span className="truncate font-medium">{file.name}</span>
<Badge variant="secondary" className="text-xs px-1 py-0.5 h-auto">
{getMimeTypeLabel(file.mimeType)}
</Badge>
</div>
<div className="flex items-center gap-2">
<span className="text-xs text-muted-foreground">{formatFileSize(file.size)}</span>
<Button
onClick={() => removeFile(file.id)}
size="sm"
variant="ghost"
className="h-6 w-6 p-0"
>
<Trash2 className="h-3 w-3" />
</Button>
</div>
</div>
))}
</div>
</div>
)}
</div>
)
}

View file

@ -1,320 +0,0 @@
"use client"
import { useState, useEffect } from "react"
import { Button } from "@/components/ui/button"
import { Badge } from "@/components/ui/badge"
import { FileText, Folder, Trash2, X } from "lucide-react"
interface OneDrivePickerProps {
onFileSelected: (files: OneDriveFile[]) => void
selectedFiles?: OneDriveFile[]
isAuthenticated: boolean
accessToken?: string
connectorType?: "onedrive" | "sharepoint"
onPickerStateChange?: (isOpen: boolean) => void
}
interface OneDriveFile {
id: string
name: string
mimeType?: string
webUrl?: string
driveItem?: {
file?: { mimeType: string }
folder?: unknown
}
}
interface GraphResponse {
value: OneDriveFile[]
}
declare global {
interface Window {
mgt?: {
Providers?: {
globalProvider?: unknown
}
}
}
}
export function OneDrivePicker({
onFileSelected,
selectedFiles = [],
isAuthenticated,
accessToken,
connectorType = "onedrive",
onPickerStateChange
}: OneDrivePickerProps) {
const [isLoading, setIsLoading] = useState(false)
const [files, setFiles] = useState<OneDriveFile[]>([])
const [isPickerOpen, setIsPickerOpen] = useState(false)
const [currentPath, setCurrentPath] = useState<string>(
connectorType === "sharepoint" ? 'sites?search=' : 'me/drive/root/children'
)
const [breadcrumbs, setBreadcrumbs] = useState<{id: string, name: string}[]>([
{id: 'root', name: connectorType === "sharepoint" ? 'SharePoint' : 'OneDrive'}
])
useEffect(() => {
const loadMGT = async () => {
if (typeof window !== 'undefined' && !window.mgt) {
try {
await import('@microsoft/mgt-components')
await import('@microsoft/mgt-msal2-provider')
// For simplicity, we'll use direct Graph API calls instead of MGT components
// MGT provider initialization would go here if needed
} catch {
console.warn('MGT not available, falling back to direct API calls')
}
}
}
loadMGT()
}, [accessToken])
const fetchFiles = async (path: string = currentPath) => {
if (!accessToken) return
setIsLoading(true)
try {
const response = await fetch(`https://graph.microsoft.com/v1.0/${path}`, {
headers: {
'Authorization': `Bearer ${accessToken}`,
'Content-Type': 'application/json'
}
})
if (response.ok) {
const data: GraphResponse = await response.json()
setFiles(data.value || [])
} else {
console.error('Failed to fetch OneDrive files:', response.statusText)
}
} catch (error) {
console.error('Error fetching OneDrive files:', error)
} finally {
setIsLoading(false)
}
}
const openPicker = () => {
if (!accessToken) return
setIsPickerOpen(true)
onPickerStateChange?.(true)
fetchFiles()
}
const closePicker = () => {
setIsPickerOpen(false)
onPickerStateChange?.(false)
setFiles([])
setCurrentPath(
connectorType === "sharepoint" ? 'sites?search=' : 'me/drive/root/children'
)
setBreadcrumbs([
{id: 'root', name: connectorType === "sharepoint" ? 'SharePoint' : 'OneDrive'}
])
}
const handleFileClick = (file: OneDriveFile) => {
if (file.driveItem?.folder) {
// Navigate to folder
const newPath = `me/drive/items/${file.id}/children`
setCurrentPath(newPath)
setBreadcrumbs([...breadcrumbs, {id: file.id, name: file.name}])
fetchFiles(newPath)
} else {
// Select file
const isAlreadySelected = selectedFiles.some(f => f.id === file.id)
if (!isAlreadySelected) {
onFileSelected([...selectedFiles, file])
}
}
}
const navigateToBreadcrumb = (index: number) => {
if (index === 0) {
setCurrentPath('me/drive/root/children')
setBreadcrumbs([{id: 'root', name: 'OneDrive'}])
fetchFiles('me/drive/root/children')
} else {
const targetCrumb = breadcrumbs[index]
const newPath = `me/drive/items/${targetCrumb.id}/children`
setCurrentPath(newPath)
setBreadcrumbs(breadcrumbs.slice(0, index + 1))
fetchFiles(newPath)
}
}
const removeFile = (fileId: string) => {
const updatedFiles = selectedFiles.filter(file => file.id !== fileId)
onFileSelected(updatedFiles)
}
const getFileIcon = (file: OneDriveFile) => {
if (file.driveItem?.folder) {
return <Folder className="h-4 w-4" />
}
return <FileText className="h-4 w-4" />
}
const getMimeTypeLabel = (file: OneDriveFile) => {
const mimeType = file.driveItem?.file?.mimeType || file.mimeType || ''
const typeMap: { [key: string]: string } = {
'application/vnd.openxmlformats-officedocument.wordprocessingml.document': 'Word Doc',
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': 'Excel',
'application/vnd.openxmlformats-officedocument.presentationml.presentation': 'PowerPoint',
'application/pdf': 'PDF',
'text/plain': 'Text',
'image/jpeg': 'Image',
'image/png': 'Image',
}
if (file.driveItem?.folder) return 'Folder'
return typeMap[mimeType] || 'Document'
}
const serviceName = connectorType === "sharepoint" ? "SharePoint" : "OneDrive"
if (!isAuthenticated) {
return (
<div className="text-sm text-muted-foreground p-4 bg-muted/20 rounded-md">
Please connect to {serviceName} first to select specific files.
</div>
)
}
return (
<div className="space-y-4">
<div className="flex items-center justify-between">
<div>
<h4 className="text-sm font-medium">{serviceName} File Selection</h4>
<p className="text-xs text-muted-foreground">
Choose specific files to sync instead of syncing everything
</p>
</div>
<Button
onClick={openPicker}
disabled={!accessToken}
size="sm"
variant="outline"
title={!accessToken ? `Access token required - try disconnecting and reconnecting ${serviceName}` : ""}
>
{!accessToken ? "No Access Token" : "Select Files"}
</Button>
</div>
{/* Status message when access token is missing */}
{isAuthenticated && !accessToken && (
<div className="text-xs text-amber-600 bg-amber-50 p-3 rounded-md border border-amber-200">
<div className="font-medium mb-1">Access token unavailable</div>
<div>The file picker requires an access token. Try disconnecting and reconnecting your {serviceName} account.</div>
</div>
)}
{/* File Picker Modal */}
{isPickerOpen && (
<div className="fixed inset-0 bg-black bg-opacity-50 flex items-center justify-center z-[100]">
<div className="bg-white rounded-lg p-6 max-w-2xl w-full max-h-[80vh] overflow-hidden flex flex-col">
<div className="flex items-center justify-between mb-4">
<h3 className="text-lg font-semibold">Select Files from {serviceName}</h3>
<Button onClick={closePicker} size="sm" variant="ghost">
<X className="h-4 w-4" />
</Button>
</div>
{/* Breadcrumbs */}
<div className="flex items-center space-x-2 mb-4 text-sm">
{breadcrumbs.map((crumb, index) => (
<div key={crumb.id} className="flex items-center">
{index > 0 && <span className="mx-2 text-gray-400">/</span>}
<button
onClick={() => navigateToBreadcrumb(index)}
className="text-blue-600 hover:underline"
>
{crumb.name}
</button>
</div>
))}
</div>
{/* File List */}
<div className="flex-1 overflow-y-auto border rounded-md">
{isLoading ? (
<div className="p-4 text-center text-muted-foreground">Loading...</div>
) : files.length === 0 ? (
<div className="p-4 text-center text-muted-foreground">No files found</div>
) : (
<div className="divide-y">
{files.map((file) => (
<div
key={file.id}
className="flex items-center p-3 hover:bg-gray-50 cursor-pointer"
onClick={() => handleFileClick(file)}
>
<div className="flex items-center gap-3 flex-1">
{getFileIcon(file)}
<span className="font-medium">{file.name}</span>
<Badge variant="secondary" className="text-xs">
{getMimeTypeLabel(file)}
</Badge>
</div>
{selectedFiles.some(f => f.id === file.id) && (
<Badge variant="default" className="text-xs">Selected</Badge>
)}
</div>
))}
</div>
)}
</div>
</div>
</div>
)}
{selectedFiles.length > 0 && (
<div className="space-y-2">
<p className="text-xs text-muted-foreground">
Selected files ({selectedFiles.length}):
</p>
<div className="max-h-48 overflow-y-auto space-y-1">
{selectedFiles.map((file) => (
<div
key={file.id}
className="flex items-center justify-between p-2 bg-muted/30 rounded-md text-xs"
>
<div className="flex items-center gap-2 flex-1 min-w-0">
{getFileIcon(file)}
<span className="truncate font-medium">{file.name}</span>
<Badge variant="secondary" className="text-xs px-1 py-0.5 h-auto">
{getMimeTypeLabel(file)}
</Badge>
</div>
<Button
onClick={() => removeFile(file.id)}
size="sm"
variant="ghost"
className="h-6 w-6 p-0"
>
<Trash2 className="h-3 w-3" />
</Button>
</div>
))}
</div>
<Button
onClick={() => onFileSelected([])}
size="sm"
variant="ghost"
className="text-xs h-6"
>
Clear all
</Button>
</div>
)}
</div>
)
}

View file

@ -127,6 +127,23 @@ async def connector_status(request: Request, connector_service, session_manager)
user_id=user.user_id, connector_type=connector_type user_id=user.user_id, connector_type=connector_type
) )
# Get the connector for each connection
connection_client_ids = {}
for connection in connections:
try:
connector = await connector_service._get_connector(connection.connection_id)
if connector is not None:
connection_client_ids[connection.connection_id] = connector.get_client_id()
else:
connection_client_ids[connection.connection_id] = None
except Exception as e:
logger.warning(
"Could not get connector for connection",
connection_id=connection.connection_id,
error=str(e),
)
connection.connector = None
# Check if there are any active connections # Check if there are any active connections
active_connections = [conn for conn in connections if conn.is_active] active_connections = [conn for conn in connections if conn.is_active]
has_authenticated_connection = len(active_connections) > 0 has_authenticated_connection = len(active_connections) > 0
@ -140,6 +157,7 @@ async def connector_status(request: Request, connector_service, session_manager)
{ {
"connection_id": conn.connection_id, "connection_id": conn.connection_id,
"name": conn.name, "name": conn.name,
"client_id": connection_client_ids.get(conn.connection_id),
"is_active": conn.is_active, "is_active": conn.is_active,
"created_at": conn.created_at.isoformat(), "created_at": conn.created_at.isoformat(),
"last_sync": conn.last_sync.isoformat() if conn.last_sync else None, "last_sync": conn.last_sync.isoformat() if conn.last_sync else None,
@ -323,8 +341,8 @@ async def connector_webhook(request: Request, connector_service, session_manager
) )
async def connector_token(request: Request, connector_service, session_manager): async def connector_token(request: Request, connector_service, session_manager):
"""Get access token for connector API calls (e.g., Google Picker)""" """Get access token for connector API calls (e.g., Pickers)."""
connector_type = request.path_params.get("connector_type") url_connector_type = request.path_params.get("connector_type")
connection_id = request.query_params.get("connection_id") connection_id = request.query_params.get("connection_id")
if not connection_id: if not connection_id:
@ -333,37 +351,81 @@ async def connector_token(request: Request, connector_service, session_manager):
user = request.state.user user = request.state.user
try: try:
# Get the connection and verify it belongs to the user # 1) Load the connection and verify ownership
connection = await connector_service.connection_manager.get_connection(connection_id) connection = await connector_service.connection_manager.get_connection(connection_id)
if not connection or connection.user_id != user.user_id: if not connection or connection.user_id != user.user_id:
return JSONResponse({"error": "Connection not found"}, status_code=404) return JSONResponse({"error": "Connection not found"}, status_code=404)
# Get the connector instance # 2) Get the ACTUAL connector instance/type for this connection_id
connector = await connector_service._get_connector(connection_id) connector = await connector_service._get_connector(connection_id)
if not connector: if not connector:
return JSONResponse({"error": f"Connector not available - authentication may have failed for {connector_type}"}, status_code=404) return JSONResponse(
{"error": f"Connector not available - authentication may have failed for {url_connector_type}"},
status_code=404,
)
# For Google Drive, get the access token real_type = getattr(connector, "type", None) or getattr(connection, "connector_type", None)
if connector_type == "google_drive" and hasattr(connector, 'oauth'): if real_type is None:
return JSONResponse({"error": "Unable to determine connector type"}, status_code=500)
# Optional: warn if URL path type disagrees with real type
if url_connector_type and url_connector_type != real_type:
# You can downgrade this to debug if you expect cross-routing.
return JSONResponse(
{
"error": "Connector type mismatch",
"detail": {
"requested_type": url_connector_type,
"actual_type": real_type,
"hint": "Call the token endpoint using the correct connector_type for this connection_id.",
},
},
status_code=400,
)
# 3) Branch by the actual connector type
# GOOGLE DRIVE (google-auth)
if real_type == "google_drive" and hasattr(connector, "oauth"):
await connector.oauth.load_credentials() await connector.oauth.load_credentials()
if connector.oauth.creds and connector.oauth.creds.valid: if connector.oauth.creds and connector.oauth.creds.valid:
return JSONResponse({ expires_in = None
try:
if connector.oauth.creds.expiry:
import time
expires_in = max(0, int(connector.oauth.creds.expiry.timestamp() - time.time()))
except Exception:
expires_in = None
return JSONResponse(
{
"access_token": connector.oauth.creds.token, "access_token": connector.oauth.creds.token,
"expires_in": (connector.oauth.creds.expiry.timestamp() - "expires_in": expires_in,
__import__('time').time()) if connector.oauth.creds.expiry else None }
}) )
else:
return JSONResponse({"error": "Invalid or expired credentials"}, status_code=401) return JSONResponse({"error": "Invalid or expired credentials"}, status_code=401)
# For OneDrive and SharePoint, get the access token # ONEDRIVE / SHAREPOINT (MSAL or custom)
elif connector_type in ["onedrive", "sharepoint"] and hasattr(connector, 'oauth'): if real_type in ("onedrive", "sharepoint") and hasattr(connector, "oauth"):
# Ensure cache/credentials are loaded before trying to use them
try: try:
# Prefer a dedicated is_authenticated() that loads cache internally
if hasattr(connector.oauth, "is_authenticated"):
ok = await connector.oauth.is_authenticated()
else:
# Fallback: try to load credentials explicitly if available
ok = True
if hasattr(connector.oauth, "load_credentials"):
ok = await connector.oauth.load_credentials()
if not ok:
return JSONResponse({"error": "Not authenticated"}, status_code=401)
# Now safe to fetch access token
access_token = connector.oauth.get_access_token() access_token = connector.oauth.get_access_token()
return JSONResponse({ # MSAL result has expiry, but were returning a raw token; keep expires_in None for simplicity
"access_token": access_token, return JSONResponse({"access_token": access_token, "expires_in": None})
"expires_in": None # MSAL handles token expiry internally
})
except ValueError as e: except ValueError as e:
# Typical when acquire_token_silent fails (e.g., needs re-auth)
return JSONResponse({"error": f"Failed to get access token: {str(e)}"}, status_code=401) return JSONResponse({"error": f"Failed to get access token: {str(e)}"}, status_code=401)
except Exception as e: except Exception as e:
return JSONResponse({"error": f"Authentication error: {str(e)}"}, status_code=500) return JSONResponse({"error": f"Authentication error: {str(e)}"}, status_code=500)
@ -371,7 +433,5 @@ async def connector_token(request: Request, connector_service, session_manager):
return JSONResponse({"error": "Token not available for this connector type"}, status_code=400) return JSONResponse({"error": "Token not available for this connector type"}, status_code=400)
except Exception as e: except Exception as e:
logger.error("Error getting connector token", error=str(e)) logger.error("Error getting connector token", exc_info=True)
return JSONResponse({"error": str(e)}, status_code=500) return JSONResponse({"error": str(e)}, status_code=500)

View file

@ -556,6 +556,19 @@ async def onboarding(request, flows_service):
) )
# Continue even if setting global variables fails # Continue even if setting global variables fails
# Initialize the OpenSearch index now that we have the embedding model configured
try:
# Import here to avoid circular imports
from main import init_index
logger.info("Initializing OpenSearch index after onboarding configuration")
await init_index()
logger.info("OpenSearch index initialization completed successfully")
except Exception as e:
logger.error("Failed to initialize OpenSearch index after onboarding", error=str(e))
# Don't fail the entire onboarding process if index creation fails
# The application can still work, but document operations may fail
# Handle sample data ingestion if requested # Handle sample data ingestion if requested
if should_ingest_sample_data: if should_ingest_sample_data:
try: try:

View file

@ -16,6 +16,8 @@ class ProviderConfig:
model_provider: str = "openai" # openai, anthropic, etc. model_provider: str = "openai" # openai, anthropic, etc.
api_key: str = "" api_key: str = ""
endpoint: str = "" # For providers like Watson/IBM that need custom endpoints
project_id: str = "" # For providers like Watson/IBM that need project IDs
@dataclass @dataclass
@ -129,6 +131,10 @@ class ConfigManager:
config_data["provider"]["model_provider"] = os.getenv("MODEL_PROVIDER") config_data["provider"]["model_provider"] = os.getenv("MODEL_PROVIDER")
if os.getenv("PROVIDER_API_KEY"): if os.getenv("PROVIDER_API_KEY"):
config_data["provider"]["api_key"] = os.getenv("PROVIDER_API_KEY") config_data["provider"]["api_key"] = os.getenv("PROVIDER_API_KEY")
if os.getenv("PROVIDER_ENDPOINT"):
config_data["provider"]["endpoint"] = os.getenv("PROVIDER_ENDPOINT")
if os.getenv("PROVIDER_PROJECT_ID"):
config_data["provider"]["project_id"] = os.getenv("PROVIDER_PROJECT_ID")
# Backward compatibility for OpenAI # Backward compatibility for OpenAI
if os.getenv("OPENAI_API_KEY"): if os.getenv("OPENAI_API_KEY"):
config_data["provider"]["api_key"] = os.getenv("OPENAI_API_KEY") config_data["provider"]["api_key"] = os.getenv("OPENAI_API_KEY")

View file

@ -78,6 +78,31 @@ INDEX_NAME = "documents"
VECTOR_DIM = 1536 VECTOR_DIM = 1536
EMBED_MODEL = "text-embedding-3-small" EMBED_MODEL = "text-embedding-3-small"
OPENAI_EMBEDDING_DIMENSIONS = {
"text-embedding-3-small": 1536,
"text-embedding-3-large": 3072,
"text-embedding-ada-002": 1536,
}
OLLAMA_EMBEDDING_DIMENSIONS = {
"nomic-embed-text": 768,
"all-minilm": 384,
"mxbai-embed-large": 1024,
}
WATSONX_EMBEDDING_DIMENSIONS = {
# IBM Models
"ibm/granite-embedding-107m-multilingual": 384,
"ibm/granite-embedding-278m-multilingual": 1024,
"ibm/slate-125m-english-rtrvr": 768,
"ibm/slate-125m-english-rtrvr-v2": 768,
"ibm/slate-30m-english-rtrvr": 384,
"ibm/slate-30m-english-rtrvr-v2": 384,
# Third Party Models
"intfloat/multilingual-e5-large": 1024,
"sentence-transformers/all-minilm-l6-v2": 384,
}
INDEX_BODY = { INDEX_BODY = {
"settings": { "settings": {
"index": {"knn": True}, "index": {"knn": True},

View file

@ -294,31 +294,38 @@ class ConnectionManager:
async def get_connector(self, connection_id: str) -> Optional[BaseConnector]: async def get_connector(self, connection_id: str) -> Optional[BaseConnector]:
"""Get an active connector instance""" """Get an active connector instance"""
logger.debug(f"Getting connector for connection_id: {connection_id}")
# Return cached connector if available # Return cached connector if available
if connection_id in self.active_connectors: if connection_id in self.active_connectors:
connector = self.active_connectors[connection_id] connector = self.active_connectors[connection_id]
if connector.is_authenticated: if connector.is_authenticated:
logger.debug(f"Returning cached authenticated connector for {connection_id}")
return connector return connector
else: else:
# Remove unauthenticated connector from cache # Remove unauthenticated connector from cache
logger.debug(f"Removing unauthenticated connector from cache for {connection_id}")
del self.active_connectors[connection_id] del self.active_connectors[connection_id]
# Try to create and authenticate connector # Try to create and authenticate connector
connection_config = self.connections.get(connection_id) connection_config = self.connections.get(connection_id)
if not connection_config or not connection_config.is_active: if not connection_config or not connection_config.is_active:
logger.debug(f"No active connection config found for {connection_id}")
return None return None
logger.debug(f"Creating connector for {connection_config.connector_type}")
connector = self._create_connector(connection_config) connector = self._create_connector(connection_config)
if await connector.authenticate():
logger.debug(f"Attempting authentication for {connection_id}")
auth_result = await connector.authenticate()
logger.debug(f"Authentication result for {connection_id}: {auth_result}")
if auth_result:
self.active_connectors[connection_id] = connector self.active_connectors[connection_id] = connector
# ... rest of the method
# Setup webhook subscription if not already set up
await self._setup_webhook_if_needed(
connection_id, connection_config, connector
)
return connector return connector
else:
logger.warning(f"Authentication failed for {connection_id}")
return None return None
def get_available_connector_types(self) -> Dict[str, Dict[str, Any]]: def get_available_connector_types(self) -> Dict[str, Dict[str, Any]]:
@ -363,6 +370,7 @@ class ConnectionManager:
def _create_connector(self, config: ConnectionConfig) -> BaseConnector: def _create_connector(self, config: ConnectionConfig) -> BaseConnector:
"""Factory method to create connector instances""" """Factory method to create connector instances"""
try:
if config.connector_type == "google_drive": if config.connector_type == "google_drive":
return GoogleDriveConnector(config.config) return GoogleDriveConnector(config.config)
elif config.connector_type == "sharepoint": elif config.connector_type == "sharepoint":
@ -370,13 +378,15 @@ class ConnectionManager:
elif config.connector_type == "onedrive": elif config.connector_type == "onedrive":
return OneDriveConnector(config.config) return OneDriveConnector(config.config)
elif config.connector_type == "box": elif config.connector_type == "box":
# Future: BoxConnector(config.config)
raise NotImplementedError("Box connector not implemented yet") raise NotImplementedError("Box connector not implemented yet")
elif config.connector_type == "dropbox": elif config.connector_type == "dropbox":
# Future: DropboxConnector(config.config)
raise NotImplementedError("Dropbox connector not implemented yet") raise NotImplementedError("Dropbox connector not implemented yet")
else: else:
raise ValueError(f"Unknown connector type: {config.connector_type}") raise ValueError(f"Unknown connector type: {config.connector_type}")
except Exception as e:
logger.error(f"Failed to create {config.connector_type} connector: {e}")
# Re-raise the exception so caller can handle appropriately
raise
async def update_last_sync(self, connection_id: str): async def update_last_sync(self, connection_id: str):
"""Update the last sync timestamp for a connection""" """Update the last sync timestamp for a connection"""

View file

@ -477,7 +477,7 @@ class GoogleDriveConnector(BaseConnector):
"next_page_token": None, # no more pages "next_page_token": None, # no more pages
} }
except Exception as e: except Exception as e:
# Optionally log error with your base class logger # Log the error
try: try:
logger.error(f"GoogleDriveConnector.list_files failed: {e}") logger.error(f"GoogleDriveConnector.list_files failed: {e}")
except Exception: except Exception:
@ -495,7 +495,6 @@ class GoogleDriveConnector(BaseConnector):
try: try:
blob = self._download_file_bytes(meta) blob = self._download_file_bytes(meta)
except Exception as e: except Exception as e:
# Use your base class logger if available
try: try:
logger.error(f"Download failed for {file_id}: {e}") logger.error(f"Download failed for {file_id}: {e}")
except Exception: except Exception:
@ -562,7 +561,6 @@ class GoogleDriveConnector(BaseConnector):
if not self.cfg.changes_page_token: if not self.cfg.changes_page_token:
self.cfg.changes_page_token = self.get_start_page_token() self.cfg.changes_page_token = self.get_start_page_token()
except Exception as e: except Exception as e:
# Optional: use your base logger
try: try:
logger.error(f"Failed to get start page token: {e}") logger.error(f"Failed to get start page token: {e}")
except Exception: except Exception:
@ -593,7 +591,6 @@ class GoogleDriveConnector(BaseConnector):
expiration = result.get("expiration") expiration = result.get("expiration")
# Persist in-memory so cleanup can stop this channel later. # Persist in-memory so cleanup can stop this channel later.
# If your project has a persistence layer, save these values there.
self._active_channel = { self._active_channel = {
"channel_id": channel_id, "channel_id": channel_id,
"resource_id": resource_id, "resource_id": resource_id,
@ -803,7 +800,7 @@ class GoogleDriveConnector(BaseConnector):
""" """
Perform a one-shot sync of the currently selected scope and emit documents. Perform a one-shot sync of the currently selected scope and emit documents.
Emits ConnectorDocument instances (adapt to your BaseConnector ingestion). Emits ConnectorDocument instances
""" """
items = self._iter_selected_items() items = self._iter_selected_items()
for meta in items: for meta in items:

View file

@ -1,223 +1,494 @@
import logging
from pathlib import Path
from typing import List, Dict, Any, Optional
from datetime import datetime
import httpx import httpx
import uuid
from datetime import datetime, timedelta
from typing import Dict, List, Any, Optional
from ..base import BaseConnector, ConnectorDocument, DocumentACL from ..base import BaseConnector, ConnectorDocument, DocumentACL
from .oauth import OneDriveOAuth from .oauth import OneDriveOAuth
logger = logging.getLogger(__name__)
class OneDriveConnector(BaseConnector): class OneDriveConnector(BaseConnector):
"""OneDrive connector using Microsoft Graph API""" """OneDrive connector using MSAL-based OAuth for authentication."""
# Required BaseConnector class attributes
CLIENT_ID_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_ID" CLIENT_ID_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_ID"
CLIENT_SECRET_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET" CLIENT_SECRET_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET"
# Connector metadata # Connector metadata
CONNECTOR_NAME = "OneDrive" CONNECTOR_NAME = "OneDrive"
CONNECTOR_DESCRIPTION = "Connect your personal OneDrive to sync documents" CONNECTOR_DESCRIPTION = "Connect to OneDrive (personal) to sync documents and files"
CONNECTOR_ICON = "onedrive" CONNECTOR_ICON = "onedrive"
def __init__(self, config: Dict[str, Any]): def __init__(self, config: Dict[str, Any]):
super().__init__(config) super().__init__(config)
logger.debug(f"OneDrive connector __init__ called with config type: {type(config)}")
logger.debug(f"OneDrive connector __init__ config value: {config}")
if config is None:
logger.debug("Config was None, using empty dict")
config = {}
try:
logger.debug("Calling super().__init__")
super().__init__(config)
logger.debug("super().__init__ completed successfully")
except Exception as e:
logger.error(f"super().__init__ failed: {e}")
raise
# Initialize with defaults that allow the connector to be listed
self.client_id = None
self.client_secret = None
self.redirect_uri = config.get("redirect_uri", "http://localhost")
# Try to get credentials, but don't fail if they're missing
try:
self.client_id = self.get_client_id()
logger.debug(f"Got client_id: {self.client_id is not None}")
except Exception as e:
logger.debug(f"Failed to get client_id: {e}")
try:
self.client_secret = self.get_client_secret()
logger.debug(f"Got client_secret: {self.client_secret is not None}")
except Exception as e:
logger.debug(f"Failed to get client_secret: {e}")
# Token file setup
project_root = Path(__file__).resolve().parent.parent.parent.parent
token_file = config.get("token_file") or str(project_root / "onedrive_token.json")
Path(token_file).parent.mkdir(parents=True, exist_ok=True)
# Only initialize OAuth if we have credentials
if self.client_id and self.client_secret:
connection_id = config.get("connection_id", "default")
# Use token_file from config if provided, otherwise generate one
if config.get("token_file"):
oauth_token_file = config["token_file"]
else:
# Use a per-connection cache file to avoid collisions with other connectors
oauth_token_file = f"onedrive_token_{connection_id}.json"
# MSA & org both work via /common for OneDrive personal testing
authority = "https://login.microsoftonline.com/common"
self.oauth = OneDriveOAuth( self.oauth = OneDriveOAuth(
client_id=self.get_client_id(), client_id=self.client_id,
client_secret=self.get_client_secret(), client_secret=self.client_secret,
token_file=config.get("token_file", "onedrive_token.json"), token_file=oauth_token_file,
authority=authority,
allow_json_refresh=True, # allows one-time migration from legacy JSON if present
) )
self.subscription_id = config.get("subscription_id") or config.get( else:
"webhook_channel_id" self.oauth = None
)
self.base_url = "https://graph.microsoft.com/v1.0"
async def authenticate(self) -> bool: # Track subscription ID for webhooks (note: change notifications might not be available for personal accounts)
if await self.oauth.is_authenticated(): self._subscription_id: Optional[str] = None
self._authenticated = True
return True
return False
async def setup_subscription(self) -> str: # Graph API defaults
if not self._authenticated: self._graph_api_version = "v1.0"
raise ValueError("Not authenticated") self._default_params = {
"$select": "id,name,size,lastModifiedDateTime,createdDateTime,webUrl,file,folder,@microsoft.graph.downloadUrl"
webhook_url = self.config.get("webhook_url")
if not webhook_url:
raise ValueError("webhook_url required in config for subscriptions")
expiration = (datetime.utcnow() + timedelta(days=2)).isoformat() + "Z"
body = {
"changeType": "created,updated,deleted",
"notificationUrl": webhook_url,
"resource": "/me/drive/root",
"expirationDateTime": expiration,
"clientState": str(uuid.uuid4()),
} }
token = self.oauth.get_access_token() @property
async with httpx.AsyncClient() as client: def _graph_base_url(self) -> str:
resp = await client.post( """Base URL for Microsoft Graph API calls."""
f"{self.base_url}/subscriptions", return f"https://graph.microsoft.com/{self._graph_api_version}"
json=body,
headers={"Authorization": f"Bearer {token}"},
)
resp.raise_for_status()
data = resp.json()
self.subscription_id = data["id"] def emit(self, doc: ConnectorDocument) -> None:
return self.subscription_id """Emit a ConnectorDocument instance."""
logger.debug(f"Emitting OneDrive document: {doc.id} ({doc.filename})")
async def authenticate(self) -> bool:
"""Test authentication - BaseConnector interface."""
logger.debug(f"OneDrive authenticate() called, oauth is None: {self.oauth is None}")
try:
if not self.oauth:
logger.debug("OneDrive authentication failed: OAuth not initialized")
self._authenticated = False
return False
logger.debug("Loading OneDrive credentials...")
load_result = await self.oauth.load_credentials()
logger.debug(f"Load credentials result: {load_result}")
logger.debug("Checking OneDrive authentication status...")
authenticated = await self.oauth.is_authenticated()
logger.debug(f"OneDrive is_authenticated result: {authenticated}")
self._authenticated = authenticated
return authenticated
except Exception as e:
logger.error(f"OneDrive authentication failed: {e}")
import traceback
traceback.print_exc()
self._authenticated = False
return False
def get_auth_url(self) -> str:
"""Get OAuth authorization URL."""
if not self.oauth:
raise RuntimeError("OneDrive OAuth not initialized - missing credentials")
return self.oauth.create_authorization_url(self.redirect_uri)
async def handle_oauth_callback(self, auth_code: str) -> Dict[str, Any]:
"""Handle OAuth callback."""
if not self.oauth:
raise RuntimeError("OneDrive OAuth not initialized - missing credentials")
try:
success = await self.oauth.handle_authorization_callback(auth_code, self.redirect_uri)
if success:
self._authenticated = True
return {"status": "success"}
else:
raise ValueError("OAuth callback failed")
except Exception as e:
logger.error(f"OAuth callback failed: {e}")
raise
def sync_once(self) -> None:
"""
Perform a one-shot sync of OneDrive files and emit documents.
"""
import asyncio
async def _async_sync():
try:
file_list = await self.list_files(max_files=1000)
files = file_list.get("files", [])
for file_info in files:
try:
file_id = file_info.get("id")
if not file_id:
continue
doc = await self.get_file_content(file_id)
self.emit(doc)
except Exception as e:
logger.error(f"Failed to sync OneDrive file {file_info.get('name', 'unknown')}: {e}")
continue
except Exception as e:
logger.error(f"OneDrive sync_once failed: {e}")
raise
if hasattr(asyncio, 'run'):
asyncio.run(_async_sync())
else:
loop = asyncio.get_event_loop()
loop.run_until_complete(_async_sync())
async def setup_subscription(self) -> str:
"""
Set up real-time subscription for file changes.
NOTE: Change notifications may not be available for personal OneDrive accounts.
"""
webhook_url = self.config.get('webhook_url')
if not webhook_url:
logger.warning("No webhook URL configured, skipping OneDrive subscription setup")
return "no-webhook-configured"
try:
if not await self.authenticate():
raise RuntimeError("OneDrive authentication failed during subscription setup")
token = self.oauth.get_access_token()
# For OneDrive personal we target the user's drive
resource = "/me/drive/root"
subscription_data = {
"changeType": "created,updated,deleted",
"notificationUrl": f"{webhook_url}/webhook/onedrive",
"resource": resource,
"expirationDateTime": self._get_subscription_expiry(),
"clientState": "onedrive_personal",
}
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
url = f"{self._graph_base_url}/subscriptions"
async with httpx.AsyncClient() as client:
response = await client.post(url, json=subscription_data, headers=headers, timeout=30)
response.raise_for_status()
result = response.json()
subscription_id = result.get("id")
if subscription_id:
self._subscription_id = subscription_id
logger.info(f"OneDrive subscription created: {subscription_id}")
return subscription_id
else:
raise ValueError("No subscription ID returned from Microsoft Graph")
except Exception as e:
logger.error(f"Failed to setup OneDrive subscription: {e}")
raise
def _get_subscription_expiry(self) -> str:
"""Get subscription expiry time (Graph caps duration; often <= 3 days)."""
from datetime import datetime, timedelta
expiry = datetime.utcnow() + timedelta(days=3)
return expiry.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
async def list_files( async def list_files(
self, page_token: Optional[str] = None, limit: int = 100 self,
page_token: Optional[str] = None,
max_files: Optional[int] = None,
**kwargs
) -> Dict[str, Any]: ) -> Dict[str, Any]:
if not self._authenticated: """List files from OneDrive using Microsoft Graph."""
raise ValueError("Not authenticated") try:
if not await self.authenticate():
raise RuntimeError("OneDrive authentication failed during file listing")
files: List[Dict[str, Any]] = []
max_files_value = max_files if max_files is not None else 100
base_url = f"{self._graph_base_url}/me/drive/root/children"
params = dict(self._default_params)
params["$top"] = str(max_files_value)
params = {"$top": str(limit)}
if page_token: if page_token:
params["$skiptoken"] = page_token params["$skiptoken"] = page_token
token = self.oauth.get_access_token() response = await self._make_graph_request(base_url, params=params)
async with httpx.AsyncClient() as client: data = response.json()
resp = await client.get(
f"{self.base_url}/me/drive/root/children",
params=params,
headers={"Authorization": f"Bearer {token}"},
)
resp.raise_for_status()
data = resp.json()
files = [] items = data.get("value", [])
for item in data.get("value", []): for item in items:
if item.get("file"): if item.get("file"): # include files only
files.append( files.append({
{ "id": item.get("id", ""),
"id": item["id"], "name": item.get("name", ""),
"name": item["name"], "path": f"/drive/items/{item.get('id')}",
"mimeType": item.get("file", {}).get( "size": int(item.get("size", 0)),
"mimeType", "application/octet-stream" "modified": item.get("lastModifiedDateTime"),
), "created": item.get("createdDateTime"),
"webViewLink": item.get("webUrl"), "mime_type": item.get("file", {}).get("mimeType", self._get_mime_type(item.get("name", ""))),
"createdTime": item.get("createdDateTime"), "url": item.get("webUrl", ""),
"modifiedTime": item.get("lastModifiedDateTime"), "download_url": item.get("@microsoft.graph.downloadUrl"),
} })
)
next_token = None # Next page
next_page_token = None
next_link = data.get("@odata.nextLink") next_link = data.get("@odata.nextLink")
if next_link: if next_link:
from urllib.parse import urlparse, parse_qs from urllib.parse import urlparse, parse_qs
parsed = urlparse(next_link) parsed = urlparse(next_link)
next_token = parse_qs(parsed.query).get("$skiptoken", [None])[0] query_params = parse_qs(parsed.query)
if "$skiptoken" in query_params:
next_page_token = query_params["$skiptoken"][0]
return {"files": files, "nextPageToken": next_token} return {"files": files, "next_page_token": next_page_token}
except Exception as e:
logger.error(f"Failed to list OneDrive files: {e}")
return {"files": [], "next_page_token": None}
async def get_file_content(self, file_id: str) -> ConnectorDocument: async def get_file_content(self, file_id: str) -> ConnectorDocument:
if not self._authenticated: """Get file content and metadata."""
raise ValueError("Not authenticated") try:
if not await self.authenticate():
raise RuntimeError("OneDrive authentication failed during file content retrieval")
file_metadata = await self._get_file_metadata_by_id(file_id)
if not file_metadata:
raise ValueError(f"File not found: {file_id}")
download_url = file_metadata.get("download_url")
if download_url:
content = await self._download_file_from_url(download_url)
else:
content = await self._download_file_content(file_id)
acl = DocumentACL(
owner="",
user_permissions={},
group_permissions={},
)
modified_time = self._parse_graph_date(file_metadata.get("modified"))
created_time = self._parse_graph_date(file_metadata.get("created"))
return ConnectorDocument(
id=file_id,
filename=file_metadata.get("name", ""),
mimetype=file_metadata.get("mime_type", "application/octet-stream"),
content=content,
source_url=file_metadata.get("url", ""),
acl=acl,
modified_time=modified_time,
created_time=created_time,
metadata={
"onedrive_path": file_metadata.get("path", ""),
"size": file_metadata.get("size", 0),
},
)
except Exception as e:
logger.error(f"Failed to get OneDrive file content {file_id}: {e}")
raise
async def _get_file_metadata_by_id(self, file_id: str) -> Optional[Dict[str, Any]]:
"""Get file metadata by ID using Graph API."""
try:
url = f"{self._graph_base_url}/me/drive/items/{file_id}"
params = dict(self._default_params)
response = await self._make_graph_request(url, params=params)
item = response.json()
if item.get("file"):
return {
"id": file_id,
"name": item.get("name", ""),
"path": f"/drive/items/{file_id}",
"size": int(item.get("size", 0)),
"modified": item.get("lastModifiedDateTime"),
"created": item.get("createdDateTime"),
"mime_type": item.get("file", {}).get("mimeType", self._get_mime_type(item.get("name", ""))),
"url": item.get("webUrl", ""),
"download_url": item.get("@microsoft.graph.downloadUrl"),
}
return None
except Exception as e:
logger.error(f"Failed to get file metadata for {file_id}: {e}")
return None
async def _download_file_content(self, file_id: str) -> bytes:
"""Download file content by file ID using Graph API."""
try:
url = f"{self._graph_base_url}/me/drive/items/{file_id}/content"
token = self.oauth.get_access_token()
headers = {"Authorization": f"Bearer {token}"}
async with httpx.AsyncClient() as client:
response = await client.get(url, headers=headers, timeout=60)
response.raise_for_status()
return response.content
except Exception as e:
logger.error(f"Failed to download file content for {file_id}: {e}")
raise
async def _download_file_from_url(self, download_url: str) -> bytes:
"""Download file content from direct download URL."""
try:
async with httpx.AsyncClient() as client:
response = await client.get(download_url, timeout=60)
response.raise_for_status()
return response.content
except Exception as e:
logger.error(f"Failed to download from URL {download_url}: {e}")
raise
def _parse_graph_date(self, date_str: Optional[str]) -> datetime:
"""Parse Microsoft Graph date string to datetime."""
if not date_str:
return datetime.now()
try:
if date_str.endswith('Z'):
return datetime.fromisoformat(date_str[:-1]).replace(tzinfo=None)
else:
return datetime.fromisoformat(date_str.replace('T', ' '))
except (ValueError, AttributeError):
return datetime.now()
async def _make_graph_request(self, url: str, method: str = "GET",
data: Optional[Dict] = None, params: Optional[Dict] = None) -> httpx.Response:
"""Make authenticated API request to Microsoft Graph."""
token = self.oauth.get_access_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
async with httpx.AsyncClient() as client:
if method.upper() == "GET":
response = await client.get(url, headers=headers, params=params, timeout=30)
elif method.upper() == "POST":
response = await client.post(url, headers=headers, json=data, timeout=30)
elif method.upper() == "DELETE":
response = await client.delete(url, headers=headers, timeout=30)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
response.raise_for_status()
return response
def _get_mime_type(self, filename: str) -> str:
"""Get MIME type based on file extension."""
import mimetypes
mime_type, _ = mimetypes.guess_type(filename)
return mime_type or "application/octet-stream"
# Webhook methods - BaseConnector interface
def handle_webhook_validation(self, request_method: str,
headers: Dict[str, str],
query_params: Dict[str, str]) -> Optional[str]:
"""Handle webhook validation (Graph API specific)."""
if request_method == "POST" and "validationToken" in query_params:
return query_params["validationToken"]
return None
def extract_webhook_channel_id(self, payload: Dict[str, Any],
headers: Dict[str, str]) -> Optional[str]:
"""Extract channel/subscription ID from webhook payload."""
notifications = payload.get("value", [])
if notifications:
return notifications[0].get("subscriptionId")
return None
async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]:
"""Handle webhook notification and return affected file IDs."""
affected_files: List[str] = []
notifications = payload.get("value", [])
for notification in notifications:
resource = notification.get("resource")
if resource and "/drive/items/" in resource:
file_id = resource.split("/drive/items/")[-1]
affected_files.append(file_id)
return affected_files
async def cleanup_subscription(self, subscription_id: str) -> bool:
"""Clean up subscription - BaseConnector interface."""
if subscription_id == "no-webhook-configured":
logger.info("No subscription to cleanup (webhook was not configured)")
return True
try:
if not await self.authenticate():
logger.error("OneDrive authentication failed during subscription cleanup")
return False
token = self.oauth.get_access_token() token = self.oauth.get_access_token()
headers = {"Authorization": f"Bearer {token}"} headers = {"Authorization": f"Bearer {token}"}
url = f"{self._graph_base_url}/subscriptions/{subscription_id}"
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
meta_resp = await client.get( response = await client.delete(url, headers=headers, timeout=30)
f"{self.base_url}/me/drive/items/{file_id}", headers=headers
)
meta_resp.raise_for_status()
metadata = meta_resp.json()
content_resp = await client.get( if response.status_code in [200, 204, 404]:
f"{self.base_url}/me/drive/items/{file_id}/content", headers=headers logger.info(f"OneDrive subscription {subscription_id} cleaned up successfully")
) return True
content_resp.raise_for_status() else:
content = content_resp.content logger.warning(f"Unexpected response cleaning up subscription: {response.status_code}")
return False
perm_resp = await client.get(
f"{self.base_url}/me/drive/items/{file_id}/permissions", headers=headers except Exception as e:
) logger.error(f"Failed to cleanup OneDrive subscription {subscription_id}: {e}")
perm_resp.raise_for_status()
permissions = perm_resp.json()
acl = self._parse_permissions(metadata, permissions)
modified = datetime.fromisoformat(
metadata["lastModifiedDateTime"].replace("Z", "+00:00")
).replace(tzinfo=None)
created = datetime.fromisoformat(
metadata["createdDateTime"].replace("Z", "+00:00")
).replace(tzinfo=None)
document = ConnectorDocument(
id=metadata["id"],
filename=metadata["name"],
mimetype=metadata.get("file", {}).get(
"mimeType", "application/octet-stream"
),
content=content,
source_url=metadata.get("webUrl"),
acl=acl,
modified_time=modified,
created_time=created,
metadata={"size": metadata.get("size")},
)
return document
def _parse_permissions(
self, metadata: Dict[str, Any], permissions: Dict[str, Any]
) -> DocumentACL:
acl = DocumentACL()
owner = metadata.get("createdBy", {}).get("user", {}).get("email")
if owner:
acl.owner = owner
for perm in permissions.get("value", []):
role = perm.get("roles", ["read"])[0]
grantee = perm.get("grantedToV2") or perm.get("grantedTo")
if not grantee:
continue
user = grantee.get("user")
if user and user.get("email"):
acl.user_permissions[user["email"]] = role
group = grantee.get("group")
if group and group.get("email"):
acl.group_permissions[group["email"]] = role
return acl
def handle_webhook_validation(
self, request_method: str, headers: Dict[str, str], query_params: Dict[str, str]
) -> Optional[str]:
"""Handle Microsoft Graph webhook validation"""
if request_method == "GET":
validation_token = query_params.get("validationtoken") or query_params.get(
"validationToken"
)
if validation_token:
return validation_token
return None
def extract_webhook_channel_id(
self, payload: Dict[str, Any], headers: Dict[str, str]
) -> Optional[str]:
"""Extract SharePoint subscription ID from webhook payload"""
values = payload.get("value", [])
return values[0].get("subscriptionId") if values else None
async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]:
values = payload.get("value", [])
file_ids = []
for item in values:
resource_data = item.get("resourceData", {})
file_id = resource_data.get("id")
if file_id:
file_ids.append(file_id)
return file_ids
async def cleanup_subscription(
self, subscription_id: str, resource_id: str = None
) -> bool:
if not self._authenticated:
return False return False
token = self.oauth.get_access_token()
async with httpx.AsyncClient() as client:
resp = await client.delete(
f"{self.base_url}/subscriptions/{subscription_id}",
headers={"Authorization": f"Bearer {token}"},
)
return resp.status_code in (200, 204)

View file

@ -1,17 +1,28 @@
import os import os
import json
import logging
from typing import Optional, Dict, Any
import aiofiles import aiofiles
from typing import Optional
import msal import msal
logger = logging.getLogger(__name__)
class OneDriveOAuth: class OneDriveOAuth:
"""Handles Microsoft Graph OAuth authentication flow""" """Handles Microsoft Graph OAuth for OneDrive (personal Microsoft accounts by default)."""
SCOPES = [ # Reserved scopes that must NOT be sent on token or silent calls
"offline_access", RESERVED_SCOPES = {"openid", "profile", "offline_access"}
"Files.Read.All",
]
# For PERSONAL Microsoft Accounts (OneDrive consumer):
# - Use AUTH_SCOPES for interactive auth (consent + refresh token issuance)
# - Use RESOURCE_SCOPES for acquire_token_silent / refresh paths
AUTH_SCOPES = ["User.Read", "Files.Read.All", "offline_access"]
RESOURCE_SCOPES = ["User.Read", "Files.Read.All"]
SCOPES = AUTH_SCOPES # Backward-compat alias if something references .SCOPES
# Kept for reference; MSAL derives endpoints from `authority`
AUTH_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize" AUTH_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
TOKEN_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/token" TOKEN_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/token"
@ -21,18 +32,29 @@ class OneDriveOAuth:
client_secret: str, client_secret: str,
token_file: str = "onedrive_token.json", token_file: str = "onedrive_token.json",
authority: str = "https://login.microsoftonline.com/common", authority: str = "https://login.microsoftonline.com/common",
allow_json_refresh: bool = True,
): ):
"""
Initialize OneDriveOAuth.
Args:
client_id: Azure AD application (client) ID.
client_secret: Azure AD application client secret.
token_file: Path to persisted token cache file (MSAL cache format).
authority: Usually "https://login.microsoftonline.com/common" for MSA + org,
or tenant-specific for work/school.
allow_json_refresh: If True, permit one-time migration from legacy flat JSON
{"access_token","refresh_token",...}. Otherwise refuse it.
"""
self.client_id = client_id self.client_id = client_id
self.client_secret = client_secret self.client_secret = client_secret
self.token_file = token_file self.token_file = token_file
self.authority = authority self.authority = authority
self.allow_json_refresh = allow_json_refresh
self.token_cache = msal.SerializableTokenCache() self.token_cache = msal.SerializableTokenCache()
self._current_account = None
# Load existing cache if available # Initialize MSAL Confidential Client
if os.path.exists(self.token_file):
with open(self.token_file, "r") as f:
self.token_cache.deserialize(f.read())
self.app = msal.ConfidentialClientApplication( self.app = msal.ConfidentialClientApplication(
client_id=self.client_id, client_id=self.client_id,
client_credential=self.client_secret, client_credential=self.client_secret,
@ -40,56 +62,261 @@ class OneDriveOAuth:
token_cache=self.token_cache, token_cache=self.token_cache,
) )
async def save_cache(self): async def load_credentials(self) -> bool:
"""Persist the token cache to file""" """Load existing credentials from token file (async)."""
async with aiofiles.open(self.token_file, "w") as f: try:
await f.write(self.token_cache.serialize()) logger.debug(f"OneDrive OAuth loading credentials from: {self.token_file}")
if os.path.exists(self.token_file):
logger.debug(f"Token file exists, reading: {self.token_file}")
def create_authorization_url(self, redirect_uri: str) -> str: # Read the token file
"""Create authorization URL for OAuth flow""" async with aiofiles.open(self.token_file, "r") as f:
return self.app.get_authorization_request_url( cache_data = await f.read()
self.SCOPES, redirect_uri=redirect_uri logger.debug(f"Read {len(cache_data)} chars from token file")
if cache_data.strip():
# 1) Try legacy flat JSON first
try:
json_data = json.loads(cache_data)
if isinstance(json_data, dict) and "refresh_token" in json_data:
if self.allow_json_refresh:
logger.debug(
"Found legacy JSON refresh_token and allow_json_refresh=True; attempting migration refresh"
) )
return await self._refresh_from_json_token(json_data)
else:
logger.warning(
"Token file contains a legacy JSON refresh_token, but allow_json_refresh=False. "
"Delete the file and re-auth."
)
return False
except json.JSONDecodeError:
logger.debug("Token file is not flat JSON; attempting MSAL cache format")
# 2) Try MSAL cache format
logger.debug("Attempting MSAL cache deserialization")
self.token_cache.deserialize(cache_data)
# Get accounts from loaded cache
accounts = self.app.get_accounts()
logger.debug(f"Found {len(accounts)} accounts in MSAL cache")
if accounts:
self._current_account = accounts[0]
logger.debug(f"Set current account: {self._current_account.get('username', 'no username')}")
# Use RESOURCE_SCOPES (no reserved scopes) for silent acquisition
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
logger.debug(f"Silent token acquisition result keys: {list(result.keys()) if result else 'None'}")
if result and "access_token" in result:
logger.debug("Silent token acquisition successful")
await self.save_cache()
return True
else:
error_msg = (result or {}).get("error") or "No result"
logger.warning(f"Silent token acquisition failed: {error_msg}")
else:
logger.debug(f"Token file {self.token_file} is empty")
else:
logger.debug(f"Token file does not exist: {self.token_file}")
return False
except Exception as e:
logger.error(f"Failed to load OneDrive credentials: {e}")
import traceback
traceback.print_exc()
return False
async def _refresh_from_json_token(self, token_data: dict) -> bool:
"""
Use refresh token from a legacy JSON file to get new tokens (one-time migration path).
Prefer using an MSAL cache file and acquire_token_silent(); this path is only for migrating older files.
"""
try:
refresh_token = token_data.get("refresh_token")
if not refresh_token:
logger.error("No refresh_token found in JSON file - cannot refresh")
logger.error("You must re-authenticate interactively to obtain a valid token")
return False
# Use only RESOURCE_SCOPES when refreshing (no reserved scopes)
refresh_scopes = [s for s in self.RESOURCE_SCOPES if s not in self.RESERVED_SCOPES]
logger.debug(f"Using refresh token; refresh scopes = {refresh_scopes}")
result = self.app.acquire_token_by_refresh_token(
refresh_token=refresh_token,
scopes=refresh_scopes,
)
if result and "access_token" in result:
logger.debug("Successfully refreshed token via legacy JSON path")
await self.save_cache()
accounts = self.app.get_accounts()
logger.debug(f"After refresh, found {len(accounts)} accounts")
if accounts:
self._current_account = accounts[0]
logger.debug(f"Set current account after refresh: {self._current_account.get('username', 'no username')}")
return True
# Error handling
err = (result or {}).get("error_description") or (result or {}).get("error") or "Unknown error"
logger.error(f"Refresh token failed: {err}")
if any(code in err for code in ("AADSTS70000", "invalid_grant", "interaction_required")):
logger.warning(
"Refresh denied due to unauthorized/expired scopes or invalid grant. "
"Delete the token file and perform interactive sign-in with correct scopes."
)
return False
except Exception as e:
logger.error(f"Exception during refresh from JSON token: {e}")
import traceback
traceback.print_exc()
return False
async def save_cache(self):
"""Persist the token cache to file."""
try:
# Ensure parent directory exists
parent = os.path.dirname(os.path.abspath(self.token_file))
if parent and not os.path.exists(parent):
os.makedirs(parent, exist_ok=True)
cache_data = self.token_cache.serialize()
if cache_data:
async with aiofiles.open(self.token_file, "w") as f:
await f.write(cache_data)
logger.debug(f"Token cache saved to {self.token_file}")
except Exception as e:
logger.error(f"Failed to save token cache: {e}")
def create_authorization_url(self, redirect_uri: str, state: Optional[str] = None) -> str:
"""Create authorization URL for OAuth flow."""
# Store redirect URI for later use in callback
self._redirect_uri = redirect_uri
kwargs: Dict[str, Any] = {
# Interactive auth includes offline_access
"scopes": self.AUTH_SCOPES,
"redirect_uri": redirect_uri,
"prompt": "consent", # ensure refresh token on first run
}
if state:
kwargs["state"] = state # Optional CSRF protection
auth_url = self.app.get_authorization_request_url(**kwargs)
logger.debug(f"Generated auth URL: {auth_url}")
logger.debug(f"Auth scopes: {self.AUTH_SCOPES}")
return auth_url
async def handle_authorization_callback( async def handle_authorization_callback(
self, authorization_code: str, redirect_uri: str self, authorization_code: str, redirect_uri: str
) -> bool: ) -> bool:
"""Handle OAuth callback and exchange code for tokens""" """Handle OAuth callback and exchange code for tokens."""
try:
result = self.app.acquire_token_by_authorization_code( result = self.app.acquire_token_by_authorization_code(
authorization_code, authorization_code,
scopes=self.SCOPES, scopes=self.AUTH_SCOPES, # same as authorize step
redirect_uri=redirect_uri, redirect_uri=redirect_uri,
) )
if "access_token" in result:
if result and "access_token" in result:
accounts = self.app.get_accounts()
if accounts:
self._current_account = accounts[0]
await self.save_cache() await self.save_cache()
logger.info("OneDrive OAuth authorization successful")
return True return True
raise ValueError(result.get("error_description") or "Authorization failed")
error_msg = (result or {}).get("error_description") or (result or {}).get("error") or "Unknown error"
logger.error(f"OneDrive OAuth authorization failed: {error_msg}")
return False
except Exception as e:
logger.error(f"Exception during OneDrive OAuth authorization: {e}")
return False
async def is_authenticated(self) -> bool: async def is_authenticated(self) -> bool:
"""Check if we have valid credentials""" """Check if we have valid credentials."""
accounts = self.app.get_accounts() try:
if not accounts: # First try to load credentials if we haven't already
return False if not self._current_account:
result = self.app.acquire_token_silent(self.SCOPES, account=accounts[0]) await self.load_credentials()
if "access_token" in result:
await self.save_cache() # Try to get a token (MSAL will refresh if needed)
if self._current_account:
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
if result and "access_token" in result:
return True return True
else:
error_msg = (result or {}).get("error") or "No result returned"
logger.debug(f"Token acquisition failed for current account: {error_msg}")
# Fallback: try without specific account
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=None)
if result and "access_token" in result:
accounts = self.app.get_accounts()
if accounts:
self._current_account = accounts[0]
return True
return False
except Exception as e:
logger.error(f"Authentication check failed: {e}")
return False return False
def get_access_token(self) -> str: def get_access_token(self) -> str:
"""Get an access token for Microsoft Graph""" """Get an access token for Microsoft Graph."""
accounts = self.app.get_accounts() try:
if not accounts: # Try with current account first
raise ValueError("Not authenticated") if self._current_account:
result = self.app.acquire_token_silent(self.SCOPES, account=accounts[0]) result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
if "access_token" not in result: if result and "access_token" in result:
raise ValueError(
result.get("error_description") or "Failed to acquire access token"
)
return result["access_token"] return result["access_token"]
# Fallback: try without specific account
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=None)
if result and "access_token" in result:
return result["access_token"]
# If we get here, authentication has failed
error_msg = (result or {}).get("error_description") or (result or {}).get("error") or "No valid authentication"
raise ValueError(f"Failed to acquire access token: {error_msg}")
except Exception as e:
logger.error(f"Failed to get access token: {e}")
raise
async def revoke_credentials(self): async def revoke_credentials(self):
"""Clear token cache and remove token file""" """Clear token cache and remove token file."""
self.token_cache.clear() try:
# Clear in-memory state
self._current_account = None
self.token_cache = msal.SerializableTokenCache()
# Recreate MSAL app with fresh cache
self.app = msal.ConfidentialClientApplication(
client_id=self.client_id,
client_credential=self.client_secret,
authority=self.authority,
token_cache=self.token_cache,
)
# Remove token file
if os.path.exists(self.token_file): if os.path.exists(self.token_file):
os.remove(self.token_file) os.remove(self.token_file)
logger.info(f"Removed OneDrive token file: {self.token_file}")
except Exception as e:
logger.error(f"Failed to revoke OneDrive credentials: {e}")
def get_service(self) -> str:
"""Return an access token (Graph client is just the bearer)."""
return self.get_access_token()

View file

@ -1,229 +1,567 @@
import logging
from pathlib import Path
from typing import List, Dict, Any, Optional
from urllib.parse import urlparse
from datetime import datetime
import httpx import httpx
import uuid
from datetime import datetime, timedelta
from typing import Dict, List, Any, Optional
from ..base import BaseConnector, ConnectorDocument, DocumentACL from ..base import BaseConnector, ConnectorDocument, DocumentACL
from .oauth import SharePointOAuth from .oauth import SharePointOAuth
logger = logging.getLogger(__name__)
class SharePointConnector(BaseConnector): class SharePointConnector(BaseConnector):
"""SharePoint Sites connector using Microsoft Graph API""" """SharePoint connector using MSAL-based OAuth for authentication"""
# Required BaseConnector class attributes
CLIENT_ID_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_ID" CLIENT_ID_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_ID"
CLIENT_SECRET_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET" CLIENT_SECRET_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET"
# Connector metadata # Connector metadata
CONNECTOR_NAME = "SharePoint" CONNECTOR_NAME = "SharePoint"
CONNECTOR_DESCRIPTION = "Connect to SharePoint sites to sync team documents" CONNECTOR_DESCRIPTION = "Connect to SharePoint to sync documents and files"
CONNECTOR_ICON = "sharepoint" CONNECTOR_ICON = "sharepoint"
def __init__(self, config: Dict[str, Any]): def __init__(self, config: Dict[str, Any]):
super().__init__(config) super().__init__(config)
logger.debug(f"SharePoint connector __init__ called with config type: {type(config)}")
logger.debug(f"SharePoint connector __init__ config value: {config}")
# Ensure we always pass a valid config to the base class
if config is None:
logger.debug("Config was None, using empty dict")
config = {}
try:
logger.debug("Calling super().__init__")
super().__init__(config) # Now safe to call with empty dict instead of None
logger.debug("super().__init__ completed successfully")
except Exception as e:
logger.error(f"super().__init__ failed: {e}")
raise
# Initialize with defaults that allow the connector to be listed
self.client_id = None
self.client_secret = None
self.tenant_id = config.get("tenant_id", "common")
self.sharepoint_url = config.get("sharepoint_url")
self.redirect_uri = config.get("redirect_uri", "http://localhost")
# Try to get credentials, but don't fail if they're missing
try:
logger.debug("Attempting to get client_id")
self.client_id = self.get_client_id()
logger.debug(f"Got client_id: {self.client_id is not None}")
except Exception as e:
logger.debug(f"Failed to get client_id: {e}")
pass # Credentials not available, that's OK for listing
try:
logger.debug("Attempting to get client_secret")
self.client_secret = self.get_client_secret()
logger.debug(f"Got client_secret: {self.client_secret is not None}")
except Exception as e:
logger.debug(f"Failed to get client_secret: {e}")
pass # Credentials not available, that's OK for listing
# Token file setup
project_root = Path(__file__).resolve().parent.parent.parent.parent
token_file = config.get("token_file") or str(project_root / "sharepoint_token.json")
Path(token_file).parent.mkdir(parents=True, exist_ok=True)
# Only initialize OAuth if we have credentials
if self.client_id and self.client_secret:
connection_id = config.get("connection_id", "default")
# Use token_file from config if provided, otherwise generate one
if config.get("token_file"):
oauth_token_file = config["token_file"]
else:
oauth_token_file = f"sharepoint_token_{connection_id}.json"
authority = f"https://login.microsoftonline.com/{self.tenant_id}" if self.tenant_id != "common" else "https://login.microsoftonline.com/common"
self.oauth = SharePointOAuth( self.oauth = SharePointOAuth(
client_id=self.get_client_id(), client_id=self.client_id,
client_secret=self.get_client_secret(), client_secret=self.client_secret,
token_file=config.get("token_file", "sharepoint_token.json"), token_file=oauth_token_file,
authority=authority
) )
self.subscription_id = config.get("subscription_id") or config.get( else:
"webhook_channel_id" self.oauth = None
)
self.base_url = "https://graph.microsoft.com/v1.0"
# SharePoint site configuration # Track subscription ID for webhooks
self.site_id = config.get("site_id") # Required for SharePoint self._subscription_id: Optional[str] = None
async def authenticate(self) -> bool: # Add Graph API defaults similar to Google Drive flags
if await self.oauth.is_authenticated(): self._graph_api_version = "v1.0"
self._authenticated = True self._default_params = {
return True "$select": "id,name,size,lastModifiedDateTime,createdDateTime,webUrl,file,folder,@microsoft.graph.downloadUrl"
return False
async def setup_subscription(self) -> str:
if not self._authenticated:
raise ValueError("Not authenticated")
webhook_url = self.config.get("webhook_url")
if not webhook_url:
raise ValueError("webhook_url required in config for subscriptions")
expiration = (datetime.utcnow() + timedelta(days=2)).isoformat() + "Z"
body = {
"changeType": "created,updated,deleted",
"notificationUrl": webhook_url,
"resource": f"/sites/{self.site_id}/drive/root",
"expirationDateTime": expiration,
"clientState": str(uuid.uuid4()),
} }
token = self.oauth.get_access_token() @property
async with httpx.AsyncClient() as client: def _graph_base_url(self) -> str:
resp = await client.post( """Base URL for Microsoft Graph API calls"""
f"{self.base_url}/subscriptions", return f"https://graph.microsoft.com/{self._graph_api_version}"
json=body,
headers={"Authorization": f"Bearer {token}"},
)
resp.raise_for_status()
data = resp.json()
self.subscription_id = data["id"] def emit(self, doc: ConnectorDocument) -> None:
return self.subscription_id """
Emit a ConnectorDocument instance.
"""
logger.debug(f"Emitting SharePoint document: {doc.id} ({doc.filename})")
async def authenticate(self) -> bool:
"""Test authentication - BaseConnector interface"""
logger.debug(f"SharePoint authenticate() called, oauth is None: {self.oauth is None}")
try:
if not self.oauth:
logger.debug("SharePoint authentication failed: OAuth not initialized")
self._authenticated = False
return False
logger.debug("Loading SharePoint credentials...")
# Try to load existing credentials first
load_result = await self.oauth.load_credentials()
logger.debug(f"Load credentials result: {load_result}")
logger.debug("Checking SharePoint authentication status...")
authenticated = await self.oauth.is_authenticated()
logger.debug(f"SharePoint is_authenticated result: {authenticated}")
self._authenticated = authenticated
return authenticated
except Exception as e:
logger.error(f"SharePoint authentication failed: {e}")
import traceback
traceback.print_exc()
self._authenticated = False
return False
def get_auth_url(self) -> str:
"""Get OAuth authorization URL"""
if not self.oauth:
raise RuntimeError("SharePoint OAuth not initialized - missing credentials")
return self.oauth.create_authorization_url(self.redirect_uri)
async def handle_oauth_callback(self, auth_code: str) -> Dict[str, Any]:
"""Handle OAuth callback"""
if not self.oauth:
raise RuntimeError("SharePoint OAuth not initialized - missing credentials")
try:
success = await self.oauth.handle_authorization_callback(auth_code, self.redirect_uri)
if success:
self._authenticated = True
return {"status": "success"}
else:
raise ValueError("OAuth callback failed")
except Exception as e:
logger.error(f"OAuth callback failed: {e}")
raise
def sync_once(self) -> None:
"""
Perform a one-shot sync of SharePoint files and emit documents.
This method mirrors the Google Drive connector's sync_once functionality.
"""
import asyncio
async def _async_sync():
try:
# Get list of files
file_list = await self.list_files(max_files=1000) # Adjust as needed
files = file_list.get("files", [])
for file_info in files:
try:
file_id = file_info.get("id")
if not file_id:
continue
# Get full document content
doc = await self.get_file_content(file_id)
self.emit(doc)
except Exception as e:
logger.error(f"Failed to sync SharePoint file {file_info.get('name', 'unknown')}: {e}")
continue
except Exception as e:
logger.error(f"SharePoint sync_once failed: {e}")
raise
# Run the async sync
if hasattr(asyncio, 'run'):
asyncio.run(_async_sync())
else:
# Python < 3.7 compatibility
loop = asyncio.get_event_loop()
loop.run_until_complete(_async_sync())
async def setup_subscription(self) -> str:
"""Set up real-time subscription for file changes - BaseConnector interface"""
webhook_url = self.config.get('webhook_url')
if not webhook_url:
logger.warning("No webhook URL configured, skipping SharePoint subscription setup")
return "no-webhook-configured"
try:
# Ensure we're authenticated
if not await self.authenticate():
raise RuntimeError("SharePoint authentication failed during subscription setup")
token = self.oauth.get_access_token()
# Microsoft Graph subscription for SharePoint site
site_info = self._parse_sharepoint_url()
if site_info:
resource = f"sites/{site_info['host_name']}:/sites/{site_info['site_name']}:/drive/root"
else:
resource = "/me/drive/root"
subscription_data = {
"changeType": "created,updated,deleted",
"notificationUrl": f"{webhook_url}/webhook/sharepoint",
"resource": resource,
"expirationDateTime": self._get_subscription_expiry(),
"clientState": f"sharepoint_{self.tenant_id}"
}
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json"
}
url = f"{self._graph_base_url}/subscriptions"
async with httpx.AsyncClient() as client:
response = await client.post(url, json=subscription_data, headers=headers, timeout=30)
response.raise_for_status()
result = response.json()
subscription_id = result.get("id")
if subscription_id:
self._subscription_id = subscription_id
logger.info(f"SharePoint subscription created: {subscription_id}")
return subscription_id
else:
raise ValueError("No subscription ID returned from Microsoft Graph")
except Exception as e:
logger.error(f"Failed to setup SharePoint subscription: {e}")
raise
def _get_subscription_expiry(self) -> str:
"""Get subscription expiry time (max 3 days for Graph API)"""
from datetime import datetime, timedelta
expiry = datetime.utcnow() + timedelta(days=3) # 3 days max for Graph
return expiry.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
def _parse_sharepoint_url(self) -> Optional[Dict[str, str]]:
"""Parse SharePoint URL to extract site information for Graph API"""
if not self.sharepoint_url:
return None
try:
parsed = urlparse(self.sharepoint_url)
# Extract hostname and site name from URL like: https://contoso.sharepoint.com/sites/teamsite
host_name = parsed.netloc
path_parts = parsed.path.strip('/').split('/')
if len(path_parts) >= 2 and path_parts[0] == 'sites':
site_name = path_parts[1]
return {
"host_name": host_name,
"site_name": site_name
}
except Exception as e:
logger.warning(f"Could not parse SharePoint URL {self.sharepoint_url}: {e}")
return None
async def list_files( async def list_files(
self, page_token: Optional[str] = None, limit: int = 100 self,
page_token: Optional[str] = None,
max_files: Optional[int] = None,
**kwargs
) -> Dict[str, Any]: ) -> Dict[str, Any]:
if not self._authenticated: """List all files using Microsoft Graph API - BaseConnector interface"""
raise ValueError("Not authenticated") try:
# Ensure authentication
if not await self.authenticate():
raise RuntimeError("SharePoint authentication failed during file listing")
files = []
max_files_value = max_files if max_files is not None else 100
# Build Graph API URL for the site or fallback to user's OneDrive
site_info = self._parse_sharepoint_url()
if site_info:
base_url = f"{self._graph_base_url}/sites/{site_info['host_name']}:/sites/{site_info['site_name']}:/drive/root/children"
else:
base_url = f"{self._graph_base_url}/me/drive/root/children"
params = dict(self._default_params)
params["$top"] = str(max_files_value)
params = {"$top": str(limit)}
if page_token: if page_token:
params["$skiptoken"] = page_token params["$skiptoken"] = page_token
token = self.oauth.get_access_token() response = await self._make_graph_request(base_url, params=params)
async with httpx.AsyncClient() as client: data = response.json()
resp = await client.get(
f"{self.base_url}/sites/{self.site_id}/drive/root/children",
params=params,
headers={"Authorization": f"Bearer {token}"},
)
resp.raise_for_status()
data = resp.json()
files = [] items = data.get("value", [])
for item in data.get("value", []): for item in items:
# Only include files, not folders
if item.get("file"): if item.get("file"):
files.append( files.append({
{ "id": item.get("id", ""),
"id": item["id"], "name": item.get("name", ""),
"name": item["name"], "path": f"/drive/items/{item.get('id')}",
"mimeType": item.get("file", {}).get( "size": int(item.get("size", 0)),
"mimeType", "application/octet-stream" "modified": item.get("lastModifiedDateTime"),
), "created": item.get("createdDateTime"),
"webViewLink": item.get("webUrl"), "mime_type": item.get("file", {}).get("mimeType", self._get_mime_type(item.get("name", ""))),
"createdTime": item.get("createdDateTime"), "url": item.get("webUrl", ""),
"modifiedTime": item.get("lastModifiedDateTime"), "download_url": item.get("@microsoft.graph.downloadUrl")
} })
)
next_token = None # Check for next page
next_page_token = None
next_link = data.get("@odata.nextLink") next_link = data.get("@odata.nextLink")
if next_link: if next_link:
from urllib.parse import urlparse, parse_qs from urllib.parse import urlparse, parse_qs
parsed = urlparse(next_link) parsed = urlparse(next_link)
next_token = parse_qs(parsed.query).get("$skiptoken", [None])[0] query_params = parse_qs(parsed.query)
if "$skiptoken" in query_params:
next_page_token = query_params["$skiptoken"][0]
return {"files": files, "nextPageToken": next_token} return {
"files": files,
"next_page_token": next_page_token
}
except Exception as e:
logger.error(f"Failed to list SharePoint files: {e}")
return {"files": [], "next_page_token": None} # Return empty result instead of raising
async def get_file_content(self, file_id: str) -> ConnectorDocument: async def get_file_content(self, file_id: str) -> ConnectorDocument:
if not self._authenticated: """Get file content and metadata - BaseConnector interface"""
raise ValueError("Not authenticated") try:
# Ensure authentication
if not await self.authenticate():
raise RuntimeError("SharePoint authentication failed during file content retrieval")
# First get file metadata using Graph API
file_metadata = await self._get_file_metadata_by_id(file_id)
if not file_metadata:
raise ValueError(f"File not found: {file_id}")
# Download file content
download_url = file_metadata.get("download_url")
if download_url:
content = await self._download_file_from_url(download_url)
else:
content = await self._download_file_content(file_id)
# Create ACL from metadata
acl = DocumentACL(
owner="", # Graph API requires additional calls for detailed permissions
user_permissions={},
group_permissions={}
)
# Parse dates
modified_time = self._parse_graph_date(file_metadata.get("modified"))
created_time = self._parse_graph_date(file_metadata.get("created"))
return ConnectorDocument(
id=file_id,
filename=file_metadata.get("name", ""),
mimetype=file_metadata.get("mime_type", "application/octet-stream"),
content=content,
source_url=file_metadata.get("url", ""),
acl=acl,
modified_time=modified_time,
created_time=created_time,
metadata={
"sharepoint_path": file_metadata.get("path", ""),
"sharepoint_url": self.sharepoint_url,
"size": file_metadata.get("size", 0)
}
)
except Exception as e:
logger.error(f"Failed to get SharePoint file content {file_id}: {e}")
raise
async def _get_file_metadata_by_id(self, file_id: str) -> Optional[Dict[str, Any]]:
"""Get file metadata by ID using Graph API"""
try:
# Try site-specific path first, then fallback to user drive
site_info = self._parse_sharepoint_url()
if site_info:
url = f"{self._graph_base_url}/sites/{site_info['host_name']}:/sites/{site_info['site_name']}:/drive/items/{file_id}"
else:
url = f"{self._graph_base_url}/me/drive/items/{file_id}"
params = dict(self._default_params)
response = await self._make_graph_request(url, params=params)
item = response.json()
if item.get("file"):
return {
"id": file_id,
"name": item.get("name", ""),
"path": f"/drive/items/{file_id}",
"size": int(item.get("size", 0)),
"modified": item.get("lastModifiedDateTime"),
"created": item.get("createdDateTime"),
"mime_type": item.get("file", {}).get("mimeType", self._get_mime_type(item.get("name", ""))),
"url": item.get("webUrl", ""),
"download_url": item.get("@microsoft.graph.downloadUrl")
}
return None
except Exception as e:
logger.error(f"Failed to get file metadata for {file_id}: {e}")
return None
async def _download_file_content(self, file_id: str) -> bytes:
"""Download file content by file ID using Graph API"""
try:
site_info = self._parse_sharepoint_url()
if site_info:
url = f"{self._graph_base_url}/sites/{site_info['host_name']}:/sites/{site_info['site_name']}:/drive/items/{file_id}/content"
else:
url = f"{self._graph_base_url}/me/drive/items/{file_id}/content"
token = self.oauth.get_access_token() token = self.oauth.get_access_token()
headers = {"Authorization": f"Bearer {token}"} headers = {"Authorization": f"Bearer {token}"}
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
meta_resp = await client.get( response = await client.get(url, headers=headers, timeout=60)
f"{self.base_url}/sites/{self.site_id}/drive/items/{file_id}", response.raise_for_status()
headers=headers, return response.content
)
meta_resp.raise_for_status()
metadata = meta_resp.json()
content_resp = await client.get( except Exception as e:
f"{self.base_url}/sites/{self.site_id}/drive/items/{file_id}/content", logger.error(f"Failed to download file content for {file_id}: {e}")
headers=headers, raise
)
content_resp.raise_for_status()
content = content_resp.content
perm_resp = await client.get( async def _download_file_from_url(self, download_url: str) -> bytes:
f"{self.base_url}/sites/{self.site_id}/drive/items/{file_id}/permissions", """Download file content from direct download URL"""
headers=headers, try:
) async with httpx.AsyncClient() as client:
perm_resp.raise_for_status() response = await client.get(download_url, timeout=60)
permissions = perm_resp.json() response.raise_for_status()
return response.content
except Exception as e:
logger.error(f"Failed to download from URL {download_url}: {e}")
raise
acl = self._parse_permissions(metadata, permissions) def _parse_graph_date(self, date_str: Optional[str]) -> datetime:
modified = datetime.fromisoformat( """Parse Microsoft Graph date string to datetime"""
metadata["lastModifiedDateTime"].replace("Z", "+00:00") if not date_str:
).replace(tzinfo=None) return datetime.now()
created = datetime.fromisoformat(
metadata["createdDateTime"].replace("Z", "+00:00")
).replace(tzinfo=None)
document = ConnectorDocument( try:
id=metadata["id"], if date_str.endswith('Z'):
filename=metadata["name"], return datetime.fromisoformat(date_str[:-1]).replace(tzinfo=None)
mimetype=metadata.get("file", {}).get( else:
"mimeType", "application/octet-stream" return datetime.fromisoformat(date_str.replace('T', ' '))
), except (ValueError, AttributeError):
content=content, return datetime.now()
source_url=metadata.get("webUrl"),
acl=acl,
modified_time=modified,
created_time=created,
metadata={"size": metadata.get("size")},
)
return document
def _parse_permissions( async def _make_graph_request(self, url: str, method: str = "GET",
self, metadata: Dict[str, Any], permissions: Dict[str, Any] data: Optional[Dict] = None, params: Optional[Dict] = None) -> httpx.Response:
) -> DocumentACL: """Make authenticated API request to Microsoft Graph"""
acl = DocumentACL() token = self.oauth.get_access_token()
owner = metadata.get("createdBy", {}).get("user", {}).get("email") headers = {
if owner: "Authorization": f"Bearer {token}",
acl.owner = owner "Content-Type": "application/json"
for perm in permissions.get("value", []): }
role = perm.get("roles", ["read"])[0]
grantee = perm.get("grantedToV2") or perm.get("grantedTo")
if not grantee:
continue
user = grantee.get("user")
if user and user.get("email"):
acl.user_permissions[user["email"]] = role
group = grantee.get("group")
if group and group.get("email"):
acl.group_permissions[group["email"]] = role
return acl
def handle_webhook_validation( async with httpx.AsyncClient() as client:
self, request_method: str, headers: Dict[str, str], query_params: Dict[str, str] if method.upper() == "GET":
) -> Optional[str]: response = await client.get(url, headers=headers, params=params, timeout=30)
"""Handle Microsoft Graph webhook validation""" elif method.upper() == "POST":
if request_method == "GET": response = await client.post(url, headers=headers, json=data, timeout=30)
validation_token = query_params.get("validationtoken") or query_params.get( elif method.upper() == "DELETE":
"validationToken" response = await client.delete(url, headers=headers, timeout=30)
) else:
if validation_token: raise ValueError(f"Unsupported HTTP method: {method}")
return validation_token
response.raise_for_status()
return response
def _get_mime_type(self, filename: str) -> str:
"""Get MIME type based on file extension"""
import mimetypes
mime_type, _ = mimetypes.guess_type(filename)
return mime_type or "application/octet-stream"
# Webhook methods - BaseConnector interface
def handle_webhook_validation(self, request_method: str, headers: Dict[str, str],
query_params: Dict[str, str]) -> Optional[str]:
"""Handle webhook validation (Graph API specific)"""
if request_method == "POST" and "validationToken" in query_params:
return query_params["validationToken"]
return None return None
def extract_webhook_channel_id( def extract_webhook_channel_id(self, payload: Dict[str, Any],
self, payload: Dict[str, Any], headers: Dict[str, str] headers: Dict[str, str]) -> Optional[str]:
) -> Optional[str]: """Extract channel/subscription ID from webhook payload"""
"""Extract SharePoint subscription ID from webhook payload""" notifications = payload.get("value", [])
values = payload.get("value", []) if notifications:
return values[0].get("subscriptionId") if values else None return notifications[0].get("subscriptionId")
return None
async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]: async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]:
values = payload.get("value", []) """Handle webhook notification and return affected file IDs"""
file_ids = [] affected_files = []
for item in values:
resource_data = item.get("resourceData", {})
file_id = resource_data.get("id")
if file_id:
file_ids.append(file_id)
return file_ids
async def cleanup_subscription( # Process Microsoft Graph webhook payload
self, subscription_id: str, resource_id: str = None notifications = payload.get("value", [])
) -> bool: for notification in notifications:
if not self._authenticated: resource = notification.get("resource")
if resource and "/drive/items/" in resource:
file_id = resource.split("/drive/items/")[-1]
affected_files.append(file_id)
return affected_files
async def cleanup_subscription(self, subscription_id: str) -> bool:
"""Clean up subscription - BaseConnector interface"""
if subscription_id == "no-webhook-configured":
logger.info("No subscription to cleanup (webhook was not configured)")
return True
try:
# Ensure authentication
if not await self.authenticate():
logger.error("SharePoint authentication failed during subscription cleanup")
return False return False
token = self.oauth.get_access_token() token = self.oauth.get_access_token()
headers = {"Authorization": f"Bearer {token}"}
url = f"{self._graph_base_url}/subscriptions/{subscription_id}"
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
resp = await client.delete( response = await client.delete(url, headers=headers, timeout=30)
f"{self.base_url}/subscriptions/{subscription_id}",
headers={"Authorization": f"Bearer {token}"}, if response.status_code in [200, 204, 404]:
) logger.info(f"SharePoint subscription {subscription_id} cleaned up successfully")
return resp.status_code in (200, 204) return True
else:
logger.warning(f"Unexpected response cleaning up subscription: {response.status_code}")
return False
except Exception as e:
logger.error(f"Failed to cleanup SharePoint subscription {subscription_id}: {e}")
return False

View file

@ -1,18 +1,28 @@
import os import os
import json
import logging
from typing import Optional, Dict, Any
import aiofiles import aiofiles
from typing import Optional
import msal import msal
logger = logging.getLogger(__name__)
class SharePointOAuth: class SharePointOAuth:
"""Handles Microsoft Graph OAuth authentication flow""" """Handles Microsoft Graph OAuth authentication flow following Google Drive pattern."""
SCOPES = [ # Reserved scopes that must NOT be sent on token or silent calls
"offline_access", RESERVED_SCOPES = {"openid", "profile", "offline_access"}
"Files.Read.All",
"Sites.Read.All",
]
# For PERSONAL Microsoft Accounts (OneDrive consumer):
# - Use AUTH_SCOPES for interactive auth (consent + refresh token issuance)
# - Use RESOURCE_SCOPES for acquire_token_silent / refresh paths
AUTH_SCOPES = ["User.Read", "Files.Read.All", "offline_access"]
RESOURCE_SCOPES = ["User.Read", "Files.Read.All"]
SCOPES = AUTH_SCOPES # Backward compatibility alias
# Kept for reference; MSAL derives endpoints from `authority`
AUTH_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize" AUTH_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
TOKEN_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/token" TOKEN_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/token"
@ -22,18 +32,29 @@ class SharePointOAuth:
client_secret: str, client_secret: str,
token_file: str = "sharepoint_token.json", token_file: str = "sharepoint_token.json",
authority: str = "https://login.microsoftonline.com/common", authority: str = "https://login.microsoftonline.com/common",
allow_json_refresh: bool = True,
): ):
"""
Initialize SharePointOAuth.
Args:
client_id: Azure AD application (client) ID.
client_secret: Azure AD application client secret.
token_file: Path to persisted token cache file (MSAL cache format).
authority: Usually "https://login.microsoftonline.com/common" for MSA + org,
or tenant-specific for work/school.
allow_json_refresh: If True, permit one-time migration from legacy flat JSON
{"access_token","refresh_token",...}. Otherwise refuse it.
"""
self.client_id = client_id self.client_id = client_id
self.client_secret = client_secret self.client_secret = client_secret
self.token_file = token_file self.token_file = token_file
self.authority = authority self.authority = authority
self.allow_json_refresh = allow_json_refresh
self.token_cache = msal.SerializableTokenCache() self.token_cache = msal.SerializableTokenCache()
self._current_account = None
# Load existing cache if available # Initialize MSAL Confidential Client
if os.path.exists(self.token_file):
with open(self.token_file, "r") as f:
self.token_cache.deserialize(f.read())
self.app = msal.ConfidentialClientApplication( self.app = msal.ConfidentialClientApplication(
client_id=self.client_id, client_id=self.client_id,
client_credential=self.client_secret, client_credential=self.client_secret,
@ -41,56 +62,268 @@ class SharePointOAuth:
token_cache=self.token_cache, token_cache=self.token_cache,
) )
async def save_cache(self): async def load_credentials(self) -> bool:
"""Persist the token cache to file""" """Load existing credentials from token file (async)."""
async with aiofiles.open(self.token_file, "w") as f: try:
await f.write(self.token_cache.serialize()) logger.debug(f"SharePoint OAuth loading credentials from: {self.token_file}")
if os.path.exists(self.token_file):
logger.debug(f"Token file exists, reading: {self.token_file}")
def create_authorization_url(self, redirect_uri: str) -> str: # Read the token file
"""Create authorization URL for OAuth flow""" async with aiofiles.open(self.token_file, "r") as f:
return self.app.get_authorization_request_url( cache_data = await f.read()
self.SCOPES, redirect_uri=redirect_uri logger.debug(f"Read {len(cache_data)} chars from token file")
if cache_data.strip():
# 1) Try legacy flat JSON first
try:
json_data = json.loads(cache_data)
if isinstance(json_data, dict) and "refresh_token" in json_data:
if self.allow_json_refresh:
logger.debug(
"Found legacy JSON refresh_token and allow_json_refresh=True; attempting migration refresh"
) )
return await self._refresh_from_json_token(json_data)
else:
logger.warning(
"Token file contains a legacy JSON refresh_token, but allow_json_refresh=False. "
"Delete the file and re-auth."
)
return False
except json.JSONDecodeError:
logger.debug("Token file is not flat JSON; attempting MSAL cache format")
# 2) Try MSAL cache format
logger.debug("Attempting MSAL cache deserialization")
self.token_cache.deserialize(cache_data)
# Get accounts from loaded cache
accounts = self.app.get_accounts()
logger.debug(f"Found {len(accounts)} accounts in MSAL cache")
if accounts:
self._current_account = accounts[0]
logger.debug(f"Set current account: {self._current_account.get('username', 'no username')}")
# IMPORTANT: Use RESOURCE_SCOPES (no reserved scopes) for silent acquisition
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
logger.debug(f"Silent token acquisition result keys: {list(result.keys()) if result else 'None'}")
if result and "access_token" in result:
logger.debug("Silent token acquisition successful")
await self.save_cache()
return True
else:
error_msg = (result or {}).get("error") or "No result"
logger.warning(f"Silent token acquisition failed: {error_msg}")
else:
logger.debug(f"Token file {self.token_file} is empty")
else:
logger.debug(f"Token file does not exist: {self.token_file}")
return False
except Exception as e:
logger.error(f"Failed to load SharePoint credentials: {e}")
import traceback
traceback.print_exc()
return False
async def _refresh_from_json_token(self, token_data: dict) -> bool:
"""
Use refresh token from a legacy JSON file to get new tokens (one-time migration path).
Notes:
- Prefer using an MSAL cache file and acquire_token_silent().
- This path is only for migrating older refresh_token JSON files.
"""
try:
refresh_token = token_data.get("refresh_token")
if not refresh_token:
logger.error("No refresh_token found in JSON file - cannot refresh")
logger.error("You must re-authenticate interactively to obtain a valid token")
return False
# Use only RESOURCE_SCOPES when refreshing (no reserved scopes)
refresh_scopes = [s for s in self.RESOURCE_SCOPES if s not in self.RESERVED_SCOPES]
logger.debug(f"Using refresh token; refresh scopes = {refresh_scopes}")
result = self.app.acquire_token_by_refresh_token(
refresh_token=refresh_token,
scopes=refresh_scopes,
)
if result and "access_token" in result:
logger.debug("Successfully refreshed token via legacy JSON path")
await self.save_cache()
accounts = self.app.get_accounts()
logger.debug(f"After refresh, found {len(accounts)} accounts")
if accounts:
self._current_account = accounts[0]
logger.debug(f"Set current account after refresh: {self._current_account.get('username', 'no username')}")
return True
# Error handling
err = (result or {}).get("error_description") or (result or {}).get("error") or "Unknown error"
logger.error(f"Refresh token failed: {err}")
if any(code in err for code in ("AADSTS70000", "invalid_grant", "interaction_required")):
logger.warning(
"Refresh denied due to unauthorized/expired scopes or invalid grant. "
"Delete the token file and perform interactive sign-in with correct scopes."
)
return False
except Exception as e:
logger.error(f"Exception during refresh from JSON token: {e}")
import traceback
traceback.print_exc()
return False
async def save_cache(self):
"""Persist the token cache to file."""
try:
# Ensure parent directory exists
parent = os.path.dirname(os.path.abspath(self.token_file))
if parent and not os.path.exists(parent):
os.makedirs(parent, exist_ok=True)
cache_data = self.token_cache.serialize()
if cache_data:
async with aiofiles.open(self.token_file, "w") as f:
await f.write(cache_data)
logger.debug(f"Token cache saved to {self.token_file}")
except Exception as e:
logger.error(f"Failed to save token cache: {e}")
def create_authorization_url(self, redirect_uri: str, state: Optional[str] = None) -> str:
"""Create authorization URL for OAuth flow."""
# Store redirect URI for later use in callback
self._redirect_uri = redirect_uri
kwargs: Dict[str, Any] = {
# IMPORTANT: interactive auth includes offline_access
"scopes": self.AUTH_SCOPES,
"redirect_uri": redirect_uri,
"prompt": "consent", # ensure refresh token on first run
}
if state:
kwargs["state"] = state # Optional CSRF protection
auth_url = self.app.get_authorization_request_url(**kwargs)
logger.debug(f"Generated auth URL: {auth_url}")
logger.debug(f"Auth scopes: {self.AUTH_SCOPES}")
return auth_url
async def handle_authorization_callback( async def handle_authorization_callback(
self, authorization_code: str, redirect_uri: str self, authorization_code: str, redirect_uri: str
) -> bool: ) -> bool:
"""Handle OAuth callback and exchange code for tokens""" """Handle OAuth callback and exchange code for tokens."""
try:
# For code exchange, we pass the same auth scopes as used in the authorize step
result = self.app.acquire_token_by_authorization_code( result = self.app.acquire_token_by_authorization_code(
authorization_code, authorization_code,
scopes=self.SCOPES, scopes=self.AUTH_SCOPES,
redirect_uri=redirect_uri, redirect_uri=redirect_uri,
) )
if "access_token" in result:
if result and "access_token" in result:
# Store the account for future use
accounts = self.app.get_accounts()
if accounts:
self._current_account = accounts[0]
await self.save_cache() await self.save_cache()
logger.info("SharePoint OAuth authorization successful")
return True return True
raise ValueError(result.get("error_description") or "Authorization failed")
error_msg = (result or {}).get("error_description") or (result or {}).get("error") or "Unknown error"
logger.error(f"SharePoint OAuth authorization failed: {error_msg}")
return False
except Exception as e:
logger.error(f"Exception during SharePoint OAuth authorization: {e}")
return False
async def is_authenticated(self) -> bool: async def is_authenticated(self) -> bool:
"""Check if we have valid credentials""" """Check if we have valid credentials (simplified like Google Drive)."""
accounts = self.app.get_accounts() try:
if not accounts: # First try to load credentials if we haven't already
return False if not self._current_account:
result = self.app.acquire_token_silent(self.SCOPES, account=accounts[0]) await self.load_credentials()
if "access_token" in result:
await self.save_cache() # If we have an account, try to get a token (MSAL will refresh if needed)
if self._current_account:
# IMPORTANT: use RESOURCE_SCOPES here
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
if result and "access_token" in result:
return True return True
else:
error_msg = (result or {}).get("error") or "No result returned"
logger.debug(f"Token acquisition failed for current account: {error_msg}")
# Fallback: try without specific account
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=None)
if result and "access_token" in result:
# Update current account if this worked
accounts = self.app.get_accounts()
if accounts:
self._current_account = accounts[0]
return True
return False
except Exception as e:
logger.error(f"Authentication check failed: {e}")
return False return False
def get_access_token(self) -> str: def get_access_token(self) -> str:
"""Get an access token for Microsoft Graph""" """Get an access token for Microsoft Graph (simplified like Google Drive)."""
accounts = self.app.get_accounts() try:
if not accounts: # Try with current account first
raise ValueError("Not authenticated") if self._current_account:
result = self.app.acquire_token_silent(self.SCOPES, account=accounts[0]) result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
if "access_token" not in result: if result and "access_token" in result:
raise ValueError(
result.get("error_description") or "Failed to acquire access token"
)
return result["access_token"] return result["access_token"]
# Fallback: try without specific account
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=None)
if result and "access_token" in result:
return result["access_token"]
# If we get here, authentication has failed
error_msg = (result or {}).get("error_description") or (result or {}).get("error") or "No valid authentication"
raise ValueError(f"Failed to acquire access token: {error_msg}")
except Exception as e:
logger.error(f"Failed to get access token: {e}")
raise
async def revoke_credentials(self): async def revoke_credentials(self):
"""Clear token cache and remove token file""" """Clear token cache and remove token file (like Google Drive)."""
self.token_cache.clear() try:
# Clear in-memory state
self._current_account = None
self.token_cache = msal.SerializableTokenCache()
# Recreate MSAL app with fresh cache
self.app = msal.ConfidentialClientApplication(
client_id=self.client_id,
client_credential=self.client_secret,
authority=self.authority,
token_cache=self.token_cache,
)
# Remove token file
if os.path.exists(self.token_file): if os.path.exists(self.token_file):
os.remove(self.token_file) os.remove(self.token_file)
logger.info(f"Removed SharePoint token file: {self.token_file}")
except Exception as e:
logger.error(f"Failed to revoke SharePoint credentials: {e}")
def get_service(self) -> str:
"""Return an access token (Graph doesn't need a generated client like Google Drive)."""
return self.get_access_token()

View file

@ -2,6 +2,7 @@
from connectors.langflow_connector_service import LangflowConnectorService from connectors.langflow_connector_service import LangflowConnectorService
from connectors.service import ConnectorService from connectors.service import ConnectorService
from services.flows_service import FlowsService from services.flows_service import FlowsService
from utils.embeddings import create_dynamic_index_body
from utils.logging_config import configure_from_env, get_logger from utils.logging_config import configure_from_env, get_logger
configure_from_env() configure_from_env()
@ -52,11 +53,11 @@ from auth_middleware import optional_auth, require_auth
from config.settings import ( from config.settings import (
DISABLE_INGEST_WITH_LANGFLOW, DISABLE_INGEST_WITH_LANGFLOW,
EMBED_MODEL, EMBED_MODEL,
INDEX_BODY,
INDEX_NAME, INDEX_NAME,
SESSION_SECRET, SESSION_SECRET,
clients, clients,
is_no_auth_mode, is_no_auth_mode,
get_openrag_config,
) )
from services.auth_service import AuthService from services.auth_service import AuthService
from services.langflow_mcp_service import LangflowMCPService from services.langflow_mcp_service import LangflowMCPService
@ -81,7 +82,6 @@ logger.info(
cuda_version=torch.version.cuda, cuda_version=torch.version.cuda,
) )
async def wait_for_opensearch(): async def wait_for_opensearch():
"""Wait for OpenSearch to be ready with retries""" """Wait for OpenSearch to be ready with retries"""
max_retries = 30 max_retries = 30
@ -132,12 +132,19 @@ async def init_index():
"""Initialize OpenSearch index and security roles""" """Initialize OpenSearch index and security roles"""
await wait_for_opensearch() await wait_for_opensearch()
# Get the configured embedding model from user configuration
config = get_openrag_config()
embedding_model = config.knowledge.embedding_model
# Create dynamic index body based on the configured embedding model
dynamic_index_body = create_dynamic_index_body(embedding_model)
# Create documents index # Create documents index
if not await clients.opensearch.indices.exists(index=INDEX_NAME): if not await clients.opensearch.indices.exists(index=INDEX_NAME):
await clients.opensearch.indices.create(index=INDEX_NAME, body=INDEX_BODY) await clients.opensearch.indices.create(index=INDEX_NAME, body=dynamic_index_body)
logger.info("Created OpenSearch index", index_name=INDEX_NAME) logger.info("Created OpenSearch index", index_name=INDEX_NAME, embedding_model=embedding_model)
else: else:
logger.info("Index already exists, skipping creation", index_name=INDEX_NAME) logger.info("Index already exists, skipping creation", index_name=INDEX_NAME, embedding_model=embedding_model)
# Create knowledge filters index # Create knowledge filters index
knowledge_filter_index_name = "knowledge_filters" knowledge_filter_index_name = "knowledge_filters"
@ -395,7 +402,12 @@ async def _ingest_default_documents_openrag(services, file_paths):
async def startup_tasks(services): async def startup_tasks(services):
"""Startup tasks""" """Startup tasks"""
logger.info("Starting startup tasks") logger.info("Starting startup tasks")
await init_index() # Only initialize basic OpenSearch connection, not the index
# Index will be created after onboarding when we know the embedding model
await wait_for_opensearch()
# Configure alerting security
await configure_alerting_security()
async def initialize_services(): async def initialize_services():

View file

@ -1,3 +1,4 @@
import asyncio
from config.settings import ( from config.settings import (
NUDGES_FLOW_ID, NUDGES_FLOW_ID,
LANGFLOW_URL, LANGFLOW_URL,
@ -19,6 +20,7 @@ from config.settings import (
WATSONX_LLM_COMPONENT_ID, WATSONX_LLM_COMPONENT_ID,
OLLAMA_EMBEDDING_COMPONENT_ID, OLLAMA_EMBEDDING_COMPONENT_ID,
OLLAMA_LLM_COMPONENT_ID, OLLAMA_LLM_COMPONENT_ID,
get_openrag_config,
) )
import json import json
import os import os
@ -29,6 +31,74 @@ logger = get_logger(__name__)
class FlowsService: class FlowsService:
def __init__(self):
# Cache for flow file mappings to avoid repeated filesystem scans
self._flow_file_cache = {}
def _get_flows_directory(self):
"""Get the flows directory path"""
current_file_dir = os.path.dirname(os.path.abspath(__file__)) # src/services/
src_dir = os.path.dirname(current_file_dir) # src/
project_root = os.path.dirname(src_dir) # project root
return os.path.join(project_root, "flows")
def _find_flow_file_by_id(self, flow_id: str):
"""
Scan the flows directory and find the JSON file that contains the specified flow ID.
Args:
flow_id: The flow ID to search for
Returns:
str: The path to the flow file, or None if not found
"""
if not flow_id:
raise ValueError("flow_id is required")
# Check cache first
if flow_id in self._flow_file_cache:
cached_path = self._flow_file_cache[flow_id]
if os.path.exists(cached_path):
return cached_path
else:
# Remove stale cache entry
del self._flow_file_cache[flow_id]
flows_dir = self._get_flows_directory()
if not os.path.exists(flows_dir):
logger.warning(f"Flows directory not found: {flows_dir}")
return None
# Scan all JSON files in the flows directory
try:
for filename in os.listdir(flows_dir):
if not filename.endswith('.json'):
continue
file_path = os.path.join(flows_dir, filename)
try:
with open(file_path, 'r') as f:
flow_data = json.load(f)
# Check if this file contains the flow we're looking for
if flow_data.get('id') == flow_id:
# Cache the result
self._flow_file_cache[flow_id] = file_path
logger.info(f"Found flow {flow_id} in file: {filename}")
return file_path
except (json.JSONDecodeError, FileNotFoundError) as e:
logger.warning(f"Error reading flow file {filename}: {e}")
continue
except Exception as e:
logger.error(f"Error scanning flows directory: {e}")
return None
logger.warning(f"Flow with ID {flow_id} not found in flows directory")
return None
async def reset_langflow_flow(self, flow_type: str): async def reset_langflow_flow(self, flow_type: str):
"""Reset a Langflow flow by uploading the corresponding JSON file """Reset a Langflow flow by uploading the corresponding JSON file
@ -41,59 +111,35 @@ class FlowsService:
if not LANGFLOW_URL: if not LANGFLOW_URL:
raise ValueError("LANGFLOW_URL environment variable is required") raise ValueError("LANGFLOW_URL environment variable is required")
# Determine flow file and ID based on type # Determine flow ID based on type
if flow_type == "nudges": if flow_type == "nudges":
flow_file = "flows/openrag_nudges.json"
flow_id = NUDGES_FLOW_ID flow_id = NUDGES_FLOW_ID
elif flow_type == "retrieval": elif flow_type == "retrieval":
flow_file = "flows/openrag_agent.json"
flow_id = LANGFLOW_CHAT_FLOW_ID flow_id = LANGFLOW_CHAT_FLOW_ID
elif flow_type == "ingest": elif flow_type == "ingest":
flow_file = "flows/ingestion_flow.json"
flow_id = LANGFLOW_INGEST_FLOW_ID flow_id = LANGFLOW_INGEST_FLOW_ID
else: else:
raise ValueError( raise ValueError(
"flow_type must be either 'nudges', 'retrieval', or 'ingest'" "flow_type must be either 'nudges', 'retrieval', or 'ingest'"
) )
if not flow_id:
raise ValueError(f"Flow ID not configured for flow_type '{flow_type}'")
# Dynamically find the flow file by ID
flow_path = self._find_flow_file_by_id(flow_id)
if not flow_path:
raise FileNotFoundError(f"Flow file not found for flow ID: {flow_id}")
# Load flow JSON file # Load flow JSON file
try: try:
# Get the project root directory (go up from src/services/ to project root)
# __file__ is src/services/chat_service.py
# os.path.dirname(__file__) is src/services/
# os.path.dirname(os.path.dirname(__file__)) is src/
# os.path.dirname(os.path.dirname(os.path.dirname(__file__))) is project root
current_file_dir = os.path.dirname(
os.path.abspath(__file__)
) # src/services/
src_dir = os.path.dirname(current_file_dir) # src/
project_root = os.path.dirname(src_dir) # project root
flow_path = os.path.join(project_root, flow_file)
if not os.path.exists(flow_path):
# List contents of project root to help debug
try:
contents = os.listdir(project_root)
logger.info(f"Project root contents: {contents}")
flows_dir = os.path.join(project_root, "flows")
if os.path.exists(flows_dir):
flows_contents = os.listdir(flows_dir)
logger.info(f"Flows directory contents: {flows_contents}")
else:
logger.info("Flows directory does not exist")
except Exception as e:
logger.error(f"Error listing directory contents: {e}")
raise FileNotFoundError(f"Flow file not found at: {flow_path}")
with open(flow_path, "r") as f: with open(flow_path, "r") as f:
flow_data = json.load(f) flow_data = json.load(f)
logger.info(f"Successfully loaded flow data from {flow_file}") logger.info(f"Successfully loaded flow data for {flow_type} from {os.path.basename(flow_path)}")
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in flow file {flow_path}: {e}")
except FileNotFoundError: except FileNotFoundError:
raise ValueError(f"Flow file not found: {flow_path}") raise ValueError(f"Flow file not found: {flow_path}")
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in flow file {flow_file}: {e}")
# Make PATCH request to Langflow API to update the flow using shared client # Make PATCH request to Langflow API to update the flow using shared client
try: try:
@ -106,8 +152,54 @@ class FlowsService:
logger.info( logger.info(
f"Successfully reset {flow_type} flow", f"Successfully reset {flow_type} flow",
flow_id=flow_id, flow_id=flow_id,
flow_file=flow_file, flow_file=os.path.basename(flow_path),
) )
# Now update the flow with current configuration settings
try:
config = get_openrag_config()
# Check if configuration has been edited (onboarding completed)
if config.edited:
logger.info(f"Updating {flow_type} flow with current configuration settings")
provider = config.provider.model_provider.lower()
# Step 1: Assign model provider (replace components) if not OpenAI
if provider != "openai":
logger.info(f"Assigning {provider} components to {flow_type} flow")
provider_result = await self.assign_model_provider(provider)
if not provider_result.get("success"):
logger.warning(f"Failed to assign {provider} components: {provider_result.get('error', 'Unknown error')}")
# Continue anyway, maybe just value updates will work
# Step 2: Update model values for the specific flow being reset
single_flow_config = [{
"name": flow_type,
"flow_id": flow_id,
}]
logger.info(f"Updating {flow_type} flow model values")
update_result = await self.change_langflow_model_value(
provider=provider,
embedding_model=config.knowledge.embedding_model,
llm_model=config.agent.llm_model,
endpoint=config.provider.endpoint if config.provider.endpoint else None,
flow_configs=single_flow_config
)
if update_result.get("success"):
logger.info(f"Successfully updated {flow_type} flow with current configuration")
else:
logger.warning(f"Failed to update {flow_type} flow with current configuration: {update_result.get('error', 'Unknown error')}")
else:
logger.info(f"Configuration not yet edited (onboarding not completed), skipping model updates for {flow_type} flow")
except Exception as e:
logger.error(f"Error updating {flow_type} flow with current configuration", error=str(e))
# Don't fail the entire reset operation if configuration update fails
return { return {
"success": True, "success": True,
"message": f"Successfully reset {flow_type} flow", "message": f"Successfully reset {flow_type} flow",
@ -155,11 +247,10 @@ class FlowsService:
logger.info(f"Assigning {provider} components") logger.info(f"Assigning {provider} components")
# Define flow configurations # Define flow configurations (removed hardcoded file paths)
flow_configs = [ flow_configs = [
{ {
"name": "nudges", "name": "nudges",
"file": "flows/openrag_nudges.json",
"flow_id": NUDGES_FLOW_ID, "flow_id": NUDGES_FLOW_ID,
"embedding_id": OPENAI_EMBEDDING_COMPONENT_ID, "embedding_id": OPENAI_EMBEDDING_COMPONENT_ID,
"llm_id": OPENAI_LLM_COMPONENT_ID, "llm_id": OPENAI_LLM_COMPONENT_ID,
@ -167,7 +258,6 @@ class FlowsService:
}, },
{ {
"name": "retrieval", "name": "retrieval",
"file": "flows/openrag_agent.json",
"flow_id": LANGFLOW_CHAT_FLOW_ID, "flow_id": LANGFLOW_CHAT_FLOW_ID,
"embedding_id": OPENAI_EMBEDDING_COMPONENT_ID, "embedding_id": OPENAI_EMBEDDING_COMPONENT_ID,
"llm_id": OPENAI_LLM_COMPONENT_ID, "llm_id": OPENAI_LLM_COMPONENT_ID,
@ -175,7 +265,6 @@ class FlowsService:
}, },
{ {
"name": "ingest", "name": "ingest",
"file": "flows/ingestion_flow.json",
"flow_id": LANGFLOW_INGEST_FLOW_ID, "flow_id": LANGFLOW_INGEST_FLOW_ID,
"embedding_id": OPENAI_EMBEDDING_COMPONENT_ID, "embedding_id": OPENAI_EMBEDDING_COMPONENT_ID,
"llm_id": None, # Ingestion flow might not have LLM "llm_id": None, # Ingestion flow might not have LLM
@ -272,7 +361,6 @@ class FlowsService:
async def _update_flow_components(self, config, llm_template, embedding_template, llm_text_template): async def _update_flow_components(self, config, llm_template, embedding_template, llm_text_template):
"""Update components in a specific flow""" """Update components in a specific flow"""
flow_name = config["name"] flow_name = config["name"]
flow_file = config["file"]
flow_id = config["flow_id"] flow_id = config["flow_id"]
old_embedding_id = config["embedding_id"] old_embedding_id = config["embedding_id"]
old_llm_id = config["llm_id"] old_llm_id = config["llm_id"]
@ -281,14 +369,11 @@ class FlowsService:
new_llm_id = llm_template["data"]["id"] new_llm_id = llm_template["data"]["id"]
new_embedding_id = embedding_template["data"]["id"] new_embedding_id = embedding_template["data"]["id"]
new_llm_text_id = llm_text_template["data"]["id"] new_llm_text_id = llm_text_template["data"]["id"]
# Get the project root directory
current_file_dir = os.path.dirname(os.path.abspath(__file__))
src_dir = os.path.dirname(current_file_dir)
project_root = os.path.dirname(src_dir)
flow_path = os.path.join(project_root, flow_file)
if not os.path.exists(flow_path): # Dynamically find the flow file by ID
raise FileNotFoundError(f"Flow file not found at: {flow_path}") flow_path = self._find_flow_file_by_id(flow_id)
if not flow_path:
raise FileNotFoundError(f"Flow file not found for flow ID: {flow_id}")
# Load flow JSON # Load flow JSON
with open(flow_path, "r") as f: with open(flow_path, "r") as f:
@ -527,16 +612,17 @@ class FlowsService:
return False return False
async def change_langflow_model_value( async def change_langflow_model_value(
self, provider: str, embedding_model: str, llm_model: str, endpoint: str = None self, provider: str, embedding_model: str, llm_model: str, endpoint: str = None, flow_configs: list = None
): ):
""" """
Change dropdown values for provider-specific components across all flows Change dropdown values for provider-specific components across flows
Args: Args:
provider: The provider ("watsonx", "ollama", "openai") provider: The provider ("watsonx", "ollama", "openai")
embedding_model: The embedding model name to set embedding_model: The embedding model name to set
llm_model: The LLM model name to set llm_model: The LLM model name to set
endpoint: The endpoint URL (required for watsonx/ibm provider) endpoint: The endpoint URL (required for watsonx/ibm provider)
flow_configs: Optional list of specific flow configs to update. If None, updates all flows.
Returns: Returns:
dict: Success/error response with details for each flow dict: Success/error response with details for each flow
@ -552,21 +638,19 @@ class FlowsService:
f"Changing dropdown values for provider {provider}, embedding: {embedding_model}, llm: {llm_model}, endpoint: {endpoint}" f"Changing dropdown values for provider {provider}, embedding: {embedding_model}, llm: {llm_model}, endpoint: {endpoint}"
) )
# Define flow configurations with provider-specific component IDs # Use provided flow_configs or default to all flows
if flow_configs is None:
flow_configs = [ flow_configs = [
{ {
"name": "nudges", "name": "nudges",
"file": "flows/openrag_nudges.json",
"flow_id": NUDGES_FLOW_ID, "flow_id": NUDGES_FLOW_ID,
}, },
{ {
"name": "retrieval", "name": "retrieval",
"file": "flows/openrag_agent.json",
"flow_id": LANGFLOW_CHAT_FLOW_ID, "flow_id": LANGFLOW_CHAT_FLOW_ID,
}, },
{ {
"name": "ingest", "name": "ingest",
"file": "flows/ingestion_flow.json",
"flow_id": LANGFLOW_INGEST_FLOW_ID, "flow_id": LANGFLOW_INGEST_FLOW_ID,
}, },
] ]

64
src/utils/embeddings.py Normal file
View file

@ -0,0 +1,64 @@
from config.settings import OLLAMA_EMBEDDING_DIMENSIONS, OPENAI_EMBEDDING_DIMENSIONS, VECTOR_DIM, WATSONX_EMBEDDING_DIMENSIONS
from utils.logging_config import get_logger
logger = get_logger(__name__)
def get_embedding_dimensions(model_name: str) -> int:
"""Get the embedding dimensions for a given model name."""
# Check all model dictionaries
all_models = {**OPENAI_EMBEDDING_DIMENSIONS, **OLLAMA_EMBEDDING_DIMENSIONS, **WATSONX_EMBEDDING_DIMENSIONS}
if model_name in all_models:
dimensions = all_models[model_name]
logger.info(f"Found dimensions for model '{model_name}': {dimensions}")
return dimensions
logger.warning(
f"Unknown embedding model '{model_name}', using default dimensions: {VECTOR_DIM}"
)
return VECTOR_DIM
def create_dynamic_index_body(embedding_model: str) -> dict:
"""Create a dynamic index body configuration based on the embedding model."""
dimensions = get_embedding_dimensions(embedding_model)
return {
"settings": {
"index": {"knn": True},
"number_of_shards": 1,
"number_of_replicas": 1,
},
"mappings": {
"properties": {
"document_id": {"type": "keyword"},
"filename": {"type": "keyword"},
"mimetype": {"type": "keyword"},
"page": {"type": "integer"},
"text": {"type": "text"},
"chunk_embedding": {
"type": "knn_vector",
"dimension": dimensions,
"method": {
"name": "disk_ann",
"engine": "jvector",
"space_type": "l2",
"parameters": {"ef_construction": 100, "m": 16},
},
},
"source_url": {"type": "keyword"},
"connector_type": {"type": "keyword"},
"owner": {"type": "keyword"},
"allowed_users": {"type": "keyword"},
"allowed_groups": {"type": "keyword"},
"user_permissions": {"type": "object"},
"group_permissions": {"type": "object"},
"created_time": {"type": "date"},
"modified_time": {"type": "date"},
"indexed_time": {"type": "date"},
"metadata": {"type": "object"},
}
},
}