onboarding filter creation + sticky filters
This commit is contained in:
parent
c5e27b636f
commit
0fc0be855c
19 changed files with 571 additions and 296 deletions
|
|
@ -60,9 +60,9 @@ export const useDoclingHealthQuery = (
|
|||
// If healthy, check every 30 seconds; otherwise check every 3 seconds
|
||||
return query.state.data?.status === "healthy" ? 30000 : 3000;
|
||||
},
|
||||
refetchOnWindowFocus: true,
|
||||
refetchOnWindowFocus: false, // Disabled to reduce unnecessary calls on tab switches
|
||||
refetchOnMount: true,
|
||||
staleTime: 30000, // Consider data stale after 25 seconds
|
||||
staleTime: 30000, // Consider data fresh for 30 seconds
|
||||
...options,
|
||||
},
|
||||
queryClient,
|
||||
|
|
|
|||
|
|
@ -51,13 +51,15 @@ export const useGetConversationsQuery = (
|
|||
) => {
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
async function getConversations(): Promise<ChatConversation[]> {
|
||||
async function getConversations(context: { signal?: AbortSignal }): Promise<ChatConversation[]> {
|
||||
try {
|
||||
// Fetch from the selected endpoint only
|
||||
const apiEndpoint =
|
||||
endpoint === "chat" ? "/api/chat/history" : "/api/langflow/history";
|
||||
|
||||
const response = await fetch(apiEndpoint);
|
||||
const response = await fetch(apiEndpoint, {
|
||||
signal: context.signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
console.error(`Failed to fetch conversations: ${response.status}`);
|
||||
|
|
@ -84,6 +86,10 @@ export const useGetConversationsQuery = (
|
|||
|
||||
return conversations;
|
||||
} catch (error) {
|
||||
// Ignore abort errors - these are expected when requests are cancelled
|
||||
if (error instanceof Error && error.name === 'AbortError') {
|
||||
return [];
|
||||
}
|
||||
console.error(`Failed to fetch ${endpoint} conversations:`, error);
|
||||
return [];
|
||||
}
|
||||
|
|
@ -94,8 +100,11 @@ export const useGetConversationsQuery = (
|
|||
queryKey: ["conversations", endpoint, refreshTrigger],
|
||||
placeholderData: (prev) => prev,
|
||||
queryFn: getConversations,
|
||||
staleTime: 0, // Always consider data stale to ensure fresh data on trigger changes
|
||||
staleTime: 5000, // Consider data fresh for 5 seconds to prevent excessive refetching
|
||||
gcTime: 5 * 60 * 1000, // Keep in cache for 5 minutes
|
||||
networkMode: 'always', // Ensure requests can be cancelled
|
||||
refetchOnMount: false, // Don't refetch on every mount
|
||||
refetchOnWindowFocus: false, // Don't refetch when window regains focus
|
||||
...options,
|
||||
},
|
||||
queryClient,
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ export const useGetNudgesQuery = (
|
|||
});
|
||||
}
|
||||
|
||||
async function getNudges(): Promise<Nudge[]> {
|
||||
async function getNudges(context: { signal?: AbortSignal }): Promise<Nudge[]> {
|
||||
try {
|
||||
const requestBody: {
|
||||
filters?: NudgeFilters;
|
||||
|
|
@ -58,6 +58,7 @@ export const useGetNudgesQuery = (
|
|||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(requestBody),
|
||||
signal: context.signal,
|
||||
});
|
||||
const data = await response.json();
|
||||
|
||||
|
|
@ -67,6 +68,10 @@ export const useGetNudgesQuery = (
|
|||
|
||||
return DEFAULT_NUDGES;
|
||||
} catch (error) {
|
||||
// Ignore abort errors - these are expected when requests are cancelled
|
||||
if (error instanceof Error && error.name === 'AbortError') {
|
||||
return DEFAULT_NUDGES;
|
||||
}
|
||||
console.error("Error getting nudges", error);
|
||||
return DEFAULT_NUDGES;
|
||||
}
|
||||
|
|
@ -76,6 +81,10 @@ export const useGetNudgesQuery = (
|
|||
{
|
||||
queryKey: ["nudges", chatId, filters, limit, scoreThreshold],
|
||||
queryFn: getNudges,
|
||||
staleTime: 10000, // Consider data fresh for 10 seconds to prevent rapid refetching
|
||||
networkMode: 'always', // Ensure requests can be cancelled
|
||||
refetchOnMount: false, // Don't refetch on every mount
|
||||
refetchOnWindowFocus: false, // Don't refetch when window regains focus
|
||||
refetchInterval: (query) => {
|
||||
// If data is empty, refetch every 5 seconds
|
||||
const data = query.state.data;
|
||||
|
|
|
|||
|
|
@ -96,9 +96,9 @@ export const useProviderHealthQuery = (
|
|||
// If healthy, check every 30 seconds; otherwise check every 3 seconds
|
||||
return query.state.data?.status === "healthy" ? 30000 : 3000;
|
||||
},
|
||||
refetchOnWindowFocus: true,
|
||||
refetchOnWindowFocus: false, // Disabled to reduce unnecessary calls on tab switches
|
||||
refetchOnMount: true,
|
||||
staleTime: 30000, // Consider data stale after 25 seconds
|
||||
staleTime: 30000, // Consider data fresh for 30 seconds
|
||||
enabled: !!settings?.edited && options?.enabled !== false, // Only run after onboarding is complete
|
||||
...options,
|
||||
},
|
||||
|
|
|
|||
|
|
@ -110,6 +110,13 @@ function ChatPage() {
|
|||
} else {
|
||||
refreshConversationsSilent();
|
||||
}
|
||||
|
||||
// Save filter association for this response
|
||||
if (conversationFilter && typeof window !== "undefined") {
|
||||
const newKey = `conversation_filter_${responseId}`;
|
||||
localStorage.setItem(newKey, conversationFilter.id);
|
||||
console.log("[CHAT] Saved filter association:", newKey, "=", conversationFilter.id);
|
||||
}
|
||||
}
|
||||
},
|
||||
onError: (error) => {
|
||||
|
|
@ -696,11 +703,18 @@ function ChatPage() {
|
|||
// Use passed previousResponseId if available, otherwise fall back to state
|
||||
const responseIdToUse = previousResponseId || previousResponseIds[endpoint];
|
||||
|
||||
console.log("[CHAT] Sending streaming message:", {
|
||||
conversationFilter: conversationFilter?.id,
|
||||
currentConversationId,
|
||||
responseIdToUse,
|
||||
});
|
||||
|
||||
// Use the hook to send the message
|
||||
await sendStreamingMessage({
|
||||
prompt: userMessage.content,
|
||||
previousResponseId: responseIdToUse || undefined,
|
||||
filters: processedFilters,
|
||||
filter_id: conversationFilter?.id, // ✅ Add filter_id for this conversation
|
||||
limit: parsedFilterData?.limit ?? 10,
|
||||
scoreThreshold: parsedFilterData?.scoreThreshold ?? 0,
|
||||
});
|
||||
|
|
@ -781,6 +795,19 @@ function ChatPage() {
|
|||
requestBody.previous_response_id = currentResponseId;
|
||||
}
|
||||
|
||||
// Add filter_id if a filter is selected for this conversation
|
||||
if (conversationFilter) {
|
||||
requestBody.filter_id = conversationFilter.id;
|
||||
}
|
||||
|
||||
// Debug logging
|
||||
console.log("[DEBUG] Sending message with:", {
|
||||
previous_response_id: requestBody.previous_response_id,
|
||||
filter_id: requestBody.filter_id,
|
||||
currentConversationId,
|
||||
previousResponseIds,
|
||||
});
|
||||
|
||||
const response = await fetch(apiEndpoint, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
|
|
@ -804,6 +831,8 @@ function ChatPage() {
|
|||
|
||||
// Store the response ID if present for this endpoint
|
||||
if (result.response_id) {
|
||||
console.log("[DEBUG] Received response_id:", result.response_id, "currentConversationId:", currentConversationId);
|
||||
|
||||
setPreviousResponseIds((prev) => ({
|
||||
...prev,
|
||||
[endpoint]: result.response_id,
|
||||
|
|
@ -811,12 +840,21 @@ function ChatPage() {
|
|||
|
||||
// If this is a new conversation (no currentConversationId), set it now
|
||||
if (!currentConversationId) {
|
||||
console.log("[DEBUG] Setting currentConversationId to:", result.response_id);
|
||||
setCurrentConversationId(result.response_id);
|
||||
refreshConversations(true);
|
||||
} else {
|
||||
console.log("[DEBUG] Existing conversation, doing silent refresh");
|
||||
// For existing conversations, do a silent refresh to keep backend in sync
|
||||
refreshConversationsSilent();
|
||||
}
|
||||
|
||||
// Carry forward the filter association to the new response_id
|
||||
if (conversationFilter && typeof window !== "undefined") {
|
||||
const newKey = `conversation_filter_${result.response_id}`;
|
||||
localStorage.setItem(newKey, conversationFilter.id);
|
||||
console.log("[DEBUG] Saved filter association:", newKey, "=", conversationFilter.id);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
console.error("Chat failed:", result.error);
|
||||
|
|
|
|||
|
|
@ -209,6 +209,16 @@ const OnboardingCard = ({
|
|||
const onboardingMutation = useOnboardingMutation({
|
||||
onSuccess: (data) => {
|
||||
console.log("Onboarding completed successfully", data);
|
||||
|
||||
// Save OpenRAG docs filter ID if sample data was ingested
|
||||
if (data.openrag_docs_filter_id && typeof window !== "undefined") {
|
||||
localStorage.setItem(
|
||||
"onboarding_openrag_docs_filter_id",
|
||||
data.openrag_docs_filter_id
|
||||
);
|
||||
console.log("Saved OpenRAG docs filter ID:", data.openrag_docs_filter_id);
|
||||
}
|
||||
|
||||
// Update provider health cache to healthy since backend just validated
|
||||
const provider =
|
||||
(isEmbedding ? settings.embedding_provider : settings.llm_provider) ||
|
||||
|
|
|
|||
|
|
@ -2,11 +2,13 @@
|
|||
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { StickToBottom } from "use-stick-to-bottom";
|
||||
import { getFilterById } from "@/app/api/queries/useGetFilterByIdQuery";
|
||||
import { AssistantMessage } from "@/app/chat/_components/assistant-message";
|
||||
import Nudges from "@/app/chat/_components/nudges";
|
||||
import { UserMessage } from "@/app/chat/_components/user-message";
|
||||
import type { Message, SelectedFilters } from "@/app/chat/_types/types";
|
||||
import OnboardingCard from "@/app/onboarding/_components/onboarding-card";
|
||||
import { useChat } from "@/contexts/chat-context";
|
||||
import { useChatStreaming } from "@/hooks/useChatStreaming";
|
||||
import {
|
||||
ONBOARDING_ASSISTANT_MESSAGE_KEY,
|
||||
|
|
@ -33,6 +35,7 @@ export function OnboardingContent({
|
|||
handleStepBack: () => void;
|
||||
currentStep: number;
|
||||
}) {
|
||||
const { setConversationFilter, setCurrentConversationId } = useChat();
|
||||
const parseFailedRef = useRef(false);
|
||||
const [responseId, setResponseId] = useState<string | null>(null);
|
||||
const [selectedNudge, setSelectedNudge] = useState<string>(() => {
|
||||
|
|
@ -78,7 +81,7 @@ export function OnboardingContent({
|
|||
}, [handleStepBack, currentStep]);
|
||||
|
||||
const { streamingMessage, isLoading, sendMessage } = useChatStreaming({
|
||||
onComplete: (message, newResponseId) => {
|
||||
onComplete: async (message, newResponseId) => {
|
||||
setAssistantMessage(message);
|
||||
// Save assistant message to localStorage when complete
|
||||
if (typeof window !== "undefined") {
|
||||
|
|
@ -96,6 +99,26 @@ export function OnboardingContent({
|
|||
}
|
||||
if (newResponseId) {
|
||||
setResponseId(newResponseId);
|
||||
|
||||
// Set the current conversation ID
|
||||
setCurrentConversationId(newResponseId);
|
||||
|
||||
// Save the filter association for this conversation
|
||||
const openragDocsFilterId = localStorage.getItem(ONBOARDING_OPENRAG_DOCS_FILTER_ID_KEY);
|
||||
if (openragDocsFilterId) {
|
||||
try {
|
||||
// Load the filter and set it in the context with explicit responseId
|
||||
// This ensures the filter is saved to localStorage with the correct conversation ID
|
||||
const filter = await getFilterById(openragDocsFilterId);
|
||||
if (filter) {
|
||||
// Pass explicit newResponseId to ensure correct localStorage association
|
||||
setConversationFilter(filter, newResponseId);
|
||||
console.log("[ONBOARDING] Saved filter association:", `conversation_filter_${newResponseId}`, "=", openragDocsFilterId);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to associate filter with conversation:", error);
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
onError: (error) => {
|
||||
|
|
@ -124,15 +147,35 @@ export function OnboardingContent({
|
|||
}
|
||||
setTimeout(async () => {
|
||||
// Check if we have the OpenRAG docs filter ID (sample data was ingested)
|
||||
const hasOpenragDocsFilter =
|
||||
typeof window !== "undefined" &&
|
||||
localStorage.getItem(ONBOARDING_OPENRAG_DOCS_FILTER_ID_KEY);
|
||||
const openragDocsFilterId =
|
||||
typeof window !== "undefined"
|
||||
? localStorage.getItem(ONBOARDING_OPENRAG_DOCS_FILTER_ID_KEY)
|
||||
: null;
|
||||
|
||||
// Load and set the OpenRAG docs filter if available
|
||||
let filterToUse = null;
|
||||
console.log("[ONBOARDING] openragDocsFilterId:", openragDocsFilterId);
|
||||
if (openragDocsFilterId) {
|
||||
try {
|
||||
const filter = await getFilterById(openragDocsFilterId);
|
||||
console.log("[ONBOARDING] Loaded filter:", filter);
|
||||
if (filter) {
|
||||
// Pass null to skip localStorage save - no conversation exists yet
|
||||
setConversationFilter(filter, null);
|
||||
filterToUse = filter;
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to load OpenRAG docs filter:", error);
|
||||
}
|
||||
}
|
||||
|
||||
console.log("[ONBOARDING] Sending message with filter_id:", filterToUse?.id);
|
||||
await sendMessage({
|
||||
prompt: nudge,
|
||||
previousResponseId: responseId || undefined,
|
||||
// Use OpenRAG docs filter if sample data was ingested
|
||||
filters: hasOpenragDocsFilter ? OPENRAG_DOCS_FILTERS : undefined,
|
||||
// Send both filter_id and filters (selections)
|
||||
filter_id: filterToUse?.id,
|
||||
filters: openragDocsFilterId ? OPENRAG_DOCS_FILTERS : undefined,
|
||||
});
|
||||
}, 1500);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
import { AnimatePresence, motion } from "motion/react";
|
||||
import { type ChangeEvent, useEffect, useRef, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
import { useCreateFilter } from "@/app/api/mutations/useCreateFilter";
|
||||
import { useGetNudgesQuery } from "@/app/api/queries/useGetNudgesQuery";
|
||||
import { useGetTasksQuery } from "@/app/api/queries/useGetTasksQuery";
|
||||
import { AnimatedProviderSteps } from "@/app/onboarding/_components/animated-provider-steps";
|
||||
|
|
@ -18,6 +20,11 @@ const OnboardingUpload = ({ onComplete }: OnboardingUploadProps) => {
|
|||
const fileInputRef = useRef<HTMLInputElement>(null);
|
||||
const [isUploading, setIsUploading] = useState(false);
|
||||
const [currentStep, setCurrentStep] = useState<number | null>(null);
|
||||
const [uploadedFilename, setUploadedFilename] = useState<string | null>(null);
|
||||
const [shouldCreateFilter, setShouldCreateFilter] = useState(false);
|
||||
const [isCreatingFilter, setIsCreatingFilter] = useState(false);
|
||||
|
||||
const createFilterMutation = useCreateFilter();
|
||||
|
||||
const STEP_LIST = [
|
||||
"Uploading your document",
|
||||
|
|
@ -56,6 +63,60 @@ const OnboardingUpload = ({ onComplete }: OnboardingUploadProps) => {
|
|||
// Set to final step to show "Done"
|
||||
setCurrentStep(STEP_LIST.length);
|
||||
|
||||
// Create knowledge filter for uploaded document if requested
|
||||
// Guard against race condition: only create if not already creating
|
||||
if (shouldCreateFilter && uploadedFilename && !isCreatingFilter) {
|
||||
// Reset flags immediately (synchronously) to prevent duplicate creation
|
||||
setShouldCreateFilter(false);
|
||||
const filename = uploadedFilename;
|
||||
setUploadedFilename(null);
|
||||
setIsCreatingFilter(true);
|
||||
|
||||
// Get display name from filename (remove extension for cleaner name)
|
||||
const displayName = filename.includes(".")
|
||||
? filename.substring(0, filename.lastIndexOf("."))
|
||||
: filename;
|
||||
|
||||
const queryData = JSON.stringify({
|
||||
query: "",
|
||||
filters: {
|
||||
data_sources: [filename],
|
||||
document_types: ["*"],
|
||||
owners: ["*"],
|
||||
connector_types: ["*"],
|
||||
},
|
||||
limit: 10,
|
||||
scoreThreshold: 0,
|
||||
color: "green",
|
||||
icon: "file",
|
||||
});
|
||||
|
||||
createFilterMutation
|
||||
.mutateAsync({
|
||||
name: displayName,
|
||||
description: `Filter for ${filename}`,
|
||||
queryData: queryData,
|
||||
})
|
||||
.then((result) => {
|
||||
if (result.filter?.id && typeof window !== "undefined") {
|
||||
localStorage.setItem(
|
||||
ONBOARDING_USER_DOC_FILTER_ID_KEY,
|
||||
result.filter.id,
|
||||
);
|
||||
console.log(
|
||||
"Created knowledge filter for uploaded document",
|
||||
result.filter.id,
|
||||
);
|
||||
}
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error("Failed to create knowledge filter:", error);
|
||||
})
|
||||
.finally(() => {
|
||||
setIsCreatingFilter(false);
|
||||
});
|
||||
}
|
||||
|
||||
// Refetch nudges to get new ones
|
||||
refetchNudges();
|
||||
|
||||
|
|
@ -64,7 +125,7 @@ const OnboardingUpload = ({ onComplete }: OnboardingUploadProps) => {
|
|||
onComplete();
|
||||
}, 1000);
|
||||
}
|
||||
}, [tasks, currentStep, onComplete, refetchNudges]);
|
||||
}, [tasks, currentStep, onComplete, refetchNudges, shouldCreateFilter, uploadedFilename]);
|
||||
|
||||
const resetFileInput = () => {
|
||||
if (fileInputRef.current) {
|
||||
|
|
@ -83,12 +144,10 @@ const OnboardingUpload = ({ onComplete }: OnboardingUploadProps) => {
|
|||
const result = await uploadFile(file, true, true); // Pass createFilter=true
|
||||
console.log("Document upload task started successfully");
|
||||
|
||||
// Store user document filter ID if returned
|
||||
if (result.userDocFilterId && typeof window !== "undefined") {
|
||||
localStorage.setItem(
|
||||
ONBOARDING_USER_DOC_FILTER_ID_KEY,
|
||||
result.userDocFilterId
|
||||
);
|
||||
// Store filename and createFilter flag in state to create filter after ingestion succeeds
|
||||
if (result.createFilter && result.filename) {
|
||||
setUploadedFilename(result.filename);
|
||||
setShouldCreateFilter(true);
|
||||
}
|
||||
|
||||
// Move to processing step - task monitoring will handle completion
|
||||
|
|
@ -96,7 +155,15 @@ const OnboardingUpload = ({ onComplete }: OnboardingUploadProps) => {
|
|||
setCurrentStep(1);
|
||||
}, 1500);
|
||||
} catch (error) {
|
||||
console.error("Upload failed", (error as Error).message);
|
||||
const errorMessage = error instanceof Error ? error.message : "Upload failed";
|
||||
console.error("Upload failed", errorMessage);
|
||||
|
||||
// Show error toast notification
|
||||
toast.error("Document upload failed", {
|
||||
description: errorMessage,
|
||||
duration: 5000,
|
||||
});
|
||||
|
||||
// Reset on error
|
||||
setCurrentStep(null);
|
||||
} finally {
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"use client";
|
||||
|
||||
import { motion } from "framer-motion";
|
||||
import { usePathname } from "next/navigation";
|
||||
import { usePathname, useRouter } from "next/navigation";
|
||||
import { useCallback, useEffect, useState } from "react";
|
||||
import {
|
||||
type ChatConversation,
|
||||
|
|
@ -39,6 +39,7 @@ export function ChatRenderer({
|
|||
children: React.ReactNode;
|
||||
}) {
|
||||
const pathname = usePathname();
|
||||
const router = useRouter();
|
||||
const { isAuthenticated, isNoAuthMode } = useAuth();
|
||||
const {
|
||||
endpoint,
|
||||
|
|
@ -46,6 +47,8 @@ export function ChatRenderer({
|
|||
refreshConversations,
|
||||
startNewConversation,
|
||||
setConversationFilter,
|
||||
setCurrentConversationId,
|
||||
setPreviousResponseIds,
|
||||
} = useChat();
|
||||
|
||||
// Initialize onboarding state based on local storage and settings
|
||||
|
|
@ -75,38 +78,74 @@ export function ChatRenderer({
|
|||
startNewConversation();
|
||||
};
|
||||
|
||||
// Helper to set the default filter after onboarding transition
|
||||
const setOnboardingFilter = useCallback(
|
||||
// Navigate to /chat when onboarding is active so animation reveals chat underneath
|
||||
useEffect(() => {
|
||||
if (!showLayout && pathname !== "/chat" && pathname !== "/") {
|
||||
router.push("/chat");
|
||||
}
|
||||
}, [showLayout, pathname, router]);
|
||||
|
||||
// Helper to store default filter ID for new conversations after onboarding
|
||||
const storeDefaultFilterForNewConversations = useCallback(
|
||||
async (preferUserDoc: boolean) => {
|
||||
if (typeof window === "undefined") return;
|
||||
|
||||
// Check if we already have a default filter set
|
||||
const existingDefault = localStorage.getItem("default_conversation_filter_id");
|
||||
if (existingDefault) {
|
||||
console.log("[FILTER] Default filter already set:", existingDefault);
|
||||
// Try to apply it to context state (don't save to localStorage to avoid overwriting)
|
||||
try {
|
||||
const filter = await getFilterById(existingDefault);
|
||||
if (filter) {
|
||||
// Pass null to skip localStorage save
|
||||
setConversationFilter(filter, null);
|
||||
return; // Successfully loaded and set, we're done
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to load existing default filter, will set new one:", error);
|
||||
// Filter doesn't exist anymore, clear it and continue to set a new one
|
||||
localStorage.removeItem("default_conversation_filter_id");
|
||||
}
|
||||
}
|
||||
|
||||
// Try to get the appropriate filter ID
|
||||
let filterId: string | null = null;
|
||||
|
||||
if (preferUserDoc) {
|
||||
// Completed full onboarding - prefer user document filter
|
||||
filterId = localStorage.getItem(ONBOARDING_USER_DOC_FILTER_ID_KEY);
|
||||
console.log("[FILTER] User doc filter ID:", filterId);
|
||||
}
|
||||
|
||||
// Fall back to OpenRAG docs filter
|
||||
if (!filterId) {
|
||||
filterId = localStorage.getItem(ONBOARDING_OPENRAG_DOCS_FILTER_ID_KEY);
|
||||
console.log("[FILTER] OpenRAG docs filter ID:", filterId);
|
||||
}
|
||||
|
||||
console.log("[FILTER] Final filter ID to use:", filterId);
|
||||
|
||||
if (filterId) {
|
||||
// Store this as the default filter for new conversations
|
||||
localStorage.setItem("default_conversation_filter_id", filterId);
|
||||
|
||||
// Apply filter to context state only (don't save to localStorage since there's no conversation yet)
|
||||
// The default_conversation_filter_id will be used when a new conversation is started
|
||||
try {
|
||||
const filter = await getFilterById(filterId);
|
||||
console.log("[FILTER] Loaded filter:", filter);
|
||||
if (filter) {
|
||||
setConversationFilter(filter);
|
||||
// Pass null to skip localStorage save - this prevents overwriting existing conversation filters
|
||||
setConversationFilter(filter, null);
|
||||
console.log("[FILTER] Set conversation filter (no save):", filter.id);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to set onboarding filter:", error);
|
||||
}
|
||||
} else {
|
||||
console.log("[FILTER] No filter ID found, not setting default");
|
||||
}
|
||||
|
||||
// Clean up onboarding filter IDs from localStorage
|
||||
localStorage.removeItem(ONBOARDING_OPENRAG_DOCS_FILTER_ID_KEY);
|
||||
localStorage.removeItem(ONBOARDING_USER_DOC_FILTER_ID_KEY);
|
||||
},
|
||||
[setConversationFilter]
|
||||
);
|
||||
|
|
@ -118,7 +157,7 @@ export function ChatRenderer({
|
|||
}
|
||||
}, [currentStep, showLayout]);
|
||||
|
||||
const handleStepComplete = () => {
|
||||
const handleStepComplete = async () => {
|
||||
if (currentStep < TOTAL_ONBOARDING_STEPS - 1) {
|
||||
setCurrentStep(currentStep + 1);
|
||||
} else {
|
||||
|
|
@ -130,8 +169,20 @@ export function ChatRenderer({
|
|||
localStorage.removeItem(ONBOARDING_CARD_STEPS_KEY);
|
||||
localStorage.removeItem(ONBOARDING_UPLOAD_STEPS_KEY);
|
||||
}
|
||||
// Set the user document filter as active (completed full onboarding)
|
||||
setOnboardingFilter(true);
|
||||
|
||||
// Clear ALL conversation state so next message starts fresh
|
||||
await startNewConversation();
|
||||
|
||||
// Store the user document filter as default for new conversations and load it
|
||||
await storeDefaultFilterForNewConversations(true);
|
||||
|
||||
// Clean up onboarding filter IDs now that we've set the default
|
||||
if (typeof window !== "undefined") {
|
||||
localStorage.removeItem(ONBOARDING_OPENRAG_DOCS_FILTER_ID_KEY);
|
||||
localStorage.removeItem(ONBOARDING_USER_DOC_FILTER_ID_KEY);
|
||||
console.log("[FILTER] Cleaned up onboarding filter IDs");
|
||||
}
|
||||
|
||||
setShowLayout(true);
|
||||
}
|
||||
};
|
||||
|
|
@ -151,8 +202,8 @@ export function ChatRenderer({
|
|||
localStorage.removeItem(ONBOARDING_CARD_STEPS_KEY);
|
||||
localStorage.removeItem(ONBOARDING_UPLOAD_STEPS_KEY);
|
||||
}
|
||||
// Set the OpenRAG docs filter as active (skipped onboarding - no user doc)
|
||||
setOnboardingFilter(false);
|
||||
// Store the OpenRAG docs filter as default for new conversations
|
||||
storeDefaultFilterForNewConversations(false);
|
||||
setShowLayout(true);
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -289,7 +289,7 @@ export function Navigation({
|
|||
handleNewConversation();
|
||||
} else if (activeConvo) {
|
||||
loadConversation(activeConvo);
|
||||
refreshConversations();
|
||||
// Don't call refreshConversations here - it causes unnecessary refetches
|
||||
} else if (
|
||||
conversations.length > 0 &&
|
||||
currentConversationId === null &&
|
||||
|
|
@ -473,7 +473,7 @@ export function Navigation({
|
|||
onClick={() => {
|
||||
if (loading || isConversationsLoading) return;
|
||||
loadConversation(conversation);
|
||||
refreshConversations();
|
||||
// Don't refresh - just loading an existing conversation
|
||||
}}
|
||||
disabled={loading || isConversationsLoading}
|
||||
>
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ interface ChatContextType {
|
|||
refreshConversationsSilent: () => Promise<void>;
|
||||
refreshTrigger: number;
|
||||
refreshTriggerSilent: number;
|
||||
loadConversation: (conversation: ConversationData) => void;
|
||||
loadConversation: (conversation: ConversationData) => Promise<void>;
|
||||
startNewConversation: () => void;
|
||||
conversationData: ConversationData | null;
|
||||
forkFromResponse: (responseId: string) => void;
|
||||
|
|
@ -77,7 +77,8 @@ interface ChatContextType {
|
|||
conversationLoaded: boolean;
|
||||
setConversationLoaded: (loaded: boolean) => void;
|
||||
conversationFilter: KnowledgeFilter | null;
|
||||
setConversationFilter: (filter: KnowledgeFilter | null) => void;
|
||||
// responseId: undefined = use currentConversationId, null = don't save to localStorage
|
||||
setConversationFilter: (filter: KnowledgeFilter | null, responseId?: string | null) => void;
|
||||
}
|
||||
|
||||
const ChatContext = createContext<ChatContextType | undefined>(undefined);
|
||||
|
|
@ -112,6 +113,8 @@ export function ChatProvider({ children }: ChatProviderProps) {
|
|||
const refreshTimeoutRef = useRef<NodeJS.Timeout | null>(null);
|
||||
|
||||
const refreshConversations = useCallback((force = false) => {
|
||||
console.log("[REFRESH] refreshConversations called, force:", force);
|
||||
|
||||
if (force) {
|
||||
// Immediate refresh for important updates like new conversations
|
||||
setRefreshTrigger((prev) => prev + 1);
|
||||
|
|
@ -145,22 +148,59 @@ export function ChatProvider({ children }: ChatProviderProps) {
|
|||
}, []);
|
||||
|
||||
const loadConversation = useCallback(
|
||||
(conversation: ConversationData) => {
|
||||
async (conversation: ConversationData) => {
|
||||
console.log("[CONVERSATION] Loading conversation:", {
|
||||
conversationId: conversation.response_id,
|
||||
title: conversation.title,
|
||||
endpoint: conversation.endpoint,
|
||||
});
|
||||
|
||||
setCurrentConversationId(conversation.response_id);
|
||||
setEndpoint(conversation.endpoint);
|
||||
// Store the full conversation data for the chat page to use
|
||||
setConversationData(conversation);
|
||||
|
||||
// Load the filter if one exists for this conversation
|
||||
// Only update the filter if this is a different conversation (to preserve user's filter selection)
|
||||
setConversationFilterState((currentFilter) => {
|
||||
// If we're loading a different conversation, load its filter
|
||||
// Otherwise keep the current filter (don't reset it when conversation refreshes)
|
||||
const isDifferentConversation =
|
||||
conversation.response_id !== conversationData?.response_id;
|
||||
return isDifferentConversation
|
||||
? conversation.filter || null
|
||||
: currentFilter;
|
||||
});
|
||||
// Always update the filter to match the conversation being loaded
|
||||
const isDifferentConversation =
|
||||
conversation.response_id !== conversationData?.response_id;
|
||||
|
||||
if (isDifferentConversation && typeof window !== "undefined") {
|
||||
// Try to load the saved filter from localStorage
|
||||
const savedFilterId = localStorage.getItem(`conversation_filter_${conversation.response_id}`);
|
||||
console.log("[CONVERSATION] Looking for filter:", {
|
||||
conversationId: conversation.response_id,
|
||||
savedFilterId,
|
||||
});
|
||||
|
||||
if (savedFilterId) {
|
||||
// Import getFilterById dynamically to avoid circular dependency
|
||||
const { getFilterById } = await import("@/app/api/queries/useGetFilterByIdQuery");
|
||||
try {
|
||||
const filter = await getFilterById(savedFilterId);
|
||||
|
||||
if (filter) {
|
||||
console.log("[CONVERSATION] Loaded filter:", filter.name, filter.id);
|
||||
setConversationFilterState(filter);
|
||||
// Update conversation data with the loaded filter
|
||||
setConversationData((prev) => {
|
||||
if (!prev) return prev;
|
||||
return { ...prev, filter };
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("[CONVERSATION] Failed to load filter:", error);
|
||||
// Filter was deleted, clean up localStorage
|
||||
localStorage.removeItem(`conversation_filter_${conversation.response_id}`);
|
||||
setConversationFilterState(null);
|
||||
}
|
||||
} else {
|
||||
// No saved filter in localStorage, clear the current filter
|
||||
console.log("[CONVERSATION] No filter found for this conversation");
|
||||
setConversationFilterState(null);
|
||||
}
|
||||
}
|
||||
|
||||
// Clear placeholder when loading a real conversation
|
||||
setPlaceholderConversation(null);
|
||||
setConversationLoaded(true);
|
||||
|
|
@ -170,15 +210,48 @@ export function ChatProvider({ children }: ChatProviderProps) {
|
|||
[conversationData?.response_id],
|
||||
);
|
||||
|
||||
const startNewConversation = useCallback(() => {
|
||||
const startNewConversation = useCallback(async () => {
|
||||
console.log("[CONVERSATION] Starting new conversation");
|
||||
|
||||
// Clear current conversation data and reset state
|
||||
setCurrentConversationId(null);
|
||||
setPreviousResponseIds({ chat: null, langflow: null });
|
||||
setConversationData(null);
|
||||
setConversationDocs([]);
|
||||
setConversationLoaded(false);
|
||||
// Clear the filter when starting a new conversation
|
||||
setConversationFilterState(null);
|
||||
|
||||
// Load default filter if available (and clear it after first use)
|
||||
if (typeof window !== "undefined") {
|
||||
const defaultFilterId = localStorage.getItem("default_conversation_filter_id");
|
||||
console.log("[CONVERSATION] Default filter ID:", defaultFilterId);
|
||||
|
||||
if (defaultFilterId) {
|
||||
// Clear the default filter now so it's only used once
|
||||
localStorage.removeItem("default_conversation_filter_id");
|
||||
console.log("[CONVERSATION] Cleared default filter (used once)");
|
||||
|
||||
try {
|
||||
const { getFilterById } = await import("@/app/api/queries/useGetFilterByIdQuery");
|
||||
const filter = await getFilterById(defaultFilterId);
|
||||
|
||||
if (filter) {
|
||||
console.log("[CONVERSATION] Loaded default filter:", filter.name, filter.id);
|
||||
setConversationFilterState(filter);
|
||||
} else {
|
||||
// Default filter was deleted
|
||||
setConversationFilterState(null);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("[CONVERSATION] Failed to load default filter:", error);
|
||||
setConversationFilterState(null);
|
||||
}
|
||||
} else {
|
||||
console.log("[CONVERSATION] No default filter set");
|
||||
setConversationFilterState(null);
|
||||
}
|
||||
} else {
|
||||
setConversationFilterState(null);
|
||||
}
|
||||
|
||||
// Create a temporary placeholder conversation to show in sidebar
|
||||
const placeholderConversation: ConversationData = {
|
||||
|
|
@ -230,7 +303,7 @@ export function ChatProvider({ children }: ChatProviderProps) {
|
|||
);
|
||||
|
||||
const setConversationFilter = useCallback(
|
||||
(filter: KnowledgeFilter | null) => {
|
||||
(filter: KnowledgeFilter | null, responseId?: string | null) => {
|
||||
setConversationFilterState(filter);
|
||||
// Update the conversation data to include the filter
|
||||
setConversationData((prev) => {
|
||||
|
|
@ -240,8 +313,24 @@ export function ChatProvider({ children }: ChatProviderProps) {
|
|||
filter,
|
||||
};
|
||||
});
|
||||
|
||||
// Determine which conversation ID to use for saving
|
||||
// - undefined: use currentConversationId (default behavior)
|
||||
// - null: explicitly skip saving to localStorage
|
||||
// - string: use the provided responseId
|
||||
const targetId = responseId === undefined ? currentConversationId : responseId;
|
||||
|
||||
// Save filter association for the target conversation
|
||||
if (typeof window !== "undefined" && targetId) {
|
||||
const key = `conversation_filter_${targetId}`;
|
||||
if (filter) {
|
||||
localStorage.setItem(key, filter.id);
|
||||
} else {
|
||||
localStorage.removeItem(key);
|
||||
}
|
||||
}
|
||||
},
|
||||
[],
|
||||
[currentConversationId],
|
||||
);
|
||||
|
||||
const value = useMemo<ChatContextType>(
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ interface SendMessageOptions {
|
|||
prompt: string;
|
||||
previousResponseId?: string;
|
||||
filters?: SelectedFilters;
|
||||
filter_id?: string;
|
||||
limit?: number;
|
||||
scoreThreshold?: number;
|
||||
}
|
||||
|
|
@ -35,6 +36,7 @@ export function useChatStreaming({
|
|||
prompt,
|
||||
previousResponseId,
|
||||
filters,
|
||||
filter_id,
|
||||
limit = 10,
|
||||
scoreThreshold = 0,
|
||||
}: SendMessageOptions) => {
|
||||
|
|
@ -73,6 +75,7 @@ export function useChatStreaming({
|
|||
stream: boolean;
|
||||
previous_response_id?: string;
|
||||
filters?: SelectedFilters;
|
||||
filter_id?: string;
|
||||
limit?: number;
|
||||
scoreThreshold?: number;
|
||||
} = {
|
||||
|
|
@ -90,6 +93,12 @@ export function useChatStreaming({
|
|||
requestBody.filters = filters;
|
||||
}
|
||||
|
||||
if (filter_id) {
|
||||
requestBody.filter_id = filter_id;
|
||||
}
|
||||
|
||||
console.log("[useChatStreaming] Sending request:", { filter_id, requestBody });
|
||||
|
||||
const response = await fetch(endpoint, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
|
|
|
|||
48
src/agent.py
48
src/agent.py
|
|
@ -1,3 +1,5 @@
|
|||
from http.client import HTTPException
|
||||
|
||||
from utils.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
|
@ -67,6 +69,7 @@ def store_conversation_thread(user_id: str, response_id: str, conversation_state
|
|||
"created_at": conversation_state.get("created_at"),
|
||||
"last_activity": conversation_state.get("last_activity"),
|
||||
"previous_response_id": conversation_state.get("previous_response_id"),
|
||||
"filter_id": conversation_state.get("filter_id"),
|
||||
"total_messages": len(
|
||||
[msg for msg in messages if msg.get("role") in ["user", "assistant"]]
|
||||
),
|
||||
|
|
@ -219,15 +222,26 @@ async def async_response(
|
|||
|
||||
response = await client.responses.create(**request_params)
|
||||
|
||||
response_text = response.output_text
|
||||
logger.info("Response generated", log_prefix=log_prefix, response=response_text)
|
||||
# Check if response has output_text using getattr to avoid issues with special objects
|
||||
output_text = getattr(response, "output_text", None)
|
||||
if output_text is not None:
|
||||
response_text = output_text
|
||||
logger.info("Response generated", log_prefix=log_prefix, response=response_text)
|
||||
|
||||
# Extract and store response_id if available
|
||||
response_id = getattr(response, "id", None) or getattr(
|
||||
response, "response_id", None
|
||||
)
|
||||
# Extract and store response_id if available
|
||||
response_id = getattr(response, "id", None) or getattr(
|
||||
response, "response_id", None
|
||||
)
|
||||
|
||||
return response_text, response_id, response
|
||||
return response_text, response_id, response
|
||||
else:
|
||||
msg = "Nudge response missing output_text"
|
||||
error = getattr(response, "error", None)
|
||||
if error:
|
||||
error_msg = getattr(error, "message", None)
|
||||
if error_msg:
|
||||
msg = error_msg
|
||||
raise ValueError(msg)
|
||||
except Exception as e:
|
||||
logger.error("Exception in non-streaming response", error=str(e))
|
||||
import traceback
|
||||
|
|
@ -314,6 +328,7 @@ async def async_chat(
|
|||
user_id: str,
|
||||
model: str = "gpt-4.1-mini",
|
||||
previous_response_id: str = None,
|
||||
filter_id: str = None,
|
||||
):
|
||||
logger.debug(
|
||||
"async_chat called", user_id=user_id, previous_response_id=previous_response_id
|
||||
|
|
@ -334,6 +349,10 @@ async def async_chat(
|
|||
"Added user message", message_count=len(conversation_state["messages"])
|
||||
)
|
||||
|
||||
# Store filter_id in conversation state if provided
|
||||
if filter_id:
|
||||
conversation_state["filter_id"] = filter_id
|
||||
|
||||
response_text, response_id, response_obj = await async_response(
|
||||
async_client,
|
||||
prompt,
|
||||
|
|
@ -389,6 +408,7 @@ async def async_chat_stream(
|
|||
user_id: str,
|
||||
model: str = "gpt-4.1-mini",
|
||||
previous_response_id: str = None,
|
||||
filter_id: str = None,
|
||||
):
|
||||
# Get the specific conversation thread (or create new one)
|
||||
conversation_state = get_conversation_thread(user_id, previous_response_id)
|
||||
|
|
@ -399,6 +419,10 @@ async def async_chat_stream(
|
|||
user_message = {"role": "user", "content": prompt, "timestamp": datetime.now()}
|
||||
conversation_state["messages"].append(user_message)
|
||||
|
||||
# Store filter_id in conversation state if provided
|
||||
if filter_id:
|
||||
conversation_state["filter_id"] = filter_id
|
||||
|
||||
full_response = ""
|
||||
response_id = None
|
||||
async for chunk in async_stream(
|
||||
|
|
@ -452,6 +476,7 @@ async def async_langflow_chat(
|
|||
extra_headers: dict = None,
|
||||
previous_response_id: str = None,
|
||||
store_conversation: bool = True,
|
||||
filter_id: str = None,
|
||||
):
|
||||
logger.debug(
|
||||
"async_langflow_chat called",
|
||||
|
|
@ -478,6 +503,10 @@ async def async_langflow_chat(
|
|||
message_count=len(conversation_state["messages"]),
|
||||
)
|
||||
|
||||
# Store filter_id in conversation state if provided
|
||||
if filter_id:
|
||||
conversation_state["filter_id"] = filter_id
|
||||
|
||||
response_text, response_id, response_obj = await async_response(
|
||||
langflow_client,
|
||||
prompt,
|
||||
|
|
@ -562,6 +591,7 @@ async def async_langflow_chat_stream(
|
|||
user_id: str,
|
||||
extra_headers: dict = None,
|
||||
previous_response_id: str = None,
|
||||
filter_id: str = None,
|
||||
):
|
||||
logger.debug(
|
||||
"async_langflow_chat_stream called",
|
||||
|
|
@ -578,6 +608,10 @@ async def async_langflow_chat_stream(
|
|||
user_message = {"role": "user", "content": prompt, "timestamp": datetime.now()}
|
||||
conversation_state["messages"].append(user_message)
|
||||
|
||||
# Store filter_id in conversation state if provided
|
||||
if filter_id:
|
||||
conversation_state["filter_id"] = filter_id
|
||||
|
||||
full_response = ""
|
||||
response_id = None
|
||||
collected_chunks = [] # Store all chunks for function call data
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ async def chat_endpoint(request: Request, chat_service, session_manager):
|
|||
filters = data.get("filters")
|
||||
limit = data.get("limit", 10)
|
||||
score_threshold = data.get("scoreThreshold", 0)
|
||||
filter_id = data.get("filter_id")
|
||||
|
||||
user = request.state.user
|
||||
user_id = user.user_id
|
||||
|
|
@ -42,6 +43,7 @@ async def chat_endpoint(request: Request, chat_service, session_manager):
|
|||
jwt_token,
|
||||
previous_response_id=previous_response_id,
|
||||
stream=True,
|
||||
filter_id=filter_id,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
|
|
@ -58,6 +60,7 @@ async def chat_endpoint(request: Request, chat_service, session_manager):
|
|||
jwt_token,
|
||||
previous_response_id=previous_response_id,
|
||||
stream=False,
|
||||
filter_id=filter_id,
|
||||
)
|
||||
return JSONResponse(result)
|
||||
|
||||
|
|
@ -71,6 +74,7 @@ async def langflow_endpoint(request: Request, chat_service, session_manager):
|
|||
filters = data.get("filters")
|
||||
limit = data.get("limit", 10)
|
||||
score_threshold = data.get("scoreThreshold", 0)
|
||||
filter_id = data.get("filter_id")
|
||||
|
||||
user = request.state.user
|
||||
user_id = user.user_id
|
||||
|
|
@ -100,6 +104,7 @@ async def langflow_endpoint(request: Request, chat_service, session_manager):
|
|||
jwt_token,
|
||||
previous_response_id=previous_response_id,
|
||||
stream=True,
|
||||
filter_id=filter_id,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
|
|
@ -116,6 +121,7 @@ async def langflow_endpoint(request: Request, chat_service, session_manager):
|
|||
jwt_token,
|
||||
previous_response_id=previous_response_id,
|
||||
stream=False,
|
||||
filter_id=filter_id,
|
||||
)
|
||||
return JSONResponse(result)
|
||||
|
||||
|
|
|
|||
|
|
@ -179,34 +179,13 @@ async def langflow_upload_ingest_task(
|
|||
|
||||
logger.debug("Langflow upload task created successfully", task_id=task_id)
|
||||
|
||||
# Create knowledge filter for the uploaded document if requested
|
||||
user_doc_filter_id = None
|
||||
if create_filter and len(original_filenames) == 1:
|
||||
try:
|
||||
from api.settings import _create_user_document_filter
|
||||
user_doc_filter_id = await _create_user_document_filter(
|
||||
request, session_manager, original_filenames[0]
|
||||
)
|
||||
if user_doc_filter_id:
|
||||
logger.info(
|
||||
"Created knowledge filter for uploaded document",
|
||||
filter_id=user_doc_filter_id,
|
||||
filename=original_filenames[0],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to create knowledge filter for uploaded document",
|
||||
error=str(e),
|
||||
)
|
||||
# Don't fail the upload if filter creation fails
|
||||
|
||||
response_data = {
|
||||
"task_id": task_id,
|
||||
"message": f"Langflow upload task created for {len(upload_files)} file(s)",
|
||||
"file_count": len(upload_files),
|
||||
"create_filter": create_filter, # Pass flag back to frontend
|
||||
"filename": original_filenames[0] if len(original_filenames) == 1 else None, # Pass filename for filter creation
|
||||
}
|
||||
if user_doc_filter_id:
|
||||
response_data["user_doc_filter_id"] = user_doc_filter_id
|
||||
|
||||
return JSONResponse(response_data, status_code=202) # 202 Accepted for async processing
|
||||
|
||||
|
|
|
|||
|
|
@ -556,7 +556,7 @@ async def update_settings(request, session_manager):
|
|||
"ollama_endpoint"
|
||||
]
|
||||
|
||||
await clients.refresh_patched_clients()
|
||||
await clients.refresh_patched_client()
|
||||
|
||||
if any(key in body for key in provider_fields_to_check):
|
||||
try:
|
||||
|
|
@ -926,12 +926,13 @@ async def onboarding(request, flows_service, session_manager=None):
|
|||
{"error": "Failed to save configuration"}, status_code=500
|
||||
)
|
||||
|
||||
# Refresh cached patched clients so latest credentials take effect immediately
|
||||
await clients.refresh_patched_clients()
|
||||
# Refresh cached patched client so latest credentials take effect immediately
|
||||
await clients.refresh_patched_client()
|
||||
|
||||
# Create OpenRAG Docs knowledge filter if sample data was ingested
|
||||
# Only create on embedding step to avoid duplicates (both LLM and embedding cards submit with sample_data)
|
||||
openrag_docs_filter_id = None
|
||||
if should_ingest_sample_data:
|
||||
if should_ingest_sample_data and ("embedding_provider" in body or "embedding_model" in body):
|
||||
try:
|
||||
openrag_docs_filter_id = await _create_openrag_docs_filter(
|
||||
request, session_manager
|
||||
|
|
@ -1031,77 +1032,6 @@ async def _create_openrag_docs_filter(request, session_manager):
|
|||
return None
|
||||
|
||||
|
||||
async def _create_user_document_filter(request, session_manager, filename):
|
||||
"""Create a knowledge filter for a user-uploaded document during onboarding"""
|
||||
import uuid
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
# Get knowledge filter service from app state
|
||||
app = request.scope.get("app")
|
||||
if not app or not hasattr(app.state, "services"):
|
||||
logger.error("Could not access services for knowledge filter creation")
|
||||
return None
|
||||
|
||||
knowledge_filter_service = app.state.services.get("knowledge_filter_service")
|
||||
if not knowledge_filter_service:
|
||||
logger.error("Knowledge filter service not available")
|
||||
return None
|
||||
|
||||
# Get user and JWT token from request
|
||||
user = request.state.user
|
||||
jwt_token = session_manager.get_effective_jwt_token(user.user_id, request.state.jwt_token)
|
||||
|
||||
# In no-auth mode, set owner to None so filter is visible to all users
|
||||
# In auth mode, use the actual user as owner
|
||||
if is_no_auth_mode():
|
||||
owner_user_id = None
|
||||
else:
|
||||
owner_user_id = user.user_id
|
||||
|
||||
# Create the filter document
|
||||
filter_id = str(uuid.uuid4())
|
||||
|
||||
# Get a display name from the filename (remove extension for cleaner name)
|
||||
display_name = filename.rsplit(".", 1)[0] if "." in filename else filename
|
||||
|
||||
query_data = json.dumps({
|
||||
"query": "",
|
||||
"filters": {
|
||||
"data_sources": [filename],
|
||||
"document_types": ["*"],
|
||||
"owners": ["*"],
|
||||
"connector_types": ["*"],
|
||||
},
|
||||
"limit": 10,
|
||||
"scoreThreshold": 0,
|
||||
"color": "green",
|
||||
"icon": "file",
|
||||
})
|
||||
|
||||
filter_doc = {
|
||||
"id": filter_id,
|
||||
"name": display_name,
|
||||
"description": f"Filter for {filename}",
|
||||
"query_data": query_data,
|
||||
"owner": owner_user_id,
|
||||
"allowed_users": [],
|
||||
"allowed_groups": [],
|
||||
"created_at": datetime.utcnow().isoformat(),
|
||||
"updated_at": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
result = await knowledge_filter_service.create_knowledge_filter(
|
||||
filter_doc, user_id=user.user_id, jwt_token=jwt_token
|
||||
)
|
||||
|
||||
if result.get("success"):
|
||||
return filter_id
|
||||
else:
|
||||
logger.error("Failed to create user document filter", error=result.get("error"))
|
||||
return None
|
||||
|
||||
|
||||
def _get_flows_service():
|
||||
"""Helper function to get flows service instance"""
|
||||
from services.flows_service import FlowsService
|
||||
|
|
|
|||
|
|
@ -297,8 +297,7 @@ class AppClients:
|
|||
self.opensearch = None
|
||||
self.langflow_client = None
|
||||
self.langflow_http_client = None
|
||||
self._patched_llm_client = None # Private attribute
|
||||
self._patched_embedding_client = None # Private attribute
|
||||
self._patched_async_client = None # Private attribute - single client for all providers
|
||||
self._client_init_lock = __import__('threading').Lock() # Lock for thread-safe initialization
|
||||
self.converter = None
|
||||
|
||||
|
|
@ -377,192 +376,157 @@ class AppClients:
|
|||
self.langflow_client = None
|
||||
return self.langflow_client
|
||||
|
||||
def _build_provider_env(self, provider_type: str):
|
||||
@property
|
||||
def patched_async_client(self):
|
||||
"""
|
||||
Build environment overrides for the requested provider type ("llm" or "embedding").
|
||||
This is used to support different credentials for LLM and embedding providers.
|
||||
Property that ensures OpenAI client is initialized on first access.
|
||||
This allows lazy initialization so the app can start without an API key.
|
||||
|
||||
The client is patched with LiteLLM support to handle multiple providers.
|
||||
All provider credentials are loaded into environment for LiteLLM routing.
|
||||
|
||||
Note: The client is a long-lived singleton that should be closed via cleanup().
|
||||
Thread-safe via lock to prevent concurrent initialization attempts.
|
||||
"""
|
||||
config = get_openrag_config()
|
||||
|
||||
if provider_type == "llm":
|
||||
provider = (config.agent.llm_provider or "openai").lower()
|
||||
else:
|
||||
provider = (config.knowledge.embedding_provider or "openai").lower()
|
||||
|
||||
env_overrides = {}
|
||||
|
||||
if provider == "openai":
|
||||
api_key = config.providers.openai.api_key or os.getenv("OPENAI_API_KEY")
|
||||
if api_key:
|
||||
env_overrides["OPENAI_API_KEY"] = api_key
|
||||
elif provider == "anthropic":
|
||||
api_key = config.providers.anthropic.api_key or os.getenv("ANTHROPIC_API_KEY")
|
||||
if api_key:
|
||||
env_overrides["ANTHROPIC_API_KEY"] = api_key
|
||||
elif provider == "watsonx":
|
||||
api_key = config.providers.watsonx.api_key or os.getenv("WATSONX_API_KEY")
|
||||
endpoint = config.providers.watsonx.endpoint or os.getenv("WATSONX_ENDPOINT")
|
||||
project_id = config.providers.watsonx.project_id or os.getenv("WATSONX_PROJECT_ID")
|
||||
if api_key:
|
||||
env_overrides["WATSONX_API_KEY"] = api_key
|
||||
if endpoint:
|
||||
env_overrides["WATSONX_ENDPOINT"] = endpoint
|
||||
if project_id:
|
||||
env_overrides["WATSONX_PROJECT_ID"] = project_id
|
||||
elif provider == "ollama":
|
||||
endpoint = config.providers.ollama.endpoint or os.getenv("OLLAMA_ENDPOINT")
|
||||
if endpoint:
|
||||
env_overrides["OLLAMA_ENDPOINT"] = endpoint
|
||||
env_overrides["OLLAMA_BASE_URL"] = endpoint
|
||||
|
||||
return env_overrides, provider
|
||||
|
||||
def _apply_env_overrides(self, env_overrides: dict):
|
||||
"""Apply non-empty environment overrides."""
|
||||
for key, value in (env_overrides or {}).items():
|
||||
if value is None:
|
||||
continue
|
||||
os.environ[key] = str(value)
|
||||
|
||||
def _initialize_patched_client(self, cache_attr: str, provider_type: str):
|
||||
"""
|
||||
Initialize a patched AsyncOpenAI client for the specified provider type.
|
||||
Uses HTTP/2 probe only when an OpenAI key is present; otherwise falls back directly.
|
||||
"""
|
||||
# Quick path
|
||||
cached_client = getattr(self, cache_attr)
|
||||
if cached_client is not None:
|
||||
return cached_client
|
||||
# Quick check without lock
|
||||
if self._patched_async_client is not None:
|
||||
return self._patched_async_client
|
||||
|
||||
# Use lock to ensure only one thread initializes
|
||||
with self._client_init_lock:
|
||||
cached_client = getattr(self, cache_attr)
|
||||
if cached_client is not None:
|
||||
return cached_client
|
||||
# Double-check after acquiring lock
|
||||
if self._patched_async_client is not None:
|
||||
return self._patched_async_client
|
||||
|
||||
env_overrides, provider_name = self._build_provider_env(provider_type)
|
||||
self._apply_env_overrides(env_overrides)
|
||||
|
||||
# Decide whether to run the HTTP/2 probe (only meaningful for OpenAI endpoints)
|
||||
has_openai_key = bool(env_overrides.get("OPENAI_API_KEY") or os.getenv("OPENAI_API_KEY"))
|
||||
if provider_name == "openai" and not has_openai_key:
|
||||
raise ValueError("OPENAI_API_KEY is required for OpenAI provider")
|
||||
# Load all provider credentials into environment for LiteLLM
|
||||
# LiteLLM routes based on model name prefixes (openai/, ollama/, watsonx/, etc.)
|
||||
try:
|
||||
config = get_openrag_config()
|
||||
|
||||
# Set OpenAI credentials
|
||||
if config.providers.openai.api_key:
|
||||
os.environ["OPENAI_API_KEY"] = config.providers.openai.api_key
|
||||
logger.debug("Loaded OpenAI API key from config")
|
||||
|
||||
# Set Anthropic credentials
|
||||
if config.providers.anthropic.api_key:
|
||||
os.environ["ANTHROPIC_API_KEY"] = config.providers.anthropic.api_key
|
||||
logger.debug("Loaded Anthropic API key from config")
|
||||
|
||||
# Set WatsonX credentials
|
||||
if config.providers.watsonx.api_key:
|
||||
os.environ["WATSONX_API_KEY"] = config.providers.watsonx.api_key
|
||||
if config.providers.watsonx.endpoint:
|
||||
os.environ["WATSONX_ENDPOINT"] = config.providers.watsonx.endpoint
|
||||
os.environ["WATSONX_API_BASE"] = config.providers.watsonx.endpoint # LiteLLM expects this name
|
||||
if config.providers.watsonx.project_id:
|
||||
os.environ["WATSONX_PROJECT_ID"] = config.providers.watsonx.project_id
|
||||
if config.providers.watsonx.api_key:
|
||||
logger.debug("Loaded WatsonX credentials from config")
|
||||
|
||||
# Set Ollama endpoint
|
||||
if config.providers.ollama.endpoint:
|
||||
os.environ["OLLAMA_BASE_URL"] = config.providers.ollama.endpoint
|
||||
os.environ["OLLAMA_ENDPOINT"] = config.providers.ollama.endpoint
|
||||
logger.debug("Loaded Ollama endpoint from config")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug("Could not load provider credentials from config", error=str(e))
|
||||
|
||||
# Try to initialize the client - AsyncOpenAI() will read from environment
|
||||
# We'll try HTTP/2 first with a probe, then fall back to HTTP/1.1 if it times out
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import threading
|
||||
|
||||
async def build_client(skip_probe: bool = False):
|
||||
if not has_openai_key:
|
||||
# No OpenAI key present; create a basic patched client without probing
|
||||
return patch_openai_with_mcp(AsyncOpenAI(http_client=httpx.AsyncClient()))
|
||||
|
||||
if skip_probe:
|
||||
http_client = httpx.AsyncClient(http2=False, timeout=httpx.Timeout(60.0, connect=10.0))
|
||||
return patch_openai_with_mcp(AsyncOpenAI(http_client=http_client))
|
||||
|
||||
async def probe_and_initialize():
|
||||
# Try HTTP/2 first (default)
|
||||
client_http2 = patch_openai_with_mcp(AsyncOpenAI())
|
||||
logger.info("Probing patched OpenAI client with HTTP/2...")
|
||||
logger.info("Probing OpenAI client with HTTP/2...")
|
||||
|
||||
try:
|
||||
# Probe with a small embedding and short timeout
|
||||
await asyncio.wait_for(
|
||||
client_http2.embeddings.create(
|
||||
model="text-embedding-3-small",
|
||||
input=["test"],
|
||||
model='text-embedding-3-small',
|
||||
input=['test']
|
||||
),
|
||||
timeout=5.0,
|
||||
timeout=5.0
|
||||
)
|
||||
logger.info("Patched OpenAI client initialized with HTTP/2 (probe successful)")
|
||||
logger.info("OpenAI client initialized with HTTP/2 (probe successful)")
|
||||
return client_http2
|
||||
except (asyncio.TimeoutError, Exception) as probe_error:
|
||||
logger.warning("HTTP/2 probe failed, falling back to HTTP/1.1", error=str(probe_error))
|
||||
# Close the HTTP/2 client
|
||||
try:
|
||||
await client_http2.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fall back to HTTP/1.1 with explicit timeout settings
|
||||
http_client = httpx.AsyncClient(
|
||||
http2=False, timeout=httpx.Timeout(60.0, connect=10.0)
|
||||
http2=False,
|
||||
timeout=httpx.Timeout(60.0, connect=10.0)
|
||||
)
|
||||
client_http1 = patch_openai_with_mcp(AsyncOpenAI(http_client=http_client))
|
||||
logger.info("Patched OpenAI client initialized with HTTP/1.1 (fallback)")
|
||||
client_http1 = patch_openai_with_mcp(
|
||||
AsyncOpenAI(http_client=http_client)
|
||||
)
|
||||
logger.info("OpenAI client initialized with HTTP/1.1 (fallback)")
|
||||
return client_http1
|
||||
|
||||
def run_builder(skip_probe=False):
|
||||
def run_probe_in_thread():
|
||||
"""Run the async probe in a new thread with its own event loop"""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
return loop.run_until_complete(build_client(skip_probe=skip_probe))
|
||||
return loop.run_until_complete(probe_and_initialize())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
try:
|
||||
# Run the probe in a separate thread with its own event loop
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(run_builder, False)
|
||||
client = future.result(timeout=15 if has_openai_key else 10)
|
||||
setattr(self, cache_attr, client)
|
||||
logger.info("Successfully initialized patched client", provider_type=provider_type)
|
||||
return client
|
||||
future = executor.submit(run_probe_in_thread)
|
||||
self._patched_async_client = future.result(timeout=15)
|
||||
logger.info("Successfully initialized OpenAI client")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to initialize patched client: {e.__class__.__name__}: {str(e)}",
|
||||
provider_type=provider_type,
|
||||
)
|
||||
raise ValueError(
|
||||
f"Failed to initialize patched client for {provider_type}: {str(e)}. "
|
||||
"Please ensure provider credentials are set."
|
||||
)
|
||||
logger.error(f"Failed to initialize OpenAI client: {e.__class__.__name__}: {str(e)}")
|
||||
raise ValueError(f"Failed to initialize OpenAI client: {str(e)}. Please complete onboarding or set OPENAI_API_KEY environment variable.")
|
||||
|
||||
return self._patched_async_client
|
||||
|
||||
@property
|
||||
def patched_llm_client(self):
|
||||
"""Patched client for LLM provider."""
|
||||
return self._initialize_patched_client("_patched_llm_client", "llm")
|
||||
"""Alias for patched_async_client - for backward compatibility with code expecting separate clients."""
|
||||
return self.patched_async_client
|
||||
|
||||
@property
|
||||
def patched_embedding_client(self):
|
||||
"""Patched client for embedding provider."""
|
||||
return self._initialize_patched_client("_patched_embedding_client", "embedding")
|
||||
"""Alias for patched_async_client - for backward compatibility with code expecting separate clients."""
|
||||
return self.patched_async_client
|
||||
|
||||
@property
|
||||
def patched_async_client(self):
|
||||
"""Backward-compatibility alias for LLM client."""
|
||||
return self.patched_llm_client
|
||||
|
||||
async def refresh_patched_clients(self):
|
||||
"""Reset patched clients so next use picks up updated provider credentials."""
|
||||
clients_to_close = []
|
||||
with self._client_init_lock:
|
||||
if self._patched_llm_client is not None:
|
||||
clients_to_close.append(self._patched_llm_client)
|
||||
self._patched_llm_client = None
|
||||
if self._patched_embedding_client is not None:
|
||||
clients_to_close.append(self._patched_embedding_client)
|
||||
self._patched_embedding_client = None
|
||||
|
||||
for client in clients_to_close:
|
||||
async def refresh_patched_client(self):
|
||||
"""Reset patched client so next use picks up updated provider credentials."""
|
||||
if self._patched_async_client is not None:
|
||||
try:
|
||||
await client.close()
|
||||
await self._patched_async_client.close()
|
||||
logger.info("Closed patched client for refresh")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to close patched client during refresh", error=str(e))
|
||||
finally:
|
||||
self._patched_async_client = None
|
||||
|
||||
async def cleanup(self):
|
||||
"""Cleanup resources - should be called on application shutdown"""
|
||||
# Close AsyncOpenAI client if it was created
|
||||
if self._patched_llm_client is not None:
|
||||
if self._patched_async_client is not None:
|
||||
try:
|
||||
await self._patched_llm_client.close()
|
||||
logger.info("Closed LLM patched client")
|
||||
await self._patched_async_client.close()
|
||||
logger.info("Closed AsyncOpenAI client")
|
||||
except Exception as e:
|
||||
logger.error("Failed to close LLM patched client", error=str(e))
|
||||
logger.error("Failed to close AsyncOpenAI client", error=str(e))
|
||||
finally:
|
||||
self._patched_llm_client = None
|
||||
|
||||
if self._patched_embedding_client is not None:
|
||||
try:
|
||||
await self._patched_embedding_client.close()
|
||||
logger.info("Closed embedding patched client")
|
||||
except Exception as e:
|
||||
logger.error("Failed to close embedding patched client", error=str(e))
|
||||
finally:
|
||||
self._patched_embedding_client = None
|
||||
self._patched_async_client = None
|
||||
|
||||
# Close Langflow HTTP client if it exists
|
||||
if self.langflow_http_client is not None:
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ class ChatService:
|
|||
jwt_token: str = None,
|
||||
previous_response_id: str = None,
|
||||
stream: bool = False,
|
||||
filter_id: str = None,
|
||||
):
|
||||
"""Handle chat requests using the patched OpenAI client"""
|
||||
if not prompt:
|
||||
|
|
@ -30,6 +31,7 @@ class ChatService:
|
|||
prompt,
|
||||
user_id,
|
||||
previous_response_id=previous_response_id,
|
||||
filter_id=filter_id,
|
||||
)
|
||||
else:
|
||||
response_text, response_id = await async_chat(
|
||||
|
|
@ -37,6 +39,7 @@ class ChatService:
|
|||
prompt,
|
||||
user_id,
|
||||
previous_response_id=previous_response_id,
|
||||
filter_id=filter_id,
|
||||
)
|
||||
response_data = {"response": response_text}
|
||||
if response_id:
|
||||
|
|
@ -50,6 +53,7 @@ class ChatService:
|
|||
jwt_token: str = None,
|
||||
previous_response_id: str = None,
|
||||
stream: bool = False,
|
||||
filter_id: str = None,
|
||||
):
|
||||
"""Handle Langflow chat requests"""
|
||||
if not prompt:
|
||||
|
|
@ -147,6 +151,7 @@ class ChatService:
|
|||
user_id,
|
||||
extra_headers=extra_headers,
|
||||
previous_response_id=previous_response_id,
|
||||
filter_id=filter_id,
|
||||
)
|
||||
else:
|
||||
from agent import async_langflow_chat
|
||||
|
|
@ -158,6 +163,7 @@ class ChatService:
|
|||
user_id,
|
||||
extra_headers=extra_headers,
|
||||
previous_response_id=previous_response_id,
|
||||
filter_id=filter_id,
|
||||
)
|
||||
response_data = {"response": response_text}
|
||||
if response_id:
|
||||
|
|
@ -429,6 +435,7 @@ class ChatService:
|
|||
"previous_response_id": conversation_state.get(
|
||||
"previous_response_id"
|
||||
),
|
||||
"filter_id": conversation_state.get("filter_id"),
|
||||
"total_messages": len(messages),
|
||||
"source": "in_memory",
|
||||
}
|
||||
|
|
@ -447,6 +454,7 @@ class ChatService:
|
|||
"created_at": metadata.get("created_at"),
|
||||
"last_activity": metadata.get("last_activity"),
|
||||
"previous_response_id": metadata.get("previous_response_id"),
|
||||
"filter_id": metadata.get("filter_id"),
|
||||
"total_messages": metadata.get("total_messages", 0),
|
||||
"source": "metadata_only",
|
||||
}
|
||||
|
|
@ -545,6 +553,7 @@ class ChatService:
|
|||
or conversation.get("created_at"),
|
||||
"last_activity": metadata.get("last_activity")
|
||||
or conversation.get("last_activity"),
|
||||
"filter_id": metadata.get("filter_id"),
|
||||
"total_messages": len(messages),
|
||||
"source": "langflow_enhanced",
|
||||
"langflow_session_id": session_id,
|
||||
|
|
|
|||
|
|
@ -147,13 +147,41 @@ class SearchService:
|
|||
attempts = 0
|
||||
last_exception = None
|
||||
|
||||
# Format model name for LiteLLM compatibility
|
||||
# The patched client routes through LiteLLM for non-OpenAI providers
|
||||
formatted_model = model_name
|
||||
|
||||
# Skip if already has a provider prefix
|
||||
if not any(model_name.startswith(prefix + "/") for prefix in ["openai", "ollama", "watsonx", "anthropic"]):
|
||||
# Detect provider from model name characteristics:
|
||||
# - Ollama: contains ":" (e.g., "nomic-embed-text:latest")
|
||||
# - WatsonX: starts with "ibm/" or known third-party models
|
||||
# - OpenAI: everything else (no prefix needed)
|
||||
|
||||
if ":" in model_name:
|
||||
# Ollama models use tags with colons
|
||||
formatted_model = f"ollama/{model_name}"
|
||||
logger.debug(f"Formatted Ollama model: {model_name} -> {formatted_model}")
|
||||
elif model_name.startswith("ibm/") or model_name in [
|
||||
"intfloat/multilingual-e5-large",
|
||||
"sentence-transformers/all-minilm-l6-v2"
|
||||
]:
|
||||
# WatsonX embedding models
|
||||
formatted_model = f"watsonx/{model_name}"
|
||||
logger.debug(f"Formatted WatsonX model: {model_name} -> {formatted_model}")
|
||||
# else: OpenAI models don't need a prefix
|
||||
|
||||
while attempts < MAX_EMBED_RETRIES:
|
||||
attempts += 1
|
||||
try:
|
||||
resp = await clients.patched_embedding_client.embeddings.create(
|
||||
model=model_name, input=[query]
|
||||
model=formatted_model, input=[query]
|
||||
)
|
||||
return model_name, resp.data[0].embedding
|
||||
# Try to get embedding - some providers return .embedding, others return ['embedding']
|
||||
embedding = getattr(resp.data[0], 'embedding', None)
|
||||
if embedding is None:
|
||||
embedding = resp.data[0]['embedding']
|
||||
return model_name, embedding
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
if attempts >= MAX_EMBED_RETRIES:
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue