diff --git a/frontend/app/api/mutations/useOnboardingMutation.ts b/frontend/app/api/mutations/useOnboardingMutation.ts index 6c8e2335..42b95236 100644 --- a/frontend/app/api/mutations/useOnboardingMutation.ts +++ b/frontend/app/api/mutations/useOnboardingMutation.ts @@ -3,6 +3,7 @@ import { useMutation, useQueryClient, } from "@tanstack/react-query"; +import { ONBOARDING_OPENRAG_DOCS_FILTER_ID_KEY } from "@/lib/constants"; export interface OnboardingVariables { // Provider selection @@ -28,6 +29,7 @@ export interface OnboardingVariables { interface OnboardingResponse { message: string; edited: boolean; + openrag_docs_filter_id?: string; } export const useOnboardingMutation = ( @@ -59,6 +61,15 @@ export const useOnboardingMutation = ( return useMutation({ mutationFn: submitOnboarding, + onSuccess: (data) => { + // Store OpenRAG Docs filter ID if returned + if (data.openrag_docs_filter_id && typeof window !== "undefined") { + localStorage.setItem( + ONBOARDING_OPENRAG_DOCS_FILTER_ID_KEY, + data.openrag_docs_filter_id + ); + } + }, onSettled: () => { // Invalidate settings query to refetch updated data queryClient.invalidateQueries({ queryKey: ["settings"] }); diff --git a/frontend/app/api/queries/useGetFilterByIdQuery.ts b/frontend/app/api/queries/useGetFilterByIdQuery.ts new file mode 100644 index 00000000..353b3153 --- /dev/null +++ b/frontend/app/api/queries/useGetFilterByIdQuery.ts @@ -0,0 +1,21 @@ +import type { KnowledgeFilter } from "./useGetFiltersSearchQuery"; + +export async function getFilterById( + filterId: string +): Promise { + try { + const response = await fetch(`/api/knowledge-filter/${filterId}`, { + method: "GET", + headers: { "Content-Type": "application/json" }, + }); + + const json = await response.json(); + if (!response.ok || !json.success) { + return null; + } + return json.filter as KnowledgeFilter; + } catch (error) { + console.error("Failed to fetch filter by ID:", error); + return null; + } +} diff --git a/frontend/app/onboarding/_components/onboarding-content.tsx b/frontend/app/onboarding/_components/onboarding-content.tsx index ee47f347..699c8723 100644 --- a/frontend/app/onboarding/_components/onboarding-content.tsx +++ b/frontend/app/onboarding/_components/onboarding-content.tsx @@ -5,17 +5,25 @@ import { StickToBottom } from "use-stick-to-bottom"; import { AssistantMessage } from "@/app/chat/_components/assistant-message"; import Nudges from "@/app/chat/_components/nudges"; import { UserMessage } from "@/app/chat/_components/user-message"; -import type { Message } from "@/app/chat/_types/types"; +import type { Message, SelectedFilters } from "@/app/chat/_types/types"; import OnboardingCard from "@/app/onboarding/_components/onboarding-card"; import { useChatStreaming } from "@/hooks/useChatStreaming"; import { ONBOARDING_ASSISTANT_MESSAGE_KEY, + ONBOARDING_OPENRAG_DOCS_FILTER_ID_KEY, ONBOARDING_SELECTED_NUDGE_KEY, } from "@/lib/constants"; import { OnboardingStep } from "./onboarding-step"; import OnboardingUpload from "./onboarding-upload"; +// Filters for OpenRAG documentation +const OPENRAG_DOCS_FILTERS: SelectedFilters = { + data_sources: ["openrag-documentation.pdf"], + document_types: [], + owners: [], +}; + export function OnboardingContent({ handleStepComplete, handleStepBack, @@ -115,9 +123,16 @@ export function OnboardingContent({ localStorage.removeItem(ONBOARDING_ASSISTANT_MESSAGE_KEY); } setTimeout(async () => { + // Check if we have the OpenRAG docs filter ID (sample data was ingested) + const hasOpenragDocsFilter = + typeof window !== "undefined" && + localStorage.getItem(ONBOARDING_OPENRAG_DOCS_FILTER_ID_KEY); + await sendMessage({ prompt: nudge, previousResponseId: responseId || undefined, + // Use OpenRAG docs filter if sample data was ingested + filters: hasOpenragDocsFilter ? OPENRAG_DOCS_FILTERS : undefined, }); }, 1500); }; diff --git a/frontend/app/onboarding/_components/onboarding-upload.tsx b/frontend/app/onboarding/_components/onboarding-upload.tsx index 3855ff83..7855ec0a 100644 --- a/frontend/app/onboarding/_components/onboarding-upload.tsx +++ b/frontend/app/onboarding/_components/onboarding-upload.tsx @@ -4,7 +4,10 @@ import { useGetNudgesQuery } from "@/app/api/queries/useGetNudgesQuery"; import { useGetTasksQuery } from "@/app/api/queries/useGetTasksQuery"; import { AnimatedProviderSteps } from "@/app/onboarding/_components/animated-provider-steps"; import { Button } from "@/components/ui/button"; -import { ONBOARDING_UPLOAD_STEPS_KEY } from "@/lib/constants"; +import { + ONBOARDING_UPLOAD_STEPS_KEY, + ONBOARDING_USER_DOC_FILTER_ID_KEY, +} from "@/lib/constants"; import { uploadFile } from "@/lib/upload-utils"; interface OnboardingUploadProps { @@ -77,8 +80,17 @@ const OnboardingUpload = ({ onComplete }: OnboardingUploadProps) => { setIsUploading(true); try { setCurrentStep(0); - await uploadFile(file, true); + const result = await uploadFile(file, true, true); // Pass createFilter=true console.log("Document upload task started successfully"); + + // Store user document filter ID if returned + if (result.userDocFilterId && typeof window !== "undefined") { + localStorage.setItem( + ONBOARDING_USER_DOC_FILTER_ID_KEY, + result.userDocFilterId + ); + } + // Move to processing step - task monitoring will handle completion setTimeout(() => { setCurrentStep(1); diff --git a/frontend/app/onboarding/_components/openai-onboarding.tsx b/frontend/app/onboarding/_components/openai-onboarding.tsx index db676553..d01cb64a 100644 --- a/frontend/app/onboarding/_components/openai-onboarding.tsx +++ b/frontend/app/onboarding/_components/openai-onboarding.tsx @@ -50,7 +50,12 @@ export function OpenAIOnboarding({ : debouncedApiKey ? { apiKey: debouncedApiKey } : undefined, - { enabled: debouncedApiKey !== "" || getFromEnv || alreadyConfigured }, + { + // Only validate when the user opts in (env) or provides a key. + // If a key was previously configured, let the user decide to reuse or replace it + // without triggering an immediate validation error. + enabled: debouncedApiKey !== "" || getFromEnv, + }, ); // Use custom hook for model selection logic const { @@ -134,11 +139,12 @@ export function OpenAIOnboarding({ } value={apiKey} onChange={(e) => setApiKey(e.target.value)} - disabled={alreadyConfigured} + // Even if a key exists, allow replacing it to avoid getting stuck on stale creds. + disabled={false} /> {alreadyConfigured && (

- Reusing key from model provider selection. + Existing OpenAI key detected. You can reuse it or enter a new one.

)} {isLoadingModels && ( diff --git a/frontend/components/chat-renderer.tsx b/frontend/components/chat-renderer.tsx index 01a5ca75..45841299 100644 --- a/frontend/components/chat-renderer.tsx +++ b/frontend/components/chat-renderer.tsx @@ -2,11 +2,12 @@ import { motion } from "framer-motion"; import { usePathname } from "next/navigation"; -import { useEffect, useState } from "react"; +import { useCallback, useEffect, useState } from "react"; import { type ChatConversation, useGetConversationsQuery, } from "@/app/api/queries/useGetConversationsQuery"; +import { getFilterById } from "@/app/api/queries/useGetFilterByIdQuery"; import type { Settings } from "@/app/api/queries/useGetSettingsQuery"; import { OnboardingContent } from "@/app/onboarding/_components/onboarding-content"; import { ProgressBar } from "@/app/onboarding/_components/progress-bar"; @@ -20,9 +21,11 @@ import { HEADER_HEIGHT, ONBOARDING_ASSISTANT_MESSAGE_KEY, ONBOARDING_CARD_STEPS_KEY, + ONBOARDING_OPENRAG_DOCS_FILTER_ID_KEY, ONBOARDING_SELECTED_NUDGE_KEY, ONBOARDING_STEP_KEY, ONBOARDING_UPLOAD_STEPS_KEY, + ONBOARDING_USER_DOC_FILTER_ID_KEY, SIDEBAR_WIDTH, TOTAL_ONBOARDING_STEPS, } from "@/lib/constants"; @@ -42,6 +45,7 @@ export function ChatRenderer({ refreshTrigger, refreshConversations, startNewConversation, + setConversationFilter, } = useChat(); // Initialize onboarding state based on local storage and settings @@ -71,6 +75,42 @@ export function ChatRenderer({ startNewConversation(); }; + // Helper to set the default filter after onboarding transition + const setOnboardingFilter = useCallback( + async (preferUserDoc: boolean) => { + if (typeof window === "undefined") return; + + // Try to get the appropriate filter ID + let filterId: string | null = null; + + if (preferUserDoc) { + // Completed full onboarding - prefer user document filter + filterId = localStorage.getItem(ONBOARDING_USER_DOC_FILTER_ID_KEY); + } + + // Fall back to OpenRAG docs filter + if (!filterId) { + filterId = localStorage.getItem(ONBOARDING_OPENRAG_DOCS_FILTER_ID_KEY); + } + + if (filterId) { + try { + const filter = await getFilterById(filterId); + if (filter) { + setConversationFilter(filter); + } + } catch (error) { + console.error("Failed to set onboarding filter:", error); + } + } + + // Clean up onboarding filter IDs from localStorage + localStorage.removeItem(ONBOARDING_OPENRAG_DOCS_FILTER_ID_KEY); + localStorage.removeItem(ONBOARDING_USER_DOC_FILTER_ID_KEY); + }, + [setConversationFilter] + ); + // Save current step to local storage whenever it changes useEffect(() => { if (typeof window !== "undefined" && !showLayout) { @@ -90,6 +130,8 @@ export function ChatRenderer({ localStorage.removeItem(ONBOARDING_CARD_STEPS_KEY); localStorage.removeItem(ONBOARDING_UPLOAD_STEPS_KEY); } + // Set the user document filter as active (completed full onboarding) + setOnboardingFilter(true); setShowLayout(true); } }; @@ -109,6 +151,8 @@ export function ChatRenderer({ localStorage.removeItem(ONBOARDING_CARD_STEPS_KEY); localStorage.removeItem(ONBOARDING_UPLOAD_STEPS_KEY); } + // Set the OpenRAG docs filter as active (skipped onboarding - no user doc) + setOnboardingFilter(false); setShowLayout(true); }; diff --git a/frontend/lib/constants.ts b/frontend/lib/constants.ts index cc5d2bdb..88baf8d0 100644 --- a/frontend/lib/constants.ts +++ b/frontend/lib/constants.ts @@ -45,6 +45,8 @@ export const ONBOARDING_ASSISTANT_MESSAGE_KEY = "onboarding_assistant_message"; export const ONBOARDING_SELECTED_NUDGE_KEY = "onboarding_selected_nudge"; export const ONBOARDING_CARD_STEPS_KEY = "onboarding_card_steps"; export const ONBOARDING_UPLOAD_STEPS_KEY = "onboarding_upload_steps"; +export const ONBOARDING_OPENRAG_DOCS_FILTER_ID_KEY = "onboarding_openrag_docs_filter_id"; +export const ONBOARDING_USER_DOC_FILTER_ID_KEY = "onboarding_user_doc_filter_id"; export const FILES_REGEX = /(?<=I'm uploading a document called ['"])[^'"]+\.[^.]+(?=['"]\. Here is its content:)/; diff --git a/frontend/lib/upload-utils.ts b/frontend/lib/upload-utils.ts index 8a16f3b9..8d6aae9a 100644 --- a/frontend/lib/upload-utils.ts +++ b/frontend/lib/upload-utils.ts @@ -10,6 +10,7 @@ export interface UploadFileResult { deletion: unknown; unified: boolean; raw: unknown; + userDocFilterId?: string; } export async function duplicateCheck( @@ -120,11 +121,15 @@ export async function uploadFileForContext( export async function uploadFile( file: File, replace = false, + createFilter = false, ): Promise { try { const formData = new FormData(); formData.append("file", file); formData.append("replace_duplicates", replace.toString()); + if (createFilter) { + formData.append("create_filter", "true"); + } const uploadResponse = await fetch("/api/router/upload_ingest", { method: "POST", @@ -177,6 +182,9 @@ export async function uploadFile( ); } + const userDocFilterId = (uploadIngestJson as { user_doc_filter_id?: string }) + .user_doc_filter_id; + const result: UploadFileResult = { fileId, filePath, @@ -184,6 +192,7 @@ export async function uploadFile( deletion: deletionJson, unified: true, raw: uploadIngestJson, + userDocFilterId, }; return result; diff --git a/src/api/router.py b/src/api/router.py index 15a9b116..e8fb924d 100644 --- a/src/api/router.py +++ b/src/api/router.py @@ -37,6 +37,7 @@ async def upload_ingest_router( # Route based on configuration if DISABLE_INGEST_WITH_LANGFLOW: # Route to traditional OpenRAG upload + # Note: onboarding filter creation is only supported in Langflow path logger.debug("Routing to traditional OpenRAG upload") return await traditional_upload(request, document_service, session_manager) else: @@ -77,6 +78,7 @@ async def langflow_upload_ingest_task( tweaks_json = form.get("tweaks") delete_after_ingest = form.get("delete_after_ingest", "true").lower() == "true" replace_duplicates = form.get("replace_duplicates", "false").lower() == "true" + create_filter = form.get("create_filter", "false").lower() == "true" # Parse JSON fields if provided settings = None @@ -177,14 +179,36 @@ async def langflow_upload_ingest_task( logger.debug("Langflow upload task created successfully", task_id=task_id) - return JSONResponse( - { - "task_id": task_id, - "message": f"Langflow upload task created for {len(upload_files)} file(s)", - "file_count": len(upload_files), - }, - status_code=202, - ) # 202 Accepted for async processing + # Create knowledge filter for the uploaded document if requested + user_doc_filter_id = None + if create_filter and len(original_filenames) == 1: + try: + from api.settings import _create_user_document_filter + user_doc_filter_id = await _create_user_document_filter( + request, session_manager, original_filenames[0] + ) + if user_doc_filter_id: + logger.info( + "Created knowledge filter for uploaded document", + filter_id=user_doc_filter_id, + filename=original_filenames[0], + ) + except Exception as e: + logger.error( + "Failed to create knowledge filter for uploaded document", + error=str(e), + ) + # Don't fail the upload if filter creation fails + + response_data = { + "task_id": task_id, + "message": f"Langflow upload task created for {len(upload_files)} file(s)", + "file_count": len(upload_files), + } + if user_doc_filter_id: + response_data["user_doc_filter_id"] = user_doc_filter_id + + return JSONResponse(response_data, status_code=202) # 202 Accepted for async processing except Exception: # Clean up temp files on error diff --git a/src/api/settings.py b/src/api/settings.py index c8e443cf..0a23fabb 100644 --- a/src/api/settings.py +++ b/src/api/settings.py @@ -508,7 +508,7 @@ async def update_settings(request, session_manager): # Update provider-specific settings if "openai_api_key" in body and body["openai_api_key"].strip(): - current_config.providers.openai.api_key = body["openai_api_key"] + current_config.providers.openai.api_key = body["openai_api_key"].strip() current_config.providers.openai.configured = True config_updated = True @@ -555,6 +555,9 @@ async def update_settings(request, session_manager): "watsonx_api_key", "watsonx_endpoint", "watsonx_project_id", "ollama_endpoint" ] + + await clients.refresh_patched_clients() + if any(key in body for key in provider_fields_to_check): try: flows_service = _get_flows_service() @@ -562,8 +565,11 @@ async def update_settings(request, session_manager): # Update global variables await _update_langflow_global_variables(current_config) + # Update LLM client credentials when embedding selection changes if "embedding_provider" in body or "embedding_model" in body: - await _update_mcp_servers_with_provider_credentials(current_config) + await _update_mcp_servers_with_provider_credentials( + current_config, session_manager + ) # Update model values if provider or model changed if "llm_provider" in body or "llm_model" in body or "embedding_provider" in body or "embedding_model" in body: @@ -574,6 +580,7 @@ async def update_settings(request, session_manager): # Don't fail the entire settings update if Langflow update fails # The config was still saved + logger.info( "Configuration updated successfully", updated_fields=list(body.keys()) ) @@ -689,7 +696,7 @@ async def onboarding(request, flows_service, session_manager=None): # Update provider-specific credentials if "openai_api_key" in body and body["openai_api_key"].strip(): - current_config.providers.openai.api_key = body["openai_api_key"] + current_config.providers.openai.api_key = body["openai_api_key"].strip() current_config.providers.openai.configured = True config_updated = True @@ -919,11 +926,33 @@ async def onboarding(request, flows_service, session_manager=None): {"error": "Failed to save configuration"}, status_code=500 ) + # Refresh cached patched clients so latest credentials take effect immediately + await clients.refresh_patched_clients() + + # Create OpenRAG Docs knowledge filter if sample data was ingested + openrag_docs_filter_id = None + if should_ingest_sample_data: + try: + openrag_docs_filter_id = await _create_openrag_docs_filter( + request, session_manager + ) + if openrag_docs_filter_id: + logger.info( + "Created OpenRAG Docs knowledge filter", + filter_id=openrag_docs_filter_id, + ) + except Exception as e: + logger.error( + "Failed to create OpenRAG Docs knowledge filter", error=str(e) + ) + # Don't fail onboarding if filter creation fails + return JSONResponse( { "message": "Onboarding configuration updated successfully", "edited": True, # Confirm that config is now marked as edited "sample_data_ingested": should_ingest_sample_data, + "openrag_docs_filter_id": openrag_docs_filter_id, } ) @@ -935,6 +964,132 @@ async def onboarding(request, flows_service, session_manager=None): ) +async def _create_openrag_docs_filter(request, session_manager): + """Create the OpenRAG Docs knowledge filter for onboarding""" + import uuid + import json + from datetime import datetime + from session_manager import AnonymousUser + + # Get knowledge filter service from app state + app = request.scope.get("app") + if not app or not hasattr(app.state, "services"): + logger.error("Could not access services for knowledge filter creation") + return None + + knowledge_filter_service = app.state.services.get("knowledge_filter_service") + if not knowledge_filter_service: + logger.error("Knowledge filter service not available") + return None + + # Use anonymous user for no-auth mode compatibility + user = AnonymousUser() + jwt_token = session_manager.get_effective_jwt_token(user.user_id, None) + + # Create the filter document + filter_id = str(uuid.uuid4()) + query_data = json.dumps({ + "query": "", + "filters": { + "data_sources": ["openrag-documentation.pdf"], + "document_types": ["*"], + "owners": ["*"], + "connector_types": ["*"], + }, + "limit": 10, + "scoreThreshold": 0, + "color": "blue", + "icon": "book", + }) + + filter_doc = { + "id": filter_id, + "name": "OpenRAG Docs", + "description": "Filter for OpenRAG documentation", + "query_data": query_data, + "owner": user.user_id, + "allowed_users": [], + "allowed_groups": [], + "created_at": datetime.utcnow().isoformat(), + "updated_at": datetime.utcnow().isoformat(), + } + + result = await knowledge_filter_service.create_knowledge_filter( + filter_doc, user_id=user.user_id, jwt_token=jwt_token + ) + + if result.get("success"): + return filter_id + else: + logger.error("Failed to create OpenRAG Docs filter", error=result.get("error")) + return None + + +async def _create_user_document_filter(request, session_manager, filename): + """Create a knowledge filter for a user-uploaded document during onboarding""" + import uuid + import json + from datetime import datetime + from session_manager import AnonymousUser + + # Get knowledge filter service from app state + app = request.scope.get("app") + if not app or not hasattr(app.state, "services"): + logger.error("Could not access services for knowledge filter creation") + return None + + knowledge_filter_service = app.state.services.get("knowledge_filter_service") + if not knowledge_filter_service: + logger.error("Knowledge filter service not available") + return None + + # Use anonymous user for no-auth mode compatibility + user = AnonymousUser() + jwt_token = session_manager.get_effective_jwt_token(user.user_id, None) + + # Create the filter document + filter_id = str(uuid.uuid4()) + + # Get a display name from the filename (remove extension for cleaner name) + display_name = filename.rsplit(".", 1)[0] if "." in filename else filename + + query_data = json.dumps({ + "query": "", + "filters": { + "data_sources": [filename], + "document_types": ["*"], + "owners": ["*"], + "connector_types": ["*"], + }, + "limit": 10, + "scoreThreshold": 0, + "color": "green", + "icon": "file", + }) + + filter_doc = { + "id": filter_id, + "name": display_name, + "description": f"Filter for {filename}", + "query_data": query_data, + "owner": user.user_id, + "allowed_users": [], + "allowed_groups": [], + "created_at": datetime.utcnow().isoformat(), + "updated_at": datetime.utcnow().isoformat(), + } + + result = await knowledge_filter_service.create_knowledge_filter( + filter_doc, user_id=user.user_id, jwt_token=jwt_token + ) + + if result.get("success"): + return filter_id + else: + logger.error("Failed to create user document filter", error=result.get("error")) + return None + + def _get_flows_service(): """Helper function to get flows service instance""" from services.flows_service import FlowsService diff --git a/src/config/settings.py b/src/config/settings.py index df221986..b7c94936 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -165,18 +165,36 @@ async def generate_langflow_api_key(modify: bool = False): if validation_response.status_code == 200: logger.debug("Cached API key is valid", key_prefix=LANGFLOW_KEY[:8]) return LANGFLOW_KEY - else: + elif validation_response.status_code in (401, 403): logger.warning( - "Cached API key is invalid, generating fresh key", + "Cached API key is unauthorized, generating fresh key", status_code=validation_response.status_code, ) LANGFLOW_KEY = None # Clear invalid key - except Exception as e: + else: + logger.warning( + "Cached API key validation returned non-access error; keeping existing key", + status_code=validation_response.status_code, + ) + return LANGFLOW_KEY + except requests.exceptions.Timeout as e: logger.warning( - "Cached API key validation failed, generating fresh key", + "Cached API key validation timed out; keeping existing key", error=str(e), ) - LANGFLOW_KEY = None # Clear invalid key + return LANGFLOW_KEY + except requests.exceptions.RequestException as e: + logger.warning( + "Cached API key validation failed due to request error; keeping existing key", + error=str(e), + ) + return LANGFLOW_KEY + except Exception as e: + logger.warning( + "Unexpected error during cached API key validation; keeping existing key", + error=str(e), + ) + return LANGFLOW_KEY # Use default langflow/langflow credentials if auto-login is enabled and credentials not set username = LANGFLOW_SUPERUSER @@ -279,7 +297,8 @@ class AppClients: self.opensearch = None self.langflow_client = None self.langflow_http_client = None - self._patched_async_client = None # Private attribute + self._patched_llm_client = None # Private attribute + self._patched_embedding_client = None # Private attribute self._client_init_lock = __import__('threading').Lock() # Lock for thread-safe initialization self.converter = None @@ -358,114 +377,192 @@ class AppClients: self.langflow_client = None return self.langflow_client - @property - def patched_async_client(self): + def _build_provider_env(self, provider_type: str): """ - Property that ensures OpenAI client is initialized on first access. - This allows lazy initialization so the app can start without an API key. - - Note: The client is a long-lived singleton that should be closed via cleanup(). - Thread-safe via lock to prevent concurrent initialization attempts. + Build environment overrides for the requested provider type ("llm" or "embedding"). + This is used to support different credentials for LLM and embedding providers. """ - # Quick check without lock - if self._patched_async_client is not None: - return self._patched_async_client + config = get_openrag_config() + + if provider_type == "llm": + provider = (config.agent.llm_provider or "openai").lower() + else: + provider = (config.knowledge.embedding_provider or "openai").lower() + + env_overrides = {} + + if provider == "openai": + api_key = config.providers.openai.api_key or os.getenv("OPENAI_API_KEY") + if api_key: + env_overrides["OPENAI_API_KEY"] = api_key + elif provider == "anthropic": + api_key = config.providers.anthropic.api_key or os.getenv("ANTHROPIC_API_KEY") + if api_key: + env_overrides["ANTHROPIC_API_KEY"] = api_key + elif provider == "watsonx": + api_key = config.providers.watsonx.api_key or os.getenv("WATSONX_API_KEY") + endpoint = config.providers.watsonx.endpoint or os.getenv("WATSONX_ENDPOINT") + project_id = config.providers.watsonx.project_id or os.getenv("WATSONX_PROJECT_ID") + if api_key: + env_overrides["WATSONX_API_KEY"] = api_key + if endpoint: + env_overrides["WATSONX_ENDPOINT"] = endpoint + if project_id: + env_overrides["WATSONX_PROJECT_ID"] = project_id + elif provider == "ollama": + endpoint = config.providers.ollama.endpoint or os.getenv("OLLAMA_ENDPOINT") + if endpoint: + env_overrides["OLLAMA_ENDPOINT"] = endpoint + env_overrides["OLLAMA_BASE_URL"] = endpoint + + return env_overrides, provider + + def _apply_env_overrides(self, env_overrides: dict): + """Apply non-empty environment overrides.""" + for key, value in (env_overrides or {}).items(): + if value is None: + continue + os.environ[key] = str(value) + + def _initialize_patched_client(self, cache_attr: str, provider_type: str): + """ + Initialize a patched AsyncOpenAI client for the specified provider type. + Uses HTTP/2 probe only when an OpenAI key is present; otherwise falls back directly. + """ + # Quick path + cached_client = getattr(self, cache_attr) + if cached_client is not None: + return cached_client - # Use lock to ensure only one thread initializes with self._client_init_lock: - # Double-check after acquiring lock - if self._patched_async_client is not None: - return self._patched_async_client + cached_client = getattr(self, cache_attr) + if cached_client is not None: + return cached_client - # Try to initialize the client on-demand - # First check if OPENAI_API_KEY is in environment - openai_key = os.getenv("OPENAI_API_KEY") + env_overrides, provider_name = self._build_provider_env(provider_type) + self._apply_env_overrides(env_overrides) - if not openai_key: - # Try to get from config (in case it was set during onboarding) - try: - config = get_openrag_config() - if config and config.provider and config.provider.api_key: - openai_key = config.provider.api_key - # Set it in environment so AsyncOpenAI can pick it up - os.environ["OPENAI_API_KEY"] = openai_key - logger.info("Loaded OpenAI API key from config file") - except Exception as e: - logger.debug("Could not load OpenAI key from config", error=str(e)) + # Decide whether to run the HTTP/2 probe (only meaningful for OpenAI endpoints) + has_openai_key = bool(env_overrides.get("OPENAI_API_KEY") or os.getenv("OPENAI_API_KEY")) + if provider_name == "openai" and not has_openai_key: + raise ValueError("OPENAI_API_KEY is required for OpenAI provider") - # Try to initialize the client - AsyncOpenAI() will read from environment - # We'll try HTTP/2 first with a probe, then fall back to HTTP/1.1 if it times out import asyncio import concurrent.futures - import threading - async def probe_and_initialize(): - # Try HTTP/2 first (default) + async def build_client(skip_probe: bool = False): + if not has_openai_key: + # No OpenAI key present; create a basic patched client without probing + return patch_openai_with_mcp(AsyncOpenAI(http_client=httpx.AsyncClient())) + + if skip_probe: + http_client = httpx.AsyncClient(http2=False, timeout=httpx.Timeout(60.0, connect=10.0)) + return patch_openai_with_mcp(AsyncOpenAI(http_client=http_client)) + client_http2 = patch_openai_with_mcp(AsyncOpenAI()) - logger.info("Probing OpenAI client with HTTP/2...") + logger.info("Probing patched OpenAI client with HTTP/2...") try: - # Probe with a small embedding and short timeout await asyncio.wait_for( client_http2.embeddings.create( - model='text-embedding-3-small', - input=['test'] + model="text-embedding-3-small", + input=["test"], ), - timeout=5.0 + timeout=5.0, ) - logger.info("OpenAI client initialized with HTTP/2 (probe successful)") + logger.info("Patched OpenAI client initialized with HTTP/2 (probe successful)") return client_http2 except (asyncio.TimeoutError, Exception) as probe_error: logger.warning("HTTP/2 probe failed, falling back to HTTP/1.1", error=str(probe_error)) - # Close the HTTP/2 client try: await client_http2.close() except Exception: pass - # Fall back to HTTP/1.1 with explicit timeout settings http_client = httpx.AsyncClient( - http2=False, - timeout=httpx.Timeout(60.0, connect=10.0) + http2=False, timeout=httpx.Timeout(60.0, connect=10.0) ) - client_http1 = patch_openai_with_mcp( - AsyncOpenAI(http_client=http_client) - ) - logger.info("OpenAI client initialized with HTTP/1.1 (fallback)") + client_http1 = patch_openai_with_mcp(AsyncOpenAI(http_client=http_client)) + logger.info("Patched OpenAI client initialized with HTTP/1.1 (fallback)") return client_http1 - def run_probe_in_thread(): - """Run the async probe in a new thread with its own event loop""" + def run_builder(skip_probe=False): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: - return loop.run_until_complete(probe_and_initialize()) + return loop.run_until_complete(build_client(skip_probe=skip_probe)) finally: loop.close() try: - # Run the probe in a separate thread with its own event loop with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: - future = executor.submit(run_probe_in_thread) - self._patched_async_client = future.result(timeout=15) - logger.info("Successfully initialized OpenAI client") + future = executor.submit(run_builder, False) + client = future.result(timeout=15 if has_openai_key else 10) + setattr(self, cache_attr, client) + logger.info("Successfully initialized patched client", provider_type=provider_type) + return client except Exception as e: - logger.error(f"Failed to initialize OpenAI client: {e.__class__.__name__}: {str(e)}") - raise ValueError(f"Failed to initialize OpenAI client: {str(e)}. Please complete onboarding or set OPENAI_API_KEY environment variable.") + logger.error( + f"Failed to initialize patched client: {e.__class__.__name__}: {str(e)}", + provider_type=provider_type, + ) + raise ValueError( + f"Failed to initialize patched client for {provider_type}: {str(e)}. " + "Please ensure provider credentials are set." + ) - return self._patched_async_client + @property + def patched_llm_client(self): + """Patched client for LLM provider.""" + return self._initialize_patched_client("_patched_llm_client", "llm") + + @property + def patched_embedding_client(self): + """Patched client for embedding provider.""" + return self._initialize_patched_client("_patched_embedding_client", "embedding") + + @property + def patched_async_client(self): + """Backward-compatibility alias for LLM client.""" + return self.patched_llm_client + + async def refresh_patched_clients(self): + """Reset patched clients so next use picks up updated provider credentials.""" + clients_to_close = [] + with self._client_init_lock: + if self._patched_llm_client is not None: + clients_to_close.append(self._patched_llm_client) + self._patched_llm_client = None + if self._patched_embedding_client is not None: + clients_to_close.append(self._patched_embedding_client) + self._patched_embedding_client = None + + for client in clients_to_close: + try: + await client.close() + except Exception as e: + logger.warning("Failed to close patched client during refresh", error=str(e)) async def cleanup(self): """Cleanup resources - should be called on application shutdown""" # Close AsyncOpenAI client if it was created - if self._patched_async_client is not None: + if self._patched_llm_client is not None: try: - await self._patched_async_client.close() - logger.info("Closed AsyncOpenAI client") + await self._patched_llm_client.close() + logger.info("Closed LLM patched client") except Exception as e: - logger.error("Failed to close AsyncOpenAI client", error=str(e)) + logger.error("Failed to close LLM patched client", error=str(e)) finally: - self._patched_async_client = None + self._patched_llm_client = None + + if self._patched_embedding_client is not None: + try: + await self._patched_embedding_client.close() + logger.info("Closed embedding patched client") + except Exception as e: + logger.error("Failed to close embedding patched client", error=str(e)) + finally: + self._patched_embedding_client = None # Close Langflow HTTP client if it exists if self.langflow_http_client is not None: @@ -750,4 +847,4 @@ def get_agent_config(): def get_embedding_model() -> str: """Return the currently configured embedding model.""" - return get_openrag_config().knowledge.embedding_model or EMBED_MODEL if DISABLE_INGEST_WITH_LANGFLOW else "" \ No newline at end of file + return get_openrag_config().knowledge.embedding_model or EMBED_MODEL if DISABLE_INGEST_WITH_LANGFLOW else "" diff --git a/src/models/processors.py b/src/models/processors.py index 7edbc475..9731adb7 100644 --- a/src/models/processors.py +++ b/src/models/processors.py @@ -209,7 +209,7 @@ class TaskProcessor: embeddings = [] for batch in text_batches: - resp = await clients.patched_async_client.embeddings.create( + resp = await clients.patched_embedding_client.embeddings.create( model=embedding_model, input=batch ) embeddings.extend([d.embedding for d in resp.data]) diff --git a/src/services/chat_service.py b/src/services/chat_service.py index 040f03d8..cb697cf0 100644 --- a/src/services/chat_service.py +++ b/src/services/chat_service.py @@ -26,14 +26,14 @@ class ChatService: if stream: return async_chat_stream( - clients.patched_async_client, + clients.patched_llm_client, prompt, user_id, previous_response_id=previous_response_id, ) else: response_text, response_id = await async_chat( - clients.patched_async_client, + clients.patched_llm_client, prompt, user_id, previous_response_id=previous_response_id, @@ -344,7 +344,7 @@ class ChatService: if user_id and jwt_token: set_auth_context(user_id, jwt_token) response_text, response_id = await async_chat( - clients.patched_async_client, + clients.patched_llm_client, document_prompt, user_id, previous_response_id=previous_response_id, @@ -632,4 +632,3 @@ class ChatService: except Exception as e: logger.error(f"Error deleting session {session_id} from Langflow: {e}") return False - diff --git a/src/services/models_service.py b/src/services/models_service.py index f26d0594..979bcec2 100644 --- a/src/services/models_service.py +++ b/src/services/models_service.py @@ -108,7 +108,7 @@ class ModelsService: else: logger.error(f"Failed to fetch OpenAI models: {response.status_code}") raise Exception( - f"OpenAI API returned status code {response.status_code}" + f"OpenAI API returned status code {response.status_code}, {response.text}" ) except Exception as e: diff --git a/src/services/search_service.py b/src/services/search_service.py index 3261511d..07c1a796 100644 --- a/src/services/search_service.py +++ b/src/services/search_service.py @@ -150,7 +150,7 @@ class SearchService: while attempts < MAX_EMBED_RETRIES: attempts += 1 try: - resp = await clients.patched_async_client.embeddings.create( + resp = await clients.patched_embedding_client.embeddings.create( model=model_name, input=[query] ) return model_name, resp.data[0].embedding