diff --git a/.gitignore b/.gitignore index b2977194..4f22035a 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,4 @@ wheels/ 1001*.pdf *.json +.DS_Store diff --git a/Dockerfile.langflow b/Dockerfile.langflow new file mode 100644 index 00000000..99e6e155 --- /dev/null +++ b/Dockerfile.langflow @@ -0,0 +1,49 @@ +FROM python:3.12-slim + +# Set environment variables +ENV DEBIAN_FRONTEND=noninteractive +ENV PYTHONUNBUFFERED=1 +ENV RUSTFLAGS="--cfg reqwest_unstable" + +# Accept build arguments for git repository and branch +ARG GIT_REPO=https://github.com/langflow-ai/langflow.git +ARG GIT_BRANCH=load_flows_autologin_false + +WORKDIR /app + +# Install system dependencies +RUN apt-get update && apt-get install -y \ + build-essential \ + curl \ + git \ + ca-certificates \ + gnupg \ + npm \ + rustc cargo pkg-config libssl-dev \ + && rm -rf /var/lib/apt/lists/* + +# Install uv for faster Python package management +RUN pip install uv + +# Clone the repository and checkout the specified branch +RUN git clone --depth 1 --branch ${GIT_BRANCH} ${GIT_REPO} /app + +# Install backend dependencies +RUN uv sync --frozen --no-install-project --no-editable --extra postgresql + +# Build frontend +WORKDIR /app/src/frontend +RUN npm ci && \ + npm run build && \ + mkdir -p /app/src/backend/base/langflow/frontend && \ + cp -r build/* /app/src/backend/base/langflow/frontend/ + +# Return to app directory and install the project +WORKDIR /app +RUN uv sync --frozen --no-dev --no-editable --extra postgresql + +# Expose ports +EXPOSE 7860 + +# Start the backend server +CMD ["uv", "run", "langflow", "run", "--host", "0.0.0.0", "--port", "7860"] diff --git a/docker-compose.yml b/docker-compose.yml index 78059a46..47781eb6 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -104,4 +104,4 @@ services: - LANGFLOW_SUPERUSER=${LANGFLOW_SUPERUSER} - LANGFLOW_SUPERUSER_PASSWORD=${LANGFLOW_SUPERUSER_PASSWORD} - LANGFLOW_NEW_USER_IS_ACTIVE=${LANGFLOW_NEW_USER_IS_ACTIVE} - - LANGFLOW_ENABLE_SUPERUSER_CLI=${LANGFLOW_ENABLE_SUPERUSER_CLI} + - LANGFLOW_ENABLE_SUPERUSER_CLI=${LANGFLOW_ENABLE_SUPERUSER_CLI} \ No newline at end of file diff --git a/frontend/components/knowledge-dropdown.tsx b/frontend/components/knowledge-dropdown.tsx index 917d64f5..8088964b 100644 --- a/frontend/components/knowledge-dropdown.tsx +++ b/frontend/components/knowledge-dropdown.tsx @@ -1,7 +1,6 @@ "use client" import { useState, useEffect, useRef } from "react" -import { useRouter } from "next/navigation" import { ChevronDown, Upload, FolderOpen, Cloud, PlugZap, Plus } from "lucide-react" import { Button } from "@/components/ui/button" import { Dialog, DialogContent, DialogDescription, DialogHeader, DialogTitle } from "@/components/ui/dialog" @@ -9,6 +8,7 @@ import { Input } from "@/components/ui/input" import { Label } from "@/components/ui/label" import { cn } from "@/lib/utils" import { useTask } from "@/contexts/task-context" +import { useRouter } from "next/navigation" interface KnowledgeDropdownProps { active?: boolean @@ -16,8 +16,8 @@ interface KnowledgeDropdownProps { } export function KnowledgeDropdown({ active, variant = 'navigation' }: KnowledgeDropdownProps) { - const router = useRouter() const { addTask } = useTask() + const router = useRouter() const [isOpen, setIsOpen] = useState(false) const [showFolderDialog, setShowFolderDialog] = useState(false) const [showS3Dialog, setShowS3Dialog] = useState(false) @@ -27,23 +27,76 @@ export function KnowledgeDropdown({ active, variant = 'navigation' }: KnowledgeD const [folderLoading, setFolderLoading] = useState(false) const [s3Loading, setS3Loading] = useState(false) const [fileUploading, setFileUploading] = useState(false) + const [cloudConnectors, setCloudConnectors] = useState<{[key: string]: {name: string, available: boolean, connected: boolean, hasToken: boolean}}>({}) const fileInputRef = useRef(null) const dropdownRef = useRef(null) - // Check AWS availability on mount + // Check AWS availability and cloud connectors on mount useEffect(() => { - const checkAws = async () => { + const checkAvailability = async () => { try { - const res = await fetch("/api/upload_options") - if (res.ok) { - const data = await res.json() - setAwsEnabled(Boolean(data.aws)) + // Check AWS + const awsRes = await fetch("/api/upload_options") + if (awsRes.ok) { + const awsData = await awsRes.json() + setAwsEnabled(Boolean(awsData.aws)) + } + + // Check cloud connectors + const connectorsRes = await fetch('/api/connectors') + if (connectorsRes.ok) { + const connectorsResult = await connectorsRes.json() + const cloudConnectorTypes = ['google_drive', 'onedrive', 'sharepoint'] + const connectorInfo: {[key: string]: {name: string, available: boolean, connected: boolean, hasToken: boolean}} = {} + + for (const type of cloudConnectorTypes) { + if (connectorsResult.connectors[type]) { + connectorInfo[type] = { + name: connectorsResult.connectors[type].name, + available: connectorsResult.connectors[type].available, + connected: false, + hasToken: false + } + + // Check connection status + try { + const statusRes = await fetch(`/api/connectors/${type}/status`) + if (statusRes.ok) { + const statusData = await statusRes.json() + const connections = statusData.connections || [] + const activeConnection = connections.find((conn: {is_active: boolean, connection_id: string}) => conn.is_active) + const isConnected = activeConnection !== undefined + + if (isConnected && activeConnection) { + connectorInfo[type].connected = true + + // Check token availability + try { + const tokenRes = await fetch(`/api/connectors/${type}/token?connection_id=${activeConnection.connection_id}`) + if (tokenRes.ok) { + const tokenData = await tokenRes.json() + if (tokenData.access_token) { + connectorInfo[type].hasToken = true + } + } + } catch { + // Token check failed + } + } + } + } catch { + // Status check failed + } + } + } + + setCloudConnectors(connectorInfo) } } catch (err) { - console.error("Failed to check AWS availability", err) + console.error("Failed to check availability", err) } } - checkAws() + checkAvailability() }, []) // Handle click outside to close dropdown @@ -220,6 +273,25 @@ export function KnowledgeDropdown({ active, variant = 'navigation' }: KnowledgeD } } + const cloudConnectorItems = Object.entries(cloudConnectors) + .filter(([, info]) => info.available) + .map(([type, info]) => ({ + label: info.name, + icon: PlugZap, + onClick: () => { + setIsOpen(false) + if (info.connected && info.hasToken) { + router.push(`/upload/${type}`) + } else { + router.push('/settings') + } + }, + disabled: !info.connected || !info.hasToken, + tooltip: !info.connected ? `Connect ${info.name} in Settings first` : + !info.hasToken ? `Reconnect ${info.name} - access token required` : + undefined + })) + const menuItems = [ { label: "Add File", @@ -242,14 +314,7 @@ export function KnowledgeDropdown({ active, variant = 'navigation' }: KnowledgeD setShowS3Dialog(true) } }] : []), - { - label: "Cloud Connectors", - icon: PlugZap, - onClick: () => { - setIsOpen(false) - router.push("/settings") - } - } + ...cloudConnectorItems ] return ( @@ -291,7 +356,12 @@ export function KnowledgeDropdown({ active, variant = 'navigation' }: KnowledgeD @@ -390,6 +460,7 @@ export function KnowledgeDropdown({ active, variant = 'navigation' }: KnowledgeD + ) } \ No newline at end of file diff --git a/frontend/components/navigation.tsx b/frontend/components/navigation.tsx index 6581ab68..7419a25a 100644 --- a/frontend/components/navigation.tsx +++ b/frontend/components/navigation.tsx @@ -85,6 +85,14 @@ export function Navigation() { if (!response.ok) { const errorText = await response.text() console.error("Upload failed:", errorText) + + // Trigger error event for chat page to handle + window.dispatchEvent(new CustomEvent('fileUploadError', { + detail: { filename: file.name, error: 'Failed to process document' } + })) + + // Trigger loading end event + window.dispatchEvent(new CustomEvent('fileUploadComplete')) return } @@ -111,7 +119,7 @@ export function Navigation() { // Trigger error event for chat page to handle window.dispatchEvent(new CustomEvent('fileUploadError', { - detail: { filename: file.name, error: error instanceof Error ? error.message : 'Unknown error' } + detail: { filename: file.name, error: 'Failed to process document' } })) } } diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 103dd7aa..5d7c9750 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -5402,18 +5402,6 @@ "@pkgjs/parseargs": "^0.11.0" } }, - "node_modules/jiti": { - "version": "2.4.2", - "resolved": "https://registry.npmjs.org/jiti/-/jiti-2.4.2.tgz", - "integrity": "sha512-rg9zJN+G4n2nfJl5MW3BMygZX56zKPNVEYYqq7adpmMh4Jn2QNEwhvQlFy6jPVdcod7txZtKHWnyZiA3a0zP7A==", - "dev": true, - "license": "MIT", - "optional": true, - "peer": true, - "bin": { - "jiti": "lib/jiti-cli.mjs" - } - }, "node_modules/js-tokens": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/js-tokens/-/js-tokens-4.0.0.tgz", diff --git a/frontend/src/app/chat/page.tsx b/frontend/src/app/chat/page.tsx index dca8084e..a38d8b32 100644 --- a/frontend/src/app/chat/page.tsx +++ b/frontend/src/app/chat/page.tsx @@ -1,285 +1,327 @@ -"use client" - -import { useState, useRef, useEffect } from "react" -import { Button } from "@/components/ui/button" -import { Loader2, User, Bot, Zap, Settings, ChevronDown, ChevronRight, Upload, AtSign, Plus, X, GitBranch } from "lucide-react" -import { ProtectedRoute } from "@/components/protected-route" -import { useTask } from "@/contexts/task-context" -import { useKnowledgeFilter } from "@/contexts/knowledge-filter-context" -import { useAuth } from "@/contexts/auth-context" -import { useChat, EndpointType } from "@/contexts/chat-context" -import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar" +"use client"; +import { ProtectedRoute } from "@/components/protected-route"; +import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar"; +import { Button } from "@/components/ui/button"; +import { useAuth } from "@/contexts/auth-context"; +import { EndpointType, useChat } from "@/contexts/chat-context"; +import { useKnowledgeFilter } from "@/contexts/knowledge-filter-context"; +import { useTask } from "@/contexts/task-context"; +import { + AtSign, + Bot, + ChevronDown, + ChevronRight, + GitBranch, + Loader2, + Plus, + Settings, + Upload, + User, + X, + Zap, +} from "lucide-react"; +import { useEffect, useRef, useState } from "react"; interface Message { - role: "user" | "assistant" - content: string - timestamp: Date - functionCalls?: FunctionCall[] - isStreaming?: boolean + role: "user" | "assistant"; + content: string; + timestamp: Date; + functionCalls?: FunctionCall[]; + isStreaming?: boolean; } interface FunctionCall { - name: string - arguments?: Record - result?: Record | ToolCallResult[] - status: "pending" | "completed" | "error" - argumentsString?: string - id?: string - type?: string + name: string; + arguments?: Record; + result?: Record | ToolCallResult[]; + status: "pending" | "completed" | "error"; + argumentsString?: string; + id?: string; + type?: string; } interface ToolCallResult { - text_key?: string + text_key?: string; data?: { - file_path?: string - text?: string - [key: string]: unknown - } - default_value?: string - [key: string]: unknown + file_path?: string; + text?: string; + [key: string]: unknown; + }; + default_value?: string; + [key: string]: unknown; } - - interface SelectedFilters { - data_sources: string[] - document_types: string[] - owners: string[] + data_sources: string[]; + document_types: string[]; + owners: string[]; } interface KnowledgeFilterData { - id: string - name: string - description: string - query_data: string - owner: string - created_at: string - updated_at: string + id: string; + name: string; + description: string; + query_data: string; + owner: string; + created_at: string; + updated_at: string; } interface RequestBody { - prompt: string - stream?: boolean - previous_response_id?: string - filters?: SelectedFilters - limit?: number - scoreThreshold?: number + prompt: string; + stream?: boolean; + previous_response_id?: string; + filters?: SelectedFilters; + limit?: number; + scoreThreshold?: number; } function ChatPage() { - const isDebugMode = process.env.NODE_ENV === 'development' || process.env.NEXT_PUBLIC_OPENRAG_DEBUG === 'true' - const { user } = useAuth() - const { endpoint, setEndpoint, currentConversationId, conversationData, setCurrentConversationId, addConversationDoc, forkFromResponse, refreshConversations, previousResponseIds, setPreviousResponseIds } = useChat() + const isDebugMode = + process.env.NODE_ENV === "development" || + process.env.NEXT_PUBLIC_OPENRAG_DEBUG === "true"; + const { user } = useAuth(); + const { + endpoint, + setEndpoint, + currentConversationId, + conversationData, + setCurrentConversationId, + addConversationDoc, + forkFromResponse, + refreshConversations, + refreshConversationsSilent, + previousResponseIds, + setPreviousResponseIds, + placeholderConversation, + } = useChat(); const [messages, setMessages] = useState([ { role: "assistant", content: "How can I assist?", - timestamp: new Date() - } - ]) - const [input, setInput] = useState("") - const [loading, setLoading] = useState(false) - const [asyncMode, setAsyncMode] = useState(true) + timestamp: new Date(), + }, + ]); + const [input, setInput] = useState(""); + const [loading, setLoading] = useState(false); + const [asyncMode, setAsyncMode] = useState(true); const [streamingMessage, setStreamingMessage] = useState<{ - content: string - functionCalls: FunctionCall[] - timestamp: Date - } | null>(null) - const [expandedFunctionCalls, setExpandedFunctionCalls] = useState>(new Set()) + content: string; + functionCalls: FunctionCall[]; + timestamp: Date; + } | null>(null); + const [expandedFunctionCalls, setExpandedFunctionCalls] = useState< + Set + >(new Set()); // previousResponseIds now comes from useChat context - const [isUploading, setIsUploading] = useState(false) - const [isDragOver, setIsDragOver] = useState(false) - const [isFilterDropdownOpen, setIsFilterDropdownOpen] = useState(false) - const [availableFilters, setAvailableFilters] = useState([]) - const [filterSearchTerm, setFilterSearchTerm] = useState("") - const [selectedFilterIndex, setSelectedFilterIndex] = useState(0) - const [isFilterHighlighted, setIsFilterHighlighted] = useState(false) - const [dropdownDismissed, setDropdownDismissed] = useState(false) - const [isUserInteracting, setIsUserInteracting] = useState(false) - const [isForkingInProgress, setIsForkingInProgress] = useState(false) - const [lastForkTimestamp, setLastForkTimestamp] = useState(0) - const dragCounterRef = useRef(0) - const messagesEndRef = useRef(null) - const inputRef = useRef(null) - const fileInputRef = useRef(null) - const dropdownRef = useRef(null) - const streamAbortRef = useRef(null) - const streamIdRef = useRef(0) - const { addTask, isMenuOpen } = useTask() - const { selectedFilter, parsedFilterData, isPanelOpen, setSelectedFilter } = useKnowledgeFilter() - - + const [isUploading, setIsUploading] = useState(false); + const [isDragOver, setIsDragOver] = useState(false); + const [isFilterDropdownOpen, setIsFilterDropdownOpen] = useState(false); + const [availableFilters, setAvailableFilters] = useState< + KnowledgeFilterData[] + >([]); + const [filterSearchTerm, setFilterSearchTerm] = useState(""); + const [selectedFilterIndex, setSelectedFilterIndex] = useState(0); + const [isFilterHighlighted, setIsFilterHighlighted] = useState(false); + const [dropdownDismissed, setDropdownDismissed] = useState(false); + const [isUserInteracting, setIsUserInteracting] = useState(false); + const [isForkingInProgress, setIsForkingInProgress] = useState(false); + const [lastForkTimestamp, setLastForkTimestamp] = useState(0); + const dragCounterRef = useRef(0); + const messagesEndRef = useRef(null); + const inputRef = useRef(null); + const fileInputRef = useRef(null); + const dropdownRef = useRef(null); + const streamAbortRef = useRef(null); + const streamIdRef = useRef(0); + const lastLoadedConversationRef = useRef(null); + const { addTask, isMenuOpen } = useTask(); + const { selectedFilter, parsedFilterData, isPanelOpen, setSelectedFilter } = + useKnowledgeFilter(); const scrollToBottom = () => { - messagesEndRef.current?.scrollIntoView({ behavior: "smooth" }) - } + messagesEndRef.current?.scrollIntoView({ behavior: "smooth" }); + }; const handleEndpointChange = (newEndpoint: EndpointType) => { - setEndpoint(newEndpoint) + setEndpoint(newEndpoint); // Clear the conversation when switching endpoints to avoid response ID conflicts - setMessages([]) - setPreviousResponseIds({ chat: null, langflow: null }) - } + setMessages([]); + setPreviousResponseIds({ chat: null, langflow: null }); + }; const handleFileUpload = async (file: File) => { - console.log("handleFileUpload called with file:", file.name) - - if (isUploading) return - - setIsUploading(true) - setLoading(true) - + console.log("handleFileUpload called with file:", file.name); + + if (isUploading) return; + + setIsUploading(true); + setLoading(true); + // Add initial upload message const uploadStartMessage: Message = { - role: "assistant", + role: "assistant", content: `🔄 Starting upload of **${file.name}**...`, - timestamp: new Date() - } - setMessages(prev => [...prev, uploadStartMessage]) - + timestamp: new Date(), + }; + setMessages((prev) => [...prev, uploadStartMessage]); + try { - const formData = new FormData() - formData.append('file', file) - formData.append('endpoint', endpoint) - + const formData = new FormData(); + formData.append("file", file); + formData.append("endpoint", endpoint); + // Add previous_response_id if we have one for this endpoint - const currentResponseId = previousResponseIds[endpoint] + const currentResponseId = previousResponseIds[endpoint]; if (currentResponseId) { - formData.append('previous_response_id', currentResponseId) + formData.append("previous_response_id", currentResponseId); } - - const response = await fetch('/api/upload_context', { - method: 'POST', + + const response = await fetch("/api/upload_context", { + method: "POST", body: formData, - }) - - console.log("Upload response status:", response.status) - + }); + + console.log("Upload response status:", response.status); + if (!response.ok) { - const errorText = await response.text() - console.error("Upload failed with status:", response.status, "Response:", errorText) - throw new Error(`Upload failed: ${response.status} - ${errorText}`) + const errorText = await response.text(); + console.error( + "Upload failed with status:", + response.status, + "Response:", + errorText + ); + throw new Error("Failed to process document"); } - - const result = await response.json() - console.log("Upload result:", result) - + + const result = await response.json(); + console.log("Upload result:", result); + if (response.status === 201) { // New flow: Got task ID, start tracking with centralized system - const taskId = result.task_id || result.id - + const taskId = result.task_id || result.id; + if (!taskId) { - console.error("No task ID in 201 response:", result) - throw new Error("No task ID received from server") + console.error("No task ID in 201 response:", result); + throw new Error("No task ID received from server"); } - + // Add task to centralized tracking - addTask(taskId) - + addTask(taskId); + // Update message to show task is being tracked const pollingMessage: Message = { role: "assistant", content: `⏳ Upload initiated for **${file.name}**. Processing in background... (Task ID: ${taskId})`, - timestamp: new Date() - } - setMessages(prev => [...prev.slice(0, -1), pollingMessage]) - + timestamp: new Date(), + }; + setMessages((prev) => [...prev.slice(0, -1), pollingMessage]); } else if (response.ok) { - // Original flow: Direct response - + // Original flow: Direct response + const uploadMessage: Message = { role: "assistant", - content: `📄 Document uploaded: **${result.filename}** (${result.pages} pages, ${result.content_length.toLocaleString()} characters)\n\n${result.confirmation}`, - timestamp: new Date() - } - - setMessages(prev => [...prev.slice(0, -1), uploadMessage]) - + content: `📄 Document uploaded: **${result.filename}** (${ + result.pages + } pages, ${result.content_length.toLocaleString()} characters)\n\n${ + result.confirmation + }`, + timestamp: new Date(), + }; + + setMessages((prev) => [...prev.slice(0, -1), uploadMessage]); + // Add file to conversation docs if (result.filename) { - addConversationDoc(result.filename) + addConversationDoc(result.filename); } - + // Update the response ID for this endpoint if (result.response_id) { - setPreviousResponseIds(prev => ({ + setPreviousResponseIds((prev) => ({ ...prev, - [endpoint]: result.response_id - })) + [endpoint]: result.response_id, + })); + + // If this is a new conversation (no currentConversationId), set it now + if (!currentConversationId) { + setCurrentConversationId(result.response_id); + refreshConversations(true); + } else { + // For existing conversations, do a silent refresh to keep backend in sync + refreshConversationsSilent(); + } } - // Sidebar should show this conversation after upload creates it - try { refreshConversations() } catch {} - } else { - throw new Error(`Upload failed: ${response.status}`) + throw new Error(`Upload failed: ${response.status}`); } - } catch (error) { - console.error('Upload failed:', error) + console.error("Upload failed:", error); const errorMessage: Message = { role: "assistant", - content: `❌ Upload failed: ${error instanceof Error ? error.message : 'Unknown error'}`, - timestamp: new Date() - } - setMessages(prev => [...prev.slice(0, -1), errorMessage]) + content: `❌ Failed to process document. Please try again.`, + timestamp: new Date(), + }; + setMessages((prev) => [...prev.slice(0, -1), errorMessage]); } finally { - setIsUploading(false) - setLoading(false) + setIsUploading(false); + setLoading(false); } - } + }; // Remove the old pollTaskStatus function since we're using centralized system const handleDragEnter = (e: React.DragEvent) => { - e.preventDefault() - e.stopPropagation() - dragCounterRef.current++ + e.preventDefault(); + e.stopPropagation(); + dragCounterRef.current++; if (dragCounterRef.current === 1) { - setIsDragOver(true) + setIsDragOver(true); } - } - + }; + const handleDragOver = (e: React.DragEvent) => { - e.preventDefault() - e.stopPropagation() - } - + e.preventDefault(); + e.stopPropagation(); + }; + const handleDragLeave = (e: React.DragEvent) => { - e.preventDefault() - e.stopPropagation() - dragCounterRef.current-- + e.preventDefault(); + e.stopPropagation(); + dragCounterRef.current--; if (dragCounterRef.current === 0) { - setIsDragOver(false) + setIsDragOver(false); } - } - + }; + const handleDrop = (e: React.DragEvent) => { - e.preventDefault() - e.stopPropagation() - dragCounterRef.current = 0 - setIsDragOver(false) - - const files = Array.from(e.dataTransfer.files) + e.preventDefault(); + e.stopPropagation(); + dragCounterRef.current = 0; + setIsDragOver(false); + + const files = Array.from(e.dataTransfer.files); if (files.length > 0) { - handleFileUpload(files[0]) // Upload first file only + handleFileUpload(files[0]); // Upload first file only } - } + }; const handleFilePickerClick = () => { - fileInputRef.current?.click() - } + fileInputRef.current?.click(); + }; const handleFilePickerChange = (e: React.ChangeEvent) => { - const files = e.target.files + const files = e.target.files; if (files && files.length > 0) { - handleFileUpload(files[0]) + handleFileUpload(files[0]); } // Reset the input so the same file can be selected again if (fileInputRef.current) { - fileInputRef.current.value = '' + fileInputRef.current.value = ""; } - } + }; const loadAvailableFilters = async () => { try { @@ -290,74 +332,74 @@ function ChatPage() { }, body: JSON.stringify({ query: "", - limit: 20 + limit: 20, }), - }) + }); - const result = await response.json() + const result = await response.json(); if (response.ok && result.success) { - setAvailableFilters(result.filters) + setAvailableFilters(result.filters); } else { - console.error("Failed to load knowledge filters:", result.error) - setAvailableFilters([]) + console.error("Failed to load knowledge filters:", result.error); + setAvailableFilters([]); } } catch (error) { - console.error('Failed to load knowledge filters:', error) - setAvailableFilters([]) + console.error("Failed to load knowledge filters:", error); + setAvailableFilters([]); } - } + }; const handleFilterDropdownToggle = () => { if (!isFilterDropdownOpen) { - loadAvailableFilters() + loadAvailableFilters(); } - setIsFilterDropdownOpen(!isFilterDropdownOpen) - } + setIsFilterDropdownOpen(!isFilterDropdownOpen); + }; const handleFilterSelect = (filter: KnowledgeFilterData | null) => { - setSelectedFilter(filter) - setIsFilterDropdownOpen(false) - setFilterSearchTerm("") - setIsFilterHighlighted(false) - + setSelectedFilter(filter); + setIsFilterDropdownOpen(false); + setFilterSearchTerm(""); + setIsFilterHighlighted(false); + // Remove the @searchTerm from the input and replace with filter pill - const words = input.split(' ') - const lastWord = words[words.length - 1] - - if (lastWord.startsWith('@')) { + const words = input.split(" "); + const lastWord = words[words.length - 1]; + + if (lastWord.startsWith("@")) { // Remove the @search term - words.pop() - setInput(words.join(' ') + (words.length > 0 ? ' ' : '')) + words.pop(); + setInput(words.join(" ") + (words.length > 0 ? " " : "")); } - } + }; useEffect(() => { // Only auto-scroll if not in the middle of user interaction if (!isUserInteracting) { const timer = setTimeout(() => { - scrollToBottom() - }, 50) // Small delay to avoid conflicts with click events - - return () => clearTimeout(timer) + scrollToBottom(); + }, 50); // Small delay to avoid conflicts with click events + + return () => clearTimeout(timer); } - }, [messages, streamingMessage, isUserInteracting]) + }, [messages, streamingMessage, isUserInteracting]); // Reset selected index when search term changes useEffect(() => { - setSelectedFilterIndex(0) - }, [filterSearchTerm]) + setSelectedFilterIndex(0); + }, [filterSearchTerm]); // Auto-focus the input on component mount useEffect(() => { - inputRef.current?.focus() - }, []) + inputRef.current?.focus(); + }, []); // Explicitly handle external new conversation trigger useEffect(() => { const handleNewConversation = () => { // Abort any in-flight streaming so it doesn't bleed into new chat if (streamAbortRef.current) { - streamAbortRef.current.abort() + streamAbortRef.current.abort(); } // Reset chat UI even if context state was already 'new' setMessages([ @@ -366,212 +408,352 @@ function ChatPage() { content: "How can I assist?", timestamp: new Date(), }, - ]) - setInput("") - setStreamingMessage(null) - setExpandedFunctionCalls(new Set()) - setIsFilterHighlighted(false) - setLoading(false) - } + ]); + setInput(""); + setStreamingMessage(null); + setExpandedFunctionCalls(new Set()); + setIsFilterHighlighted(false); + setLoading(false); + lastLoadedConversationRef.current = null; + }; const handleFocusInput = () => { - inputRef.current?.focus() - } + inputRef.current?.focus(); + }; - window.addEventListener('newConversation', handleNewConversation) - window.addEventListener('focusInput', handleFocusInput) + window.addEventListener("newConversation", handleNewConversation); + window.addEventListener("focusInput", handleFocusInput); return () => { - window.removeEventListener('newConversation', handleNewConversation) - window.removeEventListener('focusInput', handleFocusInput) - } - }, []) + window.removeEventListener("newConversation", handleNewConversation); + window.removeEventListener("focusInput", handleFocusInput); + }; + }, []); - // Load conversation when conversationData changes + // Load conversation only when user explicitly selects a conversation useEffect(() => { - const now = Date.now() - - // Don't reset messages if user is in the middle of an interaction (like forking) - if (isUserInteracting || isForkingInProgress) { - console.log("Skipping conversation load due to user interaction or forking") - return - } - - // Don't reload if we just forked recently (within 1 second) - if (now - lastForkTimestamp < 1000) { - console.log("Skipping conversation load - recent fork detected") - return - } - - if (conversationData && conversationData.messages) { - console.log("Loading conversation with", conversationData.messages.length, "messages") + // Only load conversation data when: + // 1. conversationData exists AND + // 2. It's different from the last loaded conversation AND + // 3. User is not in the middle of an interaction + if ( + conversationData && + conversationData.messages && + lastLoadedConversationRef.current !== conversationData.response_id && + !isUserInteracting && + !isForkingInProgress + ) { + console.log( + "Loading conversation with", + conversationData.messages.length, + "messages" + ); // Convert backend message format to frontend Message interface - const convertedMessages: Message[] = conversationData.messages.map((msg: { - role: string; - content: string; - timestamp?: string; - response_id?: string; - }) => ({ - role: msg.role as "user" | "assistant", - content: msg.content, - timestamp: new Date(msg.timestamp || new Date()), - // Add any other necessary properties - })) - - setMessages(convertedMessages) - + const convertedMessages: Message[] = conversationData.messages.map( + (msg: { + role: string; + content: string; + timestamp?: string; + response_id?: string; + chunks?: any[]; + response_data?: any; + }) => { + const message: Message = { + role: msg.role as "user" | "assistant", + content: msg.content, + timestamp: new Date(msg.timestamp || new Date()), + }; + + // Extract function calls from chunks or response_data + if (msg.role === "assistant" && (msg.chunks || msg.response_data)) { + const functionCalls: FunctionCall[] = []; + console.log("Processing assistant message for function calls:", { + hasChunks: !!msg.chunks, + chunksLength: msg.chunks?.length, + hasResponseData: !!msg.response_data, + }); + + // Process chunks (streaming data) + if (msg.chunks && Array.isArray(msg.chunks)) { + for (const chunk of msg.chunks) { + // Handle Langflow format: chunks[].item.tool_call + if (chunk.item && chunk.item.type === "tool_call") { + const toolCall = chunk.item; + console.log("Found Langflow tool call:", toolCall); + functionCalls.push({ + id: toolCall.id, + name: toolCall.tool_name, + arguments: toolCall.inputs || {}, + argumentsString: JSON.stringify(toolCall.inputs || {}), + result: toolCall.results, + status: toolCall.status || "completed", + type: "tool_call", + }); + } + // Handle OpenAI format: chunks[].delta.tool_calls + else if (chunk.delta?.tool_calls) { + for (const toolCall of chunk.delta.tool_calls) { + if (toolCall.function) { + functionCalls.push({ + id: toolCall.id, + name: toolCall.function.name, + arguments: toolCall.function.arguments ? JSON.parse(toolCall.function.arguments) : {}, + argumentsString: toolCall.function.arguments, + status: "completed", + type: toolCall.type || "function", + }); + } + } + } + // Process tool call results from chunks + if (chunk.type === "response.tool_call.result" || chunk.type === "tool_call_result") { + const lastCall = functionCalls[functionCalls.length - 1]; + if (lastCall) { + lastCall.result = chunk.result || chunk; + lastCall.status = "completed"; + } + } + } + } + + // Process response_data (non-streaming data) + if (msg.response_data && typeof msg.response_data === 'object') { + // Look for tool_calls in various places in the response data + const responseData = typeof msg.response_data === 'string' ? JSON.parse(msg.response_data) : msg.response_data; + + if (responseData.tool_calls && Array.isArray(responseData.tool_calls)) { + for (const toolCall of responseData.tool_calls) { + functionCalls.push({ + id: toolCall.id, + name: toolCall.function?.name || toolCall.name, + arguments: toolCall.function?.arguments || toolCall.arguments, + argumentsString: typeof (toolCall.function?.arguments || toolCall.arguments) === 'string' + ? toolCall.function?.arguments || toolCall.arguments + : JSON.stringify(toolCall.function?.arguments || toolCall.arguments), + result: toolCall.result, + status: "completed", + type: toolCall.type || "function", + }); + } + } + } + + if (functionCalls.length > 0) { + console.log("Setting functionCalls on message:", functionCalls); + message.functionCalls = functionCalls; + } else { + console.log("No function calls found in message"); + } + } + + return message; + } + ); + + setMessages(convertedMessages); + lastLoadedConversationRef.current = conversationData.response_id; + // Set the previous response ID for this conversation - setPreviousResponseIds(prev => ({ + setPreviousResponseIds((prev) => ({ ...prev, - [conversationData.endpoint]: conversationData.response_id - })) + [conversationData.endpoint]: conversationData.response_id, + })); } - // Reset messages when starting a new conversation (but not during forking) - else if (currentConversationId === null && !isUserInteracting && !isForkingInProgress && now - lastForkTimestamp > 1000) { - console.log("Resetting to default message for new conversation") + }, [ + conversationData, + isUserInteracting, + isForkingInProgress, + ]); + + // Handle new conversation creation - only reset messages when placeholderConversation is set + useEffect(() => { + if (placeholderConversation && currentConversationId === null) { + console.log("Starting new conversation"); setMessages([ { role: "assistant", content: "How can I assist?", - timestamp: new Date() - } - ]) + timestamp: new Date(), + }, + ]); + lastLoadedConversationRef.current = null; } - }, [conversationData, currentConversationId, isUserInteracting, isForkingInProgress, lastForkTimestamp, setPreviousResponseIds]) + }, [placeholderConversation, currentConversationId]); // Listen for file upload events from navigation useEffect(() => { const handleFileUploadStart = (event: CustomEvent) => { - const { filename } = event.detail - console.log("Chat page received file upload start event:", filename) - - setLoading(true) - setIsUploading(true) - + const { filename } = event.detail; + console.log("Chat page received file upload start event:", filename); + + setLoading(true); + setIsUploading(true); + // Add initial upload message const uploadStartMessage: Message = { - role: "assistant", + role: "assistant", content: `🔄 Starting upload of **${filename}**...`, - timestamp: new Date() - } - setMessages(prev => [...prev, uploadStartMessage]) - } + timestamp: new Date(), + }; + setMessages((prev) => [...prev, uploadStartMessage]); + }; const handleFileUploaded = (event: CustomEvent) => { - const { result } = event.detail - console.log("Chat page received file upload event:", result) - + const { result } = event.detail; + console.log("Chat page received file upload event:", result); + // Replace the last message with upload complete message const uploadMessage: Message = { role: "assistant", - content: `📄 Document uploaded: **${result.filename}** (${result.pages} pages, ${result.content_length.toLocaleString()} characters)\n\n${result.confirmation}`, - timestamp: new Date() - } - - setMessages(prev => [...prev.slice(0, -1), uploadMessage]) - + content: `📄 Document uploaded: **${result.filename}** (${ + result.pages + } pages, ${result.content_length.toLocaleString()} characters)\n\n${ + result.confirmation + }`, + timestamp: new Date(), + }; + + setMessages((prev) => [...prev.slice(0, -1), uploadMessage]); + // Update the response ID for this endpoint if (result.response_id) { - setPreviousResponseIds(prev => ({ + setPreviousResponseIds((prev) => ({ ...prev, - [endpoint]: result.response_id - })) + [endpoint]: result.response_id, + })); } - } + }; const handleFileUploadComplete = () => { - console.log("Chat page received file upload complete event") - setLoading(false) - setIsUploading(false) - } + console.log("Chat page received file upload complete event"); + setLoading(false); + setIsUploading(false); + }; const handleFileUploadError = (event: CustomEvent) => { - const { filename, error } = event.detail - console.log("Chat page received file upload error event:", filename, error) - + const { filename, error } = event.detail; + console.log( + "Chat page received file upload error event:", + filename, + error + ); + // Replace the last message with error message const errorMessage: Message = { role: "assistant", content: `❌ Upload failed for **${filename}**: ${error}`, - timestamp: new Date() - } - setMessages(prev => [...prev.slice(0, -1), errorMessage]) - } + timestamp: new Date(), + }; + setMessages((prev) => [...prev.slice(0, -1), errorMessage]); + }; + + window.addEventListener( + "fileUploadStart", + handleFileUploadStart as EventListener + ); + window.addEventListener( + "fileUploaded", + handleFileUploaded as EventListener + ); + window.addEventListener( + "fileUploadComplete", + handleFileUploadComplete as EventListener + ); + window.addEventListener( + "fileUploadError", + handleFileUploadError as EventListener + ); - window.addEventListener('fileUploadStart', handleFileUploadStart as EventListener) - window.addEventListener('fileUploaded', handleFileUploaded as EventListener) - window.addEventListener('fileUploadComplete', handleFileUploadComplete as EventListener) - window.addEventListener('fileUploadError', handleFileUploadError as EventListener) - return () => { - window.removeEventListener('fileUploadStart', handleFileUploadStart as EventListener) - window.removeEventListener('fileUploaded', handleFileUploaded as EventListener) - window.removeEventListener('fileUploadComplete', handleFileUploadComplete as EventListener) - window.removeEventListener('fileUploadError', handleFileUploadError as EventListener) - } - }, [endpoint, setPreviousResponseIds]) + window.removeEventListener( + "fileUploadStart", + handleFileUploadStart as EventListener + ); + window.removeEventListener( + "fileUploaded", + handleFileUploaded as EventListener + ); + window.removeEventListener( + "fileUploadComplete", + handleFileUploadComplete as EventListener + ); + window.removeEventListener( + "fileUploadError", + handleFileUploadError as EventListener + ); + }; + }, [endpoint, setPreviousResponseIds]); // Handle click outside to close dropdown useEffect(() => { const handleClickOutside = (event: MouseEvent) => { - if (isFilterDropdownOpen && - dropdownRef.current && - !dropdownRef.current.contains(event.target as Node) && - !inputRef.current?.contains(event.target as Node)) { - setIsFilterDropdownOpen(false) - setFilterSearchTerm("") - setSelectedFilterIndex(0) + if ( + isFilterDropdownOpen && + dropdownRef.current && + !dropdownRef.current.contains(event.target as Node) && + !inputRef.current?.contains(event.target as Node) + ) { + setIsFilterDropdownOpen(false); + setFilterSearchTerm(""); + setSelectedFilterIndex(0); } - } + }; - document.addEventListener('mousedown', handleClickOutside) + document.addEventListener("mousedown", handleClickOutside); return () => { - document.removeEventListener('mousedown', handleClickOutside) - } - }, [isFilterDropdownOpen]) - + document.removeEventListener("mousedown", handleClickOutside); + }; + }, [isFilterDropdownOpen]); const handleSSEStream = async (userMessage: Message) => { - const apiEndpoint = endpoint === "chat" ? "/api/chat" : "/api/langflow" - + const apiEndpoint = endpoint === "chat" ? "/api/chat" : "/api/langflow"; + try { // Abort any existing stream before starting a new one if (streamAbortRef.current) { - streamAbortRef.current.abort() + streamAbortRef.current.abort(); } - const controller = new AbortController() - streamAbortRef.current = controller - const thisStreamId = ++streamIdRef.current + const controller = new AbortController(); + streamAbortRef.current = controller; + const thisStreamId = ++streamIdRef.current; const requestBody: RequestBody = { prompt: userMessage.content, stream: true, - ...(parsedFilterData?.filters && (() => { - const filters = parsedFilterData.filters - const processed: SelectedFilters = { - data_sources: [], - document_types: [], - owners: [] - } - // Only copy non-wildcard arrays - processed.data_sources = filters.data_sources.includes("*") ? [] : filters.data_sources - processed.document_types = filters.document_types.includes("*") ? [] : filters.document_types - processed.owners = filters.owners.includes("*") ? [] : filters.owners - - // Only include filters if any array has values - const hasFilters = processed.data_sources.length > 0 || - processed.document_types.length > 0 || - processed.owners.length > 0 - return hasFilters ? { filters: processed } : {} - })()), + ...(parsedFilterData?.filters && + (() => { + const filters = parsedFilterData.filters; + const processed: SelectedFilters = { + data_sources: [], + document_types: [], + owners: [], + }; + // Only copy non-wildcard arrays + processed.data_sources = filters.data_sources.includes("*") + ? [] + : filters.data_sources; + processed.document_types = filters.document_types.includes("*") + ? [] + : filters.document_types; + processed.owners = filters.owners.includes("*") + ? [] + : filters.owners; + + // Only include filters if any array has values + const hasFilters = + processed.data_sources.length > 0 || + processed.document_types.length > 0 || + processed.owners.length > 0; + return hasFilters ? { filters: processed } : {}; + })()), limit: parsedFilterData?.limit ?? 10, - scoreThreshold: parsedFilterData?.scoreThreshold ?? 0 - } - + scoreThreshold: parsedFilterData?.scoreThreshold ?? 0, + }; + // Add previous_response_id if we have one for this endpoint - const currentResponseId = previousResponseIds[endpoint] + const currentResponseId = previousResponseIds[endpoint]; if (currentResponseId) { - requestBody.previous_response_id = currentResponseId + requestBody.previous_response_id = currentResponseId; } - + const response = await fetch(apiEndpoint, { method: "POST", headers: { @@ -579,138 +761,183 @@ function ChatPage() { }, body: JSON.stringify(requestBody), signal: controller.signal, - }) + }); if (!response.ok) { - throw new Error(`HTTP error! status: ${response.status}`) + throw new Error(`HTTP error! status: ${response.status}`); } - const reader = response.body?.getReader() + const reader = response.body?.getReader(); if (!reader) { - throw new Error("No reader available") + throw new Error("No reader available"); } - const decoder = new TextDecoder() - let buffer = "" - let currentContent = "" - const currentFunctionCalls: FunctionCall[] = [] - let newResponseId: string | null = null - + const decoder = new TextDecoder(); + let buffer = ""; + let currentContent = ""; + const currentFunctionCalls: FunctionCall[] = []; + let newResponseId: string | null = null; + // Initialize streaming message if (!controller.signal.aborted && thisStreamId === streamIdRef.current) { setStreamingMessage({ content: "", functionCalls: [], - timestamp: new Date() - }) + timestamp: new Date(), + }); } try { while (true) { - const { done, value } = await reader.read() - if (controller.signal.aborted || thisStreamId !== streamIdRef.current) break - if (done) break - buffer += decoder.decode(value, { stream: true }) - + const { done, value } = await reader.read(); + if (controller.signal.aborted || thisStreamId !== streamIdRef.current) + break; + if (done) break; + buffer += decoder.decode(value, { stream: true }); + // Process complete lines (JSON objects) - const lines = buffer.split('\n') - buffer = lines.pop() || "" // Keep incomplete line in buffer - + const lines = buffer.split("\n"); + buffer = lines.pop() || ""; // Keep incomplete line in buffer + for (const line of lines) { if (line.trim()) { try { - const chunk = JSON.parse(line) - console.log("Received chunk:", chunk.type || chunk.object, chunk) - + const chunk = JSON.parse(line); + console.log( + "Received chunk:", + chunk.type || chunk.object, + chunk + ); + // Extract response ID if present if (chunk.id) { - newResponseId = chunk.id + newResponseId = chunk.id; } else if (chunk.response_id) { - newResponseId = chunk.response_id + newResponseId = chunk.response_id; } - + // Handle OpenAI Chat Completions streaming format if (chunk.object === "response.chunk" && chunk.delta) { // Handle function calls in delta if (chunk.delta.function_call) { - console.log("Function call in delta:", chunk.delta.function_call) - + console.log( + "Function call in delta:", + chunk.delta.function_call + ); + // Check if this is a new function call if (chunk.delta.function_call.name) { - console.log("New function call:", chunk.delta.function_call.name) + console.log( + "New function call:", + chunk.delta.function_call.name + ); const functionCall: FunctionCall = { name: chunk.delta.function_call.name, arguments: undefined, status: "pending", - argumentsString: chunk.delta.function_call.arguments || "" - } - currentFunctionCalls.push(functionCall) - console.log("Added function call:", functionCall) + argumentsString: + chunk.delta.function_call.arguments || "", + }; + currentFunctionCalls.push(functionCall); + console.log("Added function call:", functionCall); } // Or if this is arguments continuation else if (chunk.delta.function_call.arguments) { - console.log("Function call arguments delta:", chunk.delta.function_call.arguments) - const lastFunctionCall = currentFunctionCalls[currentFunctionCalls.length - 1] + console.log( + "Function call arguments delta:", + chunk.delta.function_call.arguments + ); + const lastFunctionCall = + currentFunctionCalls[currentFunctionCalls.length - 1]; if (lastFunctionCall) { if (!lastFunctionCall.argumentsString) { - lastFunctionCall.argumentsString = "" + lastFunctionCall.argumentsString = ""; } - lastFunctionCall.argumentsString += chunk.delta.function_call.arguments - console.log("Accumulated arguments:", lastFunctionCall.argumentsString) - + lastFunctionCall.argumentsString += + chunk.delta.function_call.arguments; + console.log( + "Accumulated arguments:", + lastFunctionCall.argumentsString + ); + // Try to parse arguments if they look complete if (lastFunctionCall.argumentsString.includes("}")) { try { - const parsed = JSON.parse(lastFunctionCall.argumentsString) - lastFunctionCall.arguments = parsed - lastFunctionCall.status = "completed" - console.log("Parsed function arguments:", parsed) + const parsed = JSON.parse( + lastFunctionCall.argumentsString + ); + lastFunctionCall.arguments = parsed; + lastFunctionCall.status = "completed"; + console.log("Parsed function arguments:", parsed); } catch (e) { - console.log("Arguments not yet complete or invalid JSON:", e) + console.log( + "Arguments not yet complete or invalid JSON:", + e + ); } } } } } - - // Handle tool calls in delta - else if (chunk.delta.tool_calls && Array.isArray(chunk.delta.tool_calls)) { - console.log("Tool calls in delta:", chunk.delta.tool_calls) - + + // Handle tool calls in delta + else if ( + chunk.delta.tool_calls && + Array.isArray(chunk.delta.tool_calls) + ) { + console.log("Tool calls in delta:", chunk.delta.tool_calls); + for (const toolCall of chunk.delta.tool_calls) { if (toolCall.function) { // Check if this is a new tool call if (toolCall.function.name) { - console.log("New tool call:", toolCall.function.name) + console.log("New tool call:", toolCall.function.name); const functionCall: FunctionCall = { name: toolCall.function.name, arguments: undefined, status: "pending", - argumentsString: toolCall.function.arguments || "" - } - currentFunctionCalls.push(functionCall) - console.log("Added tool call:", functionCall) + argumentsString: toolCall.function.arguments || "", + }; + currentFunctionCalls.push(functionCall); + console.log("Added tool call:", functionCall); } // Or if this is arguments continuation else if (toolCall.function.arguments) { - console.log("Tool call arguments delta:", toolCall.function.arguments) - const lastFunctionCall = currentFunctionCalls[currentFunctionCalls.length - 1] + console.log( + "Tool call arguments delta:", + toolCall.function.arguments + ); + const lastFunctionCall = + currentFunctionCalls[ + currentFunctionCalls.length - 1 + ]; if (lastFunctionCall) { if (!lastFunctionCall.argumentsString) { - lastFunctionCall.argumentsString = "" + lastFunctionCall.argumentsString = ""; } - lastFunctionCall.argumentsString += toolCall.function.arguments - console.log("Accumulated tool arguments:", lastFunctionCall.argumentsString) - + lastFunctionCall.argumentsString += + toolCall.function.arguments; + console.log( + "Accumulated tool arguments:", + lastFunctionCall.argumentsString + ); + // Try to parse arguments if they look complete - if (lastFunctionCall.argumentsString.includes("}")) { + if ( + lastFunctionCall.argumentsString.includes("}") + ) { try { - const parsed = JSON.parse(lastFunctionCall.argumentsString) - lastFunctionCall.arguments = parsed - lastFunctionCall.status = "completed" - console.log("Parsed tool arguments:", parsed) + const parsed = JSON.parse( + lastFunctionCall.argumentsString + ); + lastFunctionCall.arguments = parsed; + lastFunctionCall.status = "completed"; + console.log("Parsed tool arguments:", parsed); } catch (e) { - console.log("Tool arguments not yet complete or invalid JSON:", e) + console.log( + "Tool arguments not yet complete or invalid JSON:", + e + ); } } } @@ -718,256 +945,403 @@ function ChatPage() { } } } - + // Handle content/text in delta else if (chunk.delta.content) { - console.log("Content delta:", chunk.delta.content) - currentContent += chunk.delta.content + console.log("Content delta:", chunk.delta.content); + currentContent += chunk.delta.content; } - + // Handle finish reason if (chunk.delta.finish_reason) { - console.log("Finish reason:", chunk.delta.finish_reason) + console.log("Finish reason:", chunk.delta.finish_reason); // Mark any pending function calls as completed - currentFunctionCalls.forEach(fc => { + currentFunctionCalls.forEach((fc) => { if (fc.status === "pending" && fc.argumentsString) { try { - fc.arguments = JSON.parse(fc.argumentsString) - fc.status = "completed" - console.log("Completed function call on finish:", fc) + fc.arguments = JSON.parse(fc.argumentsString); + fc.status = "completed"; + console.log("Completed function call on finish:", fc); } catch (e) { - fc.arguments = { raw: fc.argumentsString } - fc.status = "error" - console.log("Error parsing function call on finish:", fc, e) + fc.arguments = { raw: fc.argumentsString }; + fc.status = "error"; + console.log( + "Error parsing function call on finish:", + fc, + e + ); } } - }) + }); } } - + // Handle Realtime API format (this is what you're actually getting!) - else if (chunk.type === "response.output_item.added" && chunk.item?.type === "function_call") { - console.log("🟢 CREATING function call (added):", chunk.item.id, chunk.item.tool_name || chunk.item.name) - + else if ( + chunk.type === "response.output_item.added" && + chunk.item?.type === "function_call" + ) { + console.log( + "🟢 CREATING function call (added):", + chunk.item.id, + chunk.item.tool_name || chunk.item.name + ); + // Try to find an existing pending call to update (created by earlier deltas) - let existing = currentFunctionCalls.find(fc => fc.id === chunk.item.id) + let existing = currentFunctionCalls.find( + (fc) => fc.id === chunk.item.id + ); if (!existing) { - existing = [...currentFunctionCalls].reverse().find(fc => - fc.status === "pending" && - !fc.id && - (fc.name === (chunk.item.tool_name || chunk.item.name)) - ) + existing = [...currentFunctionCalls] + .reverse() + .find( + (fc) => + fc.status === "pending" && + !fc.id && + fc.name === (chunk.item.tool_name || chunk.item.name) + ); } - + if (existing) { - existing.id = chunk.item.id - existing.type = chunk.item.type - existing.name = chunk.item.tool_name || chunk.item.name || existing.name - existing.arguments = chunk.item.inputs || existing.arguments - console.log("🟢 UPDATED existing pending function call with id:", existing.id) + existing.id = chunk.item.id; + existing.type = chunk.item.type; + existing.name = + chunk.item.tool_name || chunk.item.name || existing.name; + existing.arguments = + chunk.item.inputs || existing.arguments; + console.log( + "🟢 UPDATED existing pending function call with id:", + existing.id + ); } else { const functionCall: FunctionCall = { - name: chunk.item.tool_name || chunk.item.name || "unknown", + name: + chunk.item.tool_name || chunk.item.name || "unknown", arguments: chunk.item.inputs || undefined, status: "pending", argumentsString: "", id: chunk.item.id, - type: chunk.item.type - } - currentFunctionCalls.push(functionCall) - console.log("🟢 Function calls now:", currentFunctionCalls.map(fc => ({ id: fc.id, name: fc.name }))) + type: chunk.item.type, + }; + currentFunctionCalls.push(functionCall); + console.log( + "🟢 Function calls now:", + currentFunctionCalls.map((fc) => ({ + id: fc.id, + name: fc.name, + })) + ); } } - + // Handle function call arguments streaming (Realtime API) - else if (chunk.type === "response.function_call_arguments.delta") { - console.log("Function args delta (Realtime API):", chunk.delta) - const lastFunctionCall = currentFunctionCalls[currentFunctionCalls.length - 1] + else if ( + chunk.type === "response.function_call_arguments.delta" + ) { + console.log( + "Function args delta (Realtime API):", + chunk.delta + ); + const lastFunctionCall = + currentFunctionCalls[currentFunctionCalls.length - 1]; if (lastFunctionCall) { if (!lastFunctionCall.argumentsString) { - lastFunctionCall.argumentsString = "" + lastFunctionCall.argumentsString = ""; } - lastFunctionCall.argumentsString += chunk.delta || "" - console.log("Accumulated arguments (Realtime API):", lastFunctionCall.argumentsString) + lastFunctionCall.argumentsString += chunk.delta || ""; + console.log( + "Accumulated arguments (Realtime API):", + lastFunctionCall.argumentsString + ); } } - + // Handle function call arguments completion (Realtime API) - else if (chunk.type === "response.function_call_arguments.done") { - console.log("Function args done (Realtime API):", chunk.arguments) - const lastFunctionCall = currentFunctionCalls[currentFunctionCalls.length - 1] + else if ( + chunk.type === "response.function_call_arguments.done" + ) { + console.log( + "Function args done (Realtime API):", + chunk.arguments + ); + const lastFunctionCall = + currentFunctionCalls[currentFunctionCalls.length - 1]; if (lastFunctionCall) { try { - lastFunctionCall.arguments = JSON.parse(chunk.arguments || "{}") - lastFunctionCall.status = "completed" - console.log("Parsed function arguments (Realtime API):", lastFunctionCall.arguments) + lastFunctionCall.arguments = JSON.parse( + chunk.arguments || "{}" + ); + lastFunctionCall.status = "completed"; + console.log( + "Parsed function arguments (Realtime API):", + lastFunctionCall.arguments + ); } catch (e) { - lastFunctionCall.arguments = { raw: chunk.arguments } - lastFunctionCall.status = "error" - console.log("Error parsing function arguments (Realtime API):", e) + lastFunctionCall.arguments = { raw: chunk.arguments }; + lastFunctionCall.status = "error"; + console.log( + "Error parsing function arguments (Realtime API):", + e + ); } } } - + // Handle function call completion (Realtime API) - else if (chunk.type === "response.output_item.done" && chunk.item?.type === "function_call") { - console.log("🔵 UPDATING function call (done):", chunk.item.id, chunk.item.tool_name || chunk.item.name) - console.log("🔵 Looking for existing function calls:", currentFunctionCalls.map(fc => ({ id: fc.id, name: fc.name }))) - + else if ( + chunk.type === "response.output_item.done" && + chunk.item?.type === "function_call" + ) { + console.log( + "🔵 UPDATING function call (done):", + chunk.item.id, + chunk.item.tool_name || chunk.item.name + ); + console.log( + "🔵 Looking for existing function calls:", + currentFunctionCalls.map((fc) => ({ + id: fc.id, + name: fc.name, + })) + ); + // Find existing function call by ID or name - const functionCall = currentFunctionCalls.find(fc => - fc.id === chunk.item.id || - fc.name === chunk.item.tool_name || - fc.name === chunk.item.name - ) - + const functionCall = currentFunctionCalls.find( + (fc) => + fc.id === chunk.item.id || + fc.name === chunk.item.tool_name || + fc.name === chunk.item.name + ); + if (functionCall) { - console.log("🔵 FOUND existing function call, updating:", functionCall.id, functionCall.name) + console.log( + "🔵 FOUND existing function call, updating:", + functionCall.id, + functionCall.name + ); // Update existing function call with completion data - functionCall.status = chunk.item.status === "completed" ? "completed" : "error" - functionCall.id = chunk.item.id - functionCall.type = chunk.item.type - functionCall.name = chunk.item.tool_name || chunk.item.name || functionCall.name - functionCall.arguments = chunk.item.inputs || functionCall.arguments - + functionCall.status = + chunk.item.status === "completed" ? "completed" : "error"; + functionCall.id = chunk.item.id; + functionCall.type = chunk.item.type; + functionCall.name = + chunk.item.tool_name || + chunk.item.name || + functionCall.name; + functionCall.arguments = + chunk.item.inputs || functionCall.arguments; + // Set results if present if (chunk.item.results) { - functionCall.result = chunk.item.results + functionCall.result = chunk.item.results; } } else { - console.log("🔴 WARNING: Could not find existing function call to update:", chunk.item.id, chunk.item.tool_name, chunk.item.name) + console.log( + "🔴 WARNING: Could not find existing function call to update:", + chunk.item.id, + chunk.item.tool_name, + chunk.item.name + ); } } - + // Handle tool call completion with results - else if (chunk.type === "response.output_item.done" && chunk.item?.type?.includes("_call") && chunk.item?.type !== "function_call") { - console.log("Tool call done with results:", chunk.item) - + else if ( + chunk.type === "response.output_item.done" && + chunk.item?.type?.includes("_call") && + chunk.item?.type !== "function_call" + ) { + console.log("Tool call done with results:", chunk.item); + // Find existing function call by ID, or by name/type if ID not available - const functionCall = currentFunctionCalls.find(fc => - fc.id === chunk.item.id || - (fc.name === chunk.item.tool_name) || - (fc.name === chunk.item.name) || - (fc.name === chunk.item.type) || - (fc.name.includes(chunk.item.type.replace('_call', '')) || chunk.item.type.includes(fc.name)) - ) - + const functionCall = currentFunctionCalls.find( + (fc) => + fc.id === chunk.item.id || + fc.name === chunk.item.tool_name || + fc.name === chunk.item.name || + fc.name === chunk.item.type || + fc.name.includes(chunk.item.type.replace("_call", "")) || + chunk.item.type.includes(fc.name) + ); + if (functionCall) { // Update existing function call - functionCall.arguments = chunk.item.inputs || functionCall.arguments - functionCall.status = chunk.item.status === "completed" ? "completed" : "error" - functionCall.id = chunk.item.id - functionCall.type = chunk.item.type - + functionCall.arguments = + chunk.item.inputs || functionCall.arguments; + functionCall.status = + chunk.item.status === "completed" ? "completed" : "error"; + functionCall.id = chunk.item.id; + functionCall.type = chunk.item.type; + // Set the results if (chunk.item.results) { - functionCall.result = chunk.item.results + functionCall.result = chunk.item.results; } } else { // Create new function call if not found const newFunctionCall = { - name: chunk.item.tool_name || chunk.item.name || chunk.item.type || "unknown", + name: + chunk.item.tool_name || + chunk.item.name || + chunk.item.type || + "unknown", arguments: chunk.item.inputs || {}, status: "completed" as const, id: chunk.item.id, type: chunk.item.type, - result: chunk.item.results - } - currentFunctionCalls.push(newFunctionCall) + result: chunk.item.results, + }; + currentFunctionCalls.push(newFunctionCall); } } - + // Handle function call output item added (new format) - else if (chunk.type === "response.output_item.added" && chunk.item?.type?.includes("_call") && chunk.item?.type !== "function_call") { - console.log("🟡 CREATING tool call (added):", chunk.item.id, chunk.item.tool_name || chunk.item.name, chunk.item.type) - + else if ( + chunk.type === "response.output_item.added" && + chunk.item?.type?.includes("_call") && + chunk.item?.type !== "function_call" + ) { + console.log( + "🟡 CREATING tool call (added):", + chunk.item.id, + chunk.item.tool_name || chunk.item.name, + chunk.item.type + ); + // Dedupe by id or pending with same name - let existing = currentFunctionCalls.find(fc => fc.id === chunk.item.id) + let existing = currentFunctionCalls.find( + (fc) => fc.id === chunk.item.id + ); if (!existing) { - existing = [...currentFunctionCalls].reverse().find(fc => - fc.status === "pending" && - !fc.id && - (fc.name === (chunk.item.tool_name || chunk.item.name || chunk.item.type)) - ) + existing = [...currentFunctionCalls] + .reverse() + .find( + (fc) => + fc.status === "pending" && + !fc.id && + fc.name === + (chunk.item.tool_name || + chunk.item.name || + chunk.item.type) + ); } - + if (existing) { - existing.id = chunk.item.id - existing.type = chunk.item.type - existing.name = chunk.item.tool_name || chunk.item.name || chunk.item.type || existing.name - existing.arguments = chunk.item.inputs || existing.arguments - console.log("🟡 UPDATED existing pending tool call with id:", existing.id) + existing.id = chunk.item.id; + existing.type = chunk.item.type; + existing.name = + chunk.item.tool_name || + chunk.item.name || + chunk.item.type || + existing.name; + existing.arguments = + chunk.item.inputs || existing.arguments; + console.log( + "🟡 UPDATED existing pending tool call with id:", + existing.id + ); } else { const functionCall = { - name: chunk.item.tool_name || chunk.item.name || chunk.item.type || "unknown", + name: + chunk.item.tool_name || + chunk.item.name || + chunk.item.type || + "unknown", arguments: chunk.item.inputs || {}, status: "pending" as const, id: chunk.item.id, - type: chunk.item.type - } - currentFunctionCalls.push(functionCall) - console.log("🟡 Function calls now:", currentFunctionCalls.map(fc => ({ id: fc.id, name: fc.name, type: fc.type }))) + type: chunk.item.type, + }; + currentFunctionCalls.push(functionCall); + console.log( + "🟡 Function calls now:", + currentFunctionCalls.map((fc) => ({ + id: fc.id, + name: fc.name, + type: fc.type, + })) + ); } } - + // Handle function call results - else if (chunk.type === "response.function_call.result" || chunk.type === "function_call_result") { - console.log("Function call result:", chunk.result || chunk) - const lastFunctionCall = currentFunctionCalls[currentFunctionCalls.length - 1] + else if ( + chunk.type === "response.function_call.result" || + chunk.type === "function_call_result" + ) { + console.log("Function call result:", chunk.result || chunk); + const lastFunctionCall = + currentFunctionCalls[currentFunctionCalls.length - 1]; if (lastFunctionCall) { - lastFunctionCall.result = chunk.result || chunk.output || chunk.response - lastFunctionCall.status = "completed" + lastFunctionCall.result = + chunk.result || chunk.output || chunk.response; + lastFunctionCall.status = "completed"; } } - - // Handle tool call results - else if (chunk.type === "response.tool_call.result" || chunk.type === "tool_call_result") { - console.log("Tool call result:", chunk.result || chunk) - const lastFunctionCall = currentFunctionCalls[currentFunctionCalls.length - 1] + + // Handle tool call results + else if ( + chunk.type === "response.tool_call.result" || + chunk.type === "tool_call_result" + ) { + console.log("Tool call result:", chunk.result || chunk); + const lastFunctionCall = + currentFunctionCalls[currentFunctionCalls.length - 1]; if (lastFunctionCall) { - lastFunctionCall.result = chunk.result || chunk.output || chunk.response - lastFunctionCall.status = "completed" + lastFunctionCall.result = + chunk.result || chunk.output || chunk.response; + lastFunctionCall.status = "completed"; } } - + // Handle generic results that might be in different formats - else if ((chunk.type && chunk.type.includes("result")) || chunk.result) { - console.log("Generic result:", chunk) - const lastFunctionCall = currentFunctionCalls[currentFunctionCalls.length - 1] + else if ( + (chunk.type && chunk.type.includes("result")) || + chunk.result + ) { + console.log("Generic result:", chunk); + const lastFunctionCall = + currentFunctionCalls[currentFunctionCalls.length - 1]; if (lastFunctionCall && !lastFunctionCall.result) { - lastFunctionCall.result = chunk.result || chunk.output || chunk.response || chunk - lastFunctionCall.status = "completed" + lastFunctionCall.result = + chunk.result || chunk.output || chunk.response || chunk; + lastFunctionCall.status = "completed"; } } - + // Handle text output streaming (Realtime API) else if (chunk.type === "response.output_text.delta") { - console.log("Text delta (Realtime API):", chunk.delta) - currentContent += chunk.delta || "" + console.log("Text delta (Realtime API):", chunk.delta); + currentContent += chunk.delta || ""; } - + // Log unhandled chunks - else if (chunk.type !== null && chunk.object !== "response.chunk") { - console.log("Unhandled chunk format:", chunk) + else if ( + chunk.type !== null && + chunk.object !== "response.chunk" + ) { + console.log("Unhandled chunk format:", chunk); } - + // Update streaming message - if (!controller.signal.aborted && thisStreamId === streamIdRef.current) { + if ( + !controller.signal.aborted && + thisStreamId === streamIdRef.current + ) { setStreamingMessage({ content: currentContent, functionCalls: [...currentFunctionCalls], - timestamp: new Date() - }) + timestamp: new Date(), + }); } - } catch (parseError) { - console.warn("Failed to parse chunk:", line, parseError) + console.warn("Failed to parse chunk:", line, parseError); } } } } } finally { - reader.releaseLock() + reader.releaseLock(); } // Finalize the message @@ -975,242 +1349,274 @@ function ChatPage() { role: "assistant", content: currentContent, functionCalls: currentFunctionCalls, - timestamp: new Date() - } - + timestamp: new Date(), + }; + if (!controller.signal.aborted && thisStreamId === streamIdRef.current) { - setMessages(prev => [...prev, finalMessage]) - setStreamingMessage(null) + setMessages((prev) => [...prev, finalMessage]); + setStreamingMessage(null); } - + // Store the response ID for the next request for this endpoint - if (newResponseId && !controller.signal.aborted && thisStreamId === streamIdRef.current) { - setPreviousResponseIds(prev => ({ + if ( + newResponseId && + !controller.signal.aborted && + thisStreamId === streamIdRef.current + ) { + setPreviousResponseIds((prev) => ({ ...prev, - [endpoint]: newResponseId - })) + [endpoint]: newResponseId, + })); + + // If this is a new conversation (no currentConversationId), set it now + if (!currentConversationId) { + setCurrentConversationId(newResponseId); + refreshConversations(true); + } else { + // For existing conversations, do a silent refresh to keep backend in sync + refreshConversationsSilent(); + } } - - // Trigger sidebar refresh to include this conversation (with small delay to ensure backend has processed) - setTimeout(() => { - try { refreshConversations() } catch {} - }, 100) - } catch (error) { // If stream was aborted (e.g., starting new conversation), do not append errors or final messages if (streamAbortRef.current?.signal.aborted) { - return + return; } - console.error("SSE Stream error:", error) - setStreamingMessage(null) - + console.error("SSE Stream error:", error); + setStreamingMessage(null); + const errorMessage: Message = { role: "assistant", - content: "Sorry, I couldn't connect to the chat service. Please try again.", - timestamp: new Date() - } - setMessages(prev => [...prev, errorMessage]) + content: + "Sorry, I couldn't connect to the chat service. Please try again.", + timestamp: new Date(), + }; + setMessages((prev) => [...prev, errorMessage]); } - } - + }; const handleSubmit = async (e: React.FormEvent) => { - e.preventDefault() - if (!input.trim() || loading) return + e.preventDefault(); + if (!input.trim() || loading) return; const userMessage: Message = { role: "user", content: input.trim(), - timestamp: new Date() - } + timestamp: new Date(), + }; - setMessages(prev => [...prev, userMessage]) - setInput("") - setLoading(true) - setIsFilterHighlighted(false) + setMessages((prev) => [...prev, userMessage]); + setInput(""); + setLoading(true); + setIsFilterHighlighted(false); if (asyncMode) { - await handleSSEStream(userMessage) + await handleSSEStream(userMessage); } else { // Original non-streaming logic try { - const apiEndpoint = endpoint === "chat" ? "/api/chat" : "/api/langflow" - + const apiEndpoint = endpoint === "chat" ? "/api/chat" : "/api/langflow"; + const requestBody: RequestBody = { prompt: userMessage.content, - ...(parsedFilterData?.filters && (() => { - const filters = parsedFilterData.filters - const processed: SelectedFilters = { - data_sources: [], - document_types: [], - owners: [] - } - // Only copy non-wildcard arrays - processed.data_sources = filters.data_sources.includes("*") ? [] : filters.data_sources - processed.document_types = filters.document_types.includes("*") ? [] : filters.document_types - processed.owners = filters.owners.includes("*") ? [] : filters.owners - - // Only include filters if any array has values - const hasFilters = processed.data_sources.length > 0 || - processed.document_types.length > 0 || - processed.owners.length > 0 - return hasFilters ? { filters: processed } : {} - })()), + ...(parsedFilterData?.filters && + (() => { + const filters = parsedFilterData.filters; + const processed: SelectedFilters = { + data_sources: [], + document_types: [], + owners: [], + }; + // Only copy non-wildcard arrays + processed.data_sources = filters.data_sources.includes("*") + ? [] + : filters.data_sources; + processed.document_types = filters.document_types.includes("*") + ? [] + : filters.document_types; + processed.owners = filters.owners.includes("*") + ? [] + : filters.owners; + + // Only include filters if any array has values + const hasFilters = + processed.data_sources.length > 0 || + processed.document_types.length > 0 || + processed.owners.length > 0; + return hasFilters ? { filters: processed } : {}; + })()), limit: parsedFilterData?.limit ?? 10, - scoreThreshold: parsedFilterData?.scoreThreshold ?? 0 - } - + scoreThreshold: parsedFilterData?.scoreThreshold ?? 0, + }; + // Add previous_response_id if we have one for this endpoint - const currentResponseId = previousResponseIds[endpoint] + const currentResponseId = previousResponseIds[endpoint]; if (currentResponseId) { - requestBody.previous_response_id = currentResponseId + requestBody.previous_response_id = currentResponseId; } - + const response = await fetch(apiEndpoint, { method: "POST", headers: { "Content-Type": "application/json", }, body: JSON.stringify(requestBody), - }) + }); + + const result = await response.json(); - const result = await response.json() - if (response.ok) { const assistantMessage: Message = { role: "assistant", content: result.response, - timestamp: new Date() - } - setMessages(prev => [...prev, assistantMessage]) - + timestamp: new Date(), + }; + setMessages((prev) => [...prev, assistantMessage]); + // Store the response ID if present for this endpoint if (result.response_id) { - setPreviousResponseIds(prev => ({ + setPreviousResponseIds((prev) => ({ ...prev, - [endpoint]: result.response_id - })) + [endpoint]: result.response_id, + })); + + // If this is a new conversation (no currentConversationId), set it now + if (!currentConversationId) { + setCurrentConversationId(result.response_id); + refreshConversations(true); + } else { + // For existing conversations, do a silent refresh to keep backend in sync + refreshConversationsSilent(); + } } - // Trigger sidebar refresh to include/update this conversation (with small delay to ensure backend has processed) - setTimeout(() => { - try { refreshConversations() } catch {} - }, 100) } else { - console.error("Chat failed:", result.error) + console.error("Chat failed:", result.error); const errorMessage: Message = { role: "assistant", content: "Sorry, I encountered an error. Please try again.", - timestamp: new Date() - } - setMessages(prev => [...prev, errorMessage]) + timestamp: new Date(), + }; + setMessages((prev) => [...prev, errorMessage]); } } catch (error) { - console.error("Chat error:", error) + console.error("Chat error:", error); const errorMessage: Message = { role: "assistant", - content: "Sorry, I couldn't connect to the chat service. Please try again.", - timestamp: new Date() - } - setMessages(prev => [...prev, errorMessage]) + content: + "Sorry, I couldn't connect to the chat service. Please try again.", + timestamp: new Date(), + }; + setMessages((prev) => [...prev, errorMessage]); } } - - setLoading(false) - } + + setLoading(false); + }; const toggleFunctionCall = (functionCallId: string) => { - setExpandedFunctionCalls(prev => { - const newSet = new Set(prev) + setExpandedFunctionCalls((prev) => { + const newSet = new Set(prev); if (newSet.has(functionCallId)) { - newSet.delete(functionCallId) + newSet.delete(functionCallId); } else { - newSet.add(functionCallId) + newSet.add(functionCallId); } - return newSet - }) - } + return newSet; + }); + }; - const handleForkConversation = (messageIndex: number, event?: React.MouseEvent) => { + const handleForkConversation = ( + messageIndex: number, + event?: React.MouseEvent + ) => { // Prevent any default behavior and stop event propagation if (event) { - event.preventDefault() - event.stopPropagation() + event.preventDefault(); + event.stopPropagation(); } - + // Set interaction state to prevent auto-scroll interference - const forkTimestamp = Date.now() - setIsUserInteracting(true) - setIsForkingInProgress(true) - setLastForkTimestamp(forkTimestamp) - - console.log("Fork conversation called for message index:", messageIndex) - + const forkTimestamp = Date.now(); + setIsUserInteracting(true); + setIsForkingInProgress(true); + setLastForkTimestamp(forkTimestamp); + + console.log("Fork conversation called for message index:", messageIndex); + // Get messages up to and including the selected assistant message - const messagesToKeep = messages.slice(0, messageIndex + 1) - + const messagesToKeep = messages.slice(0, messageIndex + 1); + // The selected message should be an assistant message (since fork button is only on assistant messages) - const forkedMessage = messages[messageIndex] - if (forkedMessage.role !== 'assistant') { - console.error('Fork button should only be on assistant messages') - setIsUserInteracting(false) - setIsForkingInProgress(false) - setLastForkTimestamp(0) - return + const forkedMessage = messages[messageIndex]; + if (forkedMessage.role !== "assistant") { + console.error("Fork button should only be on assistant messages"); + setIsUserInteracting(false); + setIsForkingInProgress(false); + setLastForkTimestamp(0); + return; } - + // For forking, we want to continue from the response_id of the assistant message we're forking from // Since we don't store individual response_ids per message yet, we'll use the current conversation's response_id // This means we're continuing the conversation thread from that point - const responseIdToForkFrom = currentConversationId || previousResponseIds[endpoint] - + const responseIdToForkFrom = + currentConversationId || previousResponseIds[endpoint]; + // Create a new conversation by properly forking - setMessages(messagesToKeep) - + setMessages(messagesToKeep); + // Use the chat context's fork method which handles creating a new conversation properly if (forkFromResponse) { - forkFromResponse(responseIdToForkFrom || '') + forkFromResponse(responseIdToForkFrom || ""); } else { // Fallback to manual approach - setCurrentConversationId(null) // This creates a new conversation thread - + setCurrentConversationId(null); // This creates a new conversation thread + // Set the response_id we want to continue from as the previous response ID // This tells the backend to continue the conversation from this point - setPreviousResponseIds(prev => ({ + setPreviousResponseIds((prev) => ({ ...prev, - [endpoint]: responseIdToForkFrom - })) + [endpoint]: responseIdToForkFrom, + })); } - - console.log("Forked conversation with", messagesToKeep.length, "messages") - + + console.log("Forked conversation with", messagesToKeep.length, "messages"); + // Reset interaction state after a longer delay to ensure all effects complete setTimeout(() => { - setIsUserInteracting(false) - setIsForkingInProgress(false) - console.log("Fork interaction complete, re-enabling auto effects") - }, 500) - + setIsUserInteracting(false); + setIsForkingInProgress(false); + console.log("Fork interaction complete, re-enabling auto effects"); + }, 500); + // The original conversation remains unchanged in the sidebar // This new forked conversation will get its own response_id when the user sends the next message - } + }; + + const renderFunctionCalls = ( + functionCalls: FunctionCall[], + messageIndex?: number + ) => { + if (!functionCalls || functionCalls.length === 0) return null; - const renderFunctionCalls = (functionCalls: FunctionCall[], messageIndex?: number) => { - if (!functionCalls || functionCalls.length === 0) return null - return (
{functionCalls.map((fc, index) => { - const functionCallId = `${messageIndex || 'streaming'}-${index}` - const isExpanded = expandedFunctionCalls.has(functionCallId) - + const functionCallId = `${messageIndex || "streaming"}-${index}`; + const isExpanded = expandedFunctionCalls.has(functionCallId); + // Determine display name - show both name and type if available - const displayName = fc.type && fc.type !== fc.name - ? `${fc.name} (${fc.type})` - : fc.name - + const displayName = + fc.type && fc.type !== fc.name + ? `${fc.name} (${fc.type})` + : fc.name; + return ( -
-
+
toggleFunctionCall(functionCallId)} > @@ -1223,11 +1629,15 @@ function ChatPage() { {fc.id.substring(0, 8)}... )} -
+
{fc.status}
{isExpanded ? ( @@ -1236,7 +1646,7 @@ function ChatPage() { )}
- + {isExpanded && (
{/* Show type information if available */} @@ -1248,7 +1658,7 @@ function ChatPage() {
)} - + {/* Show ID if available */} {fc.id && (
@@ -1258,20 +1668,19 @@ function ChatPage() {
)} - + {/* Show arguments - either completed or streaming */} {(fc.arguments || fc.argumentsString) && (
Arguments:
-                        {fc.arguments 
+                        {fc.arguments
                           ? JSON.stringify(fc.arguments, null, 2)
-                          : fc.argumentsString || "..."
-                        }
+                          : fc.argumentsString || "..."}
                       
)} - + {fc.result && (
Result: @@ -1279,37 +1688,43 @@ function ChatPage() {
{(() => { // Handle different result formats - let resultsToRender = fc.result - + let resultsToRender = fc.result; + // Check if this is function_call format with nested results // Function call format: results = [{ results: [...] }] // Tool call format: results = [{ text_key: ..., data: {...} }] - if (fc.result.length > 0 && - fc.result[0]?.results && - Array.isArray(fc.result[0].results) && - !fc.result[0].text_key) { - resultsToRender = fc.result[0].results + if ( + fc.result.length > 0 && + fc.result[0]?.results && + Array.isArray(fc.result[0].results) && + !fc.result[0].text_key + ) { + resultsToRender = fc.result[0].results; } - + type ToolResultItem = { - text_key?: string - data?: { file_path?: string; text?: string } - filename?: string - page?: number - score?: number - source_url?: string | null - text?: string - } - const items = resultsToRender as unknown as ToolResultItem[] + text_key?: string; + data?: { file_path?: string; text?: string }; + filename?: string; + page?: number; + score?: number; + source_url?: string | null; + text?: string; + }; + const items = + resultsToRender as unknown as ToolResultItem[]; return items.map((result, idx: number) => ( -
+
{/* Handle tool_call format (file_path in data) */} {result.data?.file_path && (
📄 {result.data.file_path || "Unknown file"}
)} - + {/* Handle function_call format (filename directly) */} {result.filename && !result.data?.file_path && (
@@ -1322,63 +1737,74 @@ function ChatPage() { )}
)} - + {/* Handle tool_call text format */} {result.data?.text && (
- {result.data.text.length > 300 - ? result.data.text.substring(0, 300) + "..." - : result.data.text - } + {result.data.text.length > 300 + ? result.data.text.substring(0, 300) + + "..." + : result.data.text}
)} - + {/* Handle function_call text format */} {result.text && !result.data?.text && (
- {result.text.length > 300 - ? result.text.substring(0, 300) + "..." - : result.text - } + {result.text.length > 300 + ? result.text.substring(0, 300) + "..." + : result.text}
)} - + {/* Show additional metadata for function_call format */} {result.source_url && ( )} - + {result.text_key && (
Key: {result.text_key}
)}
- )) + )); })()}
- Found {(() => { - let resultsToCount = fc.result - if (fc.result.length > 0 && - fc.result[0]?.results && - Array.isArray(fc.result[0].results) && - !fc.result[0].text_key) { - resultsToCount = fc.result[0].results + Found{" "} + {(() => { + let resultsToCount = fc.result; + if ( + fc.result.length > 0 && + fc.result[0]?.results && + Array.isArray(fc.result[0].results) && + !fc.result[0].text_key + ) { + resultsToCount = fc.result[0].results; } - return resultsToCount.length - })()} result{(() => { - let resultsToCount = fc.result - if (fc.result.length > 0 && - fc.result[0]?.results && - Array.isArray(fc.result[0].results) && - !fc.result[0].text_key) { - resultsToCount = fc.result[0].results + return resultsToCount.length; + })()}{" "} + result + {(() => { + let resultsToCount = fc.result; + if ( + fc.result.length > 0 && + fc.result[0]?.results && + Array.isArray(fc.result[0].results) && + !fc.result[0].text_key + ) { + resultsToCount = fc.result[0].results; } - return resultsToCount.length !== 1 ? 's' : '' + return resultsToCount.length !== 1 ? "s" : ""; })()}
@@ -1392,35 +1818,39 @@ function ChatPage() {
)}
- ) + ); })}
- ) - } + ); + }; const suggestionChips = [ "Show me this quarter's top 10 deals", "Summarize recent client interactions", - "Search OpenSearch for mentions of our competitors" - ] + "Search OpenSearch for mentions of our competitors", + ]; const handleSuggestionClick = (suggestion: string) => { - setInput(suggestion) - inputRef.current?.focus() - } + setInput(suggestion); + inputRef.current?.focus(); + }; return ( -
+
{/* Debug header - only show in debug mode */} {isDebugMode && (
-
-
+
{/* Async Mode Toggle */}
@@ -1430,7 +1860,7 @@ function ChatPage() { onClick={() => setAsyncMode(false)} className="h-7 text-xs" > - Streaming Off + Streaming Off
))} - + {/* Streaming Message Display */} {streamingMessage && (
@@ -1547,7 +1999,10 @@ function ChatPage() {
- {renderFunctionCalls(streamingMessage.functionCalls, messages.length)} + {renderFunctionCalls( + streamingMessage.functionCalls, + messages.length + )}

{streamingMessage.content} @@ -1555,7 +2010,7 @@ function ChatPage() {

)} - + {/* Loading animation - shows immediately after user submits */} {loading && (
@@ -1565,7 +2020,9 @@ function ChatPage() {
- Thinking... + + Thinking... +
@@ -1573,21 +2030,22 @@ function ChatPage() {
)} - + {/* Drag overlay for existing messages */} {isDragOver && messages.length > 0 && (
-

Drop document to add context

+

+ Drop document to add context +

)}
-
- + {/* Suggestion chips - always show unless streaming */} {!streamingMessage && (
@@ -1608,7 +2066,7 @@ function ChatPage() {
)} - + {/* Input Area - Fixed at bottom */}
@@ -1616,17 +2074,19 @@ function ChatPage() {
{selectedFilter && (
- + @filter:{selectedFilter.name}
@@ -1786,7 +2267,10 @@ function ChatPage() { {isFilterDropdownOpen && ( -
+
{filterSearchTerm && (
@@ -1803,7 +2287,7 @@ function ChatPage() { )} {availableFilters - .filter(filter => - filter.name.toLowerCase().includes(filterSearchTerm.toLowerCase()) + .filter((filter) => + filter.name + .toLowerCase() + .includes(filterSearchTerm.toLowerCase()) ) .map((filter, index) => ( ))} - {availableFilters.filter(filter => - filter.name.toLowerCase().includes(filterSearchTerm.toLowerCase()) - ).length === 0 && filterSearchTerm && ( -
- No filters match "{filterSearchTerm}" -
- )} + {availableFilters.filter((filter) => + filter.name + .toLowerCase() + .includes(filterSearchTerm.toLowerCase()) + ).length === 0 && + filterSearchTerm && ( +
+ No filters match "{filterSearchTerm}" +
+ )} )}
@@ -1864,17 +2353,13 @@ function ChatPage() { disabled={!input.trim() || loading} className="absolute bottom-3 right-3 rounded-lg h-10 px-4" > - {loading ? ( - - ) : ( - "Send" - )} + {loading ? : "Send"}
- ) + ); } export default function ProtectedChatPage() { @@ -1882,5 +2367,5 @@ export default function ProtectedChatPage() { - ) -} + ); +} diff --git a/frontend/src/app/connectors/page.tsx b/frontend/src/app/connectors/page.tsx index 3516338d..8011d1bd 100644 --- a/frontend/src/app/connectors/page.tsx +++ b/frontend/src/app/connectors/page.tsx @@ -1,495 +1,128 @@ "use client" -import { useState, useEffect, useCallback, Suspense } from "react" -import { useSearchParams } from "next/navigation" -import { Button } from "@/components/ui/button" -import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card" -import { Badge } from "@/components/ui/badge" -import { Input } from "@/components/ui/input" -import { Label } from "@/components/ui/label" -import { Loader2, PlugZap, CheckCircle, XCircle, RefreshCw, Download, AlertCircle } from "lucide-react" -import { useAuth } from "@/contexts/auth-context" +import React, { useState } from "react"; +import { GoogleDrivePicker } from "@/components/google-drive-picker" import { useTask } from "@/contexts/task-context" -import { ProtectedRoute } from "@/components/protected-route" -interface Connector { - id: string - name: string - description: string - icon: React.ReactNode - status: "not_connected" | "connecting" | "connected" | "error" - type: string - connectionId?: string // Store the active connection ID for syncing - access_token?: string // For connectors that use OAuth +interface GoogleDriveFile { + id: string; + name: string; + mimeType: string; + webViewLink?: string; + iconLink?: string; } -interface SyncResult { - processed?: number; - added?: number; - skipped?: number; - errors?: number; - error?: string; - message?: string; // For sync started messages - isStarted?: boolean; // For sync started state -} +export default function ConnectorsPage() { + const { addTask } = useTask() + const [selectedFiles, setSelectedFiles] = useState([]); + const [isSyncing, setIsSyncing] = useState(false); + const [syncResult, setSyncResult] = useState(null); -interface Connection { - connection_id: string - name: string - is_active: boolean - created_at: string - last_sync?: string -} + const handleFileSelection = (files: GoogleDriveFile[]) => { + setSelectedFiles(files); + }; -function ConnectorsPage() { - const { isAuthenticated } = useAuth() - const { addTask, refreshTasks } = useTask() - const searchParams = useSearchParams() - const [connectors, setConnectors] = useState([]) - - const [isConnecting, setIsConnecting] = useState(null) - const [isSyncing, setIsSyncing] = useState(null) - const [syncResults, setSyncResults] = useState<{[key: string]: SyncResult | null}>({}) - const [maxFiles, setMaxFiles] = useState(10) - - // Helper function to get connector icon - const getConnectorIcon = (iconName: string) => { - const iconMap: { [key: string]: React.ReactElement } = { - 'google-drive':
G
, - 'sharepoint':
SP
, - 'onedrive':
OD
, - } - return iconMap[iconName] ||
?
- } - - // Function definitions first - const checkConnectorStatuses = useCallback(async () => { - try { - // Fetch available connectors from backend - const connectorsResponse = await fetch('/api/connectors') - if (!connectorsResponse.ok) { - throw new Error('Failed to load connectors') - } - - const connectorsResult = await connectorsResponse.json() - const connectorTypes = Object.keys(connectorsResult.connectors) - - // Initialize connectors list with metadata from backend - const initialConnectors = connectorTypes - .filter(type => connectorsResult.connectors[type].available) // Only show available connectors - .map(type => ({ - id: type, - name: connectorsResult.connectors[type].name, - description: connectorsResult.connectors[type].description, - icon: getConnectorIcon(connectorsResult.connectors[type].icon), - status: "not_connected" as const, - type: type - })) - - setConnectors(initialConnectors) - - // Check status for each connector type - - for (const connectorType of connectorTypes) { - const response = await fetch(`/api/connectors/${connectorType}/status`) - if (response.ok) { - const data = await response.json() - const connections = data.connections || [] - const activeConnection = connections.find((conn: Connection) => conn.is_active) - const isConnected = activeConnection !== undefined - - setConnectors(prev => prev.map(c => - c.type === connectorType - ? { - ...c, - status: isConnected ? "connected" : "not_connected", - connectionId: activeConnection?.connection_id - } - : c - )) - } - } - } catch (error) { - console.error('Failed to check connector statuses:', error) - } - }, [setConnectors]) - - const handleConnect = async (connector: Connector) => { - setIsConnecting(connector.id) - setConnectors(prev => prev.map(c => - c.id === connector.id ? { ...c, status: "connecting" } : c - )) + const handleSync = async (connector: { connectionId: string, type: string }) => { + if (!connector.connectionId || selectedFiles.length === 0) return + + setIsSyncing(true) + setSyncResult(null) try { - // Use the shared auth callback URL, not a separate connectors callback - const redirectUri = `${window.location.origin}/auth/callback` - - const response = await fetch('/api/auth/init', { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, - body: JSON.stringify({ - connector_type: connector.type, - purpose: "data_source", - name: `${connector.name} Connection`, - redirect_uri: redirectUri - }), - }) - - const result = await response.json() - - if (response.ok) { - // Store connector ID for callback - localStorage.setItem('connecting_connector_id', result.connection_id) - localStorage.setItem('connecting_connector_type', connector.type) - - // Handle client-side OAuth with Google's library - if (result.oauth_config) { - // Use the redirect URI provided by the backend - const authUrl = `${result.oauth_config.authorization_endpoint}?` + - `client_id=${result.oauth_config.client_id}&` + - `response_type=code&` + - `scope=${result.oauth_config.scopes.join(' ')}&` + - `redirect_uri=${encodeURIComponent(result.oauth_config.redirect_uri)}&` + - `access_type=offline&` + - `prompt=consent&` + - `state=${result.connection_id}` - - window.location.href = authUrl - } - } else { - throw new Error(result.error || 'Failed to initialize OAuth') + const syncBody: { + connection_id: string; + max_files?: number; + selected_files?: string[]; + } = { + connection_id: connector.connectionId, + selected_files: selectedFiles.map(file => file.id) } - } catch (error) { - console.error('OAuth initialization failed:', error) - setConnectors(prev => prev.map(c => - c.id === connector.id ? { ...c, status: "error" } : c - )) - } finally { - setIsConnecting(null) - } - } - - const handleSync = async (connector: Connector) => { - if (!connector.connectionId) { - console.error('No connection ID available for connector') - return - } - - setIsSyncing(connector.id) - setSyncResults(prev => ({ ...prev, [connector.id]: null })) // Clear any existing progress - - try { + const response = await fetch(`/api/connectors/${connector.type}/sync`, { method: 'POST', headers: { 'Content-Type': 'application/json', }, - body: JSON.stringify({ - max_files: maxFiles - }), + body: JSON.stringify(syncBody), }) - + const result = await response.json() - - if (response.status === 201 && result.task_id) { - // Task-based sync, use centralized tracking - addTask(result.task_id) - console.log(`Sync task ${result.task_id} added to central tracking for connector ${connector.id}`) - - // Immediately refresh task notifications to show the new task - await refreshTasks() - - // Show sync started message - setSyncResults(prev => ({ - ...prev, - [connector.id]: { - message: "Check task notification panel for progress", - isStarted: true - } - })) - setIsSyncing(null) + + if (response.status === 201) { + const taskId = result.task_id + if (taskId) { + addTask(taskId) + setSyncResult({ + processed: 0, + total: selectedFiles.length, + status: 'started' + }) + } } else if (response.ok) { - // Direct sync result - still show "sync started" message - setSyncResults(prev => ({ - ...prev, - [connector.id]: { - message: "Check task notification panel for progress", - isStarted: true - } - })) - setIsSyncing(null) + setSyncResult(result) } else { - throw new Error(result.error || 'Sync failed') + console.error('Sync failed:', result.error) + setSyncResult({ error: result.error || 'Sync failed' }) } } catch (error) { - console.error('Sync failed:', error) - setSyncResults(prev => ({ - ...prev, - [connector.id]: { - error: error instanceof Error ? error.message : 'Sync failed' - } - })) - setIsSyncing(null) + console.error('Sync error:', error) + setSyncResult({ error: 'Network error occurred' }) + } finally { + setIsSyncing(false) } - } - - const handleDisconnect = async (connector: Connector) => { - // This would call a disconnect endpoint when implemented - setConnectors(prev => prev.map(c => - c.id === connector.id ? { ...c, status: "not_connected", connectionId: undefined } : c - )) - setSyncResults(prev => ({ ...prev, [connector.id]: null })) - } - - const getStatusIcon = (status: Connector['status']) => { - switch (status) { - case "connected": - return - case "connecting": - return - case "error": - return - default: - return - } - } - - const getStatusBadge = (status: Connector['status']) => { - switch (status) { - case "connected": - return Connected - case "connecting": - return Connecting... - case "error": - return Error - default: - return Not Connected - } - } - - // Check connector status on mount and when returning from OAuth - useEffect(() => { - if (isAuthenticated) { - checkConnectorStatuses() - } - - // If we just returned from OAuth, clear the URL parameter - if (searchParams.get('oauth_success') === 'true') { - // Clear the URL parameter without causing a page reload - const url = new URL(window.location.href) - url.searchParams.delete('oauth_success') - window.history.replaceState({}, '', url.toString()) - } - }, [searchParams, isAuthenticated, checkConnectorStatuses]) + }; return ( -
-
-

Connectors

-

- Connect external services to automatically sync and index your documents +

+

Connectors

+ +
+

+ This is a demo page for the Google Drive picker component. + For full connector functionality, visit the Settings page.

+ +
- {/* Sync Settings */} - - - - - Sync Settings - - - Configure how many files to sync when manually triggering a sync - - - -
-
- - setMaxFiles(parseInt(e.target.value) || 10)} - className="w-24" - min="1" - max="100" - /> - - (Leave blank or set to 0 for unlimited) - -
-
-
-
- - {/* Connectors Grid */} -
- {connectors.map((connector) => ( - - -
-
- {connector.icon} -
- {connector.name} -
- {getStatusIcon(connector.status)} - {getStatusBadge(connector.status)} -
-
+ {selectedFiles.length > 0 && ( +
+ + + {syncResult && ( +
+ {syncResult.error ? ( +
Error: {syncResult.error}
+ ) : syncResult.status === 'started' ? ( +
+ Sync started for {syncResult.total} files. Check the task notification for progress.
-
- - {connector.description} - - - -
- {connector.status === "not_connected" && ( - - )} - - {connector.status === "connected" && ( - <> - - - - )} - - {connector.status === "error" && ( - - )} -
- - {/* Sync Results */} - {syncResults[connector.id] && ( -
- {syncResults[connector.id]?.isStarted && ( -
-
- - Task initiated: -
-
- {syncResults[connector.id]?.message} -
-
- )} - {syncResults[connector.id]?.error && ( -
-
- - Sync Failed -
-
- {syncResults[connector.id]?.error} -
-
- )} + ) : ( +
+
Processed: {syncResult.processed || 0}
+
Added: {syncResult.added || 0}
+ {syncResult.errors &&
Errors: {syncResult.errors}
}
)} - - - ))} -
- - {/* Coming Soon Section */} - - - Coming Soon - - Additional connectors are in development - - - -
-
-
D
-
-
Dropbox
-
File storage
-
-
-
O
-
-
OneDrive
-
Microsoft cloud storage
-
-
-
-
B
-
-
Box
-
Enterprise file sharing
-
-
-
-
-
+ )} +
+ )}
- ) + ); } - -export default function ProtectedConnectorsPage() { - return ( - - Loading connectors...
}> - - - - ) -} \ No newline at end of file diff --git a/frontend/src/app/settings/page.tsx b/frontend/src/app/settings/page.tsx index ccb43eac..89b5b97f 100644 --- a/frontend/src/app/settings/page.tsx +++ b/frontend/src/app/settings/page.tsx @@ -19,6 +19,25 @@ import { ProtectedRoute } from "@/components/protected-route"; import { useTask } from "@/contexts/task-context"; import { useAuth } from "@/contexts/auth-context"; +interface GoogleDriveFile { + id: string + name: string + mimeType: string + webViewLink?: string + iconLink?: string +} + +interface OneDriveFile { + id: string + name: string + mimeType?: string + webUrl?: string + driveItem?: { + file?: { mimeType: string } + folder?: any + } +} + interface Connector { id: string; name: string; diff --git a/frontend/src/app/upload/[provider]/page.tsx b/frontend/src/app/upload/[provider]/page.tsx new file mode 100644 index 00000000..000c9202 --- /dev/null +++ b/frontend/src/app/upload/[provider]/page.tsx @@ -0,0 +1,370 @@ +"use client" + +import { useState, useEffect } from "react" +import { useParams, useRouter } from "next/navigation" +import { Button } from "@/components/ui/button" +import { ArrowLeft, AlertCircle } from "lucide-react" +import { GoogleDrivePicker } from "@/components/google-drive-picker" +import { OneDrivePicker } from "@/components/onedrive-picker" +import { useTask } from "@/contexts/task-context" +import { Toast } from "@/components/ui/toast" + +interface GoogleDriveFile { + id: string + name: string + mimeType: string + webViewLink?: string + iconLink?: string +} + +interface OneDriveFile { + id: string + name: string + mimeType?: string + webUrl?: string + driveItem?: { + file?: { mimeType: string } + folder?: object + } +} + +interface CloudConnector { + id: string + name: string + description: string + status: "not_connected" | "connecting" | "connected" | "error" + type: string + connectionId?: string + hasAccessToken: boolean + accessTokenError?: string +} + +export default function UploadProviderPage() { + const params = useParams() + const router = useRouter() + const provider = params.provider as string + const { addTask, tasks } = useTask() + + const [connector, setConnector] = useState(null) + const [isLoading, setIsLoading] = useState(true) + const [error, setError] = useState(null) + const [accessToken, setAccessToken] = useState(null) + const [selectedFiles, setSelectedFiles] = useState([]) + const [isIngesting, setIsIngesting] = useState(false) + const [currentSyncTaskId, setCurrentSyncTaskId] = useState(null) + const [showSuccessToast, setShowSuccessToast] = useState(false) + + useEffect(() => { + const fetchConnectorInfo = async () => { + setIsLoading(true) + setError(null) + + try { + // Fetch available connectors to validate the provider + const connectorsResponse = await fetch('/api/connectors') + if (!connectorsResponse.ok) { + throw new Error('Failed to load connectors') + } + + const connectorsResult = await connectorsResponse.json() + const providerInfo = connectorsResult.connectors[provider] + + if (!providerInfo || !providerInfo.available) { + setError(`Cloud provider "${provider}" is not available or configured.`) + return + } + + // Check connector status + const statusResponse = await fetch(`/api/connectors/${provider}/status`) + if (!statusResponse.ok) { + throw new Error(`Failed to check ${provider} status`) + } + + const statusData = await statusResponse.json() + const connections = statusData.connections || [] + const activeConnection = connections.find((conn: {is_active: boolean, connection_id: string}) => conn.is_active) + const isConnected = activeConnection !== undefined + + let hasAccessToken = false + let accessTokenError: string | undefined = undefined + + // Try to get access token for connected connectors + if (isConnected && activeConnection) { + try { + const tokenResponse = await fetch(`/api/connectors/${provider}/token?connection_id=${activeConnection.connection_id}`) + if (tokenResponse.ok) { + const tokenData = await tokenResponse.json() + if (tokenData.access_token) { + hasAccessToken = true + setAccessToken(tokenData.access_token) + } + } else { + const errorData = await tokenResponse.json().catch(() => ({ error: 'Token unavailable' })) + accessTokenError = errorData.error || 'Access token unavailable' + } + } catch { + accessTokenError = 'Failed to fetch access token' + } + } + + setConnector({ + id: provider, + name: providerInfo.name, + description: providerInfo.description, + status: isConnected ? "connected" : "not_connected", + type: provider, + connectionId: activeConnection?.connection_id, + hasAccessToken, + accessTokenError + }) + + } catch (error) { + console.error('Failed to load connector info:', error) + setError(error instanceof Error ? error.message : 'Failed to load connector information') + } finally { + setIsLoading(false) + } + } + + if (provider) { + fetchConnectorInfo() + } + }, [provider]) + + // Watch for sync task completion and redirect + useEffect(() => { + if (!currentSyncTaskId) return + + const currentTask = tasks.find(task => task.task_id === currentSyncTaskId) + + if (currentTask && currentTask.status === 'completed') { + // Task completed successfully, show toast and redirect + setIsIngesting(false) + setShowSuccessToast(true) + setTimeout(() => { + router.push('/knowledge') + }, 2000) // 2 second delay to let user see toast + } else if (currentTask && currentTask.status === 'failed') { + // Task failed, clear the tracking but don't redirect + setIsIngesting(false) + setCurrentSyncTaskId(null) + } + }, [tasks, currentSyncTaskId, router]) + + const handleFileSelected = (files: GoogleDriveFile[] | OneDriveFile[]) => { + setSelectedFiles(files) + console.log(`Selected ${files.length} files from ${provider}:`, files) + // You can add additional handling here like triggering sync, etc. + } + + const handleSync = async (connector: CloudConnector) => { + if (!connector.connectionId || selectedFiles.length === 0) return + + setIsIngesting(true) + + try { + const syncBody: { + connection_id: string; + max_files?: number; + selected_files?: string[]; + } = { + connection_id: connector.connectionId, + selected_files: selectedFiles.map(file => file.id) + } + + const response = await fetch(`/api/connectors/${connector.type}/sync`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify(syncBody), + }) + + const result = await response.json() + + if (response.status === 201) { + const taskIds = result.task_ids + if (taskIds && taskIds.length > 0) { + const taskId = taskIds[0] // Use the first task ID + addTask(taskId) + setCurrentSyncTaskId(taskId) + } + } else { + console.error('Sync failed:', result.error) + } + } catch (error) { + console.error('Sync error:', error) + setIsIngesting(false) + } + } + + const getProviderDisplayName = () => { + const nameMap: { [key: string]: string } = { + 'google_drive': 'Google Drive', + 'onedrive': 'OneDrive', + 'sharepoint': 'SharePoint' + } + return nameMap[provider] || provider + } + + if (isLoading) { + return ( +
+
+
+
+

Loading {getProviderDisplayName()} connector...

+
+
+
+ ) + } + + if (error || !connector) { + return ( +
+
+ +
+ +
+
+ +

Provider Not Available

+

{error}

+ +
+
+
+ ) + } + + if (connector.status !== "connected") { + return ( +
+
+ +
+ +
+
+ +

{connector.name} Not Connected

+

+ You need to connect your {connector.name} account before you can select files. +

+ +
+
+
+ ) + } + + if (!connector.hasAccessToken) { + return ( +
+
+ +
+ +
+
+ +

Access Token Required

+

+ {connector.accessTokenError || `Unable to get access token for ${connector.name}. Try reconnecting your account.`} +

+ +
+
+
+ ) + } + + return ( +
+
+ +

Add Cloud Knowledge

+
+ +
+ {connector.type === "google_drive" && ( + + )} + + {(connector.type === "onedrive" || connector.type === "sharepoint") && ( + + )} +
+ + {selectedFiles.length > 0 && ( +
+
+ +
+
+ )} + + {/* Success toast notification */} + setShowSuccessToast(false)} + duration={20000} + /> +
+ ) +} \ No newline at end of file diff --git a/frontend/src/components/cloud-connectors-dialog.tsx b/frontend/src/components/cloud-connectors-dialog.tsx new file mode 100644 index 00000000..a9fefbd1 --- /dev/null +++ b/frontend/src/components/cloud-connectors-dialog.tsx @@ -0,0 +1,299 @@ +"use client" + +import { useState, useEffect, useCallback } from "react" +import { Button } from "@/components/ui/button" +import { Badge } from "@/components/ui/badge" +import { + Dialog, + DialogContent, + DialogDescription, + DialogHeader, + DialogTitle, +} from "@/components/ui/dialog" +import { GoogleDrivePicker } from "@/components/google-drive-picker" +import { OneDrivePicker } from "@/components/onedrive-picker" +import { Loader2 } from "lucide-react" + +interface GoogleDriveFile { + id: string + name: string + mimeType: string + webViewLink?: string + iconLink?: string +} + +interface OneDriveFile { + id: string + name: string + mimeType?: string + webUrl?: string + driveItem?: { + file?: { mimeType: string } + folder?: any + } +} + +interface CloudConnector { + id: string + name: string + description: string + icon: React.ReactNode + status: "not_connected" | "connecting" | "connected" | "error" + type: string + connectionId?: string + hasAccessToken: boolean + accessTokenError?: string +} + +interface CloudConnectorsDialogProps { + isOpen: boolean + onOpenChange: (open: boolean) => void + onFileSelected?: (files: GoogleDriveFile[] | OneDriveFile[], connectorType: string) => void +} + +export function CloudConnectorsDialog({ + isOpen, + onOpenChange, + onFileSelected +}: CloudConnectorsDialogProps) { + const [connectors, setConnectors] = useState([]) + const [isLoading, setIsLoading] = useState(true) + const [selectedFiles, setSelectedFiles] = useState<{[connectorId: string]: GoogleDriveFile[] | OneDriveFile[]}>({}) + const [connectorAccessTokens, setConnectorAccessTokens] = useState<{[connectorType: string]: string}>({}) + const [activePickerType, setActivePickerType] = useState(null) + const [isGooglePickerOpen, setIsGooglePickerOpen] = useState(false) + + const getConnectorIcon = (iconName: string) => { + const iconMap: { [key: string]: React.ReactElement } = { + 'google-drive': ( +
+ G +
+ ), + 'sharepoint': ( +
+ SP +
+ ), + 'onedrive': ( +
+ OD +
+ ), + } + return iconMap[iconName] || ( +
+ ? +
+ ) + } + + const fetchConnectorStatuses = useCallback(async () => { + if (!isOpen) return + + setIsLoading(true) + try { + // Fetch available connectors from backend + const connectorsResponse = await fetch('/api/connectors') + if (!connectorsResponse.ok) { + throw new Error('Failed to load connectors') + } + + const connectorsResult = await connectorsResponse.json() + const connectorTypes = Object.keys(connectorsResult.connectors) + + // Filter to only cloud connectors + const cloudConnectorTypes = connectorTypes.filter(type => + ['google_drive', 'onedrive', 'sharepoint'].includes(type) && + connectorsResult.connectors[type].available + ) + + // Initialize connectors list + const initialConnectors = cloudConnectorTypes.map(type => ({ + id: type, + name: connectorsResult.connectors[type].name, + description: connectorsResult.connectors[type].description, + icon: getConnectorIcon(connectorsResult.connectors[type].icon), + status: "not_connected" as const, + type: type, + hasAccessToken: false, + accessTokenError: undefined + })) + + setConnectors(initialConnectors) + + // Check status for each cloud connector type + for (const connectorType of cloudConnectorTypes) { + try { + const response = await fetch(`/api/connectors/${connectorType}/status`) + if (response.ok) { + const data = await response.json() + const connections = data.connections || [] + const activeConnection = connections.find((conn: any) => conn.is_active) + const isConnected = activeConnection !== undefined + + let hasAccessToken = false + let accessTokenError: string | undefined = undefined + + // Try to get access token for connected connectors + if (isConnected && activeConnection) { + try { + const tokenResponse = await fetch(`/api/connectors/${connectorType}/token?connection_id=${activeConnection.connection_id}`) + if (tokenResponse.ok) { + const tokenData = await tokenResponse.json() + if (tokenData.access_token) { + hasAccessToken = true + setConnectorAccessTokens(prev => ({ + ...prev, + [connectorType]: tokenData.access_token + })) + } + } else { + const errorData = await tokenResponse.json().catch(() => ({ error: 'Token unavailable' })) + accessTokenError = errorData.error || 'Access token unavailable' + } + } catch (e) { + accessTokenError = 'Failed to fetch access token' + } + } + + setConnectors(prev => prev.map(c => + c.type === connectorType + ? { + ...c, + status: isConnected ? "connected" : "not_connected", + connectionId: activeConnection?.connection_id, + hasAccessToken, + accessTokenError + } + : c + )) + } + } catch (error) { + console.error(`Failed to check status for ${connectorType}:`, error) + } + } + } catch (error) { + console.error('Failed to load cloud connectors:', error) + } finally { + setIsLoading(false) + } + }, [isOpen]) + + const handleFileSelection = (connectorId: string, files: GoogleDriveFile[] | OneDriveFile[]) => { + setSelectedFiles(prev => ({ + ...prev, + [connectorId]: files + })) + + onFileSelected?.(files, connectorId) + } + + useEffect(() => { + fetchConnectorStatuses() + }, [fetchConnectorStatuses]) + + + return ( + + + + Cloud File Connectors + + Select files from your connected cloud storage providers + + + +
+ {isLoading ? ( +
+ + Loading connectors... +
+ ) : connectors.length === 0 ? ( +
+ No cloud connectors available. Configure them in Settings first. +
+ ) : ( +
+ {/* Service Buttons Row */} +
+ {connectors + .filter(connector => connector.status === "connected") + .map((connector) => ( + + ))} +
+ + {connectors.every(c => c.status !== "connected") && ( +
+

No connected cloud providers found.

+

Go to Settings to connect your cloud storage accounts.

+
+ )} + + {/* Render pickers inside dialog */} + {activePickerType && connectors.find(c => c.id === activePickerType) && (() => { + const connector = connectors.find(c => c.id === activePickerType)! + + if (connector.type === "google_drive") { + return ( +
+ { + handleFileSelection(connector.id, files) + setActivePickerType(null) + setIsGooglePickerOpen(false) + }} + selectedFiles={selectedFiles[connector.id] as GoogleDriveFile[] || []} + isAuthenticated={connector.status === "connected"} + accessToken={connectorAccessTokens[connector.type]} + onPickerStateChange={setIsGooglePickerOpen} + /> +
+ ) + } + + if (connector.type === "onedrive" || connector.type === "sharepoint") { + return ( +
+ { + handleFileSelection(connector.id, files) + setActivePickerType(null) + }} + selectedFiles={selectedFiles[connector.id] as OneDriveFile[] || []} + isAuthenticated={connector.status === "connected"} + accessToken={connectorAccessTokens[connector.type]} + connectorType={connector.type as "onedrive" | "sharepoint"} + /> +
+ ) + } + + return null + })()} +
+ )} +
+
+
+ ) +} \ No newline at end of file diff --git a/frontend/src/components/cloud-connectors-dropdown.tsx b/frontend/src/components/cloud-connectors-dropdown.tsx new file mode 100644 index 00000000..1989132a --- /dev/null +++ b/frontend/src/components/cloud-connectors-dropdown.tsx @@ -0,0 +1,77 @@ +"use client" + +import { useState } from "react" +import { Button } from "@/components/ui/button" +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuTrigger, +} from "@/components/ui/dropdown-menu" +import { CloudConnectorsDialog } from "@/components/cloud-connectors-dialog" +import { Cloud, ChevronDown } from "lucide-react" + +interface GoogleDriveFile { + id: string + name: string + mimeType: string + webViewLink?: string + iconLink?: string +} + +interface OneDriveFile { + id: string + name: string + mimeType?: string + webUrl?: string + driveItem?: { + file?: { mimeType: string } + folder?: any + } +} + +interface CloudConnectorsDropdownProps { + onFileSelected?: (files: GoogleDriveFile[] | OneDriveFile[], connectorType: string) => void + buttonText?: string + variant?: "default" | "outline" | "secondary" | "ghost" | "link" | "destructive" + size?: "default" | "sm" | "lg" | "icon" +} + +export function CloudConnectorsDropdown({ + onFileSelected, + buttonText = "Cloud Files", + variant = "outline", + size = "default" +}: CloudConnectorsDropdownProps) { + const [isDialogOpen, setIsDialogOpen] = useState(false) + + const handleOpenDialog = () => { + setIsDialogOpen(true) + } + + return ( + <> + + + + + + + + Select Cloud Files + + + + + + + ) +} \ No newline at end of file diff --git a/frontend/src/components/google-drive-picker.tsx b/frontend/src/components/google-drive-picker.tsx new file mode 100644 index 00000000..c9dee19a --- /dev/null +++ b/frontend/src/components/google-drive-picker.tsx @@ -0,0 +1,341 @@ +"use client" + +import { useState, useEffect } from "react" +import { Button } from "@/components/ui/button" +import { Badge } from "@/components/ui/badge" +import { FileText, Folder, Plus, Trash2 } from "lucide-react" +import { Card, CardContent } from "@/components/ui/card" + +interface GoogleDrivePickerProps { + onFileSelected: (files: GoogleDriveFile[]) => void + selectedFiles?: GoogleDriveFile[] + isAuthenticated: boolean + accessToken?: string + onPickerStateChange?: (isOpen: boolean) => void +} + +interface GoogleDriveFile { + id: string + name: string + mimeType: string + webViewLink?: string + iconLink?: string + size?: number + modifiedTime?: string + isFolder?: boolean +} + +interface GoogleAPI { + load: (api: string, options: { callback: () => void; onerror?: () => void }) => void +} + +interface GooglePickerData { + action: string + docs: GooglePickerDocument[] +} + +interface GooglePickerDocument { + [key: string]: string +} + +declare global { + interface Window { + gapi: GoogleAPI + google: { + picker: { + api: { + load: (callback: () => void) => void + } + PickerBuilder: new () => GooglePickerBuilder + ViewId: { + DOCS: string + FOLDERS: string + DOCS_IMAGES_AND_VIDEOS: string + DOCUMENTS: string + PRESENTATIONS: string + SPREADSHEETS: string + } + Feature: { + MULTISELECT_ENABLED: string + NAV_HIDDEN: string + SIMPLE_UPLOAD_ENABLED: string + } + Action: { + PICKED: string + CANCEL: string + } + Document: { + ID: string + NAME: string + MIME_TYPE: string + URL: string + ICON_URL: string + } + } + } + } +} + +interface GooglePickerBuilder { + addView: (view: string) => GooglePickerBuilder + setOAuthToken: (token: string) => GooglePickerBuilder + setCallback: (callback: (data: GooglePickerData) => void) => GooglePickerBuilder + enableFeature: (feature: string) => GooglePickerBuilder + setTitle: (title: string) => GooglePickerBuilder + build: () => GooglePicker +} + +interface GooglePicker { + setVisible: (visible: boolean) => void +} + +export function GoogleDrivePicker({ + onFileSelected, + selectedFiles = [], + isAuthenticated, + accessToken, + onPickerStateChange +}: GoogleDrivePickerProps) { + const [isPickerLoaded, setIsPickerLoaded] = useState(false) + const [isPickerOpen, setIsPickerOpen] = useState(false) + + useEffect(() => { + const loadPickerApi = () => { + if (typeof window !== 'undefined' && window.gapi) { + window.gapi.load('picker', { + callback: () => { + setIsPickerLoaded(true) + }, + onerror: () => { + console.error('Failed to load Google Picker API') + } + }) + } + } + + // Load Google API script if not already loaded + if (typeof window !== 'undefined') { + if (!window.gapi) { + const script = document.createElement('script') + script.src = 'https://apis.google.com/js/api.js' + script.async = true + script.defer = true + script.onload = loadPickerApi + script.onerror = () => { + console.error('Failed to load Google API script') + } + document.head.appendChild(script) + + return () => { + if (document.head.contains(script)) { + document.head.removeChild(script) + } + } + } else { + loadPickerApi() + } + } + }, []) + + + const openPicker = () => { + if (!isPickerLoaded || !accessToken || !window.google?.picker) { + return + } + + try { + setIsPickerOpen(true) + onPickerStateChange?.(true) + + // Create picker with higher z-index and focus handling + const picker = new window.google.picker.PickerBuilder() + .addView(window.google.picker.ViewId.DOCS) + .addView(window.google.picker.ViewId.FOLDERS) + .setOAuthToken(accessToken) + .enableFeature(window.google.picker.Feature.MULTISELECT_ENABLED) + .setTitle('Select files from Google Drive') + .setCallback(pickerCallback) + .build() + + picker.setVisible(true) + + // Apply z-index fix after a short delay to ensure picker is rendered + setTimeout(() => { + const pickerElements = document.querySelectorAll('.picker-dialog, .goog-modalpopup') + pickerElements.forEach(el => { + (el as HTMLElement).style.zIndex = '10000' + }) + const bgElements = document.querySelectorAll('.picker-dialog-bg, .goog-modalpopup-bg') + bgElements.forEach(el => { + (el as HTMLElement).style.zIndex = '9999' + }) + }, 100) + + } catch (error) { + console.error('Error creating picker:', error) + setIsPickerOpen(false) + onPickerStateChange?.(false) + } + } + + const pickerCallback = async (data: GooglePickerData) => { + if (data.action === window.google.picker.Action.PICKED) { + const files: GoogleDriveFile[] = data.docs.map((doc: GooglePickerDocument) => ({ + id: doc[window.google.picker.Document.ID], + name: doc[window.google.picker.Document.NAME], + mimeType: doc[window.google.picker.Document.MIME_TYPE], + webViewLink: doc[window.google.picker.Document.URL], + iconLink: doc[window.google.picker.Document.ICON_URL], + size: doc['sizeBytes'] ? parseInt(doc['sizeBytes']) : undefined, + modifiedTime: doc['lastEditedUtc'], + isFolder: doc[window.google.picker.Document.MIME_TYPE] === 'application/vnd.google-apps.folder' + })) + + // If size is still missing, try to fetch it via Google Drive API + if (accessToken && files.some(f => !f.size && !f.isFolder)) { + try { + const enrichedFiles = await Promise.all(files.map(async (file) => { + if (!file.size && !file.isFolder) { + try { + const response = await fetch(`https://www.googleapis.com/drive/v3/files/${file.id}?fields=size,modifiedTime`, { + headers: { + 'Authorization': `Bearer ${accessToken}` + } + }) + if (response.ok) { + const fileDetails = await response.json() + return { + ...file, + size: fileDetails.size ? parseInt(fileDetails.size) : undefined, + modifiedTime: fileDetails.modifiedTime || file.modifiedTime + } + } + } catch (error) { + console.warn('Failed to fetch file details:', error) + } + } + return file + })) + onFileSelected(enrichedFiles) + } catch (error) { + console.warn('Failed to enrich file data:', error) + onFileSelected(files) + } + } else { + onFileSelected(files) + } + } + + setIsPickerOpen(false) + onPickerStateChange?.(false) + } + + const removeFile = (fileId: string) => { + const updatedFiles = selectedFiles.filter(file => file.id !== fileId) + onFileSelected(updatedFiles) + } + + const getFileIcon = (mimeType: string) => { + if (mimeType.includes('folder')) { + return + } + return + } + + const getMimeTypeLabel = (mimeType: string) => { + const typeMap: { [key: string]: string } = { + 'application/vnd.google-apps.document': 'Google Doc', + 'application/vnd.google-apps.spreadsheet': 'Google Sheet', + 'application/vnd.google-apps.presentation': 'Google Slides', + 'application/vnd.google-apps.folder': 'Folder', + 'application/pdf': 'PDF', + 'text/plain': 'Text', + 'application/vnd.openxmlformats-officedocument.wordprocessingml.document': 'Word Doc', + 'application/vnd.openxmlformats-officedocument.presentationml.presentation': 'PowerPoint' + } + + return typeMap[mimeType] || 'Document' + } + + const formatFileSize = (bytes?: number) => { + if (!bytes) return '' + const sizes = ['B', 'KB', 'MB', 'GB', 'TB'] + if (bytes === 0) return '0 B' + const i = Math.floor(Math.log(bytes) / Math.log(1024)) + return `${(bytes / Math.pow(1024, i)).toFixed(1)} ${sizes[i]}` + } + + if (!isAuthenticated) { + return ( +
+ Please connect to Google Drive first to select specific files. +
+ ) + } + + return ( +
+ + +

+ Select files from Google Drive to ingest. +

+ +
+
+ + {selectedFiles.length > 0 && ( +
+
+

+ Added files +

+ +
+
+ {selectedFiles.map((file) => ( +
+
+ {getFileIcon(file.mimeType)} + {file.name} + + {getMimeTypeLabel(file.mimeType)} + +
+
+ {formatFileSize(file.size)} + +
+
+ ))} +
+ +
+ )} +
+ ) +} diff --git a/frontend/src/components/onedrive-picker.tsx b/frontend/src/components/onedrive-picker.tsx new file mode 100644 index 00000000..6d4cfc78 --- /dev/null +++ b/frontend/src/components/onedrive-picker.tsx @@ -0,0 +1,322 @@ +"use client" + +import { useState, useEffect } from "react" +import { Button } from "@/components/ui/button" +import { Badge } from "@/components/ui/badge" +import { FileText, Folder, Trash2, X } from "lucide-react" + +interface OneDrivePickerProps { + onFileSelected: (files: OneDriveFile[]) => void + selectedFiles?: OneDriveFile[] + isAuthenticated: boolean + accessToken?: string + connectorType?: "onedrive" | "sharepoint" + onPickerStateChange?: (isOpen: boolean) => void +} + +interface OneDriveFile { + id: string + name: string + mimeType?: string + webUrl?: string + driveItem?: { + file?: { mimeType: string } + folder?: any + } +} + +interface GraphResponse { + value: OneDriveFile[] +} + +declare global { + interface Window { + mgt?: { + Providers: { + globalProvider: any + } + } + } +} + +export function OneDrivePicker({ + onFileSelected, + selectedFiles = [], + isAuthenticated, + accessToken, + connectorType = "onedrive", + onPickerStateChange +}: OneDrivePickerProps) { + const [isLoading, setIsLoading] = useState(false) + const [files, setFiles] = useState([]) + const [isPickerOpen, setIsPickerOpen] = useState(false) + const [currentPath, setCurrentPath] = useState( + connectorType === "sharepoint" ? 'sites?search=' : 'me/drive/root/children' + ) + const [breadcrumbs, setBreadcrumbs] = useState<{id: string, name: string}[]>([ + {id: 'root', name: connectorType === "sharepoint" ? 'SharePoint' : 'OneDrive'} + ]) + + useEffect(() => { + const loadMGT = async () => { + if (typeof window !== 'undefined' && !window.mgt) { + try { + const mgtModule = await import('@microsoft/mgt-components') + const mgtProvider = await import('@microsoft/mgt-msal2-provider') + + // Initialize provider if needed + if (!window.mgt?.Providers?.globalProvider && accessToken) { + // For simplicity, we'll use direct Graph API calls instead of MGT components + } + } catch (error) { + console.warn('MGT not available, falling back to direct API calls') + } + } + } + + loadMGT() + }, [accessToken]) + + + const fetchFiles = async (path: string = currentPath) => { + if (!accessToken) return + + setIsLoading(true) + try { + const response = await fetch(`https://graph.microsoft.com/v1.0/${path}`, { + headers: { + 'Authorization': `Bearer ${accessToken}`, + 'Content-Type': 'application/json' + } + }) + + if (response.ok) { + const data: GraphResponse = await response.json() + setFiles(data.value || []) + } else { + console.error('Failed to fetch OneDrive files:', response.statusText) + } + } catch (error) { + console.error('Error fetching OneDrive files:', error) + } finally { + setIsLoading(false) + } + } + + const openPicker = () => { + if (!accessToken) return + + setIsPickerOpen(true) + onPickerStateChange?.(true) + fetchFiles() + } + + const closePicker = () => { + setIsPickerOpen(false) + onPickerStateChange?.(false) + setFiles([]) + setCurrentPath( + connectorType === "sharepoint" ? 'sites?search=' : 'me/drive/root/children' + ) + setBreadcrumbs([ + {id: 'root', name: connectorType === "sharepoint" ? 'SharePoint' : 'OneDrive'} + ]) + } + + const handleFileClick = (file: OneDriveFile) => { + if (file.driveItem?.folder) { + // Navigate to folder + const newPath = `me/drive/items/${file.id}/children` + setCurrentPath(newPath) + setBreadcrumbs([...breadcrumbs, {id: file.id, name: file.name}]) + fetchFiles(newPath) + } else { + // Select file + const isAlreadySelected = selectedFiles.some(f => f.id === file.id) + if (!isAlreadySelected) { + onFileSelected([...selectedFiles, file]) + } + } + } + + const navigateToBreadcrumb = (index: number) => { + if (index === 0) { + setCurrentPath('me/drive/root/children') + setBreadcrumbs([{id: 'root', name: 'OneDrive'}]) + fetchFiles('me/drive/root/children') + } else { + const targetCrumb = breadcrumbs[index] + const newPath = `me/drive/items/${targetCrumb.id}/children` + setCurrentPath(newPath) + setBreadcrumbs(breadcrumbs.slice(0, index + 1)) + fetchFiles(newPath) + } + } + + const removeFile = (fileId: string) => { + const updatedFiles = selectedFiles.filter(file => file.id !== fileId) + onFileSelected(updatedFiles) + } + + const getFileIcon = (file: OneDriveFile) => { + if (file.driveItem?.folder) { + return + } + return + } + + const getMimeTypeLabel = (file: OneDriveFile) => { + const mimeType = file.driveItem?.file?.mimeType || file.mimeType || '' + const typeMap: { [key: string]: string } = { + 'application/vnd.openxmlformats-officedocument.wordprocessingml.document': 'Word Doc', + 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': 'Excel', + 'application/vnd.openxmlformats-officedocument.presentationml.presentation': 'PowerPoint', + 'application/pdf': 'PDF', + 'text/plain': 'Text', + 'image/jpeg': 'Image', + 'image/png': 'Image', + } + + if (file.driveItem?.folder) return 'Folder' + return typeMap[mimeType] || 'Document' + } + + const serviceName = connectorType === "sharepoint" ? "SharePoint" : "OneDrive" + + if (!isAuthenticated) { + return ( +
+ Please connect to {serviceName} first to select specific files. +
+ ) + } + + return ( +
+
+
+

{serviceName} File Selection

+

+ Choose specific files to sync instead of syncing everything +

+
+ +
+ + {/* Status message when access token is missing */} + {isAuthenticated && !accessToken && ( +
+
Access token unavailable
+
The file picker requires an access token. Try disconnecting and reconnecting your {serviceName} account.
+
+ )} + + {/* File Picker Modal */} + {isPickerOpen && ( +
+
+
+

Select Files from {serviceName}

+ +
+ + {/* Breadcrumbs */} +
+ {breadcrumbs.map((crumb, index) => ( +
+ {index > 0 && /} + +
+ ))} +
+ + {/* File List */} +
+ {isLoading ? ( +
Loading...
+ ) : files.length === 0 ? ( +
No files found
+ ) : ( +
+ {files.map((file) => ( +
handleFileClick(file)} + > +
+ {getFileIcon(file)} + {file.name} + + {getMimeTypeLabel(file)} + +
+ {selectedFiles.some(f => f.id === file.id) && ( + Selected + )} +
+ ))} +
+ )} +
+
+
+ )} + + {selectedFiles.length > 0 && ( +
+

+ Selected files ({selectedFiles.length}): +

+
+ {selectedFiles.map((file) => ( +
+
+ {getFileIcon(file)} + {file.name} + + {getMimeTypeLabel(file)} + +
+ +
+ ))} +
+ +
+ )} +
+ ) +} \ No newline at end of file diff --git a/frontend/src/components/ui/toast.tsx b/frontend/src/components/ui/toast.tsx new file mode 100644 index 00000000..4d765f49 --- /dev/null +++ b/frontend/src/components/ui/toast.tsx @@ -0,0 +1,39 @@ +"use client" + +import { useState, useEffect } from 'react' +import { Check } from 'lucide-react' + +interface ToastProps { + message: string + show: boolean + onHide?: () => void + duration?: number +} + +export function Toast({ message, show, onHide, duration = 3000 }: ToastProps) { + const [isVisible, setIsVisible] = useState(show) + + useEffect(() => { + setIsVisible(show) + + if (show && duration > 0) { + const timer = setTimeout(() => { + setIsVisible(false) + onHide?.() + }, duration) + + return () => clearTimeout(timer) + } + }, [show, duration, onHide]) + + if (!isVisible) return null + + return ( +
+
+ + {message} +
+
+ ) +} \ No newline at end of file diff --git a/frontend/src/contexts/chat-context.tsx b/frontend/src/contexts/chat-context.tsx index cc734d99..db79e0d3 100644 --- a/frontend/src/contexts/chat-context.tsx +++ b/frontend/src/contexts/chat-context.tsx @@ -1,161 +1,244 @@ -"use client" +"use client"; -import React, { createContext, useContext, useState, ReactNode } from 'react' +import { + createContext, + ReactNode, + useCallback, + useContext, + useEffect, + useMemo, + useRef, + useState, +} from "react"; -export type EndpointType = 'chat' | 'langflow' +export type EndpointType = "chat" | "langflow"; interface ConversationDocument { - filename: string - uploadTime: Date + filename: string; + uploadTime: Date; } interface ConversationMessage { - role: string - content: string - timestamp?: string - response_id?: string + role: string; + content: string; + timestamp?: string; + response_id?: string; } interface ConversationData { - messages: ConversationMessage[] - endpoint: EndpointType - response_id: string - title: string - [key: string]: unknown + messages: ConversationMessage[]; + endpoint: EndpointType; + response_id: string; + title: string; + [key: string]: unknown; } interface ChatContextType { - endpoint: EndpointType - setEndpoint: (endpoint: EndpointType) => void - currentConversationId: string | null - setCurrentConversationId: (id: string | null) => void + endpoint: EndpointType; + setEndpoint: (endpoint: EndpointType) => void; + currentConversationId: string | null; + setCurrentConversationId: (id: string | null) => void; previousResponseIds: { - chat: string | null - langflow: string | null - } - setPreviousResponseIds: (ids: { chat: string | null; langflow: string | null } | ((prev: { chat: string | null; langflow: string | null }) => { chat: string | null; langflow: string | null })) => void - refreshConversations: () => void - refreshTrigger: number - loadConversation: (conversation: ConversationData) => void - startNewConversation: () => void - conversationData: ConversationData | null - forkFromResponse: (responseId: string) => void - conversationDocs: ConversationDocument[] - addConversationDoc: (filename: string) => void - clearConversationDocs: () => void - placeholderConversation: ConversationData | null - setPlaceholderConversation: (conversation: ConversationData | null) => void + chat: string | null; + langflow: string | null; + }; + setPreviousResponseIds: ( + ids: + | { chat: string | null; langflow: string | null } + | ((prev: { chat: string | null; langflow: string | null }) => { + chat: string | null; + langflow: string | null; + }) + ) => void; + refreshConversations: (force?: boolean) => void; + refreshConversationsSilent: () => Promise; + refreshTrigger: number; + refreshTriggerSilent: number; + loadConversation: (conversation: ConversationData) => void; + startNewConversation: () => void; + conversationData: ConversationData | null; + forkFromResponse: (responseId: string) => void; + conversationDocs: ConversationDocument[]; + addConversationDoc: (filename: string) => void; + clearConversationDocs: () => void; + placeholderConversation: ConversationData | null; + setPlaceholderConversation: (conversation: ConversationData | null) => void; } -const ChatContext = createContext(undefined) +const ChatContext = createContext(undefined); interface ChatProviderProps { - children: ReactNode + children: ReactNode; } export function ChatProvider({ children }: ChatProviderProps) { - const [endpoint, setEndpoint] = useState('langflow') - const [currentConversationId, setCurrentConversationId] = useState(null) + const [endpoint, setEndpoint] = useState("langflow"); + const [currentConversationId, setCurrentConversationId] = useState< + string | null + >(null); const [previousResponseIds, setPreviousResponseIds] = useState<{ - chat: string | null - langflow: string | null - }>({ chat: null, langflow: null }) - const [refreshTrigger, setRefreshTrigger] = useState(0) - const [conversationData, setConversationData] = useState(null) - const [conversationDocs, setConversationDocs] = useState([]) - const [placeholderConversation, setPlaceholderConversation] = useState(null) + chat: string | null; + langflow: string | null; + }>({ chat: null, langflow: null }); + const [refreshTrigger, setRefreshTrigger] = useState(0); + const [refreshTriggerSilent, setRefreshTriggerSilent] = useState(0); + const [conversationData, setConversationData] = + useState(null); + const [conversationDocs, setConversationDocs] = useState< + ConversationDocument[] + >([]); + const [placeholderConversation, setPlaceholderConversation] = + useState(null); - const refreshConversations = () => { - setRefreshTrigger(prev => prev + 1) - } + // Debounce refresh requests to prevent excessive reloads + const refreshTimeoutRef = useRef(null); - const loadConversation = (conversation: ConversationData) => { - setCurrentConversationId(conversation.response_id) - setEndpoint(conversation.endpoint) - // Store the full conversation data for the chat page to use - // We'll pass it through a ref or state that the chat page can access - setConversationData(conversation) - // Clear placeholder when loading a real conversation - setPlaceholderConversation(null) - } - - const startNewConversation = () => { - // Create a temporary placeholder conversation - const placeholderConversation: ConversationData = { - response_id: 'new-conversation-' + Date.now(), - title: 'New conversation', - endpoint: endpoint, - messages: [{ - role: 'assistant', - content: 'How can I assist?', - timestamp: new Date().toISOString() - }], - created_at: new Date().toISOString(), - last_activity: new Date().toISOString() + const refreshConversations = useCallback((force = false) => { + if (force) { + // Immediate refresh for important updates like new conversations + setRefreshTrigger((prev) => prev + 1); + return; } - - setCurrentConversationId(null) - setPreviousResponseIds({ chat: null, langflow: null }) - setConversationData(null) - setConversationDocs([]) - setPlaceholderConversation(placeholderConversation) - // Force a refresh to ensure sidebar shows correct state - setRefreshTrigger(prev => prev + 1) - } - const addConversationDoc = (filename: string) => { - setConversationDocs(prev => [...prev, { filename, uploadTime: new Date() }]) - } + // Clear any existing timeout + if (refreshTimeoutRef.current) { + clearTimeout(refreshTimeoutRef.current); + } - const clearConversationDocs = () => { - setConversationDocs([]) - } + // Set a new timeout to debounce multiple rapid refresh calls + refreshTimeoutRef.current = setTimeout(() => { + setRefreshTrigger((prev) => prev + 1); + }, 250); // 250ms debounce + }, []); - const forkFromResponse = (responseId: string) => { - // Start a new conversation with the messages up to the fork point - setCurrentConversationId(null) // Clear current conversation to indicate new conversation - setConversationData(null) // Clear conversation data to prevent reloading - // Set the response ID that we're forking from as the previous response ID - setPreviousResponseIds(prev => ({ + // Cleanup timeout on unmount + useEffect(() => { + return () => { + if (refreshTimeoutRef.current) { + clearTimeout(refreshTimeoutRef.current); + } + }; + }, []); + + // Silent refresh - updates data without loading states + const refreshConversationsSilent = useCallback(async () => { + // Trigger silent refresh that updates conversation data without showing loading states + setRefreshTriggerSilent((prev) => prev + 1); + }, []); + + const loadConversation = useCallback((conversation: ConversationData) => { + setCurrentConversationId(conversation.response_id); + setEndpoint(conversation.endpoint); + // Store the full conversation data for the chat page to use + setConversationData(conversation); + // Clear placeholder when loading a real conversation + setPlaceholderConversation(null); + }, []); + + const startNewConversation = useCallback(() => { + // Clear current conversation data and reset state + setCurrentConversationId(null); + setPreviousResponseIds({ chat: null, langflow: null }); + setConversationData(null); + setConversationDocs([]); + + // Create a temporary placeholder conversation to show in sidebar + const placeholderConversation: ConversationData = { + response_id: "new-conversation-" + Date.now(), + title: "New conversation", + endpoint: endpoint, + messages: [ + { + role: "assistant", + content: "How can I assist?", + timestamp: new Date().toISOString(), + }, + ], + created_at: new Date().toISOString(), + last_activity: new Date().toISOString(), + }; + + setPlaceholderConversation(placeholderConversation); + // Force immediate refresh to ensure sidebar shows correct state + refreshConversations(true); + }, [endpoint, refreshConversations]); + + const addConversationDoc = useCallback((filename: string) => { + setConversationDocs((prev) => [ ...prev, - [endpoint]: responseId - })) - // Clear placeholder when forking - setPlaceholderConversation(null) - // The messages are already set by the chat page component before calling this - } + { filename, uploadTime: new Date() }, + ]); + }, []); - const value: ChatContextType = { - endpoint, - setEndpoint, - currentConversationId, - setCurrentConversationId, - previousResponseIds, - setPreviousResponseIds, - refreshConversations, - refreshTrigger, - loadConversation, - startNewConversation, - conversationData, - forkFromResponse, - conversationDocs, - addConversationDoc, - clearConversationDocs, - placeholderConversation, - setPlaceholderConversation, - } + const clearConversationDocs = useCallback(() => { + setConversationDocs([]); + }, []); - return ( - - {children} - - ) + const forkFromResponse = useCallback( + (responseId: string) => { + // Start a new conversation with the messages up to the fork point + setCurrentConversationId(null); // Clear current conversation to indicate new conversation + setConversationData(null); // Clear conversation data to prevent reloading + // Set the response ID that we're forking from as the previous response ID + setPreviousResponseIds((prev) => ({ + ...prev, + [endpoint]: responseId, + })); + // Clear placeholder when forking + setPlaceholderConversation(null); + // The messages are already set by the chat page component before calling this + }, + [endpoint] + ); + + const value = useMemo( + () => ({ + endpoint, + setEndpoint, + currentConversationId, + setCurrentConversationId, + previousResponseIds, + setPreviousResponseIds, + refreshConversations, + refreshConversationsSilent, + refreshTrigger, + refreshTriggerSilent, + loadConversation, + startNewConversation, + conversationData, + forkFromResponse, + conversationDocs, + addConversationDoc, + clearConversationDocs, + placeholderConversation, + setPlaceholderConversation, + }), + [ + endpoint, + currentConversationId, + previousResponseIds, + refreshConversations, + refreshConversationsSilent, + refreshTrigger, + refreshTriggerSilent, + loadConversation, + startNewConversation, + conversationData, + forkFromResponse, + conversationDocs, + addConversationDoc, + clearConversationDocs, + placeholderConversation, + ] + ); + + return {children}; } export function useChat(): ChatContextType { - const context = useContext(ChatContext) + const context = useContext(ChatContext); if (context === undefined) { - throw new Error('useChat must be used within a ChatProvider') + throw new Error("useChat must be used within a ChatProvider"); } - return context -} \ No newline at end of file + return context; +} diff --git a/pyproject.toml b/pyproject.toml index 20d8f5c4..3ed39646 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,11 +29,15 @@ dependencies = [ "structlog>=25.4.0", ] +[project.scripts] +openrag = "tui.main:run_tui" + +[tool.uv] +package = true + [tool.uv.sources] -#agentd = { path = "/home/tato/Desktop/agentd" } torch = [ { index = "pytorch-cu128", marker = "sys_platform == 'linux' and platform_machine == 'x86_64'" }, - # macOS & other platforms use PyPI (no index entry needed) ] torchvision = [ { index = "pytorch-cu128", marker = "sys_platform == 'linux' and platform_machine == 'x86_64'" }, diff --git a/src/agent.py b/src/agent.py index d77ac8b1..6776a317 100644 --- a/src/agent.py +++ b/src/agent.py @@ -2,31 +2,31 @@ from utils.logging_config import get_logger logger = get_logger(__name__) -# User-scoped conversation state - keyed by user_id -> response_id -> conversation -user_conversations = {} # user_id -> {response_id: {"messages": [...], "previous_response_id": parent_id, "created_at": timestamp, "last_activity": timestamp}} +# Import persistent storage +from services.conversation_persistence_service import conversation_persistence +# In-memory storage for active conversation threads (preserves function calls) +active_conversations = {} def get_user_conversations(user_id: str): - """Get all conversations for a user""" - if user_id not in user_conversations: - user_conversations[user_id] = {} - return user_conversations[user_id] + """Get conversation metadata for a user from persistent storage""" + return conversation_persistence.get_user_conversations(user_id) def get_conversation_thread(user_id: str, previous_response_id: str = None): - """Get or create a specific conversation thread""" - conversations = get_user_conversations(user_id) - - if previous_response_id and previous_response_id in conversations: - # Update last activity and return existing conversation - conversations[previous_response_id]["last_activity"] = __import__( - "datetime" - ).datetime.now() - return conversations[previous_response_id] - - # Create new conversation thread + """Get or create a specific conversation thread with function call preservation""" from datetime import datetime + # Create user namespace if it doesn't exist + if user_id not in active_conversations: + active_conversations[user_id] = {} + + # If we have a previous_response_id, try to get the existing conversation + if previous_response_id and previous_response_id in active_conversations[user_id]: + logger.debug(f"Retrieved existing conversation for user {user_id}, response_id {previous_response_id}") + return active_conversations[user_id][previous_response_id] + + # Create new conversation thread new_conversation = { "messages": [ { @@ -43,19 +43,49 @@ def get_conversation_thread(user_id: str, previous_response_id: str = None): def store_conversation_thread(user_id: str, response_id: str, conversation_state: dict): - """Store a conversation thread with its response_id""" - conversations = get_user_conversations(user_id) - conversations[response_id] = conversation_state + """Store conversation both in memory (with function calls) and persist metadata to disk""" + # 1. Store full conversation in memory for function call preservation + if user_id not in active_conversations: + active_conversations[user_id] = {} + active_conversations[user_id][response_id] = conversation_state + + # 2. Store only essential metadata to disk (simplified JSON) + messages = conversation_state.get("messages", []) + first_user_msg = next((msg for msg in messages if msg.get("role") == "user"), None) + title = "New Chat" + if first_user_msg: + content = first_user_msg.get("content", "") + title = content[:50] + "..." if len(content) > 50 else content + + metadata_only = { + "response_id": response_id, + "title": title, + "endpoint": "langflow", + "created_at": conversation_state.get("created_at"), + "last_activity": conversation_state.get("last_activity"), + "previous_response_id": conversation_state.get("previous_response_id"), + "total_messages": len([msg for msg in messages if msg.get("role") in ["user", "assistant"]]), + # Don't store actual messages - Langflow has them + } + + conversation_persistence.store_conversation_thread(user_id, response_id, metadata_only) # Legacy function for backward compatibility def get_user_conversation(user_id: str): """Get the most recent conversation for a user (for backward compatibility)""" + # Check in-memory conversations first (with function calls) + if user_id in active_conversations and active_conversations[user_id]: + latest_response_id = max(active_conversations[user_id].keys(), + key=lambda k: active_conversations[user_id][k]["last_activity"]) + return active_conversations[user_id][latest_response_id] + + # Fallback to metadata-only conversations conversations = get_user_conversations(user_id) if not conversations: return get_conversation_thread(user_id) - # Return the most recently active conversation + # Return the most recently active conversation metadata latest_conversation = max(conversations.values(), key=lambda c: c["last_activity"]) return latest_conversation @@ -183,7 +213,7 @@ async def async_response( response, "response_id", None ) - return response_text, response_id + return response_text, response_id, response # Unified streaming function for both chat and langflow @@ -214,7 +244,7 @@ async def async_langflow( extra_headers: dict = None, previous_response_id: str = None, ): - response_text, response_id = await async_response( + response_text, response_id, response_obj = await async_response( langflow_client, prompt, flow_id, @@ -284,7 +314,7 @@ async def async_chat( "Added user message", message_count=len(conversation_state["messages"]) ) - response_text, response_id = await async_response( + response_text, response_id, response_obj = await async_response( async_client, prompt, model, @@ -295,12 +325,13 @@ async def async_chat( "Got response", response_preview=response_text[:50], response_id=response_id ) - # Add assistant response to conversation with response_id and timestamp + # Add assistant response to conversation with response_id, timestamp, and full response object assistant_message = { "role": "assistant", "content": response_text, "response_id": response_id, "timestamp": datetime.now(), + "response_data": response_obj.model_dump() if hasattr(response_obj, "model_dump") else str(response_obj), # Store complete response for function calls } conversation_state["messages"].append(assistant_message) logger.debug( @@ -422,7 +453,7 @@ async def async_langflow_chat( message_count=len(conversation_state["messages"]), ) - response_text, response_id = await async_response( + response_text, response_id, response_obj = await async_response( langflow_client, prompt, flow_id, @@ -436,12 +467,13 @@ async def async_langflow_chat( response_id=response_id, ) - # Add assistant response to conversation with response_id and timestamp + # Add assistant response to conversation with response_id, timestamp, and full response object assistant_message = { "role": "assistant", "content": response_text, "response_id": response_id, "timestamp": datetime.now(), + "response_data": response_obj.model_dump() if hasattr(response_obj, "model_dump") else str(response_obj), # Store complete response for function calls } conversation_state["messages"].append(assistant_message) logger.debug( @@ -453,11 +485,19 @@ async def async_langflow_chat( if response_id: conversation_state["last_activity"] = datetime.now() store_conversation_thread(user_id, response_id, conversation_state) - logger.debug( - "Stored langflow conversation thread", - user_id=user_id, - response_id=response_id, + + # Claim session ownership for this user + try: + from services.session_ownership_service import session_ownership_service + session_ownership_service.claim_session(user_id, response_id) + print(f"[DEBUG] Claimed session {response_id} for user {user_id}") + except Exception as e: + print(f"[WARNING] Failed to claim session ownership: {e}") + + print( + f"[DEBUG] Stored langflow conversation thread for user {user_id} with response_id: {response_id}" ) + logger.debug("Stored langflow conversation thread", user_id=user_id, response_id=response_id) # Debug: Check what's in user_conversations now conversations = get_user_conversations(user_id) @@ -499,6 +539,8 @@ async def async_langflow_chat_stream( full_response = "" response_id = None + collected_chunks = [] # Store all chunks for function call data + async for chunk in async_stream( langflow_client, prompt, @@ -512,6 +554,8 @@ async def async_langflow_chat_stream( import json chunk_data = json.loads(chunk.decode("utf-8")) + collected_chunks.append(chunk_data) # Collect all chunk data + if "delta" in chunk_data and "content" in chunk_data["delta"]: full_response += chunk_data["delta"]["content"] # Extract response_id from chunk @@ -523,13 +567,14 @@ async def async_langflow_chat_stream( pass yield chunk - # Add the complete assistant response to message history with response_id and timestamp + # Add the complete assistant response to message history with response_id, timestamp, and function call data if full_response: assistant_message = { "role": "assistant", "content": full_response, "response_id": response_id, "timestamp": datetime.now(), + "chunks": collected_chunks, # Store complete chunk data for function calls } conversation_state["messages"].append(assistant_message) @@ -537,8 +582,16 @@ async def async_langflow_chat_stream( if response_id: conversation_state["last_activity"] = datetime.now() store_conversation_thread(user_id, response_id, conversation_state) - logger.debug( - "Stored langflow conversation thread", - user_id=user_id, - response_id=response_id, + + # Claim session ownership for this user + try: + from services.session_ownership_service import session_ownership_service + session_ownership_service.claim_session(user_id, response_id) + print(f"[DEBUG] Claimed session {response_id} for user {user_id}") + except Exception as e: + print(f"[WARNING] Failed to claim session ownership: {e}") + + print( + f"[DEBUG] Stored langflow conversation thread for user {user_id} with response_id: {response_id}" ) + logger.debug("Stored langflow conversation thread", user_id=user_id, response_id=response_id) diff --git a/src/api/connectors.py b/src/api/connectors.py index 4365eac1..b7b603f0 100644 --- a/src/api/connectors.py +++ b/src/api/connectors.py @@ -22,6 +22,7 @@ async def connector_sync(request: Request, connector_service, session_manager): connector_type = request.path_params.get("connector_type", "google_drive") data = await request.json() max_files = data.get("max_files") + selected_files = data.get("selected_files") try: logger.debug( @@ -29,10 +30,8 @@ async def connector_sync(request: Request, connector_service, session_manager): connector_type=connector_type, max_files=max_files, ) - user = request.state.user jwt_token = request.state.jwt_token - logger.debug("User authenticated", user_id=user.user_id) # Get all active connections for this connector type and user connections = await connector_service.connection_manager.list_connections( @@ -53,12 +52,20 @@ async def connector_sync(request: Request, connector_service, session_manager): "About to call sync_connector_files for connection", connection_id=connection.connection_id, ) - task_id = await connector_service.sync_connector_files( - connection.connection_id, user.user_id, max_files, jwt_token=jwt_token - ) - task_ids.append(task_id) - logger.debug("Got task ID", task_id=task_id) - + if selected_files: + task_id = await connector_service.sync_specific_files( + connection.connection_id, + user.user_id, + selected_files, + jwt_token=jwt_token, + ) + else: + task_id = await connector_service.sync_connector_files( + connection.connection_id, + user.user_id, + max_files, + jwt_token=jwt_token, + ) return JSONResponse( { "task_ids": task_ids, @@ -70,14 +77,7 @@ async def connector_sync(request: Request, connector_service, session_manager): ) except Exception as e: - import sys - import traceback - - error_msg = f"[ERROR] Connector sync failed: {str(e)}" - logger.error(error_msg) - traceback.print_exc(file=sys.stderr) - sys.stderr.flush() - + logger.error("Connector sync failed", error=str(e)) return JSONResponse({"error": f"Sync failed: {str(e)}"}, status_code=500) @@ -117,6 +117,8 @@ async def connector_status(request: Request, connector_service, session_manager) async def connector_webhook(request: Request, connector_service, session_manager): """Handle webhook notifications from any connector type""" connector_type = request.path_params.get("connector_type") + if connector_type is None: + connector_type = "unknown" # Handle webhook validation (connector-specific) temp_config = {"token_file": "temp.json"} @@ -124,7 +126,7 @@ async def connector_webhook(request: Request, connector_service, session_manager temp_connection = ConnectionConfig( connection_id="temp", - connector_type=connector_type, + connector_type=str(connector_type), name="temp", config=temp_config, ) @@ -194,7 +196,6 @@ async def connector_webhook(request: Request, connector_service, session_manager ) # Process webhook for the specific connection - results = [] try: # Get the connector instance connector = await connector_service._get_connector(connection.connection_id) @@ -268,6 +269,7 @@ async def connector_webhook(request: Request, connector_service, session_manager import traceback traceback.print_exc() + return JSONResponse( { "status": "error", @@ -279,10 +281,59 @@ async def connector_webhook(request: Request, connector_service, session_manager ) except Exception as e: - import traceback - logger.error("Webhook processing failed", error=str(e)) - traceback.print_exc() return JSONResponse( {"error": f"Webhook processing failed: {str(e)}"}, status_code=500 ) + +async def connector_token(request: Request, connector_service, session_manager): + """Get access token for connector API calls (e.g., Google Picker)""" + connector_type = request.path_params.get("connector_type") + connection_id = request.query_params.get("connection_id") + + if not connection_id: + return JSONResponse({"error": "connection_id is required"}, status_code=400) + + user = request.state.user + + try: + # Get the connection and verify it belongs to the user + connection = await connector_service.connection_manager.get_connection(connection_id) + if not connection or connection.user_id != user.user_id: + return JSONResponse({"error": "Connection not found"}, status_code=404) + + # Get the connector instance + connector = await connector_service._get_connector(connection_id) + if not connector: + return JSONResponse({"error": f"Connector not available - authentication may have failed for {connector_type}"}, status_code=404) + + # For Google Drive, get the access token + if connector_type == "google_drive" and hasattr(connector, 'oauth'): + await connector.oauth.load_credentials() + if connector.oauth.creds and connector.oauth.creds.valid: + return JSONResponse({ + "access_token": connector.oauth.creds.token, + "expires_in": (connector.oauth.creds.expiry.timestamp() - + __import__('time').time()) if connector.oauth.creds.expiry else None + }) + else: + return JSONResponse({"error": "Invalid or expired credentials"}, status_code=401) + + # For OneDrive and SharePoint, get the access token + elif connector_type in ["onedrive", "sharepoint"] and hasattr(connector, 'oauth'): + try: + access_token = connector.oauth.get_access_token() + return JSONResponse({ + "access_token": access_token, + "expires_in": None # MSAL handles token expiry internally + }) + except ValueError as e: + return JSONResponse({"error": f"Failed to get access token: {str(e)}"}, status_code=401) + except Exception as e: + return JSONResponse({"error": f"Authentication error: {str(e)}"}, status_code=500) + + return JSONResponse({"error": "Token not available for this connector type"}, status_code=400) + + except Exception as e: + logger.error("Error getting connector token", error=str(e)) + return JSONResponse({"error": str(e)}, status_code=500) diff --git a/src/connectors/base.py b/src/connectors/base.py index d16fe4cf..35c43555 100644 --- a/src/connectors/base.py +++ b/src/connectors/base.py @@ -108,7 +108,7 @@ class BaseConnector(ABC): pass @abstractmethod - async def list_files(self, page_token: Optional[str] = None) -> Dict[str, Any]: + async def list_files(self, page_token: Optional[str] = None, max_files: Optional[int] = None) -> Dict[str, Any]: """List all files. Returns files and next_page_token if any.""" pass diff --git a/src/connectors/google_drive/connector.py b/src/connectors/google_drive/connector.py index aa1b5ef9..887ffeca 100644 --- a/src/connectors/google_drive/connector.py +++ b/src/connectors/google_drive/connector.py @@ -1,585 +1,989 @@ -import asyncio import io import os -import uuid -from datetime import datetime -from typing import Dict, List, Any, Optional -from googleapiclient.discovery import build +from pathlib import Path +import time +from collections import deque +from dataclasses import dataclass +from typing import Dict, List, Any, Optional, Iterable, Set + from googleapiclient.errors import HttpError from googleapiclient.http import MediaIoBaseDownload from utils.logging_config import get_logger logger = get_logger(__name__) +# Project-specific base types (adjust imports to your project) from ..base import BaseConnector, ConnectorDocument, DocumentACL from .oauth import GoogleDriveOAuth -# Global worker service cache for process pools -_worker_drive_service = None - - -def get_worker_drive_service(client_id: str, client_secret: str, token_file: str): - """Get or create a Google Drive service instance for this worker process""" - global _worker_drive_service - if _worker_drive_service is None: - logger.info( - "Initializing Google Drive service in worker process", pid=os.getpid() - ) - - # Create OAuth instance and load credentials in worker - from .oauth import GoogleDriveOAuth - - oauth = GoogleDriveOAuth( - client_id=client_id, client_secret=client_secret, token_file=token_file - ) - - # Load credentials synchronously in worker - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete(oauth.load_credentials()) - _worker_drive_service = oauth.get_service() - logger.info("Google Drive service ready in worker process", pid=os.getpid()) - finally: - loop.close() - - return _worker_drive_service - - -# Module-level functions for process pool execution (must be pickleable) -def _sync_list_files_worker( - client_id, client_secret, token_file, query, page_token, page_size -): - """Worker function for listing files in process pool""" - service = get_worker_drive_service(client_id, client_secret, token_file) - return ( - service.files() - .list( - q=query, - pageSize=page_size, - pageToken=page_token, - fields="nextPageToken, files(id, name, mimeType, modifiedTime, createdTime, webViewLink, permissions, owners)", - ) - .execute() - ) - - -def _sync_get_metadata_worker(client_id, client_secret, token_file, file_id): - """Worker function for getting file metadata in process pool""" - service = get_worker_drive_service(client_id, client_secret, token_file) - return ( - service.files() - .get( - fileId=file_id, - fields="id, name, mimeType, modifiedTime, createdTime, webViewLink, permissions, owners, size", - ) - .execute() - ) - - -def _sync_download_worker( - client_id, client_secret, token_file, file_id, mime_type, file_size=None -): - """Worker function for downloading files in process pool""" - import signal - import time - - # File size limits (in bytes) - MAX_REGULAR_FILE_SIZE = 100 * 1024 * 1024 # 100MB for regular files - MAX_GOOGLE_WORKSPACE_SIZE = ( - 50 * 1024 * 1024 - ) # 50MB for Google Workspace docs (they can't be streamed) - - # Check file size limits - if file_size: - if ( - mime_type.startswith("application/vnd.google-apps.") - and file_size > MAX_GOOGLE_WORKSPACE_SIZE - ): - raise ValueError( - f"Google Workspace file too large: {file_size} bytes (max {MAX_GOOGLE_WORKSPACE_SIZE})" - ) - elif ( - not mime_type.startswith("application/vnd.google-apps.") - and file_size > MAX_REGULAR_FILE_SIZE - ): - raise ValueError( - f"File too large: {file_size} bytes (max {MAX_REGULAR_FILE_SIZE})" - ) - - # Dynamic timeout based on file size (minimum 60s, 10s per MB, max 300s) - if file_size: - file_size_mb = file_size / (1024 * 1024) - timeout_seconds = min(300, max(60, int(file_size_mb * 10))) - else: - timeout_seconds = 60 # Default timeout if size unknown - - # Set a timeout for the entire download operation - def timeout_handler(signum, frame): - raise TimeoutError(f"File download timed out after {timeout_seconds} seconds") - - signal.signal(signal.SIGALRM, timeout_handler) - signal.alarm(timeout_seconds) - - try: - service = get_worker_drive_service(client_id, client_secret, token_file) - - # For Google native formats, export as PDF - if mime_type.startswith("application/vnd.google-apps."): - export_format = "application/pdf" - request = service.files().export_media( - fileId=file_id, mimeType=export_format - ) - else: - # For regular files, download directly - request = service.files().get_media(fileId=file_id) - - # Download file with chunked approach - file_io = io.BytesIO() - downloader = MediaIoBaseDownload( - file_io, request, chunksize=1024 * 1024 - ) # 1MB chunks - - done = False - retry_count = 0 - max_retries = 2 - - while not done and retry_count < max_retries: - try: - status, done = downloader.next_chunk() - retry_count = 0 # Reset retry count on successful chunk - except Exception as e: - retry_count += 1 - if retry_count >= max_retries: - raise e - time.sleep(1) # Brief pause before retry - - return file_io.getvalue() - - finally: - # Cancel the alarm - signal.alarm(0) +# ------------------------- +# Config model +# ------------------------- +@dataclass +class GoogleDriveConfig: + client_id: str + client_secret: str + token_file: str + + # Selective sync + file_ids: Optional[List[str]] = None + folder_ids: Optional[List[str]] = None + recursive: bool = True + + # Shared Drives control + drive_id: Optional[str] = None # when set, we use corpora='drive' + corpora: Optional[str] = None # 'user' | 'drive' | 'domain'; auto-picked if None + + # Optional filtering + include_mime_types: Optional[List[str]] = None + exclude_mime_types: Optional[List[str]] = None + + # Export overrides for Google-native types + export_format_overrides: Optional[Dict[str, str]] = None # mime -> export-mime + + # Changes API state persistence (store these in your DB/kv if needed) + changes_page_token: Optional[str] = None + + # Optional: resource_id for webhook cleanup + resource_id: Optional[str] = None +# ------------------------- +# Connector implementation +# ------------------------- class GoogleDriveConnector(BaseConnector): - """Google Drive connector with OAuth and webhook support""" + """ + Google Drive connector with first-class support for selective sync: + - Sync specific file IDs + - Sync specific folder IDs (optionally recursive) + - Works across My Drive and Shared Drives + - Resolves shortcuts to their targets + - Robust changes page token management - # OAuth environment variables - CLIENT_ID_ENV_VAR = "GOOGLE_OAUTH_CLIENT_ID" - CLIENT_SECRET_ENV_VAR = "GOOGLE_OAUTH_CLIENT_SECRET" + Integration points: + - `BaseConnector` is your project’s base class; minimum methods used here: + * self.emit(doc: ConnectorDocument) -> None (or adapt to your ingestion pipeline) + * self.log/info/warn/error (optional) + - Adjust paths, logging, and error handling to match your project style. + """ + + # Names of env vars that hold your OAuth client creds + CLIENT_ID_ENV_VAR: str = "GOOGLE_OAUTH_CLIENT_ID" + CLIENT_SECRET_ENV_VAR: str = "GOOGLE_OAUTH_CLIENT_SECRET" # Connector metadata CONNECTOR_NAME = "Google Drive" CONNECTOR_DESCRIPTION = "Connect your Google Drive to automatically sync documents" CONNECTOR_ICON = "google-drive" - # Supported file types that can be processed by docling - SUPPORTED_MIMETYPES = { - "application/pdf", - "application/vnd.openxmlformats-officedocument.wordprocessingml.document", # .docx - "application/msword", # .doc - "application/vnd.openxmlformats-officedocument.presentationml.presentation", # .pptx - "application/vnd.ms-powerpoint", # .ppt - "text/plain", - "text/html", - "application/rtf", - # Google Docs native formats - we'll export these - "application/vnd.google-apps.document", # Google Docs -> PDF - "application/vnd.google-apps.presentation", # Google Slides -> PDF - "application/vnd.google-apps.spreadsheet", # Google Sheets -> PDF - } + # Supported alias keys coming from various frontends / pickers + _FILE_ID_ALIASES = ("file_ids", "selected_file_ids", "selected_files") + _FOLDER_ID_ALIASES = ("folder_ids", "selected_folder_ids", "selected_folders") - def __init__(self, config: Dict[str, Any]): - super().__init__(config) + def log(self, message: str) -> None: + print(message) + + def emit(self, doc: ConnectorDocument) -> None: + """ + Emit a ConnectorDocument instance. + Override this method to integrate with your ingestion pipeline. + """ + # If BaseConnector has an emit method, call super().emit(doc) + # Otherwise, implement your custom logic here. + print(f"Emitting document: {doc.id} ({doc.filename})") + + def __init__(self, config: Dict[str, Any]) -> None: + # Read from config OR env (backend env, not NEXT_PUBLIC_*): + env_client_id = os.getenv(self.CLIENT_ID_ENV_VAR) + env_client_secret = os.getenv(self.CLIENT_SECRET_ENV_VAR) + + client_id = config.get("client_id") or env_client_id + client_secret = config.get("client_secret") or env_client_secret + + # Token file default (so callback & workers don’t need to pass it) + token_file = config.get("token_file") or os.getenv("GOOGLE_DRIVE_TOKEN_FILE") + if not token_file: + token_file = str(Path.home() / ".config" / "openrag" / "google_drive" / "token.json") + Path(token_file).parent.mkdir(parents=True, exist_ok=True) + + if not isinstance(client_id, str) or not client_id.strip(): + raise RuntimeError( + f"Missing Google Drive OAuth client_id. " + f"Provide config['client_id'] or set {self.CLIENT_ID_ENV_VAR}." + ) + if not isinstance(client_secret, str) or not client_secret.strip(): + raise RuntimeError( + f"Missing Google Drive OAuth client_secret. " + f"Provide config['client_secret'] or set {self.CLIENT_SECRET_ENV_VAR}." + ) + + # Normalize incoming IDs from any of the supported alias keys + def _first_present_list(cfg: Dict[str, Any], keys: Iterable[str]) -> Optional[List[str]]: + for k in keys: + v = cfg.get(k) + if v: # accept non-empty list + return list(v) + return None + + normalized_file_ids = _first_present_list(config, self._FILE_ID_ALIASES) + normalized_folder_ids = _first_present_list(config, self._FOLDER_ID_ALIASES) + + self.cfg = GoogleDriveConfig( + client_id=client_id, + client_secret=client_secret, + token_file=token_file, + # Accept "selected_files" and "selected_folders" used by the Drive Picker flow + file_ids=normalized_file_ids, + folder_ids=normalized_folder_ids, + recursive=bool(config.get("recursive", True)), + drive_id=config.get("drive_id"), + corpora=config.get("corpora"), + include_mime_types=config.get("include_mime_types"), + exclude_mime_types=config.get("exclude_mime_types"), + export_format_overrides=config.get("export_format_overrides"), + changes_page_token=config.get("changes_page_token"), + resource_id=config.get("resource_id"), + ) + + # Build OAuth wrapper; DO NOT load creds here (it's async) self.oauth = GoogleDriveOAuth( - client_id=self.get_client_id(), - client_secret=self.get_client_secret(), - token_file=config.get("token_file", "gdrive_token.json"), + client_id=self.cfg.client_id, + client_secret=self.cfg.client_secret, + token_file=self.cfg.token_file, ) - self.service = None - # Load existing webhook channel ID from config if available - self.webhook_channel_id = config.get("webhook_channel_id") or config.get( - "subscription_id" - ) - # Load existing webhook resource ID (Google Drive requires this to stop a channel) - self.webhook_resource_id = config.get("resource_id") + # Drive client is built in authenticate() + from google.oauth2.credentials import Credentials + self.creds: Optional[Credentials] = None + self.service: Any = None + + # cache of resolved shortcutId -> target file metadata + self._shortcut_cache: Dict[str, Dict[str, Any]] = {} + + # Authentication state + self._authenticated: bool = False + + # ------------------------- + # Helpers + # ------------------------- + @property + def _drives_get_flags(self) -> Dict[str, Any]: + """ + Flags valid for GET-like calls (files.get, changes.getStartPageToken). + """ + return {"supportsAllDrives": True} + + @property + def _drives_list_flags(self) -> Dict[str, Any]: + """ + Flags valid for LIST-like calls (files.list, changes.list). + """ + return {"supportsAllDrives": True, "includeItemsFromAllDrives": True} + + def _pick_corpora_args(self) -> Dict[str, Any]: + """ + Decide corpora/driveId based on config. + + If drive_id is provided, prefer corpora='drive' with that driveId. + Otherwise, default to allDrives (so Shared Drive selections from the Picker still work). + """ + if self.cfg.drive_id: + return {"corpora": "drive", "driveId": self.cfg.drive_id} + if self.cfg.corpora: + return {"corpora": self.cfg.corpora} + # Default to allDrives so Picker selections from Shared Drives work without explicit drive_id + return {"corpora": "allDrives"} + + def _resolve_shortcut(self, file_obj: Dict[str, Any]) -> Dict[str, Any]: + """ + If a file is a shortcut, fetch and return the real target metadata. + """ + if file_obj.get("mimeType") != "application/vnd.google-apps.shortcut": + return file_obj + + target_id = file_obj.get("shortcutDetails", {}).get("targetId") + if not target_id: + return file_obj + + if target_id in self._shortcut_cache: + return self._shortcut_cache[target_id] + + try: + meta = ( + self.service.files() + .get( + fileId=target_id, + fields=( + "id, name, mimeType, modifiedTime, createdTime, size, " + "webViewLink, parents, owners, driveId" + ), + **self._drives_flags, + ) + .execute() + ) + self._shortcut_cache[target_id] = meta + return meta + except HttpError: + # shortcut target not accessible + return file_obj + + def _list_children(self, folder_id: str) -> List[Dict[str, Any]]: + """ + List immediate children of a folder. + """ + query = f"'{folder_id}' in parents and trashed = false" + page_token = None + results: List[Dict[str, Any]] = [] + + while True: + resp = ( + self.service.files() + .list( + q=query, + pageSize=1000, + pageToken=page_token, + fields=( + "nextPageToken, files(" + "id, name, mimeType, modifiedTime, createdTime, size, " + "webViewLink, parents, shortcutDetails, driveId)" + ), + **self._drives_list_flags, + **self._pick_corpora_args(), + ) + .execute() + ) + for f in resp.get("files", []): + results.append(f) + page_token = resp.get("nextPageToken") + if not page_token: + break + + return results + + def _bfs_expand_folders(self, folder_ids: Iterable[str]) -> List[Dict[str, Any]]: + """ + Breadth-first traversal to expand folders to all descendant files (if recursive), + or just immediate children (if not recursive). Folders themselves are returned + as items too, but filtered later. + """ + out: List[Dict[str, Any]] = [] + queue = deque(folder_ids) + + while queue: + fid = queue.popleft() + children = self._list_children(fid) + out.extend(children) + + if self.cfg.recursive: + # Enqueue subfolders + for c in children: + c = self._resolve_shortcut(c) + if c.get("mimeType") == "application/vnd.google-apps.folder": + queue.append(c["id"]) + + return out + + def _get_file_meta_by_id(self, file_id: str) -> Optional[Dict[str, Any]]: + """ + Fetch metadata for a file by ID (resolving shortcuts). + """ + if self.service is None: + raise RuntimeError("Google Drive service is not initialized. Please authenticate first.") + try: + meta = ( + self.service.files() + .get( + fileId=file_id, + fields=( + "id, name, mimeType, modifiedTime, createdTime, size, " + "webViewLink, parents, shortcutDetails, driveId" + ), + **self._drives_get_flags, + ) + .execute() + ) + return self._resolve_shortcut(meta) + except HttpError: + return None + + def _filter_by_mime(self, items: Iterable[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Apply include/exclude mime filters if configured. + """ + include = set(self.cfg.include_mime_types or []) + exclude = set(self.cfg.exclude_mime_types or []) + + def keep(m: Dict[str, Any]) -> bool: + mt = m.get("mimeType") + if exclude and mt in exclude: + return False + if include and mt not in include: + return False + return True + + return [m for m in items if keep(m)] + + def _iter_selected_items(self) -> List[Dict[str, Any]]: + """ + Return a de-duplicated list of file metadata for the selected scope: + - explicit file_ids + - items inside folder_ids (with optional recursion) + Shortcuts are resolved to their targets automatically. + """ + seen: Set[str] = set() + items: List[Dict[str, Any]] = [] + + # Explicit files + if self.cfg.file_ids: + for fid in self.cfg.file_ids: + meta = self._get_file_meta_by_id(fid) + if meta and meta["id"] not in seen: + seen.add(meta["id"]) + items.append(meta) + + # Folders + if self.cfg.folder_ids: + folder_children = self._bfs_expand_folders(self.cfg.folder_ids) + for meta in folder_children: + meta = self._resolve_shortcut(meta) + if meta.get("id") in seen: + continue + seen.add(meta["id"]) + items.append(meta) + + # If neither file_ids nor folder_ids are set, you could: + # - return [] to force explicit selection + # - OR default to entire drive. + # Here we choose to require explicit selection: + if not self.cfg.file_ids and not self.cfg.folder_ids: + return [] + + items = self._filter_by_mime(items) + # Exclude folders from final emits: + items = [m for m in items if m.get("mimeType") != "application/vnd.google-apps.folder"] + return items + + # ------------------------- + # Download logic + # ------------------------- + def _pick_export_mime(self, source_mime: str) -> Optional[str]: + """ + Choose export mime for Google-native docs if needed. + """ + overrides = self.cfg.export_format_overrides or {} + if source_mime == "application/vnd.google-apps.document": + return overrides.get( + source_mime, + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + ) + if source_mime == "application/vnd.google-apps.spreadsheet": + return overrides.get( + source_mime, + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + ) + if source_mime == "application/vnd.google-apps.presentation": + return overrides.get( + source_mime, + "application/vnd.openxmlformats-officedocument.presentationml.presentation", + ) + # Return None for non-Google-native or unsupported types + return overrides.get(source_mime) + + def _download_file_bytes(self, file_meta: Dict[str, Any]) -> bytes: + """ + Download bytes for a given file (exporting if Google-native). + """ + file_id = file_meta["id"] + mime_type = file_meta.get("mimeType") or "" + + # Google-native: export + export_mime = self._pick_export_mime(mime_type) + if mime_type.startswith("application/vnd.google-apps."): + # default fallback if not overridden + if not export_mime: + export_mime = "application/pdf" + # NOTE: export_media does not accept supportsAllDrives/includeItemsFromAllDrives + request = self.service.files().export_media(fileId=file_id, mimeType=export_mime) + else: + # Binary download (get_media also doesn't accept the Drive flags) + request = self.service.files().get_media(fileId=file_id) + + fh = io.BytesIO() + downloader = MediaIoBaseDownload(fh, request, chunksize=1024 * 1024) + done = False + while not done: + status, done = downloader.next_chunk() + # Optional: you can log progress via status.progress() + + return fh.getvalue() + + # ------------------------- + # Public sync surface + # ------------------------- + # ---- Required by BaseConnector: start OAuth flow async def authenticate(self) -> bool: - """Authenticate with Google Drive""" + """ + Ensure we have valid Google Drive credentials and an authenticated service. + Returns True if ready to use; False otherwise. + """ try: - if await self.oauth.is_authenticated(): - self.service = self.oauth.get_service() - self._authenticated = True - return True - return False + # Load/refresh creds from token file (async) + self.creds = await self.oauth.load_credentials() + + # If still not authenticated, bail (caller should kick off OAuth init) + if not await self.oauth.is_authenticated(): + self.log("authenticate: no valid credentials; run OAuth init/callback first.") + return False + + # Build Drive service from OAuth helper + self.service = self.oauth.get_service() + + # Optional sanity check (small, fast request) + _ = self.service.files().get(fileId="root", fields="id").execute() + self._authenticated = True + return True + except Exception as e: - logger.error("Authentication failed", error=str(e)) + self._authenticated = False + logger.error(f"GoogleDriveConnector.authenticate failed: {e}") return False - async def setup_subscription(self) -> str: - """Set up Google Drive push notifications""" - if not self._authenticated: - raise ValueError("Not authenticated") + async def list_files( + self, + page_token: Optional[str] = None, + max_files: Optional[int] = None, + **kwargs + ) -> Dict[str, Any]: + """ + List files in the currently selected scope (file_ids/folder_ids/recursive). + Returns a dict with 'files' and 'next_page_token'. - # Generate unique channel ID - channel_id = str(uuid.uuid4()) + Since we pre-compute the selected set, pagination is simulated: + - If page_token is None: return all files in one batch. + - Otherwise: return {} and no next_page_token. + """ + try: + items = self._iter_selected_items() - # Set up push notification - # Note: This requires a publicly accessible webhook endpoint - webhook_url = self.config.get("webhook_url") - if not webhook_url: - raise ValueError("webhook_url required in config for subscriptions") + # Optionally honor a request-scoped max_files (e.g., from your API payload) + if isinstance(max_files, int) and max_files > 0: + items = items[:max_files] + + # Simplest: ignore page_token and just dump all + # If you want real pagination, slice items here + if page_token: + return {"files": [], "next_page_token": None} + + return { + "files": items, + "next_page_token": None, # no more pages + } + except Exception as e: + # Optionally log error with your base class logger + try: + self.log(f"GoogleDriveConnector.list_files failed: {e}") + except Exception: + pass + return {"files": [], "next_page_token": None} + + async def get_file_content(self, file_id: str) -> ConnectorDocument: + """ + Fetch a file's metadata and content from Google Drive and wrap it in a ConnectorDocument. + """ + meta = self._get_file_meta_by_id(file_id) + if not meta: + raise FileNotFoundError(f"Google Drive file not found: {file_id}") try: + blob = self._download_file_bytes(meta) + except Exception as e: + # Use your base class logger if available + try: + self.log(f"Download failed for {file_id}: {e}") + except Exception: + pass + raise + + from datetime import datetime + + def parse_datetime(dt_str): + if not dt_str: + return None + try: + # Google Drive returns RFC3339 format + return datetime.strptime(dt_str, "%Y-%m-%dT%H:%M:%S.%fZ") + except ValueError: + try: + return datetime.strptime(dt_str, "%Y-%m-%dT%H:%M:%SZ") + except ValueError: + return None + + doc = ConnectorDocument( + id=meta["id"], + filename=meta.get("name", ""), + source_url=meta.get("webViewLink", ""), + created_time=parse_datetime(meta.get("createdTime")), + modified_time=parse_datetime(meta.get("modifiedTime")), + mimetype=str(meta.get("mimeType", "")), + acl=DocumentACL(), # TODO: map Google Drive permissions if you want ACLs + content=blob, + metadata={ + "parents": meta.get("parents"), + "driveId": meta.get("driveId"), + "size": int(meta.get("size", 0)) if str(meta.get("size", "")).isdigit() else None, + }, + ) + return doc + + async def setup_subscription(self) -> str: + """ + Start a Google Drive Changes API watch (webhook). + Returns the channel ID (subscription ID) as a string. + + Requires a webhook URL to be configured. This implementation looks for: + 1) self.cfg.webhook_address (preferred if you have it in your config dataclass) + 2) os.environ["GOOGLE_DRIVE_WEBHOOK_URL"] + """ + import os + + # 1) Ensure we are authenticated and have a live Drive service + ok = await self.authenticate() + if not ok: + raise RuntimeError("GoogleDriveConnector.setup_subscription: not authenticated") + + # 2) Resolve webhook address (no param in ABC, so pull from config/env) + webhook_address = getattr(self.cfg, "webhook_address", None) or os.getenv("GOOGLE_DRIVE_WEBHOOK_URL") + if not webhook_address: + raise RuntimeError( + "GoogleDriveConnector.setup_subscription: webhook URL not configured. " + "Set cfg.webhook_address or GOOGLE_DRIVE_WEBHOOK_URL." + ) + + # 3) Ensure we have a starting page token (checkpoint) + try: + if not self.cfg.changes_page_token: + self.cfg.changes_page_token = self.get_start_page_token() + except Exception as e: + # Optional: use your base logger + try: + self.log(f"Failed to get start page token: {e}") + except Exception: + pass + raise + + # 4) Start the watch on the current token + try: + # Build a simple watch body; customize id if you want a stable deterministic value body = { - "id": channel_id, + "id": f"drive-channel-{int(time.time())}", # subscription (channel) ID to return "type": "web_hook", - "address": webhook_url, - "payload": True, - "expiration": str( - int((datetime.now().timestamp() + 86400) * 1000) - ), # 24 hours + "address": webhook_address, } + # Shared Drives flags so we see everything we’re scoped to + flags = dict(supportsAllDrives=True) + result = ( self.service.changes() - .watch(pageToken=self._get_start_page_token(), body=body) + .watch(pageToken=self.cfg.changes_page_token, body=body, **flags) .execute() ) - self.webhook_channel_id = channel_id - # Persist the resourceId returned by Google to allow proper cleanup - try: - self.webhook_resource_id = result.get("resourceId") - except Exception: - self.webhook_resource_id = None + # Example fields: id, resourceId, expiration, kind + channel_id = result.get("id") + resource_id = result.get("resourceId") + expiration = result.get("expiration") + + # Persist in-memory so cleanup can stop this channel later. + # If your project has a persistence layer, save these values there. + self._active_channel = { + "channel_id": channel_id, + "resource_id": resource_id, + "expiration": expiration, + "webhook_address": webhook_address, + "page_token": self.cfg.changes_page_token, + } + + if not isinstance(channel_id, str) or not channel_id: + raise RuntimeError(f"Drive watch returned invalid channel id: {channel_id!r}") + return channel_id - except HttpError as e: - logger.error("Failed to set up subscription", error=str(e)) + except Exception as e: + try: + logger.error(f"GoogleDriveConnector.setup_subscription failed: {e}") + except Exception: + pass raise - def _get_start_page_token(self) -> str: - """Get the current page token for change notifications""" - return self.service.changes().getStartPageToken().execute()["startPageToken"] - - async def list_files( - self, page_token: Optional[str] = None, limit: Optional[int] = None - ) -> Dict[str, Any]: - """List all supported files in Google Drive. - - Uses a thread pool (not the shared process pool) to avoid issues with - Google API clients in forked processes and adds light retries for - transient BrokenPipe/connection errors. + async def cleanup_subscription(self, subscription_id: str) -> bool: """ - if not self._authenticated: - raise ValueError("Not authenticated") + Stop an active Google Drive Changes API watch (webhook) channel. - # Build query for supported file types - mimetype_query = " or ".join( - [f"mimeType='{mt}'" for mt in self.SUPPORTED_MIMETYPES] - ) - query = f"({mimetype_query}) and trashed=false" + Google requires BOTH the channel id (subscription_id) AND its resource_id. + We try to retrieve resource_id from: + 1) self._active_channel (single-channel use) + 2) self._subscriptions[subscription_id] (multi-channel use, if present) + 3) self.cfg.resource_id (as a last-resort override provided by caller/config) - # Use provided limit or default to 100, max 1000 (Google Drive API limit) - page_size = min(limit or 100, 1000) + Returns: + bool: True if the stop call succeeded, otherwise False. + """ + # 1) Ensure auth/service + ok = await self.authenticate() + if not ok: + try: + self.log("cleanup_subscription: not authenticated") + except Exception: + pass + return False - def _sync_list_files_inner(): - import time + # 2) Resolve resource_id + resource_id = None - attempts = 0 - max_attempts = 3 - backoff = 1.0 - while True: - try: - return ( - self.service.files() - .list( - q=query, - pageSize=page_size, - pageToken=page_token, - fields="nextPageToken, files(id, name, mimeType, modifiedTime, createdTime, webViewLink, permissions, owners)", - ) - .execute() - ) - except Exception as e: - attempts += 1 - is_broken_pipe = isinstance(e, BrokenPipeError) or ( - isinstance(e, OSError) and getattr(e, "errno", None) == 32 - ) - if attempts < max_attempts and is_broken_pipe: - time.sleep(backoff) - backoff = min(4.0, backoff * 2) - continue - raise + # Single-channel memory + if getattr(self, "_active_channel", None): + ch = getattr(self, "_active_channel") + if isinstance(ch, dict) and ch.get("channel_id") == subscription_id: + resource_id = ch.get("resource_id") + + # Multi-channel memory + if resource_id is None and hasattr(self, "_subscriptions"): + subs = getattr(self, "_subscriptions") + if isinstance(subs, dict): + entry = subs.get(subscription_id) + if isinstance(entry, dict): + resource_id = entry.get("resource_id") + + # Config override (optional) + if resource_id is None and getattr(self.cfg, "resource_id", None): + resource_id = self.cfg.resource_id + + if not resource_id: + try: + self.log( + f"cleanup_subscription: missing resource_id for channel {subscription_id}. " + f"Persist (channel_id, resource_id) when creating the subscription." + ) + except Exception: + pass + return False try: - # Offload blocking HTTP call to default ThreadPoolExecutor - import asyncio + self.service.channels().stop(body={"id": subscription_id, "resourceId": resource_id}).execute() - loop = asyncio.get_event_loop() - results = await loop.run_in_executor(None, _sync_list_files_inner) + # 4) Clear local bookkeeping + if getattr(self, "_active_channel", None) and self._active_channel.get("channel_id") == subscription_id: + self._active_channel = {} - files = [] - for file in results.get("files", []): - files.append( - { - "id": file["id"], - "name": file["name"], - "mimeType": file["mimeType"], - "modifiedTime": file["modifiedTime"], - "createdTime": file["createdTime"], - "webViewLink": file["webViewLink"], - "permissions": file.get("permissions", []), - "owners": file.get("owners", []), - } + if hasattr(self, "_subscriptions") and isinstance(self._subscriptions, dict): + self._subscriptions.pop(subscription_id, None) + + return True + + except Exception as e: + try: + self.log(f"cleanup_subscription failed for {subscription_id}: {e}") + except Exception: + pass + return False + + async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]: + """ + Process a Google Drive Changes webhook. + Drive push notifications do NOT include the changed files themselves; they merely tell us + "there are changes". We must pull them using the Changes API with our saved page token. + + Args: + payload: Arbitrary dict your framework passes. We *may* log/use headers like + X-Goog-Resource-State / X-Goog-Message-Number if present, but we don't rely on them. + + Returns: + List[str]: unique list of affected file IDs (filtered to our selected scope). + """ + affected: List[str] = [] + try: + # 1) Ensure we're authenticated / service ready + ok = await self.authenticate() + if not ok: + try: + self.log("handle_webhook: not authenticated") + except Exception: + pass + return affected + + # 2) Establish/restore our checkpoint page token + page_token = self.cfg.changes_page_token + if not page_token: + # First time / missing state: initialize + page_token = self.get_start_page_token() + self.cfg.changes_page_token = page_token + + # 3) Build current selected scope to filter changes + # (file_ids + expanded folder descendants) + try: + selected_items = self._iter_selected_items() + selected_ids = {m["id"] for m in selected_items} + except Exception as e: + selected_ids = set() + try: + self.log(f"handle_webhook: scope build failed, proceeding unfiltered: {e}") + except Exception: + pass + + # 4) Pull changes until nextPageToken is exhausted, then advance to newStartPageToken + while True: + resp = ( + self.service.changes() + .list( + pageToken=page_token, + fields=( + "nextPageToken, newStartPageToken, " + "changes(fileId, file(id, name, mimeType, trashed, parents, " + "shortcutDetails, driveId, modifiedTime, webViewLink))" + ), + supportsAllDrives=True, + includeItemsFromAllDrives=True, + ) + .execute() ) - return {"files": files, "nextPageToken": results.get("nextPageToken")} + for ch in resp.get("changes", []): + fid = ch.get("fileId") + fobj = ch.get("file") or {} - except HttpError as e: - logger.error("Failed to list files", error=str(e)) - raise + # Skip if no file or explicitly trashed (you can choose to still return these IDs) + if not fid or fobj.get("trashed"): + # If you want to *include* deletions, collect fid here instead of skipping. + continue - async def get_file_content(self, file_id: str) -> ConnectorDocument: - """Get file content and metadata""" - if not self._authenticated: - raise ValueError("Not authenticated") + # Resolve shortcuts to target + resolved = self._resolve_shortcut(fobj) + rid = resolved.get("id", fid) - try: - # Get file metadata (run in thread pool to avoid blocking) - import asyncio + # Filter to our selected scope if we have one; otherwise accept all + if selected_ids and (rid not in selected_ids): + # Shortcut target might be in scope even if the shortcut isn't + tgt = fobj.get("shortcutDetails", {}).get("targetId") if fobj else None + if not (tgt and tgt in selected_ids): + continue - loop = asyncio.get_event_loop() + affected.append(rid) - # Use the same process pool as docling processing - from utils.process_pool import process_pool + # Handle pagination of the changes feed + next_token = resp.get("nextPageToken") + if next_token: + page_token = next_token + continue - file_metadata = await loop.run_in_executor( - process_pool, - _sync_get_metadata_worker, - self.oauth.client_id, - self.oauth.client_secret, - self.oauth.token_file, - file_id, - ) + # No nextPageToken: checkpoint with newStartPageToken + new_start = resp.get("newStartPageToken") + if new_start: + self.cfg.changes_page_token = new_start + else: + # Fallback: keep the last consumed token if API didn't return newStartPageToken + self.cfg.changes_page_token = page_token + break - # Download file content (pass file size for timeout calculation) - file_size = file_metadata.get("size") - if file_size: - file_size = int(file_size) # Ensure it's an integer - content = await self._download_file_content( - file_id, file_metadata["mimeType"], file_size - ) + # Deduplicate while preserving order + seen = set() + deduped: List[str] = [] + for x in affected: + if x not in seen: + seen.add(x) + deduped.append(x) + return deduped - # Extract ACL information - acl = self._extract_acl(file_metadata) + except Exception as e: + try: + self.log(f"handle_webhook failed: {e}") + except Exception: + pass + return [] - return ConnectorDocument( - id=file_id, - filename=file_metadata["name"], - mimetype=file_metadata["mimeType"], - content=content, - source_url=file_metadata["webViewLink"], - acl=acl, - modified_time=datetime.fromisoformat( - file_metadata["modifiedTime"].replace("Z", "+00:00") - ).replace(tzinfo=None), - created_time=datetime.fromisoformat( - file_metadata["createdTime"].replace("Z", "+00:00") - ).replace(tzinfo=None), + def sync_once(self) -> None: + """ + Perform a one-shot sync of the currently selected scope and emit documents. + + Emits ConnectorDocument instances (adapt to your BaseConnector ingestion). + """ + items = self._iter_selected_items() + for meta in items: + try: + blob = self._download_file_bytes(meta) + except HttpError as e: + # Skip/record failures + self.log(f"Failed to download {meta.get('name')} ({meta.get('id')}): {e}") + continue + + from datetime import datetime + + def parse_datetime(dt_str): + if not dt_str: + return None + try: + # Google Drive returns RFC3339 format + return datetime.strptime(dt_str, "%Y-%m-%dT%H:%M:%S.%fZ") + except ValueError: + try: + return datetime.strptime(dt_str, "%Y-%m-%dT%H:%M:%SZ") + except ValueError: + return None + + doc = ConnectorDocument( + id=meta["id"], + filename=meta.get("name", ""), + source_url=meta.get("webViewLink", ""), + created_time=parse_datetime(meta.get("createdTime")), + modified_time=parse_datetime(meta.get("modifiedTime")), + mimetype=str(meta.get("mimeType", "")), + acl=DocumentACL(), # TODO: set appropriate ACL instance or value metadata={ - "size": file_metadata.get("size"), - "owners": file_metadata.get("owners", []), + "name": meta.get("name"), + "webViewLink": meta.get("webViewLink"), + "parents": meta.get("parents"), + "driveId": meta.get("driveId"), + "size": int(meta.get("size", 0)) if str(meta.get("size", "")).isdigit() else None, }, + content=blob, ) + self.emit(doc) - except HttpError as e: - logger.error("Failed to get file content", error=str(e)) - raise + # ------------------------- + # Changes API (polling or webhook-backed) + # ------------------------- + def get_start_page_token(self) -> str: + # getStartPageToken accepts supportsAllDrives (not includeItemsFromAllDrives) + resp = self.service.changes().getStartPageToken(**self._drives_get_flags).execute() + return resp["startPageToken"] - async def _download_file_content( - self, file_id: str, mime_type: str, file_size: int = None - ) -> bytes: - """Download file content, converting Google Docs formats if needed""" + def poll_changes_and_sync(self) -> Optional[str]: + """ + Incrementally process changes since the last page token in cfg.changes_page_token. - # Download file (run in process pool to avoid blocking) - import asyncio + Returns the new page token you should persist (or None if unchanged). + """ + page_token = self.cfg.changes_page_token or self.get_start_page_token() - loop = asyncio.get_event_loop() - - # Use the same process pool as docling processing - from utils.process_pool import process_pool - - return await loop.run_in_executor( - process_pool, - _sync_download_worker, - self.oauth.client_id, - self.oauth.client_secret, - self.oauth.token_file, - file_id, - mime_type, - file_size, - ) - - def _extract_acl(self, file_metadata: Dict[str, Any]) -> DocumentACL: - """Extract ACL information from file metadata""" - user_permissions = {} - group_permissions = {} - - owner = None - if file_metadata.get("owners"): - owner = file_metadata["owners"][0].get("emailAddress") - - # Process permissions - for perm in file_metadata.get("permissions", []): - email = perm.get("emailAddress") - role = perm.get("role", "reader") - perm_type = perm.get("type") - - if perm_type == "user" and email: - user_permissions[email] = role - elif perm_type == "group" and email: - group_permissions[email] = role - elif perm_type == "domain": - # Domain-wide permissions - could be treated as a group - domain = perm.get("domain", "unknown-domain") - group_permissions[f"domain:{domain}"] = role - - return DocumentACL( - owner=owner, - user_permissions=user_permissions, - group_permissions=group_permissions, - ) - - def extract_webhook_channel_id( - self, payload: Dict[str, Any], headers: Dict[str, str] - ) -> Optional[str]: - """Extract Google Drive channel ID from webhook headers""" - return headers.get("x-goog-channel-id") - - def extract_webhook_resource_id(self, headers: Dict[str, str]) -> Optional[str]: - """Extract Google Drive resource ID from webhook headers""" - return headers.get("x-goog-resource-id") - - async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]: - """Handle Google Drive webhook notification""" - if not self._authenticated: - raise ValueError("Not authenticated") - - # Google Drive sends headers with the important info - headers = payload.get("_headers", {}) - - # Extract Google Drive specific headers - channel_id = headers.get("x-goog-channel-id") - resource_state = headers.get("x-goog-resource-state") - - if not channel_id: - logger.warning("No channel ID found in Google Drive webhook") - return [] - - # Check if this webhook belongs to this connection - if self.webhook_channel_id != channel_id: - logger.warning( - "Channel ID mismatch", - expected=self.webhook_channel_id, - received=channel_id, - ) - return [] - - # Only process certain states (ignore 'sync' which is just a ping) - if resource_state not in ["exists", "not_exists", "change"]: - logger.debug("Ignoring resource state", state=resource_state) - return [] - - try: - # Extract page token from the resource URI if available - page_token = None - headers = payload.get("_headers", {}) - resource_uri = headers.get("x-goog-resource-uri") - - if resource_uri and "pageToken=" in resource_uri: - # Extract page token from URI like: - # https://www.googleapis.com/drive/v3/changes?alt=json&pageToken=4337807 - import urllib.parse - - parsed = urllib.parse.urlparse(resource_uri) - query_params = urllib.parse.parse_qs(parsed.query) - page_token = query_params.get("pageToken", [None])[0] - - if not page_token: - logger.warning("No page token found, cannot identify specific changes") - return [] - - logger.info("Getting changes since page token", page_token=page_token) - - # Get list of changes since the page token - changes = ( + while True: + resp = ( self.service.changes() .list( pageToken=page_token, - fields="changes(fileId, file(id, name, mimeType, trashed, parents))", + fields=( + "nextPageToken, newStartPageToken, " + "changes(fileId, file(id, name, mimeType, trashed, parents, " + "shortcutDetails, driveId, modifiedTime, webViewLink))" + ), + **self._drives_list_flags, ) .execute() ) - affected_files = [] - for change in changes.get("changes", []): - file_info = change.get("file", {}) - file_id = change.get("fileId") + changes = resp.get("changes", []) - if not file_id: + # Filter to our selected scope (files and folder descendants): + selected_ids = {m["id"] for m in self._iter_selected_items()} + for ch in changes: + fid = ch.get("fileId") + file_obj = ch.get("file") or {} + if not fid or not file_obj or file_obj.get("trashed"): continue - # Only include supported file types that aren't trashed - mime_type = file_info.get("mimeType", "") - is_trashed = file_info.get("trashed", False) + # Match scope + if fid not in selected_ids: + # also consider shortcut target + if file_obj.get("mimeType") == "application/vnd.google-apps.shortcut": + tgt = file_obj.get("shortcutDetails", {}).get("targetId") + if tgt and tgt in selected_ids: + pass + else: + continue - if not is_trashed and mime_type in self.SUPPORTED_MIMETYPES: - logger.info( - "File changed", - filename=file_info.get("name", "Unknown"), - file_id=file_id, - ) - affected_files.append(file_id) - elif is_trashed: - logger.info( - "File deleted/trashed", - filename=file_info.get("name", "Unknown"), - file_id=file_id, - ) - # TODO: Handle file deletion (remove from index) - else: - logger.debug("Ignoring unsupported file type", mime_type=mime_type) + # Download and emit the updated file + resolved = self._resolve_shortcut(file_obj) + try: + blob = self._download_file_bytes(resolved) + except HttpError: + continue - logger.info("Found affected supported files", count=len(affected_files)) - return affected_files + from datetime import datetime - except HttpError as e: - logger.error("Failed to handle webhook", error=str(e)) - return [] + def parse_datetime(dt_str): + if not dt_str: + return None + try: + return datetime.strptime(dt_str, "%Y-%m-%dT%H:%M:%S.%fZ") + except ValueError: + try: + return datetime.strptime(dt_str, "%Y-%m-%dT%H:%M:%SZ") + except ValueError: + return None - async def cleanup_subscription(self, subscription_id: str) -> bool: - """Clean up Google Drive subscription for this connection. - - Uses the stored resource_id captured during subscription setup. - """ - if not self._authenticated: - return False - - try: - # Google Channels API requires both 'id' (channel) and 'resourceId' - if not self.webhook_resource_id: - raise ValueError( - "Missing resource_id for cleanup; ensure subscription state is persisted" + doc = ConnectorDocument( + id=resolved["id"], + filename=resolved.get("name", ""), + source_url=resolved.get("webViewLink", ""), + created_time=parse_datetime(resolved.get("createdTime")), + modified_time=parse_datetime(resolved.get("modifiedTime")), + mimetype=str(resolved.get("mimeType", "")), + acl=DocumentACL(), # Set appropriate ACL if needed + metadata={"parents": resolved.get("parents"), "driveId": resolved.get("driveId")}, + content=blob, ) - body = {"id": subscription_id, "resourceId": self.webhook_resource_id} + self.emit(doc) - self.service.channels().stop(body=body).execute() + new_page_token = resp.get("nextPageToken") + if new_page_token: + page_token = new_page_token + continue + + # No nextPageToken: advance to newStartPageToken (checkpoint) + new_start = resp.get("newStartPageToken") + if new_start: + self.cfg.changes_page_token = new_start + return new_start + + # Should not happen often + return page_token + + # ------------------------- + # Optional: webhook stubs + # ------------------------- + def build_watch_body(self, webhook_address: str, channel_id: Optional[str] = None) -> Dict[str, Any]: + """ + Prepare the request body for changes.watch if you use webhooks. + """ + return { + "id": channel_id or f"drive-channel-{int(time.time())}", + "type": "web_hook", + "address": webhook_address, + } + + def start_watch(self, webhook_address: str) -> Dict[str, Any]: + """ + Start a webhook watch on changes using the current page token. + Persist the returned resourceId/expiration on your side. + """ + page_token = self.cfg.changes_page_token or self.get_start_page_token() + body = self.build_watch_body(webhook_address) + result = ( + self.service.changes() + .watch(pageToken=page_token, body=body, **self._drives_flags) + .execute() + ) + return result + + def stop_watch(self, channel_id: str, resource_id: str) -> bool: + """ + Stop a previously started webhook watch. + """ + try: + self.service.channels().stop(body={"id": channel_id, "resourceId": resource_id}).execute() return True + except HttpError as e: logger.error("Failed to cleanup subscription", error=str(e)) + return False diff --git a/src/connectors/google_drive/oauth.py b/src/connectors/google_drive/oauth.py index 1c33079f..f23e4796 100644 --- a/src/connectors/google_drive/oauth.py +++ b/src/connectors/google_drive/oauth.py @@ -1,7 +1,6 @@ import os import json -import asyncio -from typing import Dict, Any, Optional +from typing import Optional from google.auth.transport.requests import Request from google.oauth2.credentials import Credentials from google_auth_oauthlib.flow import Flow @@ -25,8 +24,8 @@ class GoogleDriveOAuth: def __init__( self, - client_id: str = None, - client_secret: str = None, + client_id: Optional[str] = None, + client_secret: Optional[str] = None, token_file: str = "token.json", ): self.client_id = client_id @@ -133,7 +132,7 @@ class GoogleDriveOAuth: if not self.creds: await self.load_credentials() - return self.creds and self.creds.valid + return bool(self.creds and self.creds.valid) def get_service(self): """Get authenticated Google Drive service""" diff --git a/src/connectors/service.py b/src/connectors/service.py index 8461d043..01a41519 100644 --- a/src/connectors/service.py +++ b/src/connectors/service.py @@ -1,4 +1,3 @@ -import asyncio import tempfile import os from typing import Dict, Any, List, Optional @@ -12,6 +11,8 @@ from .sharepoint import SharePointConnector from .onedrive import OneDriveConnector from .connection_manager import ConnectionManager +logger = get_logger(__name__) + class ConnectorService: """Service to manage document connectors and process files""" @@ -267,9 +268,6 @@ class ConnectorService: page_token = file_list.get("nextPageToken") - if not files_to_process: - raise ValueError("No files found to sync") - # Get user information user = self.session_manager.get_user(user_id) if self.session_manager else None owner_name = user.name if user else None diff --git a/src/main.py b/src/main.py index 29dea7ca..a907bd61 100644 --- a/src/main.py +++ b/src/main.py @@ -1,12 +1,5 @@ import sys -# Check for TUI flag FIRST, before any heavy imports -if __name__ == "__main__" and len(sys.argv) > 1 and sys.argv[1] == "--tui": - from tui.main import run_tui - - run_tui() - sys.exit(0) - # Configure structured logging early from utils.logging_config import configure_from_env, get_logger @@ -27,6 +20,8 @@ from starlette.routing import Route multiprocessing.set_start_method("spawn", force=True) # Create process pool FIRST, before any torch/CUDA imports +from utils.process_pool import process_pool + import torch # API endpoints @@ -73,6 +68,7 @@ from utils.process_pool import process_pool # API endpoints + logger.info( "CUDA device information", cuda_available=torch.cuda.is_available(), @@ -336,8 +332,6 @@ async def initialize_services(): else: logger.info("[CONNECTORS] Skipping connection loading in no-auth mode") - # New: Langflow file service - langflow_file_service = LangflowFileService() return { @@ -730,6 +724,17 @@ async def create_app(): ), methods=["GET"], ), + Route( + "/connectors/{connector_type}/token", + require_auth(services["session_manager"])( + partial( + connectors.connector_token, + connector_service=services["connector_service"], + session_manager=services["session_manager"], + ) + ), + methods=["GET"], + ), Route( "/connectors/{connector_type}/webhook", partial( diff --git a/src/services/auth_service.py b/src/services/auth_service.py index e5361233..a29c197f 100644 --- a/src/services/auth_service.py +++ b/src/services/auth_service.py @@ -107,11 +107,27 @@ class AuthService: auth_endpoint = oauth_class.AUTH_ENDPOINT token_endpoint = oauth_class.TOKEN_ENDPOINT - # Get client_id from environment variable using connector's env var name - client_id = os.getenv(connector_class.CLIENT_ID_ENV_VAR) - if not client_id: - raise ValueError( - f"{connector_class.CLIENT_ID_ENV_VAR} environment variable not set" + # src/services/auth_service.py + client_key = getattr(connector_class, "CLIENT_ID_ENV_VAR", None) + secret_key = getattr(connector_class, "CLIENT_SECRET_ENV_VAR", None) + + def _assert_env_key(name, val): + if not isinstance(val, str) or not val.strip(): + raise RuntimeError( + f"{connector_class.__name__} misconfigured: {name} must be a non-empty string " + f"(got {val!r}). Define it as a class attribute on the connector." + ) + + _assert_env_key("CLIENT_ID_ENV_VAR", client_key) + _assert_env_key("CLIENT_SECRET_ENV_VAR", secret_key) + + client_id = os.getenv(client_key) + client_secret = os.getenv(secret_key) + + if not client_id or not client_secret: + raise RuntimeError( + f"Missing OAuth env vars for {connector_class.__name__}. " + f"Set {client_key} and {secret_key} in the environment." ) oauth_config = { @@ -267,12 +283,11 @@ class AuthService: ) if jwt_token: - # Get the user info to create a persistent Google Drive connection + # Get the user info to create a persistent connector connection user_info = await self.session_manager.get_user_info_from_token( token_data["access_token"] ) - user_id = user_info["id"] if user_info else None - + response_data = { "status": "authenticated", "purpose": "app_auth", @@ -280,13 +295,13 @@ class AuthService: "jwt_token": jwt_token, # Include JWT token in response } - if user_id: - # Convert the temporary auth connection to a persistent Google Drive connection + if user_info and user_info.get("id"): + # Convert the temporary auth connection to a persistent OAuth connection await self.connector_service.connection_manager.update_connection( connection_id=connection_id, connector_type="google_drive", name=f"Google Drive ({user_info.get('email', 'Unknown')})", - user_id=user_id, + user_id=user_info.get("id"), config={ **connection_config.config, "purpose": "data_source", @@ -335,7 +350,7 @@ class AuthService: user = getattr(request.state, "user", None) if user: - return { + user_data = { "authenticated": True, "user": { "user_id": user.user_id, @@ -348,5 +363,7 @@ class AuthService: else None, }, } + + return user_data else: return {"authenticated": False, "user": None} diff --git a/src/services/chat_service.py b/src/services/chat_service.py index 65e1e37f..122c90fc 100644 --- a/src/services/chat_service.py +++ b/src/services/chat_service.py @@ -199,21 +199,29 @@ class ChatService: async def get_chat_history(self, user_id: str): """Get chat conversation history for a user""" - from agent import get_user_conversations + from agent import get_user_conversations, active_conversations if not user_id: return {"error": "User ID is required", "conversations": []} + # Get metadata from persistent storage conversations_dict = get_user_conversations(user_id) + + # Get in-memory conversations (with function calls) + in_memory_conversations = active_conversations.get(user_id, {}) + logger.debug( "Getting chat history for user", user_id=user_id, - conversation_count=len(conversations_dict), + persistent_count=len(conversations_dict), + in_memory_count=len(in_memory_conversations), ) # Convert conversations dict to list format with metadata conversations = [] - for response_id, conversation_state in conversations_dict.items(): + + # First, process in-memory conversations (they have function calls) + for response_id, conversation_state in in_memory_conversations.items(): # Filter out system messages messages = [] for msg in conversation_state.get("messages", []): @@ -227,6 +235,13 @@ class ChatService: } if msg.get("response_id"): message_data["response_id"] = msg["response_id"] + + # Include function call data if present + if msg.get("chunks"): + message_data["chunks"] = msg["chunks"] + if msg.get("response_data"): + message_data["response_data"] = msg["response_data"] + messages.append(message_data) if messages: # Only include conversations with actual messages @@ -260,11 +275,28 @@ class ChatService: "previous_response_id" ), "total_messages": len(messages), + "source": "in_memory" } ) + + # Then, add any persistent metadata that doesn't have in-memory data + for response_id, metadata in conversations_dict.items(): + if response_id not in in_memory_conversations: + # This is metadata-only conversation (no function calls) + conversations.append({ + "response_id": response_id, + "title": metadata.get("title", "New Chat"), + "endpoint": "chat", + "messages": [], # No messages in metadata-only + "created_at": metadata.get("created_at"), + "last_activity": metadata.get("last_activity"), + "previous_response_id": metadata.get("previous_response_id"), + "total_messages": metadata.get("total_messages", 0), + "source": "metadata_only" + }) # Sort by last activity (most recent first) - conversations.sort(key=lambda c: c["last_activity"], reverse=True) + conversations.sort(key=lambda c: c.get("last_activity", ""), reverse=True) return { "user_id": user_id, @@ -274,72 +306,117 @@ class ChatService: } async def get_langflow_history(self, user_id: str): - """Get langflow conversation history for a user""" + """Get langflow conversation history for a user - now fetches from both OpenRAG memory and Langflow database""" from agent import get_user_conversations - + from services.langflow_history_service import langflow_history_service + if not user_id: return {"error": "User ID is required", "conversations": []} - - conversations_dict = get_user_conversations(user_id) - - # Convert conversations dict to list format with metadata - conversations = [] - for response_id, conversation_state in conversations_dict.items(): - # Filter out system messages - messages = [] - for msg in conversation_state.get("messages", []): - if msg.get("role") in ["user", "assistant"]: - message_data = { - "role": msg["role"], - "content": msg["content"], - "timestamp": msg.get("timestamp").isoformat() - if msg.get("timestamp") - else None, - } - if msg.get("response_id"): - message_data["response_id"] = msg["response_id"] - messages.append(message_data) - - if messages: # Only include conversations with actual messages - # Generate title from first user message - first_user_msg = next( - (msg for msg in messages if msg["role"] == "user"), None - ) - title = ( - first_user_msg["content"][:50] + "..." - if first_user_msg and len(first_user_msg["content"]) > 50 - else first_user_msg["content"] - if first_user_msg - else "New chat" - ) - - conversations.append( - { + + all_conversations = [] + + try: + # 1. Get local conversation metadata (no actual messages stored here) + conversations_dict = get_user_conversations(user_id) + local_metadata = {} + + for response_id, conversation_metadata in conversations_dict.items(): + # Store metadata for later use with Langflow data + local_metadata[response_id] = conversation_metadata + + # 2. Get actual conversations from Langflow database (source of truth for messages) + print(f"[DEBUG] Attempting to fetch Langflow history for user: {user_id}") + langflow_history = await langflow_history_service.get_user_conversation_history(user_id, flow_id=FLOW_ID) + + if langflow_history.get("conversations"): + for conversation in langflow_history["conversations"]: + session_id = conversation["session_id"] + + # Only process sessions that belong to this user (exist in local metadata) + if session_id not in local_metadata: + continue + + # Use Langflow messages (with function calls) as source of truth + messages = [] + for msg in conversation.get("messages", []): + message_data = { + "role": msg["role"], + "content": msg["content"], + "timestamp": msg.get("timestamp"), + "langflow_message_id": msg.get("langflow_message_id"), + "source": "langflow" + } + + # Include function call data if present + if msg.get("chunks"): + message_data["chunks"] = msg["chunks"] + if msg.get("response_data"): + message_data["response_data"] = msg["response_data"] + + messages.append(message_data) + + if messages: + # Use local metadata if available, otherwise generate from Langflow data + metadata = local_metadata.get(session_id, {}) + + if not metadata.get("title"): + first_user_msg = next((msg for msg in messages if msg["role"] == "user"), None) + title = ( + first_user_msg["content"][:50] + "..." + if first_user_msg and len(first_user_msg["content"]) > 50 + else first_user_msg["content"] + if first_user_msg + else "Langflow chat" + ) + else: + title = metadata["title"] + + all_conversations.append({ + "response_id": session_id, + "title": title, + "endpoint": "langflow", + "messages": messages, # Function calls preserved from Langflow + "created_at": metadata.get("created_at") or conversation.get("created_at"), + "last_activity": metadata.get("last_activity") or conversation.get("last_activity"), + "total_messages": len(messages), + "source": "langflow_enhanced", + "langflow_session_id": session_id, + "langflow_flow_id": conversation.get("flow_id") + }) + + # 3. Add any local metadata that doesn't have Langflow data yet (recent conversations) + for response_id, metadata in local_metadata.items(): + if not any(c["response_id"] == response_id for c in all_conversations): + all_conversations.append({ "response_id": response_id, - "title": title, - "endpoint": "langflow", - "messages": messages, - "created_at": conversation_state.get("created_at").isoformat() - if conversation_state.get("created_at") - else None, - "last_activity": conversation_state.get( - "last_activity" - ).isoformat() - if conversation_state.get("last_activity") - else None, - "previous_response_id": conversation_state.get( - "previous_response_id" - ), - "total_messages": len(messages), - } - ) - + "title": metadata.get("title", "New Chat"), + "endpoint": "langflow", + "messages": [], # Will be filled when Langflow sync catches up + "created_at": metadata.get("created_at"), + "last_activity": metadata.get("last_activity"), + "total_messages": metadata.get("total_messages", 0), + "source": "metadata_only" + }) + + if langflow_history.get("conversations"): + print(f"[DEBUG] Added {len(langflow_history['conversations'])} historical conversations from Langflow") + elif langflow_history.get("error"): + print(f"[DEBUG] Could not fetch Langflow history for user {user_id}: {langflow_history['error']}") + else: + print(f"[DEBUG] No Langflow conversations found for user {user_id}") + + except Exception as e: + print(f"[ERROR] Failed to fetch Langflow history: {e}") + # Continue with just in-memory conversations + # Sort by last activity (most recent first) - conversations.sort(key=lambda c: c["last_activity"], reverse=True) - + all_conversations.sort(key=lambda c: c.get("last_activity", ""), reverse=True) + + print(f"[DEBUG] Returning {len(all_conversations)} conversations ({len(local_metadata)} from local metadata)") + return { "user_id": user_id, "endpoint": "langflow", - "conversations": conversations, - "total_conversations": len(conversations), + "conversations": all_conversations, + "total_conversations": len(all_conversations), } diff --git a/src/services/conversation_persistence_service.py b/src/services/conversation_persistence_service.py new file mode 100644 index 00000000..1b37eb4e --- /dev/null +++ b/src/services/conversation_persistence_service.py @@ -0,0 +1,126 @@ +""" +Conversation Persistence Service +Simple service to persist chat conversations to disk so they survive server restarts +""" + +import json +import os +from typing import Dict, Any +from datetime import datetime +import threading + + +class ConversationPersistenceService: + """Simple service to persist conversations to disk""" + + def __init__(self, storage_file: str = "conversations.json"): + self.storage_file = storage_file + self.lock = threading.Lock() + self._conversations = self._load_conversations() + + def _load_conversations(self) -> Dict[str, Dict[str, Any]]: + """Load conversations from disk""" + if os.path.exists(self.storage_file): + try: + with open(self.storage_file, 'r', encoding='utf-8') as f: + data = json.load(f) + print(f"Loaded {self._count_total_conversations(data)} conversations from {self.storage_file}") + return data + except Exception as e: + print(f"Error loading conversations from {self.storage_file}: {e}") + return {} + return {} + + def _save_conversations(self): + """Save conversations to disk""" + try: + with self.lock: + with open(self.storage_file, 'w', encoding='utf-8') as f: + json.dump(self._conversations, f, indent=2, ensure_ascii=False, default=str) + print(f"Saved {self._count_total_conversations(self._conversations)} conversations to {self.storage_file}") + except Exception as e: + print(f"Error saving conversations to {self.storage_file}: {e}") + + def _count_total_conversations(self, data: Dict[str, Any]) -> int: + """Count total conversations across all users""" + total = 0 + for user_conversations in data.values(): + if isinstance(user_conversations, dict): + total += len(user_conversations) + return total + + def get_user_conversations(self, user_id: str) -> Dict[str, Any]: + """Get all conversations for a user""" + if user_id not in self._conversations: + self._conversations[user_id] = {} + return self._conversations[user_id] + + def _serialize_datetime(self, obj: Any) -> Any: + """Recursively convert datetime objects to ISO strings for JSON serialization""" + if isinstance(obj, datetime): + return obj.isoformat() + elif isinstance(obj, dict): + return {key: self._serialize_datetime(value) for key, value in obj.items()} + elif isinstance(obj, list): + return [self._serialize_datetime(item) for item in obj] + else: + return obj + + def store_conversation_thread(self, user_id: str, response_id: str, conversation_state: Dict[str, Any]): + """Store a conversation thread and persist to disk""" + if user_id not in self._conversations: + self._conversations[user_id] = {} + + # Recursively convert datetime objects to strings for JSON serialization + serialized_conversation = self._serialize_datetime(conversation_state) + + self._conversations[user_id][response_id] = serialized_conversation + + # Save to disk (we could optimize this with batching if needed) + self._save_conversations() + + def get_conversation_thread(self, user_id: str, response_id: str) -> Dict[str, Any]: + """Get a specific conversation thread""" + user_conversations = self.get_user_conversations(user_id) + return user_conversations.get(response_id, {}) + + def delete_conversation_thread(self, user_id: str, response_id: str): + """Delete a specific conversation thread""" + if user_id in self._conversations and response_id in self._conversations[user_id]: + del self._conversations[user_id][response_id] + self._save_conversations() + print(f"Deleted conversation {response_id} for user {user_id}") + + def clear_user_conversations(self, user_id: str): + """Clear all conversations for a user""" + if user_id in self._conversations: + del self._conversations[user_id] + self._save_conversations() + print(f"Cleared all conversations for user {user_id}") + + def get_storage_stats(self) -> Dict[str, Any]: + """Get statistics about stored conversations""" + total_users = len(self._conversations) + total_conversations = self._count_total_conversations(self._conversations) + + user_stats = {} + for user_id, conversations in self._conversations.items(): + user_stats[user_id] = { + 'conversation_count': len(conversations), + 'latest_activity': max( + (conv.get('last_activity', '') for conv in conversations.values()), + default='' + ) + } + + return { + 'total_users': total_users, + 'total_conversations': total_conversations, + 'storage_file': self.storage_file, + 'file_exists': os.path.exists(self.storage_file), + 'user_stats': user_stats + } + + +# Global instance +conversation_persistence = ConversationPersistenceService() \ No newline at end of file diff --git a/src/services/langflow_file_service.py b/src/services/langflow_file_service.py index 60056a09..dc2c19fe 100644 --- a/src/services/langflow_file_service.py +++ b/src/services/langflow_file_service.py @@ -1,9 +1,9 @@ -import logging from typing import Any, Dict, List, Optional from config.settings import LANGFLOW_INGEST_FLOW_ID, clients +from utils.logging_config import get_logger -logger = logging.getLogger(__name__) +logger = get_logger(__name__) class LangflowFileService: @@ -24,14 +24,16 @@ class LangflowFileService: headers={"Content-Type": None}, ) logger.debug( - "[LF] Upload response: %s %s", resp.status_code, resp.reason_phrase + "[LF] Upload response", + status_code=resp.status_code, + reason=resp.reason_phrase, ) if resp.status_code >= 400: logger.error( - "[LF] Upload failed: %s %s | body=%s", - resp.status_code, - resp.reason_phrase, - resp.text[:500], + "[LF] Upload failed", + status_code=resp.status_code, + reason=resp.reason_phrase, + body=resp.text[:500], ) resp.raise_for_status() return resp.json() @@ -39,17 +41,19 @@ class LangflowFileService: async def delete_user_file(self, file_id: str) -> None: """Delete a file by id using v2: DELETE /api/v2/files/{id}.""" # NOTE: use v2 root, not /api/v1 - logger.debug("[LF] Delete (v2) -> /api/v2/files/%s", file_id) + logger.debug("[LF] Delete (v2) -> /api/v2/files/{id}", file_id=file_id) resp = await clients.langflow_request("DELETE", f"/api/v2/files/{file_id}") logger.debug( - "[LF] Delete response: %s %s", resp.status_code, resp.reason_phrase + "[LF] Delete response", + status_code=resp.status_code, + reason=resp.reason_phrase, ) if resp.status_code >= 400: logger.error( - "[LF] Delete failed: %s %s | body=%s", - resp.status_code, - resp.reason_phrase, - resp.text[:500], + "[LF] Delete failed", + status_code=resp.status_code, + reason=resp.reason_phrase, + body=resp.text[:500], ) resp.raise_for_status() @@ -84,9 +88,11 @@ class LangflowFileService: if jwt_token: # Using the global variable pattern that Langflow expects for OpenSearch components tweaks["OpenSearchHybrid-Ve6bS"] = {"jwt_token": jwt_token} - logger.error("[LF] Adding JWT token to tweaks for OpenSearch components") + logger.debug( + "[LF] Added JWT token to tweaks for OpenSearch components" + ) else: - logger.error("[LF] No JWT token provided") + logger.warning("[LF] No JWT token provided") if tweaks: payload["tweaks"] = tweaks if session_id: @@ -101,22 +107,32 @@ class LangflowFileService: bool(jwt_token), ) - # Log the full payload for debugging - logger.debug("[LF] Request payload: %s", payload) + # Avoid logging full payload to prevent leaking sensitive data (e.g., JWT) resp = await clients.langflow_request( "POST", f"/api/v1/run/{self.flow_id_ingest}", json=payload ) - logger.debug("[LF] Run response: %s %s", resp.status_code, resp.reason_phrase) + logger.debug( + "[LF] Run response", status_code=resp.status_code, reason=resp.reason_phrase + ) if resp.status_code >= 400: logger.error( - "[LF] Run failed: %s %s | body=%s", - resp.status_code, - resp.reason_phrase, - resp.text[:1000], + "[LF] Run failed", + status_code=resp.status_code, + reason=resp.reason_phrase, + body=resp.text[:1000], ) resp.raise_for_status() - return resp.json() + try: + resp_json = resp.json() + except Exception as e: + logger.error( + "[LF] Failed to parse run response as JSON", + body=resp.text[:1000], + error=str(e), + ) + raise + return resp_json async def upload_and_ingest_file( self, @@ -251,4 +267,4 @@ class LangflowFileService: elif delete_error: result["message"] += f" (cleanup warning: {delete_error})" - return result \ No newline at end of file + return result diff --git a/src/services/langflow_history_service.py b/src/services/langflow_history_service.py new file mode 100644 index 00000000..0b04a2e9 --- /dev/null +++ b/src/services/langflow_history_service.py @@ -0,0 +1,227 @@ +""" +Langflow Message History Service +Simplified service that retrieves message history from Langflow using a single token +""" + +import httpx +from typing import List, Dict, Optional, Any + +from config.settings import LANGFLOW_URL, LANGFLOW_SUPERUSER, LANGFLOW_SUPERUSER_PASSWORD + + +class LangflowHistoryService: + """Simplified service to retrieve message history from Langflow""" + + def __init__(self): + self.langflow_url = LANGFLOW_URL + self.auth_token = None + + async def _authenticate(self) -> Optional[str]: + """Authenticate with Langflow and get access token""" + if self.auth_token: + return self.auth_token + + if not all([LANGFLOW_SUPERUSER, LANGFLOW_SUPERUSER_PASSWORD]): + print("Missing Langflow credentials") + return None + + try: + login_data = { + "username": LANGFLOW_SUPERUSER, + "password": LANGFLOW_SUPERUSER_PASSWORD + } + + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.langflow_url.rstrip('/')}/api/v1/login", + data=login_data, + headers={"Content-Type": "application/x-www-form-urlencoded"} + ) + + if response.status_code == 200: + result = response.json() + self.auth_token = result.get('access_token') + print(f"Successfully authenticated with Langflow for history retrieval") + return self.auth_token + else: + print(f"Langflow authentication failed: {response.status_code}") + return None + + except Exception as e: + print(f"Error authenticating with Langflow: {e}") + return None + + async def get_user_sessions(self, user_id: str, flow_id: Optional[str] = None) -> List[str]: + """Get all session IDs for a user's conversations + + Since we use one Langflow token, we get all sessions and filter by user_id locally + """ + token = await self._authenticate() + if not token: + return [] + + try: + headers = {"Authorization": f"Bearer {token}"} + params = {} + + if flow_id: + params["flow_id"] = flow_id + + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.langflow_url.rstrip('/')}/api/v1/monitor/messages/sessions", + headers=headers, + params=params + ) + + if response.status_code == 200: + session_ids = response.json() + print(f"Found {len(session_ids)} total sessions from Langflow") + + # Since we use a single Langflow instance, return all sessions + # Session filtering is handled by user_id at the application level + return session_ids + else: + print(f"Failed to get sessions: {response.status_code} - {response.text}") + return [] + + except Exception as e: + print(f"Error getting user sessions: {e}") + return [] + + async def get_session_messages(self, user_id: str, session_id: str) -> List[Dict[str, Any]]: + """Get all messages for a specific session""" + token = await self._authenticate() + if not token: + return [] + + try: + headers = {"Authorization": f"Bearer {token}"} + + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.langflow_url.rstrip('/')}/api/v1/monitor/messages", + headers=headers, + params={ + "session_id": session_id, + "order_by": "timestamp" + } + ) + + if response.status_code == 200: + messages = response.json() + # Convert to OpenRAG format + return self._convert_langflow_messages(messages) + else: + print(f"Failed to get messages for session {session_id}: {response.status_code}") + return [] + + except Exception as e: + print(f"Error getting session messages: {e}") + return [] + + def _convert_langflow_messages(self, langflow_messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Convert Langflow messages to OpenRAG format""" + converted_messages = [] + + for msg in langflow_messages: + try: + # Map Langflow message format to OpenRAG format + converted_msg = { + "role": "user" if msg.get("sender") == "User" else "assistant", + "content": msg.get("text", ""), + "timestamp": msg.get("timestamp"), + "langflow_message_id": msg.get("id"), + "langflow_session_id": msg.get("session_id"), + "langflow_flow_id": msg.get("flow_id"), + "sender": msg.get("sender"), + "sender_name": msg.get("sender_name"), + "files": msg.get("files", []), + "properties": msg.get("properties", {}), + "error": msg.get("error", False), + "edit": msg.get("edit", False) + } + + # Extract function calls from content_blocks if present + content_blocks = msg.get("content_blocks", []) + if content_blocks: + chunks = [] + for block in content_blocks: + if block.get("title") == "Agent Steps" and block.get("contents"): + for content in block["contents"]: + if content.get("type") == "tool_use": + # Convert Langflow tool_use format to OpenRAG chunks format + chunk = { + "type": "function", + "function": { + "name": content.get("name", ""), + "arguments": content.get("tool_input", {}), + "response": content.get("output", {}) + }, + "function_call_result": content.get("output", {}), + "duration": content.get("duration"), + "error": content.get("error") + } + chunks.append(chunk) + + if chunks: + converted_msg["chunks"] = chunks + converted_msg["response_data"] = {"tool_calls": chunks} + + converted_messages.append(converted_msg) + + except Exception as e: + print(f"Error converting message: {e}") + continue + + return converted_messages + + async def get_user_conversation_history(self, user_id: str, flow_id: Optional[str] = None) -> Dict[str, Any]: + """Get all conversation history for a user, organized by session + + Simplified version - gets all sessions and lets the frontend filter by user_id + """ + try: + # Get all sessions (no complex filtering needed) + session_ids = await self.get_user_sessions(user_id, flow_id) + + conversations = [] + for session_id in session_ids: + messages = await self.get_session_messages(user_id, session_id) + if messages: + # Create conversation metadata + first_message = messages[0] if messages else None + last_message = messages[-1] if messages else None + + conversation = { + "session_id": session_id, + "langflow_session_id": session_id, # For compatibility + "response_id": session_id, # Map session_id to response_id for frontend compatibility + "messages": messages, + "message_count": len(messages), + "created_at": first_message.get("timestamp") if first_message else None, + "last_activity": last_message.get("timestamp") if last_message else None, + "flow_id": first_message.get("langflow_flow_id") if first_message else None, + "source": "langflow" + } + conversations.append(conversation) + + # Sort by last activity (most recent first) + conversations.sort(key=lambda c: c.get("last_activity", ""), reverse=True) + + return { + "conversations": conversations, + "total_conversations": len(conversations), + "user_id": user_id + } + + except Exception as e: + print(f"Error getting user conversation history: {e}") + return { + "error": str(e), + "conversations": [] + } + + +# Global instance +langflow_history_service = LangflowHistoryService() \ No newline at end of file diff --git a/src/services/session_ownership_service.py b/src/services/session_ownership_service.py new file mode 100644 index 00000000..9e3677fd --- /dev/null +++ b/src/services/session_ownership_service.py @@ -0,0 +1,93 @@ +""" +Session Ownership Service +Simple service that tracks which user owns which session +""" + +import json +import os +from typing import Dict, List, Optional +from datetime import datetime + + +class SessionOwnershipService: + """Simple service to track which user owns which session""" + + def __init__(self): + self.ownership_file = "session_ownership.json" + self.ownership_data = self._load_ownership_data() + + def _load_ownership_data(self) -> Dict[str, Dict[str, any]]: + """Load session ownership data from JSON file""" + if os.path.exists(self.ownership_file): + try: + with open(self.ownership_file, 'r') as f: + return json.load(f) + except Exception as e: + print(f"Error loading session ownership data: {e}") + return {} + return {} + + def _save_ownership_data(self): + """Save session ownership data to JSON file""" + try: + with open(self.ownership_file, 'w') as f: + json.dump(self.ownership_data, f, indent=2) + print(f"Saved session ownership data to {self.ownership_file}") + except Exception as e: + print(f"Error saving session ownership data: {e}") + + def claim_session(self, user_id: str, session_id: str): + """Claim a session for a user""" + if session_id not in self.ownership_data: + self.ownership_data[session_id] = { + "user_id": user_id, + "created_at": datetime.now().isoformat(), + "last_accessed": datetime.now().isoformat() + } + self._save_ownership_data() + print(f"Claimed session {session_id} for user {user_id}") + else: + # Update last accessed time + self.ownership_data[session_id]["last_accessed"] = datetime.now().isoformat() + self._save_ownership_data() + + def get_session_owner(self, session_id: str) -> Optional[str]: + """Get the user ID that owns a session""" + session_data = self.ownership_data.get(session_id) + return session_data.get("user_id") if session_data else None + + def get_user_sessions(self, user_id: str) -> List[str]: + """Get all sessions owned by a user""" + return [ + session_id + for session_id, session_data in self.ownership_data.items() + if session_data.get("user_id") == user_id + ] + + def is_session_owned_by_user(self, session_id: str, user_id: str) -> bool: + """Check if a session is owned by a specific user""" + return self.get_session_owner(session_id) == user_id + + def filter_sessions_for_user(self, session_ids: List[str], user_id: str) -> List[str]: + """Filter a list of sessions to only include those owned by the user""" + user_sessions = self.get_user_sessions(user_id) + return [session for session in session_ids if session in user_sessions] + + def get_ownership_stats(self) -> Dict[str, any]: + """Get statistics about session ownership""" + users = set() + for session_data in self.ownership_data.values(): + users.add(session_data.get("user_id")) + + return { + "total_tracked_sessions": len(self.ownership_data), + "unique_users": len(users), + "sessions_per_user": { + user: len(self.get_user_sessions(user)) + for user in users if user + } + } + + +# Global instance +session_ownership_service = SessionOwnershipService() \ No newline at end of file diff --git a/src/services/task_service.py b/src/services/task_service.py index ad24b188..0537e933 100644 --- a/src/services/task_service.py +++ b/src/services/task_service.py @@ -1,11 +1,10 @@ import asyncio import random -import time -import uuid -from typing import Dict +from typing import Dict, Optional -from models.tasks import FileTask, TaskStatus, UploadTask +from models.tasks import TaskStatus, UploadTask, FileTask from utils.gpu_detection import get_worker_count +from session_manager import AnonymousUser from utils.logging_config import get_logger logger = get_logger(__name__) @@ -179,16 +178,29 @@ class TaskService: self.task_store[user_id][task_id].status = TaskStatus.FAILED self.task_store[user_id][task_id].updated_at = time.time() - def get_task_status(self, user_id: str, task_id: str) -> dict: - """Get the status of a specific upload task""" - if ( - not task_id - or user_id not in self.task_store - or task_id not in self.task_store[user_id] - ): + def get_task_status(self, user_id: str, task_id: str) -> Optional[dict]: + """Get the status of a specific upload task + + Includes fallback to shared tasks stored under the "anonymous" user key + so default system tasks are visible to all users. + """ + if not task_id: return None - upload_task = self.task_store[user_id][task_id] + # Prefer the caller's user_id; otherwise check shared/anonymous tasks + candidate_user_ids = [user_id, AnonymousUser().user_id] + + upload_task = None + for candidate_user_id in candidate_user_ids: + if ( + candidate_user_id in self.task_store + and task_id in self.task_store[candidate_user_id] + ): + upload_task = self.task_store[candidate_user_id][task_id] + break + + if upload_task is None: + return None file_statuses = {} for file_path, file_task in upload_task.file_tasks.items(): @@ -214,14 +226,21 @@ class TaskService: } def get_all_tasks(self, user_id: str) -> list: - """Get all tasks for a user""" - if user_id not in self.task_store: - return [] + """Get all tasks for a user - tasks = [] - for task_id, upload_task in self.task_store[user_id].items(): - tasks.append( - { + Returns the union of the user's own tasks and shared default tasks stored + under the "anonymous" user key. User-owned tasks take precedence + if a task_id overlaps. + """ + tasks_by_id = {} + + def add_tasks_from_store(store_user_id): + if store_user_id not in self.task_store: + return + for task_id, upload_task in self.task_store[store_user_id].items(): + if task_id in tasks_by_id: + continue + tasks_by_id[task_id] = { "task_id": upload_task.task_id, "status": upload_task.status.value, "total_files": upload_task.total_files, @@ -231,18 +250,36 @@ class TaskService: "created_at": upload_task.created_at, "updated_at": upload_task.updated_at, } - ) - # Sort by creation time, most recent first + # First, add user-owned tasks; then shared anonymous; + add_tasks_from_store(user_id) + add_tasks_from_store(AnonymousUser().user_id) + + tasks = list(tasks_by_id.values()) tasks.sort(key=lambda x: x["created_at"], reverse=True) return tasks def cancel_task(self, user_id: str, task_id: str) -> bool: - """Cancel a task if it exists and is not already completed""" - if user_id not in self.task_store or task_id not in self.task_store[user_id]: + """Cancel a task if it exists and is not already completed. + + Supports cancellation of shared default tasks stored under the anonymous user. + """ + # Check candidate user IDs first, then anonymous to find which user ID the task is mapped to + candidate_user_ids = [user_id, AnonymousUser().user_id] + + store_user_id = None + for candidate_user_id in candidate_user_ids: + if ( + candidate_user_id in self.task_store + and task_id in self.task_store[candidate_user_id] + ): + store_user_id = candidate_user_id + break + + if store_user_id is None: return False - upload_task = self.task_store[user_id][task_id] + upload_task = self.task_store[store_user_id][task_id] # Can only cancel pending or running tasks if upload_task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED]: diff --git a/uv.lock b/uv.lock index a08b7457..87734b48 100644 --- a/uv.lock +++ b/uv.lock @@ -1406,7 +1406,7 @@ wheels = [ [[package]] name = "openrag" version = "0.1.0" -source = { virtual = "." } +source = { editable = "." } dependencies = [ { name = "agentd" }, { name = "aiofiles" }, diff --git a/warm_up_docling.py b/warm_up_docling.py index 7e865ae4..c605bef5 100644 --- a/warm_up_docling.py +++ b/warm_up_docling.py @@ -1,16 +1,18 @@ -from docling.document_converter import DocumentConverter -from src.utils.logging_config import get_logger +import logging -logger = get_logger(__name__) +from docling.document_converter import DocumentConverter + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) logger.info("Warming up docling models") try: # Use the sample document to warm up docling test_file = "/app/warmup_ocr.pdf" - logger.info("Using test file to warm up docling", test_file=test_file) + logger.info(f"Using test file to warm up docling: {test_file}") DocumentConverter().convert(test_file) logger.info("Docling models warmed up successfully") except Exception as e: - logger.info("Docling warm-up completed with exception", error=str(e)) + logger.info(f"Docling warm-up completed with exception: {str(e)}") # This is expected - we just want to trigger the model downloads