This commit is contained in:
phact 2025-12-01 16:36:35 -05:00
parent 459c676c02
commit c106438d75
15 changed files with 489 additions and 94 deletions

View file

@ -3,6 +3,7 @@ import {
useMutation,
useQueryClient,
} from "@tanstack/react-query";
import { ONBOARDING_OPENRAG_DOCS_FILTER_ID_KEY } from "@/lib/constants";
export interface OnboardingVariables {
// Provider selection
@ -28,6 +29,7 @@ export interface OnboardingVariables {
interface OnboardingResponse {
message: string;
edited: boolean;
openrag_docs_filter_id?: string;
}
export const useOnboardingMutation = (
@ -59,6 +61,15 @@ export const useOnboardingMutation = (
return useMutation({
mutationFn: submitOnboarding,
onSuccess: (data) => {
// Store OpenRAG Docs filter ID if returned
if (data.openrag_docs_filter_id && typeof window !== "undefined") {
localStorage.setItem(
ONBOARDING_OPENRAG_DOCS_FILTER_ID_KEY,
data.openrag_docs_filter_id
);
}
},
onSettled: () => {
// Invalidate settings query to refetch updated data
queryClient.invalidateQueries({ queryKey: ["settings"] });

View file

@ -0,0 +1,21 @@
import type { KnowledgeFilter } from "./useGetFiltersSearchQuery";
export async function getFilterById(
filterId: string
): Promise<KnowledgeFilter | null> {
try {
const response = await fetch(`/api/knowledge-filter/${filterId}`, {
method: "GET",
headers: { "Content-Type": "application/json" },
});
const json = await response.json();
if (!response.ok || !json.success) {
return null;
}
return json.filter as KnowledgeFilter;
} catch (error) {
console.error("Failed to fetch filter by ID:", error);
return null;
}
}

View file

@ -5,17 +5,25 @@ import { StickToBottom } from "use-stick-to-bottom";
import { AssistantMessage } from "@/app/chat/_components/assistant-message";
import Nudges from "@/app/chat/_components/nudges";
import { UserMessage } from "@/app/chat/_components/user-message";
import type { Message } from "@/app/chat/_types/types";
import type { Message, SelectedFilters } from "@/app/chat/_types/types";
import OnboardingCard from "@/app/onboarding/_components/onboarding-card";
import { useChatStreaming } from "@/hooks/useChatStreaming";
import {
ONBOARDING_ASSISTANT_MESSAGE_KEY,
ONBOARDING_OPENRAG_DOCS_FILTER_ID_KEY,
ONBOARDING_SELECTED_NUDGE_KEY,
} from "@/lib/constants";
import { OnboardingStep } from "./onboarding-step";
import OnboardingUpload from "./onboarding-upload";
// Filters for OpenRAG documentation
const OPENRAG_DOCS_FILTERS: SelectedFilters = {
data_sources: ["openrag-documentation.pdf"],
document_types: [],
owners: [],
};
export function OnboardingContent({
handleStepComplete,
handleStepBack,
@ -115,9 +123,16 @@ export function OnboardingContent({
localStorage.removeItem(ONBOARDING_ASSISTANT_MESSAGE_KEY);
}
setTimeout(async () => {
// Check if we have the OpenRAG docs filter ID (sample data was ingested)
const hasOpenragDocsFilter =
typeof window !== "undefined" &&
localStorage.getItem(ONBOARDING_OPENRAG_DOCS_FILTER_ID_KEY);
await sendMessage({
prompt: nudge,
previousResponseId: responseId || undefined,
// Use OpenRAG docs filter if sample data was ingested
filters: hasOpenragDocsFilter ? OPENRAG_DOCS_FILTERS : undefined,
});
}, 1500);
};

View file

@ -4,7 +4,10 @@ import { useGetNudgesQuery } from "@/app/api/queries/useGetNudgesQuery";
import { useGetTasksQuery } from "@/app/api/queries/useGetTasksQuery";
import { AnimatedProviderSteps } from "@/app/onboarding/_components/animated-provider-steps";
import { Button } from "@/components/ui/button";
import { ONBOARDING_UPLOAD_STEPS_KEY } from "@/lib/constants";
import {
ONBOARDING_UPLOAD_STEPS_KEY,
ONBOARDING_USER_DOC_FILTER_ID_KEY,
} from "@/lib/constants";
import { uploadFile } from "@/lib/upload-utils";
interface OnboardingUploadProps {
@ -77,8 +80,17 @@ const OnboardingUpload = ({ onComplete }: OnboardingUploadProps) => {
setIsUploading(true);
try {
setCurrentStep(0);
await uploadFile(file, true);
const result = await uploadFile(file, true, true); // Pass createFilter=true
console.log("Document upload task started successfully");
// Store user document filter ID if returned
if (result.userDocFilterId && typeof window !== "undefined") {
localStorage.setItem(
ONBOARDING_USER_DOC_FILTER_ID_KEY,
result.userDocFilterId
);
}
// Move to processing step - task monitoring will handle completion
setTimeout(() => {
setCurrentStep(1);

View file

@ -50,7 +50,12 @@ export function OpenAIOnboarding({
: debouncedApiKey
? { apiKey: debouncedApiKey }
: undefined,
{ enabled: debouncedApiKey !== "" || getFromEnv || alreadyConfigured },
{
// Only validate when the user opts in (env) or provides a key.
// If a key was previously configured, let the user decide to reuse or replace it
// without triggering an immediate validation error.
enabled: debouncedApiKey !== "" || getFromEnv,
},
);
// Use custom hook for model selection logic
const {
@ -134,11 +139,12 @@ export function OpenAIOnboarding({
}
value={apiKey}
onChange={(e) => setApiKey(e.target.value)}
disabled={alreadyConfigured}
// Even if a key exists, allow replacing it to avoid getting stuck on stale creds.
disabled={false}
/>
{alreadyConfigured && (
<p className="text-mmd text-muted-foreground">
Reusing key from model provider selection.
Existing OpenAI key detected. You can reuse it or enter a new one.
</p>
)}
{isLoadingModels && (

View file

@ -2,11 +2,12 @@
import { motion } from "framer-motion";
import { usePathname } from "next/navigation";
import { useEffect, useState } from "react";
import { useCallback, useEffect, useState } from "react";
import {
type ChatConversation,
useGetConversationsQuery,
} from "@/app/api/queries/useGetConversationsQuery";
import { getFilterById } from "@/app/api/queries/useGetFilterByIdQuery";
import type { Settings } from "@/app/api/queries/useGetSettingsQuery";
import { OnboardingContent } from "@/app/onboarding/_components/onboarding-content";
import { ProgressBar } from "@/app/onboarding/_components/progress-bar";
@ -20,9 +21,11 @@ import {
HEADER_HEIGHT,
ONBOARDING_ASSISTANT_MESSAGE_KEY,
ONBOARDING_CARD_STEPS_KEY,
ONBOARDING_OPENRAG_DOCS_FILTER_ID_KEY,
ONBOARDING_SELECTED_NUDGE_KEY,
ONBOARDING_STEP_KEY,
ONBOARDING_UPLOAD_STEPS_KEY,
ONBOARDING_USER_DOC_FILTER_ID_KEY,
SIDEBAR_WIDTH,
TOTAL_ONBOARDING_STEPS,
} from "@/lib/constants";
@ -42,6 +45,7 @@ export function ChatRenderer({
refreshTrigger,
refreshConversations,
startNewConversation,
setConversationFilter,
} = useChat();
// Initialize onboarding state based on local storage and settings
@ -71,6 +75,42 @@ export function ChatRenderer({
startNewConversation();
};
// Helper to set the default filter after onboarding transition
const setOnboardingFilter = useCallback(
async (preferUserDoc: boolean) => {
if (typeof window === "undefined") return;
// Try to get the appropriate filter ID
let filterId: string | null = null;
if (preferUserDoc) {
// Completed full onboarding - prefer user document filter
filterId = localStorage.getItem(ONBOARDING_USER_DOC_FILTER_ID_KEY);
}
// Fall back to OpenRAG docs filter
if (!filterId) {
filterId = localStorage.getItem(ONBOARDING_OPENRAG_DOCS_FILTER_ID_KEY);
}
if (filterId) {
try {
const filter = await getFilterById(filterId);
if (filter) {
setConversationFilter(filter);
}
} catch (error) {
console.error("Failed to set onboarding filter:", error);
}
}
// Clean up onboarding filter IDs from localStorage
localStorage.removeItem(ONBOARDING_OPENRAG_DOCS_FILTER_ID_KEY);
localStorage.removeItem(ONBOARDING_USER_DOC_FILTER_ID_KEY);
},
[setConversationFilter]
);
// Save current step to local storage whenever it changes
useEffect(() => {
if (typeof window !== "undefined" && !showLayout) {
@ -90,6 +130,8 @@ export function ChatRenderer({
localStorage.removeItem(ONBOARDING_CARD_STEPS_KEY);
localStorage.removeItem(ONBOARDING_UPLOAD_STEPS_KEY);
}
// Set the user document filter as active (completed full onboarding)
setOnboardingFilter(true);
setShowLayout(true);
}
};
@ -109,6 +151,8 @@ export function ChatRenderer({
localStorage.removeItem(ONBOARDING_CARD_STEPS_KEY);
localStorage.removeItem(ONBOARDING_UPLOAD_STEPS_KEY);
}
// Set the OpenRAG docs filter as active (skipped onboarding - no user doc)
setOnboardingFilter(false);
setShowLayout(true);
};

View file

@ -45,6 +45,8 @@ export const ONBOARDING_ASSISTANT_MESSAGE_KEY = "onboarding_assistant_message";
export const ONBOARDING_SELECTED_NUDGE_KEY = "onboarding_selected_nudge";
export const ONBOARDING_CARD_STEPS_KEY = "onboarding_card_steps";
export const ONBOARDING_UPLOAD_STEPS_KEY = "onboarding_upload_steps";
export const ONBOARDING_OPENRAG_DOCS_FILTER_ID_KEY = "onboarding_openrag_docs_filter_id";
export const ONBOARDING_USER_DOC_FILTER_ID_KEY = "onboarding_user_doc_filter_id";
export const FILES_REGEX =
/(?<=I'm uploading a document called ['"])[^'"]+\.[^.]+(?=['"]\. Here is its content:)/;

View file

@ -10,6 +10,7 @@ export interface UploadFileResult {
deletion: unknown;
unified: boolean;
raw: unknown;
userDocFilterId?: string;
}
export async function duplicateCheck(
@ -120,11 +121,15 @@ export async function uploadFileForContext(
export async function uploadFile(
file: File,
replace = false,
createFilter = false,
): Promise<UploadFileResult> {
try {
const formData = new FormData();
formData.append("file", file);
formData.append("replace_duplicates", replace.toString());
if (createFilter) {
formData.append("create_filter", "true");
}
const uploadResponse = await fetch("/api/router/upload_ingest", {
method: "POST",
@ -177,6 +182,9 @@ export async function uploadFile(
);
}
const userDocFilterId = (uploadIngestJson as { user_doc_filter_id?: string })
.user_doc_filter_id;
const result: UploadFileResult = {
fileId,
filePath,
@ -184,6 +192,7 @@ export async function uploadFile(
deletion: deletionJson,
unified: true,
raw: uploadIngestJson,
userDocFilterId,
};
return result;

View file

@ -37,6 +37,7 @@ async def upload_ingest_router(
# Route based on configuration
if DISABLE_INGEST_WITH_LANGFLOW:
# Route to traditional OpenRAG upload
# Note: onboarding filter creation is only supported in Langflow path
logger.debug("Routing to traditional OpenRAG upload")
return await traditional_upload(request, document_service, session_manager)
else:
@ -77,6 +78,7 @@ async def langflow_upload_ingest_task(
tweaks_json = form.get("tweaks")
delete_after_ingest = form.get("delete_after_ingest", "true").lower() == "true"
replace_duplicates = form.get("replace_duplicates", "false").lower() == "true"
create_filter = form.get("create_filter", "false").lower() == "true"
# Parse JSON fields if provided
settings = None
@ -177,14 +179,36 @@ async def langflow_upload_ingest_task(
logger.debug("Langflow upload task created successfully", task_id=task_id)
return JSONResponse(
{
"task_id": task_id,
"message": f"Langflow upload task created for {len(upload_files)} file(s)",
"file_count": len(upload_files),
},
status_code=202,
) # 202 Accepted for async processing
# Create knowledge filter for the uploaded document if requested
user_doc_filter_id = None
if create_filter and len(original_filenames) == 1:
try:
from api.settings import _create_user_document_filter
user_doc_filter_id = await _create_user_document_filter(
request, session_manager, original_filenames[0]
)
if user_doc_filter_id:
logger.info(
"Created knowledge filter for uploaded document",
filter_id=user_doc_filter_id,
filename=original_filenames[0],
)
except Exception as e:
logger.error(
"Failed to create knowledge filter for uploaded document",
error=str(e),
)
# Don't fail the upload if filter creation fails
response_data = {
"task_id": task_id,
"message": f"Langflow upload task created for {len(upload_files)} file(s)",
"file_count": len(upload_files),
}
if user_doc_filter_id:
response_data["user_doc_filter_id"] = user_doc_filter_id
return JSONResponse(response_data, status_code=202) # 202 Accepted for async processing
except Exception:
# Clean up temp files on error

View file

@ -508,7 +508,7 @@ async def update_settings(request, session_manager):
# Update provider-specific settings
if "openai_api_key" in body and body["openai_api_key"].strip():
current_config.providers.openai.api_key = body["openai_api_key"]
current_config.providers.openai.api_key = body["openai_api_key"].strip()
current_config.providers.openai.configured = True
config_updated = True
@ -555,6 +555,9 @@ async def update_settings(request, session_manager):
"watsonx_api_key", "watsonx_endpoint", "watsonx_project_id",
"ollama_endpoint"
]
await clients.refresh_patched_clients()
if any(key in body for key in provider_fields_to_check):
try:
flows_service = _get_flows_service()
@ -562,8 +565,11 @@ async def update_settings(request, session_manager):
# Update global variables
await _update_langflow_global_variables(current_config)
# Update LLM client credentials when embedding selection changes
if "embedding_provider" in body or "embedding_model" in body:
await _update_mcp_servers_with_provider_credentials(current_config)
await _update_mcp_servers_with_provider_credentials(
current_config, session_manager
)
# Update model values if provider or model changed
if "llm_provider" in body or "llm_model" in body or "embedding_provider" in body or "embedding_model" in body:
@ -574,6 +580,7 @@ async def update_settings(request, session_manager):
# Don't fail the entire settings update if Langflow update fails
# The config was still saved
logger.info(
"Configuration updated successfully", updated_fields=list(body.keys())
)
@ -689,7 +696,7 @@ async def onboarding(request, flows_service, session_manager=None):
# Update provider-specific credentials
if "openai_api_key" in body and body["openai_api_key"].strip():
current_config.providers.openai.api_key = body["openai_api_key"]
current_config.providers.openai.api_key = body["openai_api_key"].strip()
current_config.providers.openai.configured = True
config_updated = True
@ -919,11 +926,33 @@ async def onboarding(request, flows_service, session_manager=None):
{"error": "Failed to save configuration"}, status_code=500
)
# Refresh cached patched clients so latest credentials take effect immediately
await clients.refresh_patched_clients()
# Create OpenRAG Docs knowledge filter if sample data was ingested
openrag_docs_filter_id = None
if should_ingest_sample_data:
try:
openrag_docs_filter_id = await _create_openrag_docs_filter(
request, session_manager
)
if openrag_docs_filter_id:
logger.info(
"Created OpenRAG Docs knowledge filter",
filter_id=openrag_docs_filter_id,
)
except Exception as e:
logger.error(
"Failed to create OpenRAG Docs knowledge filter", error=str(e)
)
# Don't fail onboarding if filter creation fails
return JSONResponse(
{
"message": "Onboarding configuration updated successfully",
"edited": True, # Confirm that config is now marked as edited
"sample_data_ingested": should_ingest_sample_data,
"openrag_docs_filter_id": openrag_docs_filter_id,
}
)
@ -935,6 +964,132 @@ async def onboarding(request, flows_service, session_manager=None):
)
async def _create_openrag_docs_filter(request, session_manager):
"""Create the OpenRAG Docs knowledge filter for onboarding"""
import uuid
import json
from datetime import datetime
from session_manager import AnonymousUser
# Get knowledge filter service from app state
app = request.scope.get("app")
if not app or not hasattr(app.state, "services"):
logger.error("Could not access services for knowledge filter creation")
return None
knowledge_filter_service = app.state.services.get("knowledge_filter_service")
if not knowledge_filter_service:
logger.error("Knowledge filter service not available")
return None
# Use anonymous user for no-auth mode compatibility
user = AnonymousUser()
jwt_token = session_manager.get_effective_jwt_token(user.user_id, None)
# Create the filter document
filter_id = str(uuid.uuid4())
query_data = json.dumps({
"query": "",
"filters": {
"data_sources": ["openrag-documentation.pdf"],
"document_types": ["*"],
"owners": ["*"],
"connector_types": ["*"],
},
"limit": 10,
"scoreThreshold": 0,
"color": "blue",
"icon": "book",
})
filter_doc = {
"id": filter_id,
"name": "OpenRAG Docs",
"description": "Filter for OpenRAG documentation",
"query_data": query_data,
"owner": user.user_id,
"allowed_users": [],
"allowed_groups": [],
"created_at": datetime.utcnow().isoformat(),
"updated_at": datetime.utcnow().isoformat(),
}
result = await knowledge_filter_service.create_knowledge_filter(
filter_doc, user_id=user.user_id, jwt_token=jwt_token
)
if result.get("success"):
return filter_id
else:
logger.error("Failed to create OpenRAG Docs filter", error=result.get("error"))
return None
async def _create_user_document_filter(request, session_manager, filename):
"""Create a knowledge filter for a user-uploaded document during onboarding"""
import uuid
import json
from datetime import datetime
from session_manager import AnonymousUser
# Get knowledge filter service from app state
app = request.scope.get("app")
if not app or not hasattr(app.state, "services"):
logger.error("Could not access services for knowledge filter creation")
return None
knowledge_filter_service = app.state.services.get("knowledge_filter_service")
if not knowledge_filter_service:
logger.error("Knowledge filter service not available")
return None
# Use anonymous user for no-auth mode compatibility
user = AnonymousUser()
jwt_token = session_manager.get_effective_jwt_token(user.user_id, None)
# Create the filter document
filter_id = str(uuid.uuid4())
# Get a display name from the filename (remove extension for cleaner name)
display_name = filename.rsplit(".", 1)[0] if "." in filename else filename
query_data = json.dumps({
"query": "",
"filters": {
"data_sources": [filename],
"document_types": ["*"],
"owners": ["*"],
"connector_types": ["*"],
},
"limit": 10,
"scoreThreshold": 0,
"color": "green",
"icon": "file",
})
filter_doc = {
"id": filter_id,
"name": display_name,
"description": f"Filter for {filename}",
"query_data": query_data,
"owner": user.user_id,
"allowed_users": [],
"allowed_groups": [],
"created_at": datetime.utcnow().isoformat(),
"updated_at": datetime.utcnow().isoformat(),
}
result = await knowledge_filter_service.create_knowledge_filter(
filter_doc, user_id=user.user_id, jwt_token=jwt_token
)
if result.get("success"):
return filter_id
else:
logger.error("Failed to create user document filter", error=result.get("error"))
return None
def _get_flows_service():
"""Helper function to get flows service instance"""
from services.flows_service import FlowsService

View file

@ -165,18 +165,36 @@ async def generate_langflow_api_key(modify: bool = False):
if validation_response.status_code == 200:
logger.debug("Cached API key is valid", key_prefix=LANGFLOW_KEY[:8])
return LANGFLOW_KEY
else:
elif validation_response.status_code in (401, 403):
logger.warning(
"Cached API key is invalid, generating fresh key",
"Cached API key is unauthorized, generating fresh key",
status_code=validation_response.status_code,
)
LANGFLOW_KEY = None # Clear invalid key
except Exception as e:
else:
logger.warning(
"Cached API key validation returned non-access error; keeping existing key",
status_code=validation_response.status_code,
)
return LANGFLOW_KEY
except requests.exceptions.Timeout as e:
logger.warning(
"Cached API key validation failed, generating fresh key",
"Cached API key validation timed out; keeping existing key",
error=str(e),
)
LANGFLOW_KEY = None # Clear invalid key
return LANGFLOW_KEY
except requests.exceptions.RequestException as e:
logger.warning(
"Cached API key validation failed due to request error; keeping existing key",
error=str(e),
)
return LANGFLOW_KEY
except Exception as e:
logger.warning(
"Unexpected error during cached API key validation; keeping existing key",
error=str(e),
)
return LANGFLOW_KEY
# Use default langflow/langflow credentials if auto-login is enabled and credentials not set
username = LANGFLOW_SUPERUSER
@ -279,7 +297,8 @@ class AppClients:
self.opensearch = None
self.langflow_client = None
self.langflow_http_client = None
self._patched_async_client = None # Private attribute
self._patched_llm_client = None # Private attribute
self._patched_embedding_client = None # Private attribute
self._client_init_lock = __import__('threading').Lock() # Lock for thread-safe initialization
self.converter = None
@ -358,114 +377,192 @@ class AppClients:
self.langflow_client = None
return self.langflow_client
@property
def patched_async_client(self):
def _build_provider_env(self, provider_type: str):
"""
Property that ensures OpenAI client is initialized on first access.
This allows lazy initialization so the app can start without an API key.
Note: The client is a long-lived singleton that should be closed via cleanup().
Thread-safe via lock to prevent concurrent initialization attempts.
Build environment overrides for the requested provider type ("llm" or "embedding").
This is used to support different credentials for LLM and embedding providers.
"""
# Quick check without lock
if self._patched_async_client is not None:
return self._patched_async_client
config = get_openrag_config()
if provider_type == "llm":
provider = (config.agent.llm_provider or "openai").lower()
else:
provider = (config.knowledge.embedding_provider or "openai").lower()
env_overrides = {}
if provider == "openai":
api_key = config.providers.openai.api_key or os.getenv("OPENAI_API_KEY")
if api_key:
env_overrides["OPENAI_API_KEY"] = api_key
elif provider == "anthropic":
api_key = config.providers.anthropic.api_key or os.getenv("ANTHROPIC_API_KEY")
if api_key:
env_overrides["ANTHROPIC_API_KEY"] = api_key
elif provider == "watsonx":
api_key = config.providers.watsonx.api_key or os.getenv("WATSONX_API_KEY")
endpoint = config.providers.watsonx.endpoint or os.getenv("WATSONX_ENDPOINT")
project_id = config.providers.watsonx.project_id or os.getenv("WATSONX_PROJECT_ID")
if api_key:
env_overrides["WATSONX_API_KEY"] = api_key
if endpoint:
env_overrides["WATSONX_ENDPOINT"] = endpoint
if project_id:
env_overrides["WATSONX_PROJECT_ID"] = project_id
elif provider == "ollama":
endpoint = config.providers.ollama.endpoint or os.getenv("OLLAMA_ENDPOINT")
if endpoint:
env_overrides["OLLAMA_ENDPOINT"] = endpoint
env_overrides["OLLAMA_BASE_URL"] = endpoint
return env_overrides, provider
def _apply_env_overrides(self, env_overrides: dict):
"""Apply non-empty environment overrides."""
for key, value in (env_overrides or {}).items():
if value is None:
continue
os.environ[key] = str(value)
def _initialize_patched_client(self, cache_attr: str, provider_type: str):
"""
Initialize a patched AsyncOpenAI client for the specified provider type.
Uses HTTP/2 probe only when an OpenAI key is present; otherwise falls back directly.
"""
# Quick path
cached_client = getattr(self, cache_attr)
if cached_client is not None:
return cached_client
# Use lock to ensure only one thread initializes
with self._client_init_lock:
# Double-check after acquiring lock
if self._patched_async_client is not None:
return self._patched_async_client
cached_client = getattr(self, cache_attr)
if cached_client is not None:
return cached_client
# Try to initialize the client on-demand
# First check if OPENAI_API_KEY is in environment
openai_key = os.getenv("OPENAI_API_KEY")
env_overrides, provider_name = self._build_provider_env(provider_type)
self._apply_env_overrides(env_overrides)
if not openai_key:
# Try to get from config (in case it was set during onboarding)
try:
config = get_openrag_config()
if config and config.provider and config.provider.api_key:
openai_key = config.provider.api_key
# Set it in environment so AsyncOpenAI can pick it up
os.environ["OPENAI_API_KEY"] = openai_key
logger.info("Loaded OpenAI API key from config file")
except Exception as e:
logger.debug("Could not load OpenAI key from config", error=str(e))
# Decide whether to run the HTTP/2 probe (only meaningful for OpenAI endpoints)
has_openai_key = bool(env_overrides.get("OPENAI_API_KEY") or os.getenv("OPENAI_API_KEY"))
if provider_name == "openai" and not has_openai_key:
raise ValueError("OPENAI_API_KEY is required for OpenAI provider")
# Try to initialize the client - AsyncOpenAI() will read from environment
# We'll try HTTP/2 first with a probe, then fall back to HTTP/1.1 if it times out
import asyncio
import concurrent.futures
import threading
async def probe_and_initialize():
# Try HTTP/2 first (default)
async def build_client(skip_probe: bool = False):
if not has_openai_key:
# No OpenAI key present; create a basic patched client without probing
return patch_openai_with_mcp(AsyncOpenAI(http_client=httpx.AsyncClient()))
if skip_probe:
http_client = httpx.AsyncClient(http2=False, timeout=httpx.Timeout(60.0, connect=10.0))
return patch_openai_with_mcp(AsyncOpenAI(http_client=http_client))
client_http2 = patch_openai_with_mcp(AsyncOpenAI())
logger.info("Probing OpenAI client with HTTP/2...")
logger.info("Probing patched OpenAI client with HTTP/2...")
try:
# Probe with a small embedding and short timeout
await asyncio.wait_for(
client_http2.embeddings.create(
model='text-embedding-3-small',
input=['test']
model="text-embedding-3-small",
input=["test"],
),
timeout=5.0
timeout=5.0,
)
logger.info("OpenAI client initialized with HTTP/2 (probe successful)")
logger.info("Patched OpenAI client initialized with HTTP/2 (probe successful)")
return client_http2
except (asyncio.TimeoutError, Exception) as probe_error:
logger.warning("HTTP/2 probe failed, falling back to HTTP/1.1", error=str(probe_error))
# Close the HTTP/2 client
try:
await client_http2.close()
except Exception:
pass
# Fall back to HTTP/1.1 with explicit timeout settings
http_client = httpx.AsyncClient(
http2=False,
timeout=httpx.Timeout(60.0, connect=10.0)
http2=False, timeout=httpx.Timeout(60.0, connect=10.0)
)
client_http1 = patch_openai_with_mcp(
AsyncOpenAI(http_client=http_client)
)
logger.info("OpenAI client initialized with HTTP/1.1 (fallback)")
client_http1 = patch_openai_with_mcp(AsyncOpenAI(http_client=http_client))
logger.info("Patched OpenAI client initialized with HTTP/1.1 (fallback)")
return client_http1
def run_probe_in_thread():
"""Run the async probe in a new thread with its own event loop"""
def run_builder(skip_probe=False):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(probe_and_initialize())
return loop.run_until_complete(build_client(skip_probe=skip_probe))
finally:
loop.close()
try:
# Run the probe in a separate thread with its own event loop
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(run_probe_in_thread)
self._patched_async_client = future.result(timeout=15)
logger.info("Successfully initialized OpenAI client")
future = executor.submit(run_builder, False)
client = future.result(timeout=15 if has_openai_key else 10)
setattr(self, cache_attr, client)
logger.info("Successfully initialized patched client", provider_type=provider_type)
return client
except Exception as e:
logger.error(f"Failed to initialize OpenAI client: {e.__class__.__name__}: {str(e)}")
raise ValueError(f"Failed to initialize OpenAI client: {str(e)}. Please complete onboarding or set OPENAI_API_KEY environment variable.")
logger.error(
f"Failed to initialize patched client: {e.__class__.__name__}: {str(e)}",
provider_type=provider_type,
)
raise ValueError(
f"Failed to initialize patched client for {provider_type}: {str(e)}. "
"Please ensure provider credentials are set."
)
return self._patched_async_client
@property
def patched_llm_client(self):
"""Patched client for LLM provider."""
return self._initialize_patched_client("_patched_llm_client", "llm")
@property
def patched_embedding_client(self):
"""Patched client for embedding provider."""
return self._initialize_patched_client("_patched_embedding_client", "embedding")
@property
def patched_async_client(self):
"""Backward-compatibility alias for LLM client."""
return self.patched_llm_client
async def refresh_patched_clients(self):
"""Reset patched clients so next use picks up updated provider credentials."""
clients_to_close = []
with self._client_init_lock:
if self._patched_llm_client is not None:
clients_to_close.append(self._patched_llm_client)
self._patched_llm_client = None
if self._patched_embedding_client is not None:
clients_to_close.append(self._patched_embedding_client)
self._patched_embedding_client = None
for client in clients_to_close:
try:
await client.close()
except Exception as e:
logger.warning("Failed to close patched client during refresh", error=str(e))
async def cleanup(self):
"""Cleanup resources - should be called on application shutdown"""
# Close AsyncOpenAI client if it was created
if self._patched_async_client is not None:
if self._patched_llm_client is not None:
try:
await self._patched_async_client.close()
logger.info("Closed AsyncOpenAI client")
await self._patched_llm_client.close()
logger.info("Closed LLM patched client")
except Exception as e:
logger.error("Failed to close AsyncOpenAI client", error=str(e))
logger.error("Failed to close LLM patched client", error=str(e))
finally:
self._patched_async_client = None
self._patched_llm_client = None
if self._patched_embedding_client is not None:
try:
await self._patched_embedding_client.close()
logger.info("Closed embedding patched client")
except Exception as e:
logger.error("Failed to close embedding patched client", error=str(e))
finally:
self._patched_embedding_client = None
# Close Langflow HTTP client if it exists
if self.langflow_http_client is not None:
@ -750,4 +847,4 @@ def get_agent_config():
def get_embedding_model() -> str:
"""Return the currently configured embedding model."""
return get_openrag_config().knowledge.embedding_model or EMBED_MODEL if DISABLE_INGEST_WITH_LANGFLOW else ""
return get_openrag_config().knowledge.embedding_model or EMBED_MODEL if DISABLE_INGEST_WITH_LANGFLOW else ""

View file

@ -209,7 +209,7 @@ class TaskProcessor:
embeddings = []
for batch in text_batches:
resp = await clients.patched_async_client.embeddings.create(
resp = await clients.patched_embedding_client.embeddings.create(
model=embedding_model, input=batch
)
embeddings.extend([d.embedding for d in resp.data])

View file

@ -26,14 +26,14 @@ class ChatService:
if stream:
return async_chat_stream(
clients.patched_async_client,
clients.patched_llm_client,
prompt,
user_id,
previous_response_id=previous_response_id,
)
else:
response_text, response_id = await async_chat(
clients.patched_async_client,
clients.patched_llm_client,
prompt,
user_id,
previous_response_id=previous_response_id,
@ -344,7 +344,7 @@ class ChatService:
if user_id and jwt_token:
set_auth_context(user_id, jwt_token)
response_text, response_id = await async_chat(
clients.patched_async_client,
clients.patched_llm_client,
document_prompt,
user_id,
previous_response_id=previous_response_id,
@ -632,4 +632,3 @@ class ChatService:
except Exception as e:
logger.error(f"Error deleting session {session_id} from Langflow: {e}")
return False

View file

@ -108,7 +108,7 @@ class ModelsService:
else:
logger.error(f"Failed to fetch OpenAI models: {response.status_code}")
raise Exception(
f"OpenAI API returned status code {response.status_code}"
f"OpenAI API returned status code {response.status_code}, {response.text}"
)
except Exception as e:

View file

@ -150,7 +150,7 @@ class SearchService:
while attempts < MAX_EMBED_RETRIES:
attempts += 1
try:
resp = await clients.patched_async_client.embeddings.create(
resp = await clients.patched_embedding_client.embeddings.create(
model=model_name, input=[query]
)
return model_name, resp.data[0].embedding