From 0fc0be855cc97e6ebaec9ab2fdc4fbb6deef5949 Mon Sep 17 00:00:00 2001 From: phact Date: Tue, 2 Dec 2025 02:31:59 -0500 Subject: [PATCH] onboarding filter creation + sticky filters --- .../app/api/queries/useDoclingHealthQuery.ts | 4 +- .../api/queries/useGetConversationsQuery.ts | 15 +- frontend/app/api/queries/useGetNudgesQuery.ts | 11 +- .../app/api/queries/useProviderHealthQuery.ts | 4 +- frontend/app/chat/page.tsx | 38 +++ .../_components/onboarding-card.tsx | 10 + .../_components/onboarding-content.tsx | 55 ++++- .../_components/onboarding-upload.tsx | 83 ++++++- frontend/components/chat-renderer.tsx | 77 +++++- frontend/components/navigation.tsx | 4 +- frontend/contexts/chat-context.tsx | 125 ++++++++-- frontend/hooks/useChatStreaming.ts | 9 + src/agent.py | 48 +++- src/api/chat.py | 6 + src/api/router.py | 25 +- src/api/settings.py | 80 +----- src/config/settings.py | 232 ++++++++---------- src/services/chat_service.py | 9 + src/services/search_service.py | 32 ++- 19 files changed, 571 insertions(+), 296 deletions(-) diff --git a/frontend/app/api/queries/useDoclingHealthQuery.ts b/frontend/app/api/queries/useDoclingHealthQuery.ts index 01441f4b..b41effd4 100644 --- a/frontend/app/api/queries/useDoclingHealthQuery.ts +++ b/frontend/app/api/queries/useDoclingHealthQuery.ts @@ -60,9 +60,9 @@ export const useDoclingHealthQuery = ( // If healthy, check every 30 seconds; otherwise check every 3 seconds return query.state.data?.status === "healthy" ? 30000 : 3000; }, - refetchOnWindowFocus: true, + refetchOnWindowFocus: false, // Disabled to reduce unnecessary calls on tab switches refetchOnMount: true, - staleTime: 30000, // Consider data stale after 25 seconds + staleTime: 30000, // Consider data fresh for 30 seconds ...options, }, queryClient, diff --git a/frontend/app/api/queries/useGetConversationsQuery.ts b/frontend/app/api/queries/useGetConversationsQuery.ts index f7e579b3..d77b7eff 100644 --- a/frontend/app/api/queries/useGetConversationsQuery.ts +++ b/frontend/app/api/queries/useGetConversationsQuery.ts @@ -51,13 +51,15 @@ export const useGetConversationsQuery = ( ) => { const queryClient = useQueryClient(); - async function getConversations(): Promise { + async function getConversations(context: { signal?: AbortSignal }): Promise { try { // Fetch from the selected endpoint only const apiEndpoint = endpoint === "chat" ? "/api/chat/history" : "/api/langflow/history"; - const response = await fetch(apiEndpoint); + const response = await fetch(apiEndpoint, { + signal: context.signal, + }); if (!response.ok) { console.error(`Failed to fetch conversations: ${response.status}`); @@ -84,6 +86,10 @@ export const useGetConversationsQuery = ( return conversations; } catch (error) { + // Ignore abort errors - these are expected when requests are cancelled + if (error instanceof Error && error.name === 'AbortError') { + return []; + } console.error(`Failed to fetch ${endpoint} conversations:`, error); return []; } @@ -94,8 +100,11 @@ export const useGetConversationsQuery = ( queryKey: ["conversations", endpoint, refreshTrigger], placeholderData: (prev) => prev, queryFn: getConversations, - staleTime: 0, // Always consider data stale to ensure fresh data on trigger changes + staleTime: 5000, // Consider data fresh for 5 seconds to prevent excessive refetching gcTime: 5 * 60 * 1000, // Keep in cache for 5 minutes + networkMode: 'always', // Ensure requests can be cancelled + refetchOnMount: false, // Don't refetch on every mount + refetchOnWindowFocus: false, // Don't refetch when window regains focus ...options, }, queryClient, diff --git a/frontend/app/api/queries/useGetNudgesQuery.ts b/frontend/app/api/queries/useGetNudgesQuery.ts index 45ef61e7..05c97bde 100644 --- a/frontend/app/api/queries/useGetNudgesQuery.ts +++ b/frontend/app/api/queries/useGetNudgesQuery.ts @@ -34,7 +34,7 @@ export const useGetNudgesQuery = ( }); } - async function getNudges(): Promise { + async function getNudges(context: { signal?: AbortSignal }): Promise { try { const requestBody: { filters?: NudgeFilters; @@ -58,6 +58,7 @@ export const useGetNudgesQuery = ( "Content-Type": "application/json", }, body: JSON.stringify(requestBody), + signal: context.signal, }); const data = await response.json(); @@ -67,6 +68,10 @@ export const useGetNudgesQuery = ( return DEFAULT_NUDGES; } catch (error) { + // Ignore abort errors - these are expected when requests are cancelled + if (error instanceof Error && error.name === 'AbortError') { + return DEFAULT_NUDGES; + } console.error("Error getting nudges", error); return DEFAULT_NUDGES; } @@ -76,6 +81,10 @@ export const useGetNudgesQuery = ( { queryKey: ["nudges", chatId, filters, limit, scoreThreshold], queryFn: getNudges, + staleTime: 10000, // Consider data fresh for 10 seconds to prevent rapid refetching + networkMode: 'always', // Ensure requests can be cancelled + refetchOnMount: false, // Don't refetch on every mount + refetchOnWindowFocus: false, // Don't refetch when window regains focus refetchInterval: (query) => { // If data is empty, refetch every 5 seconds const data = query.state.data; diff --git a/frontend/app/api/queries/useProviderHealthQuery.ts b/frontend/app/api/queries/useProviderHealthQuery.ts index 82ca2db2..5cd86450 100644 --- a/frontend/app/api/queries/useProviderHealthQuery.ts +++ b/frontend/app/api/queries/useProviderHealthQuery.ts @@ -96,9 +96,9 @@ export const useProviderHealthQuery = ( // If healthy, check every 30 seconds; otherwise check every 3 seconds return query.state.data?.status === "healthy" ? 30000 : 3000; }, - refetchOnWindowFocus: true, + refetchOnWindowFocus: false, // Disabled to reduce unnecessary calls on tab switches refetchOnMount: true, - staleTime: 30000, // Consider data stale after 25 seconds + staleTime: 30000, // Consider data fresh for 30 seconds enabled: !!settings?.edited && options?.enabled !== false, // Only run after onboarding is complete ...options, }, diff --git a/frontend/app/chat/page.tsx b/frontend/app/chat/page.tsx index 9594a0ea..358424f3 100644 --- a/frontend/app/chat/page.tsx +++ b/frontend/app/chat/page.tsx @@ -110,6 +110,13 @@ function ChatPage() { } else { refreshConversationsSilent(); } + + // Save filter association for this response + if (conversationFilter && typeof window !== "undefined") { + const newKey = `conversation_filter_${responseId}`; + localStorage.setItem(newKey, conversationFilter.id); + console.log("[CHAT] Saved filter association:", newKey, "=", conversationFilter.id); + } } }, onError: (error) => { @@ -696,11 +703,18 @@ function ChatPage() { // Use passed previousResponseId if available, otherwise fall back to state const responseIdToUse = previousResponseId || previousResponseIds[endpoint]; + console.log("[CHAT] Sending streaming message:", { + conversationFilter: conversationFilter?.id, + currentConversationId, + responseIdToUse, + }); + // Use the hook to send the message await sendStreamingMessage({ prompt: userMessage.content, previousResponseId: responseIdToUse || undefined, filters: processedFilters, + filter_id: conversationFilter?.id, // ✅ Add filter_id for this conversation limit: parsedFilterData?.limit ?? 10, scoreThreshold: parsedFilterData?.scoreThreshold ?? 0, }); @@ -781,6 +795,19 @@ function ChatPage() { requestBody.previous_response_id = currentResponseId; } + // Add filter_id if a filter is selected for this conversation + if (conversationFilter) { + requestBody.filter_id = conversationFilter.id; + } + + // Debug logging + console.log("[DEBUG] Sending message with:", { + previous_response_id: requestBody.previous_response_id, + filter_id: requestBody.filter_id, + currentConversationId, + previousResponseIds, + }); + const response = await fetch(apiEndpoint, { method: "POST", headers: { @@ -804,6 +831,8 @@ function ChatPage() { // Store the response ID if present for this endpoint if (result.response_id) { + console.log("[DEBUG] Received response_id:", result.response_id, "currentConversationId:", currentConversationId); + setPreviousResponseIds((prev) => ({ ...prev, [endpoint]: result.response_id, @@ -811,12 +840,21 @@ function ChatPage() { // If this is a new conversation (no currentConversationId), set it now if (!currentConversationId) { + console.log("[DEBUG] Setting currentConversationId to:", result.response_id); setCurrentConversationId(result.response_id); refreshConversations(true); } else { + console.log("[DEBUG] Existing conversation, doing silent refresh"); // For existing conversations, do a silent refresh to keep backend in sync refreshConversationsSilent(); } + + // Carry forward the filter association to the new response_id + if (conversationFilter && typeof window !== "undefined") { + const newKey = `conversation_filter_${result.response_id}`; + localStorage.setItem(newKey, conversationFilter.id); + console.log("[DEBUG] Saved filter association:", newKey, "=", conversationFilter.id); + } } } else { console.error("Chat failed:", result.error); diff --git a/frontend/app/onboarding/_components/onboarding-card.tsx b/frontend/app/onboarding/_components/onboarding-card.tsx index 7ac2e85c..7c257088 100644 --- a/frontend/app/onboarding/_components/onboarding-card.tsx +++ b/frontend/app/onboarding/_components/onboarding-card.tsx @@ -209,6 +209,16 @@ const OnboardingCard = ({ const onboardingMutation = useOnboardingMutation({ onSuccess: (data) => { console.log("Onboarding completed successfully", data); + + // Save OpenRAG docs filter ID if sample data was ingested + if (data.openrag_docs_filter_id && typeof window !== "undefined") { + localStorage.setItem( + "onboarding_openrag_docs_filter_id", + data.openrag_docs_filter_id + ); + console.log("Saved OpenRAG docs filter ID:", data.openrag_docs_filter_id); + } + // Update provider health cache to healthy since backend just validated const provider = (isEmbedding ? settings.embedding_provider : settings.llm_provider) || diff --git a/frontend/app/onboarding/_components/onboarding-content.tsx b/frontend/app/onboarding/_components/onboarding-content.tsx index 699c8723..7473a916 100644 --- a/frontend/app/onboarding/_components/onboarding-content.tsx +++ b/frontend/app/onboarding/_components/onboarding-content.tsx @@ -2,11 +2,13 @@ import { useEffect, useRef, useState } from "react"; import { StickToBottom } from "use-stick-to-bottom"; +import { getFilterById } from "@/app/api/queries/useGetFilterByIdQuery"; import { AssistantMessage } from "@/app/chat/_components/assistant-message"; import Nudges from "@/app/chat/_components/nudges"; import { UserMessage } from "@/app/chat/_components/user-message"; import type { Message, SelectedFilters } from "@/app/chat/_types/types"; import OnboardingCard from "@/app/onboarding/_components/onboarding-card"; +import { useChat } from "@/contexts/chat-context"; import { useChatStreaming } from "@/hooks/useChatStreaming"; import { ONBOARDING_ASSISTANT_MESSAGE_KEY, @@ -33,6 +35,7 @@ export function OnboardingContent({ handleStepBack: () => void; currentStep: number; }) { + const { setConversationFilter, setCurrentConversationId } = useChat(); const parseFailedRef = useRef(false); const [responseId, setResponseId] = useState(null); const [selectedNudge, setSelectedNudge] = useState(() => { @@ -78,7 +81,7 @@ export function OnboardingContent({ }, [handleStepBack, currentStep]); const { streamingMessage, isLoading, sendMessage } = useChatStreaming({ - onComplete: (message, newResponseId) => { + onComplete: async (message, newResponseId) => { setAssistantMessage(message); // Save assistant message to localStorage when complete if (typeof window !== "undefined") { @@ -96,6 +99,26 @@ export function OnboardingContent({ } if (newResponseId) { setResponseId(newResponseId); + + // Set the current conversation ID + setCurrentConversationId(newResponseId); + + // Save the filter association for this conversation + const openragDocsFilterId = localStorage.getItem(ONBOARDING_OPENRAG_DOCS_FILTER_ID_KEY); + if (openragDocsFilterId) { + try { + // Load the filter and set it in the context with explicit responseId + // This ensures the filter is saved to localStorage with the correct conversation ID + const filter = await getFilterById(openragDocsFilterId); + if (filter) { + // Pass explicit newResponseId to ensure correct localStorage association + setConversationFilter(filter, newResponseId); + console.log("[ONBOARDING] Saved filter association:", `conversation_filter_${newResponseId}`, "=", openragDocsFilterId); + } + } catch (error) { + console.error("Failed to associate filter with conversation:", error); + } + } } }, onError: (error) => { @@ -124,15 +147,35 @@ export function OnboardingContent({ } setTimeout(async () => { // Check if we have the OpenRAG docs filter ID (sample data was ingested) - const hasOpenragDocsFilter = - typeof window !== "undefined" && - localStorage.getItem(ONBOARDING_OPENRAG_DOCS_FILTER_ID_KEY); + const openragDocsFilterId = + typeof window !== "undefined" + ? localStorage.getItem(ONBOARDING_OPENRAG_DOCS_FILTER_ID_KEY) + : null; + // Load and set the OpenRAG docs filter if available + let filterToUse = null; + console.log("[ONBOARDING] openragDocsFilterId:", openragDocsFilterId); + if (openragDocsFilterId) { + try { + const filter = await getFilterById(openragDocsFilterId); + console.log("[ONBOARDING] Loaded filter:", filter); + if (filter) { + // Pass null to skip localStorage save - no conversation exists yet + setConversationFilter(filter, null); + filterToUse = filter; + } + } catch (error) { + console.error("Failed to load OpenRAG docs filter:", error); + } + } + + console.log("[ONBOARDING] Sending message with filter_id:", filterToUse?.id); await sendMessage({ prompt: nudge, previousResponseId: responseId || undefined, - // Use OpenRAG docs filter if sample data was ingested - filters: hasOpenragDocsFilter ? OPENRAG_DOCS_FILTERS : undefined, + // Send both filter_id and filters (selections) + filter_id: filterToUse?.id, + filters: openragDocsFilterId ? OPENRAG_DOCS_FILTERS : undefined, }); }, 1500); }; diff --git a/frontend/app/onboarding/_components/onboarding-upload.tsx b/frontend/app/onboarding/_components/onboarding-upload.tsx index 7855ec0a..263af7b7 100644 --- a/frontend/app/onboarding/_components/onboarding-upload.tsx +++ b/frontend/app/onboarding/_components/onboarding-upload.tsx @@ -1,5 +1,7 @@ import { AnimatePresence, motion } from "motion/react"; import { type ChangeEvent, useEffect, useRef, useState } from "react"; +import { toast } from "sonner"; +import { useCreateFilter } from "@/app/api/mutations/useCreateFilter"; import { useGetNudgesQuery } from "@/app/api/queries/useGetNudgesQuery"; import { useGetTasksQuery } from "@/app/api/queries/useGetTasksQuery"; import { AnimatedProviderSteps } from "@/app/onboarding/_components/animated-provider-steps"; @@ -18,6 +20,11 @@ const OnboardingUpload = ({ onComplete }: OnboardingUploadProps) => { const fileInputRef = useRef(null); const [isUploading, setIsUploading] = useState(false); const [currentStep, setCurrentStep] = useState(null); + const [uploadedFilename, setUploadedFilename] = useState(null); + const [shouldCreateFilter, setShouldCreateFilter] = useState(false); + const [isCreatingFilter, setIsCreatingFilter] = useState(false); + + const createFilterMutation = useCreateFilter(); const STEP_LIST = [ "Uploading your document", @@ -56,6 +63,60 @@ const OnboardingUpload = ({ onComplete }: OnboardingUploadProps) => { // Set to final step to show "Done" setCurrentStep(STEP_LIST.length); + // Create knowledge filter for uploaded document if requested + // Guard against race condition: only create if not already creating + if (shouldCreateFilter && uploadedFilename && !isCreatingFilter) { + // Reset flags immediately (synchronously) to prevent duplicate creation + setShouldCreateFilter(false); + const filename = uploadedFilename; + setUploadedFilename(null); + setIsCreatingFilter(true); + + // Get display name from filename (remove extension for cleaner name) + const displayName = filename.includes(".") + ? filename.substring(0, filename.lastIndexOf(".")) + : filename; + + const queryData = JSON.stringify({ + query: "", + filters: { + data_sources: [filename], + document_types: ["*"], + owners: ["*"], + connector_types: ["*"], + }, + limit: 10, + scoreThreshold: 0, + color: "green", + icon: "file", + }); + + createFilterMutation + .mutateAsync({ + name: displayName, + description: `Filter for ${filename}`, + queryData: queryData, + }) + .then((result) => { + if (result.filter?.id && typeof window !== "undefined") { + localStorage.setItem( + ONBOARDING_USER_DOC_FILTER_ID_KEY, + result.filter.id, + ); + console.log( + "Created knowledge filter for uploaded document", + result.filter.id, + ); + } + }) + .catch((error) => { + console.error("Failed to create knowledge filter:", error); + }) + .finally(() => { + setIsCreatingFilter(false); + }); + } + // Refetch nudges to get new ones refetchNudges(); @@ -64,7 +125,7 @@ const OnboardingUpload = ({ onComplete }: OnboardingUploadProps) => { onComplete(); }, 1000); } - }, [tasks, currentStep, onComplete, refetchNudges]); + }, [tasks, currentStep, onComplete, refetchNudges, shouldCreateFilter, uploadedFilename]); const resetFileInput = () => { if (fileInputRef.current) { @@ -83,12 +144,10 @@ const OnboardingUpload = ({ onComplete }: OnboardingUploadProps) => { const result = await uploadFile(file, true, true); // Pass createFilter=true console.log("Document upload task started successfully"); - // Store user document filter ID if returned - if (result.userDocFilterId && typeof window !== "undefined") { - localStorage.setItem( - ONBOARDING_USER_DOC_FILTER_ID_KEY, - result.userDocFilterId - ); + // Store filename and createFilter flag in state to create filter after ingestion succeeds + if (result.createFilter && result.filename) { + setUploadedFilename(result.filename); + setShouldCreateFilter(true); } // Move to processing step - task monitoring will handle completion @@ -96,7 +155,15 @@ const OnboardingUpload = ({ onComplete }: OnboardingUploadProps) => { setCurrentStep(1); }, 1500); } catch (error) { - console.error("Upload failed", (error as Error).message); + const errorMessage = error instanceof Error ? error.message : "Upload failed"; + console.error("Upload failed", errorMessage); + + // Show error toast notification + toast.error("Document upload failed", { + description: errorMessage, + duration: 5000, + }); + // Reset on error setCurrentStep(null); } finally { diff --git a/frontend/components/chat-renderer.tsx b/frontend/components/chat-renderer.tsx index 45841299..6804b065 100644 --- a/frontend/components/chat-renderer.tsx +++ b/frontend/components/chat-renderer.tsx @@ -1,7 +1,7 @@ "use client"; import { motion } from "framer-motion"; -import { usePathname } from "next/navigation"; +import { usePathname, useRouter } from "next/navigation"; import { useCallback, useEffect, useState } from "react"; import { type ChatConversation, @@ -39,6 +39,7 @@ export function ChatRenderer({ children: React.ReactNode; }) { const pathname = usePathname(); + const router = useRouter(); const { isAuthenticated, isNoAuthMode } = useAuth(); const { endpoint, @@ -46,6 +47,8 @@ export function ChatRenderer({ refreshConversations, startNewConversation, setConversationFilter, + setCurrentConversationId, + setPreviousResponseIds, } = useChat(); // Initialize onboarding state based on local storage and settings @@ -75,38 +78,74 @@ export function ChatRenderer({ startNewConversation(); }; - // Helper to set the default filter after onboarding transition - const setOnboardingFilter = useCallback( + // Navigate to /chat when onboarding is active so animation reveals chat underneath + useEffect(() => { + if (!showLayout && pathname !== "/chat" && pathname !== "/") { + router.push("/chat"); + } + }, [showLayout, pathname, router]); + + // Helper to store default filter ID for new conversations after onboarding + const storeDefaultFilterForNewConversations = useCallback( async (preferUserDoc: boolean) => { if (typeof window === "undefined") return; + // Check if we already have a default filter set + const existingDefault = localStorage.getItem("default_conversation_filter_id"); + if (existingDefault) { + console.log("[FILTER] Default filter already set:", existingDefault); + // Try to apply it to context state (don't save to localStorage to avoid overwriting) + try { + const filter = await getFilterById(existingDefault); + if (filter) { + // Pass null to skip localStorage save + setConversationFilter(filter, null); + return; // Successfully loaded and set, we're done + } + } catch (error) { + console.error("Failed to load existing default filter, will set new one:", error); + // Filter doesn't exist anymore, clear it and continue to set a new one + localStorage.removeItem("default_conversation_filter_id"); + } + } + // Try to get the appropriate filter ID let filterId: string | null = null; if (preferUserDoc) { // Completed full onboarding - prefer user document filter filterId = localStorage.getItem(ONBOARDING_USER_DOC_FILTER_ID_KEY); + console.log("[FILTER] User doc filter ID:", filterId); } // Fall back to OpenRAG docs filter if (!filterId) { filterId = localStorage.getItem(ONBOARDING_OPENRAG_DOCS_FILTER_ID_KEY); + console.log("[FILTER] OpenRAG docs filter ID:", filterId); } + console.log("[FILTER] Final filter ID to use:", filterId); + if (filterId) { + // Store this as the default filter for new conversations + localStorage.setItem("default_conversation_filter_id", filterId); + + // Apply filter to context state only (don't save to localStorage since there's no conversation yet) + // The default_conversation_filter_id will be used when a new conversation is started try { const filter = await getFilterById(filterId); + console.log("[FILTER] Loaded filter:", filter); if (filter) { - setConversationFilter(filter); + // Pass null to skip localStorage save - this prevents overwriting existing conversation filters + setConversationFilter(filter, null); + console.log("[FILTER] Set conversation filter (no save):", filter.id); } } catch (error) { console.error("Failed to set onboarding filter:", error); } + } else { + console.log("[FILTER] No filter ID found, not setting default"); } - - // Clean up onboarding filter IDs from localStorage - localStorage.removeItem(ONBOARDING_OPENRAG_DOCS_FILTER_ID_KEY); - localStorage.removeItem(ONBOARDING_USER_DOC_FILTER_ID_KEY); }, [setConversationFilter] ); @@ -118,7 +157,7 @@ export function ChatRenderer({ } }, [currentStep, showLayout]); - const handleStepComplete = () => { + const handleStepComplete = async () => { if (currentStep < TOTAL_ONBOARDING_STEPS - 1) { setCurrentStep(currentStep + 1); } else { @@ -130,8 +169,20 @@ export function ChatRenderer({ localStorage.removeItem(ONBOARDING_CARD_STEPS_KEY); localStorage.removeItem(ONBOARDING_UPLOAD_STEPS_KEY); } - // Set the user document filter as active (completed full onboarding) - setOnboardingFilter(true); + + // Clear ALL conversation state so next message starts fresh + await startNewConversation(); + + // Store the user document filter as default for new conversations and load it + await storeDefaultFilterForNewConversations(true); + + // Clean up onboarding filter IDs now that we've set the default + if (typeof window !== "undefined") { + localStorage.removeItem(ONBOARDING_OPENRAG_DOCS_FILTER_ID_KEY); + localStorage.removeItem(ONBOARDING_USER_DOC_FILTER_ID_KEY); + console.log("[FILTER] Cleaned up onboarding filter IDs"); + } + setShowLayout(true); } }; @@ -151,8 +202,8 @@ export function ChatRenderer({ localStorage.removeItem(ONBOARDING_CARD_STEPS_KEY); localStorage.removeItem(ONBOARDING_UPLOAD_STEPS_KEY); } - // Set the OpenRAG docs filter as active (skipped onboarding - no user doc) - setOnboardingFilter(false); + // Store the OpenRAG docs filter as default for new conversations + storeDefaultFilterForNewConversations(false); setShowLayout(true); }; diff --git a/frontend/components/navigation.tsx b/frontend/components/navigation.tsx index b172779e..68ffa4e6 100644 --- a/frontend/components/navigation.tsx +++ b/frontend/components/navigation.tsx @@ -289,7 +289,7 @@ export function Navigation({ handleNewConversation(); } else if (activeConvo) { loadConversation(activeConvo); - refreshConversations(); + // Don't call refreshConversations here - it causes unnecessary refetches } else if ( conversations.length > 0 && currentConversationId === null && @@ -473,7 +473,7 @@ export function Navigation({ onClick={() => { if (loading || isConversationsLoading) return; loadConversation(conversation); - refreshConversations(); + // Don't refresh - just loading an existing conversation }} disabled={loading || isConversationsLoading} > diff --git a/frontend/contexts/chat-context.tsx b/frontend/contexts/chat-context.tsx index 8c203003..bee05b98 100644 --- a/frontend/contexts/chat-context.tsx +++ b/frontend/contexts/chat-context.tsx @@ -65,7 +65,7 @@ interface ChatContextType { refreshConversationsSilent: () => Promise; refreshTrigger: number; refreshTriggerSilent: number; - loadConversation: (conversation: ConversationData) => void; + loadConversation: (conversation: ConversationData) => Promise; startNewConversation: () => void; conversationData: ConversationData | null; forkFromResponse: (responseId: string) => void; @@ -77,7 +77,8 @@ interface ChatContextType { conversationLoaded: boolean; setConversationLoaded: (loaded: boolean) => void; conversationFilter: KnowledgeFilter | null; - setConversationFilter: (filter: KnowledgeFilter | null) => void; + // responseId: undefined = use currentConversationId, null = don't save to localStorage + setConversationFilter: (filter: KnowledgeFilter | null, responseId?: string | null) => void; } const ChatContext = createContext(undefined); @@ -112,6 +113,8 @@ export function ChatProvider({ children }: ChatProviderProps) { const refreshTimeoutRef = useRef(null); const refreshConversations = useCallback((force = false) => { + console.log("[REFRESH] refreshConversations called, force:", force); + if (force) { // Immediate refresh for important updates like new conversations setRefreshTrigger((prev) => prev + 1); @@ -145,22 +148,59 @@ export function ChatProvider({ children }: ChatProviderProps) { }, []); const loadConversation = useCallback( - (conversation: ConversationData) => { + async (conversation: ConversationData) => { + console.log("[CONVERSATION] Loading conversation:", { + conversationId: conversation.response_id, + title: conversation.title, + endpoint: conversation.endpoint, + }); + setCurrentConversationId(conversation.response_id); setEndpoint(conversation.endpoint); // Store the full conversation data for the chat page to use setConversationData(conversation); + // Load the filter if one exists for this conversation - // Only update the filter if this is a different conversation (to preserve user's filter selection) - setConversationFilterState((currentFilter) => { - // If we're loading a different conversation, load its filter - // Otherwise keep the current filter (don't reset it when conversation refreshes) - const isDifferentConversation = - conversation.response_id !== conversationData?.response_id; - return isDifferentConversation - ? conversation.filter || null - : currentFilter; - }); + // Always update the filter to match the conversation being loaded + const isDifferentConversation = + conversation.response_id !== conversationData?.response_id; + + if (isDifferentConversation && typeof window !== "undefined") { + // Try to load the saved filter from localStorage + const savedFilterId = localStorage.getItem(`conversation_filter_${conversation.response_id}`); + console.log("[CONVERSATION] Looking for filter:", { + conversationId: conversation.response_id, + savedFilterId, + }); + + if (savedFilterId) { + // Import getFilterById dynamically to avoid circular dependency + const { getFilterById } = await import("@/app/api/queries/useGetFilterByIdQuery"); + try { + const filter = await getFilterById(savedFilterId); + + if (filter) { + console.log("[CONVERSATION] Loaded filter:", filter.name, filter.id); + setConversationFilterState(filter); + // Update conversation data with the loaded filter + setConversationData((prev) => { + if (!prev) return prev; + return { ...prev, filter }; + }); + } + } catch (error) { + console.error("[CONVERSATION] Failed to load filter:", error); + // Filter was deleted, clean up localStorage + localStorage.removeItem(`conversation_filter_${conversation.response_id}`); + setConversationFilterState(null); + } + } else { + // No saved filter in localStorage, clear the current filter + console.log("[CONVERSATION] No filter found for this conversation"); + setConversationFilterState(null); + } + } + // Clear placeholder when loading a real conversation setPlaceholderConversation(null); setConversationLoaded(true); @@ -170,15 +210,48 @@ export function ChatProvider({ children }: ChatProviderProps) { [conversationData?.response_id], ); - const startNewConversation = useCallback(() => { + const startNewConversation = useCallback(async () => { + console.log("[CONVERSATION] Starting new conversation"); + // Clear current conversation data and reset state setCurrentConversationId(null); setPreviousResponseIds({ chat: null, langflow: null }); setConversationData(null); setConversationDocs([]); setConversationLoaded(false); - // Clear the filter when starting a new conversation - setConversationFilterState(null); + + // Load default filter if available (and clear it after first use) + if (typeof window !== "undefined") { + const defaultFilterId = localStorage.getItem("default_conversation_filter_id"); + console.log("[CONVERSATION] Default filter ID:", defaultFilterId); + + if (defaultFilterId) { + // Clear the default filter now so it's only used once + localStorage.removeItem("default_conversation_filter_id"); + console.log("[CONVERSATION] Cleared default filter (used once)"); + + try { + const { getFilterById } = await import("@/app/api/queries/useGetFilterByIdQuery"); + const filter = await getFilterById(defaultFilterId); + + if (filter) { + console.log("[CONVERSATION] Loaded default filter:", filter.name, filter.id); + setConversationFilterState(filter); + } else { + // Default filter was deleted + setConversationFilterState(null); + } + } catch (error) { + console.error("[CONVERSATION] Failed to load default filter:", error); + setConversationFilterState(null); + } + } else { + console.log("[CONVERSATION] No default filter set"); + setConversationFilterState(null); + } + } else { + setConversationFilterState(null); + } // Create a temporary placeholder conversation to show in sidebar const placeholderConversation: ConversationData = { @@ -230,7 +303,7 @@ export function ChatProvider({ children }: ChatProviderProps) { ); const setConversationFilter = useCallback( - (filter: KnowledgeFilter | null) => { + (filter: KnowledgeFilter | null, responseId?: string | null) => { setConversationFilterState(filter); // Update the conversation data to include the filter setConversationData((prev) => { @@ -240,8 +313,24 @@ export function ChatProvider({ children }: ChatProviderProps) { filter, }; }); + + // Determine which conversation ID to use for saving + // - undefined: use currentConversationId (default behavior) + // - null: explicitly skip saving to localStorage + // - string: use the provided responseId + const targetId = responseId === undefined ? currentConversationId : responseId; + + // Save filter association for the target conversation + if (typeof window !== "undefined" && targetId) { + const key = `conversation_filter_${targetId}`; + if (filter) { + localStorage.setItem(key, filter.id); + } else { + localStorage.removeItem(key); + } + } }, - [], + [currentConversationId], ); const value = useMemo( diff --git a/frontend/hooks/useChatStreaming.ts b/frontend/hooks/useChatStreaming.ts index b2877fd0..6836ed4b 100644 --- a/frontend/hooks/useChatStreaming.ts +++ b/frontend/hooks/useChatStreaming.ts @@ -15,6 +15,7 @@ interface SendMessageOptions { prompt: string; previousResponseId?: string; filters?: SelectedFilters; + filter_id?: string; limit?: number; scoreThreshold?: number; } @@ -35,6 +36,7 @@ export function useChatStreaming({ prompt, previousResponseId, filters, + filter_id, limit = 10, scoreThreshold = 0, }: SendMessageOptions) => { @@ -73,6 +75,7 @@ export function useChatStreaming({ stream: boolean; previous_response_id?: string; filters?: SelectedFilters; + filter_id?: string; limit?: number; scoreThreshold?: number; } = { @@ -90,6 +93,12 @@ export function useChatStreaming({ requestBody.filters = filters; } + if (filter_id) { + requestBody.filter_id = filter_id; + } + + console.log("[useChatStreaming] Sending request:", { filter_id, requestBody }); + const response = await fetch(endpoint, { method: "POST", headers: { diff --git a/src/agent.py b/src/agent.py index 74da24e7..bd4d257f 100644 --- a/src/agent.py +++ b/src/agent.py @@ -1,3 +1,5 @@ +from http.client import HTTPException + from utils.logging_config import get_logger logger = get_logger(__name__) @@ -67,6 +69,7 @@ def store_conversation_thread(user_id: str, response_id: str, conversation_state "created_at": conversation_state.get("created_at"), "last_activity": conversation_state.get("last_activity"), "previous_response_id": conversation_state.get("previous_response_id"), + "filter_id": conversation_state.get("filter_id"), "total_messages": len( [msg for msg in messages if msg.get("role") in ["user", "assistant"]] ), @@ -219,15 +222,26 @@ async def async_response( response = await client.responses.create(**request_params) - response_text = response.output_text - logger.info("Response generated", log_prefix=log_prefix, response=response_text) + # Check if response has output_text using getattr to avoid issues with special objects + output_text = getattr(response, "output_text", None) + if output_text is not None: + response_text = output_text + logger.info("Response generated", log_prefix=log_prefix, response=response_text) - # Extract and store response_id if available - response_id = getattr(response, "id", None) or getattr( - response, "response_id", None - ) + # Extract and store response_id if available + response_id = getattr(response, "id", None) or getattr( + response, "response_id", None + ) - return response_text, response_id, response + return response_text, response_id, response + else: + msg = "Nudge response missing output_text" + error = getattr(response, "error", None) + if error: + error_msg = getattr(error, "message", None) + if error_msg: + msg = error_msg + raise ValueError(msg) except Exception as e: logger.error("Exception in non-streaming response", error=str(e)) import traceback @@ -314,6 +328,7 @@ async def async_chat( user_id: str, model: str = "gpt-4.1-mini", previous_response_id: str = None, + filter_id: str = None, ): logger.debug( "async_chat called", user_id=user_id, previous_response_id=previous_response_id @@ -334,6 +349,10 @@ async def async_chat( "Added user message", message_count=len(conversation_state["messages"]) ) + # Store filter_id in conversation state if provided + if filter_id: + conversation_state["filter_id"] = filter_id + response_text, response_id, response_obj = await async_response( async_client, prompt, @@ -389,6 +408,7 @@ async def async_chat_stream( user_id: str, model: str = "gpt-4.1-mini", previous_response_id: str = None, + filter_id: str = None, ): # Get the specific conversation thread (or create new one) conversation_state = get_conversation_thread(user_id, previous_response_id) @@ -399,6 +419,10 @@ async def async_chat_stream( user_message = {"role": "user", "content": prompt, "timestamp": datetime.now()} conversation_state["messages"].append(user_message) + # Store filter_id in conversation state if provided + if filter_id: + conversation_state["filter_id"] = filter_id + full_response = "" response_id = None async for chunk in async_stream( @@ -452,6 +476,7 @@ async def async_langflow_chat( extra_headers: dict = None, previous_response_id: str = None, store_conversation: bool = True, + filter_id: str = None, ): logger.debug( "async_langflow_chat called", @@ -478,6 +503,10 @@ async def async_langflow_chat( message_count=len(conversation_state["messages"]), ) + # Store filter_id in conversation state if provided + if filter_id: + conversation_state["filter_id"] = filter_id + response_text, response_id, response_obj = await async_response( langflow_client, prompt, @@ -562,6 +591,7 @@ async def async_langflow_chat_stream( user_id: str, extra_headers: dict = None, previous_response_id: str = None, + filter_id: str = None, ): logger.debug( "async_langflow_chat_stream called", @@ -578,6 +608,10 @@ async def async_langflow_chat_stream( user_message = {"role": "user", "content": prompt, "timestamp": datetime.now()} conversation_state["messages"].append(user_message) + # Store filter_id in conversation state if provided + if filter_id: + conversation_state["filter_id"] = filter_id + full_response = "" response_id = None collected_chunks = [] # Store all chunks for function call data diff --git a/src/api/chat.py b/src/api/chat.py index 58492118..56da5b2d 100644 --- a/src/api/chat.py +++ b/src/api/chat.py @@ -14,6 +14,7 @@ async def chat_endpoint(request: Request, chat_service, session_manager): filters = data.get("filters") limit = data.get("limit", 10) score_threshold = data.get("scoreThreshold", 0) + filter_id = data.get("filter_id") user = request.state.user user_id = user.user_id @@ -42,6 +43,7 @@ async def chat_endpoint(request: Request, chat_service, session_manager): jwt_token, previous_response_id=previous_response_id, stream=True, + filter_id=filter_id, ), media_type="text/event-stream", headers={ @@ -58,6 +60,7 @@ async def chat_endpoint(request: Request, chat_service, session_manager): jwt_token, previous_response_id=previous_response_id, stream=False, + filter_id=filter_id, ) return JSONResponse(result) @@ -71,6 +74,7 @@ async def langflow_endpoint(request: Request, chat_service, session_manager): filters = data.get("filters") limit = data.get("limit", 10) score_threshold = data.get("scoreThreshold", 0) + filter_id = data.get("filter_id") user = request.state.user user_id = user.user_id @@ -100,6 +104,7 @@ async def langflow_endpoint(request: Request, chat_service, session_manager): jwt_token, previous_response_id=previous_response_id, stream=True, + filter_id=filter_id, ), media_type="text/event-stream", headers={ @@ -116,6 +121,7 @@ async def langflow_endpoint(request: Request, chat_service, session_manager): jwt_token, previous_response_id=previous_response_id, stream=False, + filter_id=filter_id, ) return JSONResponse(result) diff --git a/src/api/router.py b/src/api/router.py index e8fb924d..79d03df5 100644 --- a/src/api/router.py +++ b/src/api/router.py @@ -179,34 +179,13 @@ async def langflow_upload_ingest_task( logger.debug("Langflow upload task created successfully", task_id=task_id) - # Create knowledge filter for the uploaded document if requested - user_doc_filter_id = None - if create_filter and len(original_filenames) == 1: - try: - from api.settings import _create_user_document_filter - user_doc_filter_id = await _create_user_document_filter( - request, session_manager, original_filenames[0] - ) - if user_doc_filter_id: - logger.info( - "Created knowledge filter for uploaded document", - filter_id=user_doc_filter_id, - filename=original_filenames[0], - ) - except Exception as e: - logger.error( - "Failed to create knowledge filter for uploaded document", - error=str(e), - ) - # Don't fail the upload if filter creation fails - response_data = { "task_id": task_id, "message": f"Langflow upload task created for {len(upload_files)} file(s)", "file_count": len(upload_files), + "create_filter": create_filter, # Pass flag back to frontend + "filename": original_filenames[0] if len(original_filenames) == 1 else None, # Pass filename for filter creation } - if user_doc_filter_id: - response_data["user_doc_filter_id"] = user_doc_filter_id return JSONResponse(response_data, status_code=202) # 202 Accepted for async processing diff --git a/src/api/settings.py b/src/api/settings.py index e5a26de5..73c673be 100644 --- a/src/api/settings.py +++ b/src/api/settings.py @@ -556,7 +556,7 @@ async def update_settings(request, session_manager): "ollama_endpoint" ] - await clients.refresh_patched_clients() + await clients.refresh_patched_client() if any(key in body for key in provider_fields_to_check): try: @@ -926,12 +926,13 @@ async def onboarding(request, flows_service, session_manager=None): {"error": "Failed to save configuration"}, status_code=500 ) - # Refresh cached patched clients so latest credentials take effect immediately - await clients.refresh_patched_clients() + # Refresh cached patched client so latest credentials take effect immediately + await clients.refresh_patched_client() # Create OpenRAG Docs knowledge filter if sample data was ingested + # Only create on embedding step to avoid duplicates (both LLM and embedding cards submit with sample_data) openrag_docs_filter_id = None - if should_ingest_sample_data: + if should_ingest_sample_data and ("embedding_provider" in body or "embedding_model" in body): try: openrag_docs_filter_id = await _create_openrag_docs_filter( request, session_manager @@ -1031,77 +1032,6 @@ async def _create_openrag_docs_filter(request, session_manager): return None -async def _create_user_document_filter(request, session_manager, filename): - """Create a knowledge filter for a user-uploaded document during onboarding""" - import uuid - import json - from datetime import datetime - - # Get knowledge filter service from app state - app = request.scope.get("app") - if not app or not hasattr(app.state, "services"): - logger.error("Could not access services for knowledge filter creation") - return None - - knowledge_filter_service = app.state.services.get("knowledge_filter_service") - if not knowledge_filter_service: - logger.error("Knowledge filter service not available") - return None - - # Get user and JWT token from request - user = request.state.user - jwt_token = session_manager.get_effective_jwt_token(user.user_id, request.state.jwt_token) - - # In no-auth mode, set owner to None so filter is visible to all users - # In auth mode, use the actual user as owner - if is_no_auth_mode(): - owner_user_id = None - else: - owner_user_id = user.user_id - - # Create the filter document - filter_id = str(uuid.uuid4()) - - # Get a display name from the filename (remove extension for cleaner name) - display_name = filename.rsplit(".", 1)[0] if "." in filename else filename - - query_data = json.dumps({ - "query": "", - "filters": { - "data_sources": [filename], - "document_types": ["*"], - "owners": ["*"], - "connector_types": ["*"], - }, - "limit": 10, - "scoreThreshold": 0, - "color": "green", - "icon": "file", - }) - - filter_doc = { - "id": filter_id, - "name": display_name, - "description": f"Filter for {filename}", - "query_data": query_data, - "owner": owner_user_id, - "allowed_users": [], - "allowed_groups": [], - "created_at": datetime.utcnow().isoformat(), - "updated_at": datetime.utcnow().isoformat(), - } - - result = await knowledge_filter_service.create_knowledge_filter( - filter_doc, user_id=user.user_id, jwt_token=jwt_token - ) - - if result.get("success"): - return filter_id - else: - logger.error("Failed to create user document filter", error=result.get("error")) - return None - - def _get_flows_service(): """Helper function to get flows service instance""" from services.flows_service import FlowsService diff --git a/src/config/settings.py b/src/config/settings.py index b7c94936..75b09f09 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -297,8 +297,7 @@ class AppClients: self.opensearch = None self.langflow_client = None self.langflow_http_client = None - self._patched_llm_client = None # Private attribute - self._patched_embedding_client = None # Private attribute + self._patched_async_client = None # Private attribute - single client for all providers self._client_init_lock = __import__('threading').Lock() # Lock for thread-safe initialization self.converter = None @@ -377,192 +376,157 @@ class AppClients: self.langflow_client = None return self.langflow_client - def _build_provider_env(self, provider_type: str): + @property + def patched_async_client(self): """ - Build environment overrides for the requested provider type ("llm" or "embedding"). - This is used to support different credentials for LLM and embedding providers. + Property that ensures OpenAI client is initialized on first access. + This allows lazy initialization so the app can start without an API key. + + The client is patched with LiteLLM support to handle multiple providers. + All provider credentials are loaded into environment for LiteLLM routing. + + Note: The client is a long-lived singleton that should be closed via cleanup(). + Thread-safe via lock to prevent concurrent initialization attempts. """ - config = get_openrag_config() - - if provider_type == "llm": - provider = (config.agent.llm_provider or "openai").lower() - else: - provider = (config.knowledge.embedding_provider or "openai").lower() - - env_overrides = {} - - if provider == "openai": - api_key = config.providers.openai.api_key or os.getenv("OPENAI_API_KEY") - if api_key: - env_overrides["OPENAI_API_KEY"] = api_key - elif provider == "anthropic": - api_key = config.providers.anthropic.api_key or os.getenv("ANTHROPIC_API_KEY") - if api_key: - env_overrides["ANTHROPIC_API_KEY"] = api_key - elif provider == "watsonx": - api_key = config.providers.watsonx.api_key or os.getenv("WATSONX_API_KEY") - endpoint = config.providers.watsonx.endpoint or os.getenv("WATSONX_ENDPOINT") - project_id = config.providers.watsonx.project_id or os.getenv("WATSONX_PROJECT_ID") - if api_key: - env_overrides["WATSONX_API_KEY"] = api_key - if endpoint: - env_overrides["WATSONX_ENDPOINT"] = endpoint - if project_id: - env_overrides["WATSONX_PROJECT_ID"] = project_id - elif provider == "ollama": - endpoint = config.providers.ollama.endpoint or os.getenv("OLLAMA_ENDPOINT") - if endpoint: - env_overrides["OLLAMA_ENDPOINT"] = endpoint - env_overrides["OLLAMA_BASE_URL"] = endpoint - - return env_overrides, provider - - def _apply_env_overrides(self, env_overrides: dict): - """Apply non-empty environment overrides.""" - for key, value in (env_overrides or {}).items(): - if value is None: - continue - os.environ[key] = str(value) - - def _initialize_patched_client(self, cache_attr: str, provider_type: str): - """ - Initialize a patched AsyncOpenAI client for the specified provider type. - Uses HTTP/2 probe only when an OpenAI key is present; otherwise falls back directly. - """ - # Quick path - cached_client = getattr(self, cache_attr) - if cached_client is not None: - return cached_client + # Quick check without lock + if self._patched_async_client is not None: + return self._patched_async_client + # Use lock to ensure only one thread initializes with self._client_init_lock: - cached_client = getattr(self, cache_attr) - if cached_client is not None: - return cached_client + # Double-check after acquiring lock + if self._patched_async_client is not None: + return self._patched_async_client - env_overrides, provider_name = self._build_provider_env(provider_type) - self._apply_env_overrides(env_overrides) - - # Decide whether to run the HTTP/2 probe (only meaningful for OpenAI endpoints) - has_openai_key = bool(env_overrides.get("OPENAI_API_KEY") or os.getenv("OPENAI_API_KEY")) - if provider_name == "openai" and not has_openai_key: - raise ValueError("OPENAI_API_KEY is required for OpenAI provider") + # Load all provider credentials into environment for LiteLLM + # LiteLLM routes based on model name prefixes (openai/, ollama/, watsonx/, etc.) + try: + config = get_openrag_config() + + # Set OpenAI credentials + if config.providers.openai.api_key: + os.environ["OPENAI_API_KEY"] = config.providers.openai.api_key + logger.debug("Loaded OpenAI API key from config") + + # Set Anthropic credentials + if config.providers.anthropic.api_key: + os.environ["ANTHROPIC_API_KEY"] = config.providers.anthropic.api_key + logger.debug("Loaded Anthropic API key from config") + + # Set WatsonX credentials + if config.providers.watsonx.api_key: + os.environ["WATSONX_API_KEY"] = config.providers.watsonx.api_key + if config.providers.watsonx.endpoint: + os.environ["WATSONX_ENDPOINT"] = config.providers.watsonx.endpoint + os.environ["WATSONX_API_BASE"] = config.providers.watsonx.endpoint # LiteLLM expects this name + if config.providers.watsonx.project_id: + os.environ["WATSONX_PROJECT_ID"] = config.providers.watsonx.project_id + if config.providers.watsonx.api_key: + logger.debug("Loaded WatsonX credentials from config") + + # Set Ollama endpoint + if config.providers.ollama.endpoint: + os.environ["OLLAMA_BASE_URL"] = config.providers.ollama.endpoint + os.environ["OLLAMA_ENDPOINT"] = config.providers.ollama.endpoint + logger.debug("Loaded Ollama endpoint from config") + + except Exception as e: + logger.debug("Could not load provider credentials from config", error=str(e)) + # Try to initialize the client - AsyncOpenAI() will read from environment + # We'll try HTTP/2 first with a probe, then fall back to HTTP/1.1 if it times out import asyncio import concurrent.futures + import threading - async def build_client(skip_probe: bool = False): - if not has_openai_key: - # No OpenAI key present; create a basic patched client without probing - return patch_openai_with_mcp(AsyncOpenAI(http_client=httpx.AsyncClient())) - - if skip_probe: - http_client = httpx.AsyncClient(http2=False, timeout=httpx.Timeout(60.0, connect=10.0)) - return patch_openai_with_mcp(AsyncOpenAI(http_client=http_client)) - + async def probe_and_initialize(): + # Try HTTP/2 first (default) client_http2 = patch_openai_with_mcp(AsyncOpenAI()) - logger.info("Probing patched OpenAI client with HTTP/2...") + logger.info("Probing OpenAI client with HTTP/2...") try: + # Probe with a small embedding and short timeout await asyncio.wait_for( client_http2.embeddings.create( - model="text-embedding-3-small", - input=["test"], + model='text-embedding-3-small', + input=['test'] ), - timeout=5.0, + timeout=5.0 ) - logger.info("Patched OpenAI client initialized with HTTP/2 (probe successful)") + logger.info("OpenAI client initialized with HTTP/2 (probe successful)") return client_http2 except (asyncio.TimeoutError, Exception) as probe_error: logger.warning("HTTP/2 probe failed, falling back to HTTP/1.1", error=str(probe_error)) + # Close the HTTP/2 client try: await client_http2.close() except Exception: pass + # Fall back to HTTP/1.1 with explicit timeout settings http_client = httpx.AsyncClient( - http2=False, timeout=httpx.Timeout(60.0, connect=10.0) + http2=False, + timeout=httpx.Timeout(60.0, connect=10.0) ) - client_http1 = patch_openai_with_mcp(AsyncOpenAI(http_client=http_client)) - logger.info("Patched OpenAI client initialized with HTTP/1.1 (fallback)") + client_http1 = patch_openai_with_mcp( + AsyncOpenAI(http_client=http_client) + ) + logger.info("OpenAI client initialized with HTTP/1.1 (fallback)") return client_http1 - def run_builder(skip_probe=False): + def run_probe_in_thread(): + """Run the async probe in a new thread with its own event loop""" loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: - return loop.run_until_complete(build_client(skip_probe=skip_probe)) + return loop.run_until_complete(probe_and_initialize()) finally: loop.close() try: + # Run the probe in a separate thread with its own event loop with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: - future = executor.submit(run_builder, False) - client = future.result(timeout=15 if has_openai_key else 10) - setattr(self, cache_attr, client) - logger.info("Successfully initialized patched client", provider_type=provider_type) - return client + future = executor.submit(run_probe_in_thread) + self._patched_async_client = future.result(timeout=15) + logger.info("Successfully initialized OpenAI client") except Exception as e: - logger.error( - f"Failed to initialize patched client: {e.__class__.__name__}: {str(e)}", - provider_type=provider_type, - ) - raise ValueError( - f"Failed to initialize patched client for {provider_type}: {str(e)}. " - "Please ensure provider credentials are set." - ) + logger.error(f"Failed to initialize OpenAI client: {e.__class__.__name__}: {str(e)}") + raise ValueError(f"Failed to initialize OpenAI client: {str(e)}. Please complete onboarding or set OPENAI_API_KEY environment variable.") + + return self._patched_async_client @property def patched_llm_client(self): - """Patched client for LLM provider.""" - return self._initialize_patched_client("_patched_llm_client", "llm") + """Alias for patched_async_client - for backward compatibility with code expecting separate clients.""" + return self.patched_async_client @property def patched_embedding_client(self): - """Patched client for embedding provider.""" - return self._initialize_patched_client("_patched_embedding_client", "embedding") + """Alias for patched_async_client - for backward compatibility with code expecting separate clients.""" + return self.patched_async_client - @property - def patched_async_client(self): - """Backward-compatibility alias for LLM client.""" - return self.patched_llm_client - - async def refresh_patched_clients(self): - """Reset patched clients so next use picks up updated provider credentials.""" - clients_to_close = [] - with self._client_init_lock: - if self._patched_llm_client is not None: - clients_to_close.append(self._patched_llm_client) - self._patched_llm_client = None - if self._patched_embedding_client is not None: - clients_to_close.append(self._patched_embedding_client) - self._patched_embedding_client = None - - for client in clients_to_close: + async def refresh_patched_client(self): + """Reset patched client so next use picks up updated provider credentials.""" + if self._patched_async_client is not None: try: - await client.close() + await self._patched_async_client.close() + logger.info("Closed patched client for refresh") except Exception as e: logger.warning("Failed to close patched client during refresh", error=str(e)) + finally: + self._patched_async_client = None async def cleanup(self): """Cleanup resources - should be called on application shutdown""" # Close AsyncOpenAI client if it was created - if self._patched_llm_client is not None: + if self._patched_async_client is not None: try: - await self._patched_llm_client.close() - logger.info("Closed LLM patched client") + await self._patched_async_client.close() + logger.info("Closed AsyncOpenAI client") except Exception as e: - logger.error("Failed to close LLM patched client", error=str(e)) + logger.error("Failed to close AsyncOpenAI client", error=str(e)) finally: - self._patched_llm_client = None - - if self._patched_embedding_client is not None: - try: - await self._patched_embedding_client.close() - logger.info("Closed embedding patched client") - except Exception as e: - logger.error("Failed to close embedding patched client", error=str(e)) - finally: - self._patched_embedding_client = None + self._patched_async_client = None # Close Langflow HTTP client if it exists if self.langflow_http_client is not None: diff --git a/src/services/chat_service.py b/src/services/chat_service.py index cb697cf0..e965623c 100644 --- a/src/services/chat_service.py +++ b/src/services/chat_service.py @@ -15,6 +15,7 @@ class ChatService: jwt_token: str = None, previous_response_id: str = None, stream: bool = False, + filter_id: str = None, ): """Handle chat requests using the patched OpenAI client""" if not prompt: @@ -30,6 +31,7 @@ class ChatService: prompt, user_id, previous_response_id=previous_response_id, + filter_id=filter_id, ) else: response_text, response_id = await async_chat( @@ -37,6 +39,7 @@ class ChatService: prompt, user_id, previous_response_id=previous_response_id, + filter_id=filter_id, ) response_data = {"response": response_text} if response_id: @@ -50,6 +53,7 @@ class ChatService: jwt_token: str = None, previous_response_id: str = None, stream: bool = False, + filter_id: str = None, ): """Handle Langflow chat requests""" if not prompt: @@ -147,6 +151,7 @@ class ChatService: user_id, extra_headers=extra_headers, previous_response_id=previous_response_id, + filter_id=filter_id, ) else: from agent import async_langflow_chat @@ -158,6 +163,7 @@ class ChatService: user_id, extra_headers=extra_headers, previous_response_id=previous_response_id, + filter_id=filter_id, ) response_data = {"response": response_text} if response_id: @@ -429,6 +435,7 @@ class ChatService: "previous_response_id": conversation_state.get( "previous_response_id" ), + "filter_id": conversation_state.get("filter_id"), "total_messages": len(messages), "source": "in_memory", } @@ -447,6 +454,7 @@ class ChatService: "created_at": metadata.get("created_at"), "last_activity": metadata.get("last_activity"), "previous_response_id": metadata.get("previous_response_id"), + "filter_id": metadata.get("filter_id"), "total_messages": metadata.get("total_messages", 0), "source": "metadata_only", } @@ -545,6 +553,7 @@ class ChatService: or conversation.get("created_at"), "last_activity": metadata.get("last_activity") or conversation.get("last_activity"), + "filter_id": metadata.get("filter_id"), "total_messages": len(messages), "source": "langflow_enhanced", "langflow_session_id": session_id, diff --git a/src/services/search_service.py b/src/services/search_service.py index 07c1a796..4d0caafe 100644 --- a/src/services/search_service.py +++ b/src/services/search_service.py @@ -147,13 +147,41 @@ class SearchService: attempts = 0 last_exception = None + # Format model name for LiteLLM compatibility + # The patched client routes through LiteLLM for non-OpenAI providers + formatted_model = model_name + + # Skip if already has a provider prefix + if not any(model_name.startswith(prefix + "/") for prefix in ["openai", "ollama", "watsonx", "anthropic"]): + # Detect provider from model name characteristics: + # - Ollama: contains ":" (e.g., "nomic-embed-text:latest") + # - WatsonX: starts with "ibm/" or known third-party models + # - OpenAI: everything else (no prefix needed) + + if ":" in model_name: + # Ollama models use tags with colons + formatted_model = f"ollama/{model_name}" + logger.debug(f"Formatted Ollama model: {model_name} -> {formatted_model}") + elif model_name.startswith("ibm/") or model_name in [ + "intfloat/multilingual-e5-large", + "sentence-transformers/all-minilm-l6-v2" + ]: + # WatsonX embedding models + formatted_model = f"watsonx/{model_name}" + logger.debug(f"Formatted WatsonX model: {model_name} -> {formatted_model}") + # else: OpenAI models don't need a prefix + while attempts < MAX_EMBED_RETRIES: attempts += 1 try: resp = await clients.patched_embedding_client.embeddings.create( - model=model_name, input=[query] + model=formatted_model, input=[query] ) - return model_name, resp.data[0].embedding + # Try to get embedding - some providers return .embedding, others return ['embedding'] + embedding = getattr(resp.data[0], 'embedding', None) + if embedding is None: + embedding = resp.data[0]['embedding'] + return model_name, embedding except Exception as e: last_exception = e if attempts >= MAX_EMBED_RETRIES: