Merge pull request #254 from langflow-ai/multi-embedding-support
Multi embedding support
This commit is contained in:
commit
5f7fc66655
36 changed files with 3524 additions and 297 deletions
5
.github/workflows/test-integration.yml
vendored
5
.github/workflows/test-integration.yml
vendored
|
|
@ -16,6 +16,11 @@ jobs:
|
|||
# Prefer repository/environment variable first, then secret, then a sane fallback
|
||||
OPENSEARCH_PASSWORD: ${{ vars.OPENSEARCH_PASSWORD || secrets.OPENSEARCH_PASSWORD || 'OpenRag#2025!' }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
LANGFLOW_AUTO_LOGIN: "True"
|
||||
LANGFLOW_NEW_USER_IS_ACTIVE: "True"
|
||||
LANGFLOW_ENABLE_SUPERUSER_CLI: "True"
|
||||
LANGFLOW_CHAT_FLOW_ID: ${{ vars.LANGFLOW_CHAT_FLOW_ID || '1098eea1-6649-4e1d-aed1-b77249fb8dd0' }}
|
||||
NUDGES_FLOW_ID: ${{ vars.NUDGES_FLOW_ID || 'ebc01d31-1976-46ce-a385-b0240327226c' }}
|
||||
|
||||
steps:
|
||||
- run: df -h
|
||||
|
|
|
|||
|
|
@ -53,10 +53,11 @@ RUN echo y | opensearch-plugin install repository-s3
|
|||
# Create a script to apply security configuration after OpenSearch starts
|
||||
RUN echo '#!/bin/bash' > /usr/share/opensearch/setup-security.sh && \
|
||||
echo 'echo "Waiting for OpenSearch to start..."' >> /usr/share/opensearch/setup-security.sh && \
|
||||
echo 'until curl -s -k -u admin:${OPENSEARCH_INITIAL_ADMIN_PASSWORD} https://localhost:9200; do sleep 1; done' >> /usr/share/opensearch/setup-security.sh && \
|
||||
echo 'echo "Generating admin hash from OPENSEARCH_INITIAL_ADMIN_PASSWORD..."' >> /usr/share/opensearch/setup-security.sh && \
|
||||
echo 'if [ -z "${OPENSEARCH_INITIAL_ADMIN_PASSWORD}" ]; then echo "[ERROR] OPENSEARCH_INITIAL_ADMIN_PASSWORD not set"; exit 1; fi' >> /usr/share/opensearch/setup-security.sh && \
|
||||
echo 'HASH=$(/usr/share/opensearch/plugins/opensearch-security/tools/hash.sh -p "${OPENSEARCH_INITIAL_ADMIN_PASSWORD}")' >> /usr/share/opensearch/setup-security.sh && \
|
||||
echo 'PASSWORD=${OPENSEARCH_INITIAL_ADMIN_PASSWORD:-${OPENSEARCH_PASSWORD}}' >> /usr/share/opensearch/setup-security.sh && \
|
||||
echo 'if [ -z "$PASSWORD" ]; then echo "[ERROR] OPENSEARCH_INITIAL_ADMIN_PASSWORD or OPENSEARCH_PASSWORD must be set"; exit 1; fi' >> /usr/share/opensearch/setup-security.sh && \
|
||||
echo 'until curl -s -k -u admin:$PASSWORD https://localhost:9200; do sleep 1; done' >> /usr/share/opensearch/setup-security.sh && \
|
||||
echo 'echo "Generating admin hash from configured password..."' >> /usr/share/opensearch/setup-security.sh && \
|
||||
echo 'HASH=$(/usr/share/opensearch/plugins/opensearch-security/tools/hash.sh -p "$PASSWORD")' >> /usr/share/opensearch/setup-security.sh && \
|
||||
echo 'if [ -z "$HASH" ]; then echo "[ERROR] Failed to generate admin hash"; exit 1; fi' >> /usr/share/opensearch/setup-security.sh && \
|
||||
echo 'sed -i "s|^ hash: \".*\"| hash: \"$HASH\"|" /usr/share/opensearch/securityconfig/internal_users.yml' >> /usr/share/opensearch/setup-security.sh && \
|
||||
echo 'echo "Updated internal_users.yml with runtime-generated admin hash"' >> /usr/share/opensearch/setup-security.sh && \
|
||||
|
|
|
|||
2
Makefile
2
Makefile
|
|
@ -206,6 +206,8 @@ test-ci:
|
|||
docker compose -f docker-compose-cpu.yml down -v 2>/dev/null || true; \
|
||||
echo "Pulling latest images..."; \
|
||||
docker compose -f docker-compose-cpu.yml pull; \
|
||||
echo "Building OpenSearch image override..."; \
|
||||
docker build --no-cache -t phact/openrag-opensearch:latest -f Dockerfile .; \
|
||||
echo "Starting infra (OpenSearch + Dashboards + Langflow) with CPU containers"; \
|
||||
docker compose -f docker-compose-cpu.yml up -d opensearch dashboards langflow; \
|
||||
echo "Starting docling-serve..."; \
|
||||
|
|
|
|||
1291
flows/components/opensearch.py
Normal file
1291
flows/components/opensearch.py
Normal file
File diff suppressed because it is too large
Load diff
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
|
@ -34,20 +34,25 @@ export const EmbeddingModelInput = ({
|
|||
modelsData,
|
||||
currentProvider = "openai",
|
||||
}: EmbeddingModelInputProps) => {
|
||||
const isDisabled = Boolean(disabled);
|
||||
const tooltipMessage = isDisabled
|
||||
? "Locked to keep embeddings consistent"
|
||||
: "Choose the embedding model for ingest and retrieval";
|
||||
|
||||
return (
|
||||
<LabelWrapper
|
||||
helperText="Model used for knowledge ingest and retrieval"
|
||||
id="embedding-model-select"
|
||||
label="Embedding model"
|
||||
>
|
||||
<Select disabled={disabled} value={value} onValueChange={onChange}>
|
||||
<Select disabled={isDisabled} value={value} onValueChange={onChange}>
|
||||
<Tooltip delayDuration={0}>
|
||||
<TooltipTrigger asChild>
|
||||
<SelectTrigger disabled id="embedding-model-select">
|
||||
<SelectTrigger disabled={isDisabled} id="embedding-model-select">
|
||||
<SelectValue placeholder="Select an embedding model" />
|
||||
</SelectTrigger>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>Locked to keep embeddings consistent</TooltipContent>
|
||||
<TooltipContent>{tooltipMessage}</TooltipContent>
|
||||
</Tooltip>
|
||||
<SelectContent>
|
||||
<ModelSelectItems
|
||||
|
|
|
|||
|
|
@ -29,6 +29,8 @@ export interface ChunkResult {
|
|||
owner_email?: string;
|
||||
file_size?: number;
|
||||
connector_type?: string;
|
||||
embedding_model?: string;
|
||||
embedding_dimensions?: number;
|
||||
index?: number;
|
||||
}
|
||||
|
||||
|
|
@ -43,6 +45,8 @@ export interface File {
|
|||
owner_email?: string;
|
||||
size: number;
|
||||
connector_type: string;
|
||||
embedding_model?: string;
|
||||
embedding_dimensions?: number;
|
||||
status?:
|
||||
| "processing"
|
||||
| "active"
|
||||
|
|
@ -50,6 +54,7 @@ export interface File {
|
|||
| "failed"
|
||||
| "hidden"
|
||||
| "sync";
|
||||
error?: string;
|
||||
chunks?: ChunkResult[];
|
||||
}
|
||||
|
||||
|
|
@ -133,6 +138,8 @@ export const useGetSearchQuery = (
|
|||
owner_email?: string;
|
||||
file_size?: number;
|
||||
connector_type?: string;
|
||||
embedding_model?: string;
|
||||
embedding_dimensions?: number;
|
||||
}
|
||||
>();
|
||||
|
||||
|
|
@ -141,6 +148,15 @@ export const useGetSearchQuery = (
|
|||
if (existing) {
|
||||
existing.chunks.push(chunk);
|
||||
existing.totalScore += chunk.score;
|
||||
if (!existing.embedding_model && chunk.embedding_model) {
|
||||
existing.embedding_model = chunk.embedding_model;
|
||||
}
|
||||
if (
|
||||
existing.embedding_dimensions == null &&
|
||||
typeof chunk.embedding_dimensions === "number"
|
||||
) {
|
||||
existing.embedding_dimensions = chunk.embedding_dimensions;
|
||||
}
|
||||
} else {
|
||||
fileMap.set(chunk.filename, {
|
||||
filename: chunk.filename,
|
||||
|
|
@ -153,6 +169,8 @@ export const useGetSearchQuery = (
|
|||
owner_email: chunk.owner_email,
|
||||
file_size: chunk.file_size,
|
||||
connector_type: chunk.connector_type,
|
||||
embedding_model: chunk.embedding_model,
|
||||
embedding_dimensions: chunk.embedding_dimensions,
|
||||
});
|
||||
}
|
||||
});
|
||||
|
|
@ -168,6 +186,8 @@ export const useGetSearchQuery = (
|
|||
owner_email: file.owner_email || "",
|
||||
size: file.file_size || 0,
|
||||
connector_type: file.connector_type || "local",
|
||||
embedding_model: file.embedding_model,
|
||||
embedding_dimensions: file.embedding_dimensions,
|
||||
chunks: file.chunks,
|
||||
}));
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,26 @@ import {
|
|||
useQueryClient,
|
||||
} from "@tanstack/react-query";
|
||||
|
||||
export interface TaskFileEntry {
|
||||
status?:
|
||||
| "pending"
|
||||
| "running"
|
||||
| "processing"
|
||||
| "completed"
|
||||
| "failed"
|
||||
| "error";
|
||||
result?: unknown;
|
||||
error?: string;
|
||||
retry_count?: number;
|
||||
created_at?: string;
|
||||
updated_at?: string;
|
||||
duration_seconds?: number;
|
||||
filename?: string;
|
||||
embedding_model?: string;
|
||||
embedding_dimensions?: number;
|
||||
[key: string]: unknown;
|
||||
}
|
||||
|
||||
export interface Task {
|
||||
task_id: string;
|
||||
status:
|
||||
|
|
@ -24,7 +44,7 @@ export interface Task {
|
|||
duration_seconds?: number;
|
||||
result?: Record<string, unknown>;
|
||||
error?: string;
|
||||
files?: Record<string, Record<string, unknown>>;
|
||||
files?: Record<string, TaskFileEntry>;
|
||||
}
|
||||
|
||||
export interface TasksResponse {
|
||||
|
|
|
|||
|
|
@ -168,7 +168,7 @@
|
|||
}
|
||||
|
||||
.header-notifications {
|
||||
@apply absolute right-[0px] top-[-4px] h-1 w-1 rounded-full bg-destructive;
|
||||
@apply absolute right-1 top-1 h-2 w-2 rounded-full bg-destructive;
|
||||
}
|
||||
|
||||
.header-menu-bar {
|
||||
|
|
|
|||
|
|
@ -26,6 +26,14 @@ import GoogleDriveIcon from "../settings/icons/google-drive-icon";
|
|||
import OneDriveIcon from "../settings/icons/one-drive-icon";
|
||||
import SharePointIcon from "../settings/icons/share-point-icon";
|
||||
import { KnowledgeSearchInput } from "@/components/knowledge-search-input";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
DialogTrigger,
|
||||
} from "@/components/ui/dialog";
|
||||
|
||||
// Function to get the appropriate icon for a connector type
|
||||
function getSourceIcon(connectorType?: string) {
|
||||
|
|
@ -77,6 +85,9 @@ function SearchPage() {
|
|||
size: taskFile.size,
|
||||
connector_type: taskFile.connector_type,
|
||||
status: taskFile.status,
|
||||
error: taskFile.error,
|
||||
embedding_model: taskFile.embedding_model,
|
||||
embedding_dimensions: taskFile.embedding_dimensions,
|
||||
};
|
||||
});
|
||||
|
||||
|
|
@ -115,7 +126,7 @@ function SearchPage() {
|
|||
|
||||
const gridRef = useRef<AgGridReact>(null);
|
||||
|
||||
const columnDefs = [
|
||||
const columnDefs: ColDef<File>[] = [
|
||||
{
|
||||
field: "filename",
|
||||
headerName: "Source",
|
||||
|
|
@ -128,7 +139,6 @@ function SearchPage() {
|
|||
// Read status directly from data on each render
|
||||
const status = data?.status || "active";
|
||||
const isActive = status === "active";
|
||||
console.log(data?.filename, status, "a");
|
||||
return (
|
||||
<div className="flex items-center overflow-hidden w-full">
|
||||
<div
|
||||
|
|
@ -192,13 +202,63 @@ function SearchPage() {
|
|||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
field: "embedding_model",
|
||||
headerName: "Embedding model",
|
||||
minWidth: 200,
|
||||
cellRenderer: ({ data }: CustomCellRendererProps<File>) => (
|
||||
<span className="text-xs text-muted-foreground">
|
||||
{data?.embedding_model || "—"}
|
||||
</span>
|
||||
),
|
||||
},
|
||||
{
|
||||
field: "embedding_dimensions",
|
||||
headerName: "Dimensions",
|
||||
width: 110,
|
||||
cellRenderer: ({ data }: CustomCellRendererProps<File>) => (
|
||||
<span className="text-xs text-muted-foreground">
|
||||
{typeof data?.embedding_dimensions === "number"
|
||||
? data.embedding_dimensions.toString()
|
||||
: "—"}
|
||||
</span>
|
||||
),
|
||||
},
|
||||
{
|
||||
field: "status",
|
||||
headerName: "Status",
|
||||
cellRenderer: ({ data }: CustomCellRendererProps<File>) => {
|
||||
console.log(data?.filename, data?.status, "b");
|
||||
// Default to 'active' status if no status is provided
|
||||
const status = data?.status || "active";
|
||||
const error =
|
||||
typeof data?.error === "string" && data.error.trim().length > 0
|
||||
? data.error.trim()
|
||||
: undefined;
|
||||
if (status === "failed" && error) {
|
||||
return (
|
||||
<Dialog>
|
||||
<DialogTrigger asChild>
|
||||
<button
|
||||
type="button"
|
||||
className="inline-flex items-center gap-1 text-red-500 transition hover:text-red-400"
|
||||
aria-label="View ingestion error"
|
||||
>
|
||||
<StatusBadge status={status} className="pointer-events-none" />
|
||||
</button>
|
||||
</DialogTrigger>
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle>Ingestion failed</DialogTitle>
|
||||
<DialogDescription className="text-sm text-muted-foreground">
|
||||
{data?.filename || "Unknown file"}
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
<div className="rounded-md border border-destructive/20 bg-destructive/10 p-4 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
return <StatusBadge status={status} />;
|
||||
},
|
||||
},
|
||||
|
|
|
|||
|
|
@ -77,7 +77,7 @@ export function AdvancedOnboarding({
|
|||
{(hasLanguageModels || hasEmbeddingModels) && <Separator />}
|
||||
<LabelWrapper
|
||||
label="Sample dataset"
|
||||
description="Load 2 sample PDFs to chat with data immediately."
|
||||
description="Load sample data to chat with immediately."
|
||||
id="sample-dataset"
|
||||
flex
|
||||
>
|
||||
|
|
|
|||
|
|
@ -243,6 +243,8 @@ function KnowledgeSourcesPage() {
|
|||
updateFlowSettingMutation.mutate({ embedding_model: newModel });
|
||||
};
|
||||
|
||||
const isEmbeddingModelSelectDisabled = updateFlowSettingMutation.isPending;
|
||||
|
||||
// Update chunk size setting with debounce
|
||||
const handleChunkSizeChange = (value: string) => {
|
||||
const numValue = Math.max(0, parseInt(value) || 0);
|
||||
|
|
@ -1029,8 +1031,7 @@ function KnowledgeSourcesPage() {
|
|||
label="Embedding model"
|
||||
>
|
||||
<Select
|
||||
// Disabled until API supports multiple embedding models
|
||||
disabled={true}
|
||||
disabled={isEmbeddingModelSelectDisabled}
|
||||
value={
|
||||
settings.knowledge?.embedding_model ||
|
||||
modelsData?.embedding_models?.find((m) => m.default)
|
||||
|
|
@ -1041,12 +1042,17 @@ function KnowledgeSourcesPage() {
|
|||
>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<SelectTrigger disabled id="embedding-model-select">
|
||||
<SelectTrigger
|
||||
disabled={isEmbeddingModelSelectDisabled}
|
||||
id="embedding-model-select"
|
||||
>
|
||||
<SelectValue placeholder="Select an embedding model" />
|
||||
</SelectTrigger>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
Locked to keep embeddings consistent
|
||||
{isEmbeddingModelSelectDisabled
|
||||
? "Please wait while we update your settings"
|
||||
: "Choose the embedding model used for new ingests"}
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
<SelectContent>
|
||||
|
|
|
|||
|
|
@ -21,8 +21,16 @@ import {
|
|||
TooltipTrigger,
|
||||
} from "@/components/ui/tooltip";
|
||||
import { ModelSelectItems } from "@/app/settings/helpers/model-select-item";
|
||||
import { getFallbackModels } from "@/app/settings/helpers/model-helpers";
|
||||
import { getFallbackModels, type ModelProvider } from "@/app/settings/helpers/model-helpers";
|
||||
import { NumberInput } from "@/components/ui/inputs/number-input";
|
||||
import { useGetSettingsQuery } from "@/app/api/queries/useGetSettingsQuery";
|
||||
import {
|
||||
useGetOpenAIModelsQuery,
|
||||
useGetOllamaModelsQuery,
|
||||
useGetIBMModelsQuery,
|
||||
} from "@/app/api/queries/useGetModelsQuery";
|
||||
import { useAuth } from "@/contexts/auth-context";
|
||||
import { useEffect } from "react";
|
||||
|
||||
interface IngestSettingsProps {
|
||||
isOpen: boolean;
|
||||
|
|
@ -37,18 +45,65 @@ export const IngestSettings = ({
|
|||
settings,
|
||||
onSettingsChange,
|
||||
}: IngestSettingsProps) => {
|
||||
// Default settings
|
||||
const { isAuthenticated, isNoAuthMode } = useAuth();
|
||||
|
||||
// Fetch settings from API to get current embedding model
|
||||
const { data: apiSettings = {} } = useGetSettingsQuery({
|
||||
enabled: isAuthenticated || isNoAuthMode,
|
||||
});
|
||||
|
||||
// Get the current provider from API settings
|
||||
const currentProvider = (apiSettings.provider?.model_provider ||
|
||||
"openai") as ModelProvider;
|
||||
|
||||
// Fetch available models based on provider
|
||||
const { data: openaiModelsData } = useGetOpenAIModelsQuery(undefined, {
|
||||
enabled: (isAuthenticated || isNoAuthMode) && currentProvider === "openai",
|
||||
});
|
||||
|
||||
const { data: ollamaModelsData } = useGetOllamaModelsQuery(undefined, {
|
||||
enabled: (isAuthenticated || isNoAuthMode) && currentProvider === "ollama",
|
||||
});
|
||||
|
||||
const { data: ibmModelsData } = useGetIBMModelsQuery(undefined, {
|
||||
enabled: (isAuthenticated || isNoAuthMode) && currentProvider === "watsonx",
|
||||
});
|
||||
|
||||
// Select the appropriate models data based on provider
|
||||
const modelsData =
|
||||
currentProvider === "openai"
|
||||
? openaiModelsData
|
||||
: currentProvider === "ollama"
|
||||
? ollamaModelsData
|
||||
: currentProvider === "watsonx"
|
||||
? ibmModelsData
|
||||
: openaiModelsData;
|
||||
|
||||
// Get embedding model from API settings
|
||||
const apiEmbeddingModel =
|
||||
apiSettings.knowledge?.embedding_model ||
|
||||
modelsData?.embedding_models?.find((m) => m.default)?.value ||
|
||||
"text-embedding-3-small";
|
||||
|
||||
// Default settings - use API embedding model
|
||||
const defaultSettings: IngestSettingsType = {
|
||||
chunkSize: 1000,
|
||||
chunkOverlap: 200,
|
||||
ocr: false,
|
||||
pictureDescriptions: false,
|
||||
embeddingModel: "text-embedding-3-small",
|
||||
embeddingModel: apiEmbeddingModel,
|
||||
};
|
||||
|
||||
// Use provided settings or defaults
|
||||
const currentSettings = settings || defaultSettings;
|
||||
|
||||
// Update settings when API embedding model changes
|
||||
useEffect(() => {
|
||||
if (apiEmbeddingModel && (!settings || settings.embeddingModel !== apiEmbeddingModel)) {
|
||||
onSettingsChange?.({ ...currentSettings, embeddingModel: apiEmbeddingModel });
|
||||
}
|
||||
}, [apiEmbeddingModel]);
|
||||
|
||||
const handleSettingsChange = (newSettings: Partial<IngestSettingsType>) => {
|
||||
const updatedSettings = { ...currentSettings, ...newSettings };
|
||||
onSettingsChange?.(updatedSettings);
|
||||
|
|
@ -73,38 +128,32 @@ export const IngestSettings = ({
|
|||
|
||||
<CollapsibleContent className="data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:slide-up-2 data-[state=open]:slide-down-2">
|
||||
<div className="mt-6">
|
||||
{/* Embedding model selection - currently disabled */}
|
||||
{/* Embedding model selection */}
|
||||
<LabelWrapper
|
||||
helperText="Model used for knowledge ingest and retrieval"
|
||||
id="embedding-model-select"
|
||||
label="Embedding model"
|
||||
>
|
||||
<Select
|
||||
// Disabled until API supports multiple embedding models
|
||||
disabled={true}
|
||||
disabled={false}
|
||||
value={currentSettings.embeddingModel}
|
||||
onValueChange={() => {}}
|
||||
onValueChange={(value) => handleSettingsChange({ embeddingModel: value })}
|
||||
>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<SelectTrigger disabled id="embedding-model-select">
|
||||
<SelectTrigger id="embedding-model-select">
|
||||
<SelectValue placeholder="Select an embedding model" />
|
||||
</SelectTrigger>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
Locked to keep embeddings consistent
|
||||
Choose the embedding model for this upload
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
<SelectContent>
|
||||
<ModelSelectItems
|
||||
models={[
|
||||
{
|
||||
value: "text-embedding-3-small",
|
||||
label: "text-embedding-3-small",
|
||||
},
|
||||
]}
|
||||
fallbackModels={getFallbackModels("openai").embedding}
|
||||
provider={"openai"}
|
||||
models={modelsData?.embedding_models}
|
||||
fallbackModels={getFallbackModels(currentProvider).embedding}
|
||||
provider={currentProvider}
|
||||
/>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
|
|
|
|||
|
|
@ -129,7 +129,7 @@ export function LayoutWrapper({ children }: { children: React.ReactNode }) {
|
|||
{/* Task Notification Bell */}
|
||||
<button
|
||||
onClick={toggleMenu}
|
||||
className="h-8 w-8 hover:bg-muted rounded-lg flex items-center justify-center"
|
||||
className="relative h-8 w-8 hover:bg-muted rounded-lg flex items-center justify-center"
|
||||
>
|
||||
<Bell size={16} className="text-muted-foreground" />
|
||||
{activeTasks.length > 0 && (
|
||||
|
|
|
|||
|
|
@ -169,95 +169,87 @@ export function TaskNotificationMenu() {
|
|||
{activeTasks.length > 0 && (
|
||||
<div className="p-4 space-y-3">
|
||||
<h4 className="text-sm font-medium text-muted-foreground">Active Tasks</h4>
|
||||
{activeTasks.map((task) => (
|
||||
<Card key={task.task_id} className="bg-card/50">
|
||||
<CardHeader className="pb-2">
|
||||
<div className="flex items-center justify-between">
|
||||
<CardTitle className="text-sm flex items-center gap-2">
|
||||
{getTaskIcon(task.status)}
|
||||
Task {task.task_id.substring(0, 8)}...
|
||||
</CardTitle>
|
||||
</div>
|
||||
<CardDescription className="text-xs">
|
||||
Started {formatRelativeTime(task.created_at)}
|
||||
{formatDuration(task.duration_seconds) && (
|
||||
<span className="ml-2 text-muted-foreground">
|
||||
• {formatDuration(task.duration_seconds)}
|
||||
</span>
|
||||
)}
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
{formatTaskProgress(task) && (
|
||||
<CardContent className="pt-0">
|
||||
<div className="space-y-2">
|
||||
<div className="text-xs text-muted-foreground">
|
||||
Progress: {formatTaskProgress(task)?.basic}
|
||||
</div>
|
||||
{formatTaskProgress(task)?.detailed && (
|
||||
<div className="grid grid-cols-2 gap-2 text-xs">
|
||||
<div className="flex items-center gap-1">
|
||||
<div className="w-2 h-2 bg-green-500 rounded-full"></div>
|
||||
<span className="text-green-600">
|
||||
{formatTaskProgress(task)?.detailed.successful} success
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-1">
|
||||
<div className="w-2 h-2 bg-red-500 rounded-full"></div>
|
||||
<span className="text-red-600">
|
||||
{formatTaskProgress(task)?.detailed.failed} failed
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-1">
|
||||
<div className="w-2 h-2 bg-blue-500 rounded-full"></div>
|
||||
<span className="text-blue-600">
|
||||
{formatTaskProgress(task)?.detailed.running} running
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-1">
|
||||
<div className="w-2 h-2 bg-yellow-500 rounded-full"></div>
|
||||
<span className="text-yellow-600">
|
||||
{formatTaskProgress(task)?.detailed.pending} pending
|
||||
</span>
|
||||
{activeTasks.map((task) => {
|
||||
const progress = formatTaskProgress(task)
|
||||
const showCancel =
|
||||
task.status === 'pending' ||
|
||||
task.status === 'running' ||
|
||||
task.status === 'processing'
|
||||
|
||||
return (
|
||||
<Card key={task.task_id} className="bg-card/50">
|
||||
<CardHeader className="pb-2">
|
||||
<div className="flex items-center justify-between">
|
||||
<CardTitle className="text-sm flex items-center gap-2">
|
||||
{getTaskIcon(task.status)}
|
||||
Task {task.task_id.substring(0, 8)}...
|
||||
</CardTitle>
|
||||
</div>
|
||||
<CardDescription className="text-xs">
|
||||
Started {formatRelativeTime(task.created_at)}
|
||||
{formatDuration(task.duration_seconds) && (
|
||||
<span className="ml-2 text-muted-foreground">
|
||||
• {formatDuration(task.duration_seconds)}
|
||||
</span>
|
||||
)}
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
{(progress || showCancel) && (
|
||||
<CardContent className="pt-0">
|
||||
{progress && (
|
||||
<div className="space-y-2">
|
||||
<div className="text-xs text-muted-foreground">
|
||||
Progress: {progress.basic}
|
||||
</div>
|
||||
{progress.detailed && (
|
||||
<div className="grid grid-cols-2 gap-2 text-xs">
|
||||
<div className="flex items-center gap-1">
|
||||
<div className="w-2 h-2 bg-green-500 rounded-full"></div>
|
||||
<span className="text-green-600">
|
||||
{progress.detailed.successful} success
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-1">
|
||||
<div className="w-2 h-2 bg-red-500 rounded-full"></div>
|
||||
<span className="text-red-600">
|
||||
{progress.detailed.failed} failed
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-1">
|
||||
<div className="w-2 h-2 bg-blue-500 rounded-full"></div>
|
||||
<span className="text-blue-600">
|
||||
{progress.detailed.running} running
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-1">
|
||||
<div className="w-2 h-2 bg-yellow-500 rounded-full"></div>
|
||||
<span className="text-yellow-600">
|
||||
{progress.detailed.pending} pending
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
{/* Cancel button in bottom right */}
|
||||
{(task.status === 'pending' || task.status === 'running' || task.status === 'processing') && (
|
||||
<div className="flex justify-end mt-3">
|
||||
<Button
|
||||
variant="destructive"
|
||||
size="sm"
|
||||
onClick={() => cancelTask(task.task_id)}
|
||||
className="h-7 px-3 text-xs"
|
||||
title="Cancel task"
|
||||
>
|
||||
<X className="h-3 w-3 mr-1" />
|
||||
Cancel
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</CardContent>
|
||||
)}
|
||||
{/* Cancel button for tasks without progress */}
|
||||
{!formatTaskProgress(task) && (task.status === 'pending' || task.status === 'running' || task.status === 'processing') && (
|
||||
<CardContent className="pt-0">
|
||||
<div className="flex justify-end">
|
||||
<Button
|
||||
variant="destructive"
|
||||
size="sm"
|
||||
onClick={() => cancelTask(task.task_id)}
|
||||
className="h-7 px-3 text-xs"
|
||||
title="Cancel task"
|
||||
>
|
||||
<X className="h-3 w-3 mr-1" />
|
||||
Cancel
|
||||
</Button>
|
||||
</div>
|
||||
</CardContent>
|
||||
)}
|
||||
</Card>
|
||||
))}
|
||||
{showCancel && (
|
||||
<div className={`flex justify-end ${progress ? 'mt-3' : ''}`}>
|
||||
<Button
|
||||
variant="destructive"
|
||||
size="sm"
|
||||
onClick={() => cancelTask(task.task_id)}
|
||||
className="h-7 px-3 text-xs"
|
||||
title="Cancel task"
|
||||
>
|
||||
<X className="h-3 w-3 mr-1" />
|
||||
Cancel
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</CardContent>
|
||||
)}
|
||||
</Card>
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
)}
|
||||
|
||||
|
|
@ -282,43 +274,47 @@ export function TaskNotificationMenu() {
|
|||
|
||||
{isExpanded && (
|
||||
<div className="space-y-2 transition-all duration-200">
|
||||
{recentTasks.map((task) => (
|
||||
<div
|
||||
key={task.task_id}
|
||||
className="flex items-center gap-3 p-2 rounded-lg hover:bg-muted/50 transition-colors"
|
||||
>
|
||||
{getTaskIcon(task.status)}
|
||||
<div className="flex-1 min-w-0">
|
||||
<div className="text-xs font-medium truncate">
|
||||
Task {task.task_id.substring(0, 8)}...
|
||||
</div>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{formatRelativeTime(task.updated_at)}
|
||||
{formatDuration(task.duration_seconds) && (
|
||||
<span className="ml-2">
|
||||
• {formatDuration(task.duration_seconds)}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
{/* Show final results for completed tasks */}
|
||||
{task.status === 'completed' && formatTaskProgress(task)?.detailed && (
|
||||
<div className="text-xs text-muted-foreground mt-1">
|
||||
{formatTaskProgress(task)?.detailed.successful} success, {' '}
|
||||
{formatTaskProgress(task)?.detailed.failed} failed
|
||||
{(formatTaskProgress(task)?.detailed.running || 0) > 0 && (
|
||||
<span>, {formatTaskProgress(task)?.detailed.running} running</span>
|
||||
{recentTasks.map((task) => {
|
||||
const progress = formatTaskProgress(task)
|
||||
|
||||
return (
|
||||
<div
|
||||
key={task.task_id}
|
||||
className="flex items-center gap-3 p-2 rounded-lg hover:bg-muted/50 transition-colors"
|
||||
>
|
||||
{getTaskIcon(task.status)}
|
||||
<div className="flex-1 min-w-0">
|
||||
<div className="text-xs font-medium truncate">
|
||||
Task {task.task_id.substring(0, 8)}...
|
||||
</div>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{formatRelativeTime(task.updated_at)}
|
||||
{formatDuration(task.duration_seconds) && (
|
||||
<span className="ml-2">
|
||||
• {formatDuration(task.duration_seconds)}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
{task.status === 'failed' && task.error && (
|
||||
<div className="text-xs text-red-600 mt-1 truncate">
|
||||
{task.error}
|
||||
</div>
|
||||
)}
|
||||
{/* Show final results for completed tasks */}
|
||||
{task.status === 'completed' && progress?.detailed && (
|
||||
<div className="text-xs text-muted-foreground mt-1">
|
||||
{progress.detailed.successful} success,{' '}
|
||||
{progress.detailed.failed} failed
|
||||
{(progress.detailed.running || 0) > 0 && (
|
||||
<span>, {progress.detailed.running} running</span>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
{task.status === 'failed' && task.error && (
|
||||
<div className="text-xs text-red-600 mt-1 truncate">
|
||||
{task.error}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
{getStatusBadge(task.status)}
|
||||
</div>
|
||||
{getStatusBadge(task.status)}
|
||||
</div>
|
||||
))}
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
|
@ -338,4 +334,4 @@ export function TaskNotificationMenu() {
|
|||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ import { toast } from "sonner";
|
|||
import { useCancelTaskMutation } from "@/app/api/mutations/useCancelTaskMutation";
|
||||
import {
|
||||
type Task,
|
||||
type TaskFileEntry,
|
||||
useGetTasksQuery,
|
||||
} from "@/app/api/queries/useGetTasksQuery";
|
||||
import { useAuth } from "@/contexts/auth-context";
|
||||
|
|
@ -31,6 +32,9 @@ export interface TaskFile {
|
|||
task_id: string;
|
||||
created_at: string;
|
||||
updated_at: string;
|
||||
error?: string;
|
||||
embedding_model?: string;
|
||||
embedding_dimensions?: number;
|
||||
}
|
||||
interface TaskContextType {
|
||||
tasks: Task[];
|
||||
|
|
@ -105,6 +109,9 @@ export function TaskProvider({ children }: { children: React.ReactNode }) {
|
|||
task_id: taskId,
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
error: file.error,
|
||||
embedding_model: file.embedding_model,
|
||||
embedding_dimensions: file.embedding_dimensions,
|
||||
}));
|
||||
|
||||
setFiles((prevFiles) => [...prevFiles, ...filesToAdd]);
|
||||
|
|
@ -138,12 +145,13 @@ export function TaskProvider({ children }: { children: React.ReactNode }) {
|
|||
|
||||
taskFileEntries.forEach(([filePath, fileInfo]) => {
|
||||
if (typeof fileInfo === "object" && fileInfo) {
|
||||
const fileInfoEntry = fileInfo as TaskFileEntry;
|
||||
// Use the filename from backend if available, otherwise extract from path
|
||||
const fileName =
|
||||
(fileInfo as any).filename ||
|
||||
fileInfoEntry.filename ||
|
||||
filePath.split("/").pop() ||
|
||||
filePath;
|
||||
const fileStatus = fileInfo.status as string;
|
||||
const fileStatus = fileInfoEntry.status ?? "processing";
|
||||
|
||||
// Map backend file status to our TaskFile status
|
||||
let mappedStatus: TaskFile["status"];
|
||||
|
|
@ -162,6 +170,23 @@ export function TaskProvider({ children }: { children: React.ReactNode }) {
|
|||
mappedStatus = "processing";
|
||||
}
|
||||
|
||||
const fileError = (() => {
|
||||
if (
|
||||
typeof fileInfoEntry.error === "string" &&
|
||||
fileInfoEntry.error.trim().length > 0
|
||||
) {
|
||||
return fileInfoEntry.error.trim();
|
||||
}
|
||||
if (
|
||||
mappedStatus === "failed" &&
|
||||
typeof currentTask.error === "string" &&
|
||||
currentTask.error.trim().length > 0
|
||||
) {
|
||||
return currentTask.error.trim();
|
||||
}
|
||||
return undefined;
|
||||
})();
|
||||
|
||||
setFiles((prevFiles) => {
|
||||
const existingFileIndex = prevFiles.findIndex(
|
||||
(f) =>
|
||||
|
|
@ -185,13 +210,22 @@ export function TaskProvider({ children }: { children: React.ReactNode }) {
|
|||
status: mappedStatus,
|
||||
task_id: currentTask.task_id,
|
||||
created_at:
|
||||
typeof fileInfo.created_at === "string"
|
||||
? fileInfo.created_at
|
||||
typeof fileInfoEntry.created_at === "string"
|
||||
? fileInfoEntry.created_at
|
||||
: now,
|
||||
updated_at:
|
||||
typeof fileInfo.updated_at === "string"
|
||||
? fileInfo.updated_at
|
||||
typeof fileInfoEntry.updated_at === "string"
|
||||
? fileInfoEntry.updated_at
|
||||
: now,
|
||||
error: fileError,
|
||||
embedding_model:
|
||||
typeof fileInfoEntry.embedding_model === "string"
|
||||
? fileInfoEntry.embedding_model
|
||||
: undefined,
|
||||
embedding_dimensions:
|
||||
typeof fileInfoEntry.embedding_dimensions === "number"
|
||||
? fileInfoEntry.embedding_dimensions
|
||||
: undefined,
|
||||
};
|
||||
|
||||
if (existingFileIndex >= 0) {
|
||||
|
|
|
|||
174
scripts/extract_flow_component.py
Normal file
174
scripts/extract_flow_component.py
Normal file
|
|
@ -0,0 +1,174 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Extract embedded component code from a Langflow JSON flow.
|
||||
|
||||
Example:
|
||||
python scripts/extract_flow_component.py \\
|
||||
--flow-file flows/ingestion_flow.json \\
|
||||
--display-name "OpenSearch (Multi-Model)" \\
|
||||
--output flows/components/opensearch_multimodel.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def should_select_component(
|
||||
node: dict,
|
||||
*,
|
||||
display_name: Optional[str],
|
||||
metadata_module: Optional[str],
|
||||
) -> bool:
|
||||
"""Return True if the node matches the requested component filters."""
|
||||
node_data = node.get("data", {})
|
||||
component = node_data.get("node", {})
|
||||
|
||||
if display_name and component.get("display_name") != display_name:
|
||||
return False
|
||||
|
||||
if metadata_module:
|
||||
metadata = component.get("metadata", {})
|
||||
if metadata.get("module") != metadata_module:
|
||||
return False
|
||||
|
||||
template = component.get("template", {})
|
||||
code_entry = template.get("code")
|
||||
return isinstance(code_entry, dict) and "value" in code_entry
|
||||
|
||||
|
||||
def extract_code_from_flow(
|
||||
flow_path: Path,
|
||||
*,
|
||||
display_name: Optional[str],
|
||||
metadata_module: Optional[str],
|
||||
match_index: int,
|
||||
) -> str:
|
||||
"""Fetch the embedded code string from the matching component node."""
|
||||
try:
|
||||
flow_data = json.loads(flow_path.read_text(encoding="utf-8"))
|
||||
except json.JSONDecodeError as exc:
|
||||
raise SystemExit(f"[error] failed to parse {flow_path}: {exc}") from exc
|
||||
|
||||
matches = []
|
||||
for node in flow_data.get("data", {}).get("nodes", []):
|
||||
if should_select_component(
|
||||
node,
|
||||
display_name=display_name,
|
||||
metadata_module=metadata_module,
|
||||
):
|
||||
matches.append(node)
|
||||
|
||||
if not matches:
|
||||
raise SystemExit(
|
||||
"[error] no component found matching the supplied filters "
|
||||
f"in {flow_path}"
|
||||
)
|
||||
|
||||
if match_index < 0 or match_index >= len(matches):
|
||||
raise SystemExit(
|
||||
f"[error] match index {match_index} out of range "
|
||||
f"(found {len(matches)} matches)"
|
||||
)
|
||||
|
||||
target = matches[match_index]
|
||||
return target["data"]["node"]["template"]["code"]["value"]
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Extract component code from a Langflow JSON flow."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--flow-file",
|
||||
required=True,
|
||||
type=Path,
|
||||
help="Path to the flow JSON file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--display-name",
|
||||
help="Component display_name to match (e.g. 'OpenSearch (Multi-Model)').",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata-module",
|
||||
help="Component metadata.module value to match.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--match-index",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Index of the matched component when multiple exist (default: 0).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=Path,
|
||||
help="Destination file for the extracted code (stdout if omitted).",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.display_name and not args.metadata_module:
|
||||
# Offer an interactive selection of component display names
|
||||
if not args.flow_file.exists():
|
||||
parser.error(f"Flow file not found: {args.flow_file}")
|
||||
|
||||
try:
|
||||
flow_data = json.loads(args.flow_file.read_text(encoding="utf-8"))
|
||||
except json.JSONDecodeError as exc:
|
||||
raise SystemExit(f"[error] failed to parse {args.flow_file}: {exc}") from exc
|
||||
|
||||
nodes = flow_data.get("data", {}).get("nodes", [])
|
||||
display_names = sorted(
|
||||
{
|
||||
node.get("data", {})
|
||||
.get("node", {})
|
||||
.get("display_name", "<unknown>")
|
||||
for node in nodes
|
||||
}
|
||||
)
|
||||
|
||||
if not display_names:
|
||||
parser.error(
|
||||
"Unable to locate any components in the flow; supply --metadata-module instead."
|
||||
)
|
||||
|
||||
print("Select a component display name:")
|
||||
for idx, name in enumerate(display_names):
|
||||
print(f" [{idx}] {name}")
|
||||
|
||||
while True:
|
||||
choice = input(f"Enter choice (0-{len(display_names)-1}): ").strip() or "0"
|
||||
if choice.isdigit():
|
||||
index = int(choice)
|
||||
if 0 <= index < len(display_names):
|
||||
args.display_name = display_names[index]
|
||||
break
|
||||
print("Invalid selection, please try again.")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
if not args.flow_file.exists():
|
||||
raise SystemExit(f"[error] flow file not found: {args.flow_file}")
|
||||
|
||||
code = extract_code_from_flow(
|
||||
args.flow_file,
|
||||
display_name=args.display_name,
|
||||
metadata_module=args.metadata_module,
|
||||
match_index=args.match_index,
|
||||
)
|
||||
|
||||
if args.output:
|
||||
args.output.write_text(code, encoding="utf-8")
|
||||
else:
|
||||
print(code, end="")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
379
scripts/migrate_embedding_model_field.py
Normal file
379
scripts/migrate_embedding_model_field.py
Normal file
|
|
@ -0,0 +1,379 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Migration script to migrate legacy embeddings to multi-model setup.
|
||||
|
||||
This script migrates documents from the legacy single-field embedding system
|
||||
to the new multi-model system with dynamic field names.
|
||||
|
||||
Legacy format:
|
||||
{
|
||||
"chunk_embedding": [0.1, 0.2, ...],
|
||||
// no embedding_model field
|
||||
}
|
||||
|
||||
New format:
|
||||
{
|
||||
"chunk_embedding_text_embedding_3_small": [0.1, 0.2, ...],
|
||||
"embedding_model": "text-embedding-3-small"
|
||||
}
|
||||
|
||||
Usage:
|
||||
uv run python scripts/migrate_embedding_model_field.py --model <model_name>
|
||||
|
||||
Example:
|
||||
uv run python scripts/migrate_embedding_model_field.py --model text-embedding-3-small
|
||||
|
||||
Options:
|
||||
--model MODEL The embedding model name to assign to legacy embeddings
|
||||
(e.g., "text-embedding-3-small", "nomic-embed-text")
|
||||
--batch-size SIZE Number of documents to process per batch (default: 100)
|
||||
--dry-run Show what would be migrated without making changes
|
||||
--index INDEX Index name (default: documents)
|
||||
|
||||
What it does:
|
||||
1. Finds all documents with legacy "chunk_embedding" field but no "embedding_model" field
|
||||
2. For each document:
|
||||
- Copies the vector from "chunk_embedding" to "chunk_embedding_{model_name}"
|
||||
- Adds "embedding_model" field with the specified model name
|
||||
- Optionally removes the legacy "chunk_embedding" field
|
||||
3. Uses bulk updates for efficiency
|
||||
|
||||
Note: This script does NOT re-embed documents. It simply tags existing embeddings
|
||||
with the model name you specify. Make sure to specify the correct model that was
|
||||
actually used to create those embeddings.
|
||||
"""
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
import argparse
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from opensearchpy import AsyncOpenSearch, helpers
|
||||
from opensearchpy._async.http_aiohttp import AIOHttpConnection
|
||||
|
||||
# Add src directory to path to import config
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
|
||||
|
||||
from config.settings import (
|
||||
OPENSEARCH_HOST,
|
||||
OPENSEARCH_PORT,
|
||||
OPENSEARCH_USERNAME,
|
||||
OPENSEARCH_PASSWORD,
|
||||
INDEX_NAME,
|
||||
)
|
||||
from utils.logging_config import get_logger
|
||||
from utils.embedding_fields import get_embedding_field_name
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def ensure_new_field_exists(
|
||||
client: AsyncOpenSearch,
|
||||
index_name: str,
|
||||
field_name: str,
|
||||
dimensions: int
|
||||
) -> None:
|
||||
"""Ensure the new embedding field exists in the index."""
|
||||
mapping = {
|
||||
"properties": {
|
||||
field_name: {
|
||||
"type": "knn_vector",
|
||||
"dimension": dimensions,
|
||||
"method": {
|
||||
"name": "disk_ann",
|
||||
"engine": "jvector",
|
||||
"space_type": "l2",
|
||||
"parameters": {"ef_construction": 100, "m": 16},
|
||||
},
|
||||
},
|
||||
"embedding_model": {
|
||||
"type": "keyword"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
await client.indices.put_mapping(index=index_name, body=mapping)
|
||||
logger.info(f"Ensured field exists: {field_name}")
|
||||
except Exception as e:
|
||||
error_msg = str(e).lower()
|
||||
if "already" in error_msg or "exists" in error_msg:
|
||||
logger.debug(f"Field already exists: {field_name}")
|
||||
else:
|
||||
logger.error(f"Failed to add field mapping: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def find_legacy_documents(
|
||||
client: AsyncOpenSearch,
|
||||
index_name: str,
|
||||
batch_size: int = 100
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Find all documents with legacy chunk_embedding but no embedding_model field."""
|
||||
query = {
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"exists": {"field": "chunk_embedding"}}
|
||||
],
|
||||
"must_not": [
|
||||
{"exists": {"field": "embedding_model"}}
|
||||
]
|
||||
}
|
||||
},
|
||||
"size": batch_size,
|
||||
"_source": True
|
||||
}
|
||||
|
||||
try:
|
||||
response = await client.search(index=index_name, body=query, scroll='5m')
|
||||
scroll_id = response['_scroll_id']
|
||||
hits = response['hits']['hits']
|
||||
|
||||
all_docs = hits
|
||||
|
||||
# Continue scrolling until no more results
|
||||
while len(hits) > 0:
|
||||
response = await client.scroll(scroll_id=scroll_id, scroll='5m')
|
||||
scroll_id = response['_scroll_id']
|
||||
hits = response['hits']['hits']
|
||||
all_docs.extend(hits)
|
||||
|
||||
# Clean up scroll
|
||||
await client.clear_scroll(scroll_id=scroll_id)
|
||||
|
||||
return all_docs
|
||||
except Exception as e:
|
||||
logger.error(f"Error finding legacy documents: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def migrate_documents(
|
||||
client: AsyncOpenSearch,
|
||||
index_name: str,
|
||||
documents: List[Dict[str, Any]],
|
||||
model_name: str,
|
||||
new_field_name: str,
|
||||
dry_run: bool = False
|
||||
) -> Dict[str, int]:
|
||||
"""Migrate legacy documents to new format."""
|
||||
if not documents:
|
||||
return {"migrated": 0, "errors": 0}
|
||||
|
||||
if dry_run:
|
||||
logger.info(f"DRY RUN: Would migrate {len(documents)} documents")
|
||||
for doc in documents[:5]: # Show first 5 as sample
|
||||
doc_id = doc['_id']
|
||||
has_legacy = 'chunk_embedding' in doc['_source']
|
||||
logger.info(f" Document {doc_id}: has_legacy={has_legacy}")
|
||||
if len(documents) > 5:
|
||||
logger.info(f" ... and {len(documents) - 5} more documents")
|
||||
return {"migrated": len(documents), "errors": 0}
|
||||
|
||||
# Prepare bulk update actions
|
||||
actions = []
|
||||
for doc in documents:
|
||||
doc_id = doc['_id']
|
||||
source = doc['_source']
|
||||
|
||||
# Copy the legacy embedding to the new field
|
||||
legacy_embedding = source.get('chunk_embedding')
|
||||
if not legacy_embedding:
|
||||
logger.warning(f"Document {doc_id} missing chunk_embedding, skipping")
|
||||
continue
|
||||
|
||||
# Build update document
|
||||
update_doc = {
|
||||
new_field_name: legacy_embedding,
|
||||
"embedding_model": model_name
|
||||
}
|
||||
|
||||
action = {
|
||||
"_op_type": "update",
|
||||
"_index": index_name,
|
||||
"_id": doc_id,
|
||||
"doc": update_doc
|
||||
}
|
||||
actions.append(action)
|
||||
|
||||
# Execute bulk update
|
||||
migrated = 0
|
||||
errors = 0
|
||||
|
||||
try:
|
||||
success, failed = await helpers.async_bulk(
|
||||
client,
|
||||
actions,
|
||||
raise_on_error=False,
|
||||
raise_on_exception=False
|
||||
)
|
||||
migrated = success
|
||||
errors = len(failed) if isinstance(failed, list) else 0
|
||||
|
||||
if errors > 0:
|
||||
logger.error(f"Failed to migrate {errors} documents")
|
||||
for failure in (failed if isinstance(failed, list) else [])[:5]:
|
||||
logger.error(f" Error: {failure}")
|
||||
|
||||
logger.info(f"Successfully migrated {migrated} documents")
|
||||
except Exception as e:
|
||||
logger.error(f"Bulk migration failed: {e}")
|
||||
raise
|
||||
|
||||
return {"migrated": migrated, "errors": errors}
|
||||
|
||||
|
||||
async def migrate_legacy_embeddings(
|
||||
model_name: str,
|
||||
batch_size: int = 100,
|
||||
dry_run: bool = False,
|
||||
index_name: str = None
|
||||
) -> bool:
|
||||
"""Main migration function."""
|
||||
if index_name is None:
|
||||
index_name = INDEX_NAME
|
||||
|
||||
new_field_name = get_embedding_field_name(model_name)
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("Legacy Embedding Migration")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"Index: {index_name}")
|
||||
logger.info(f"Model: {model_name}")
|
||||
logger.info(f"New field: {new_field_name}")
|
||||
logger.info(f"Batch size: {batch_size}")
|
||||
logger.info(f"Dry run: {dry_run}")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Create admin OpenSearch client
|
||||
client = AsyncOpenSearch(
|
||||
hosts=[{"host": OPENSEARCH_HOST, "port": OPENSEARCH_PORT}],
|
||||
connection_class=AIOHttpConnection,
|
||||
scheme="https",
|
||||
use_ssl=True,
|
||||
verify_certs=False,
|
||||
ssl_assert_fingerprint=None,
|
||||
http_auth=(OPENSEARCH_USERNAME, OPENSEARCH_PASSWORD),
|
||||
http_compress=True,
|
||||
)
|
||||
|
||||
try:
|
||||
# Check if index exists
|
||||
exists = await client.indices.exists(index=index_name)
|
||||
if not exists:
|
||||
logger.error(f"Index '{index_name}' does not exist")
|
||||
return False
|
||||
|
||||
# Find legacy documents
|
||||
logger.info("Searching for legacy documents...")
|
||||
legacy_docs = await find_legacy_documents(client, index_name, batch_size)
|
||||
|
||||
if not legacy_docs:
|
||||
logger.info("No legacy documents found. Migration not needed.")
|
||||
return True
|
||||
|
||||
logger.info(f"Found {len(legacy_docs)} legacy documents to migrate")
|
||||
|
||||
# Get vector dimension from first document
|
||||
first_doc = legacy_docs[0]
|
||||
legacy_embedding = first_doc['_source'].get('chunk_embedding', [])
|
||||
dimensions = len(legacy_embedding)
|
||||
logger.info(f"Detected vector dimensions: {dimensions}")
|
||||
|
||||
# Ensure new field exists
|
||||
if not dry_run:
|
||||
logger.info(f"Ensuring new field exists: {new_field_name}")
|
||||
await ensure_new_field_exists(client, index_name, new_field_name, dimensions)
|
||||
|
||||
# Migrate documents
|
||||
logger.info("Starting migration...")
|
||||
result = await migrate_documents(
|
||||
client,
|
||||
index_name,
|
||||
legacy_docs,
|
||||
model_name,
|
||||
new_field_name,
|
||||
dry_run
|
||||
)
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("Migration Summary")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"Total documents: {len(legacy_docs)}")
|
||||
logger.info(f"Successfully migrated: {result['migrated']}")
|
||||
logger.info(f"Errors: {result['errors']}")
|
||||
logger.info("=" * 60)
|
||||
|
||||
if result['errors'] > 0:
|
||||
logger.warning("Migration completed with errors")
|
||||
return False
|
||||
|
||||
if dry_run:
|
||||
logger.info("DRY RUN completed. No changes were made.")
|
||||
logger.info(f"Run without --dry-run to perform the migration")
|
||||
else:
|
||||
logger.info("Migration completed successfully!")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Migration failed: {e}")
|
||||
return False
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Migrate legacy embeddings to multi-model setup",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Dry run to see what would be migrated
|
||||
uv run python scripts/migrate_embedding_model_field.py --model text-embedding-3-small --dry-run
|
||||
|
||||
# Perform actual migration
|
||||
uv run python scripts/migrate_embedding_model_field.py --model text-embedding-3-small
|
||||
|
||||
# Migrate with custom batch size
|
||||
uv run python scripts/migrate_embedding_model_field.py --model nomic-embed-text --batch-size 500
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--model',
|
||||
required=True,
|
||||
help='Embedding model name to assign to legacy embeddings (e.g., "text-embedding-3-small")'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--batch-size',
|
||||
type=int,
|
||||
default=100,
|
||||
help='Number of documents to process per batch (default: 100)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--dry-run',
|
||||
action='store_true',
|
||||
help='Show what would be migrated without making changes'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--index',
|
||||
default=None,
|
||||
help=f'Index name (default: {INDEX_NAME})'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Run migration
|
||||
success = asyncio.run(migrate_legacy_embeddings(
|
||||
model_name=args.model,
|
||||
batch_size=args.batch_size,
|
||||
dry_run=args.dry_run,
|
||||
index_name=args.index
|
||||
))
|
||||
|
||||
sys.exit(0 if success else 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
138
scripts/update_flow_components.py
Normal file
138
scripts/update_flow_components.py
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Utility to sync embedded component code inside Langflow JSON files.
|
||||
|
||||
Given a Python source file (e.g. the OpenSearch component implementation) and
|
||||
a target selector, this script updates every flow definition in ``./flows`` so
|
||||
that the component's ``template.code.value`` matches the supplied file.
|
||||
|
||||
Example:
|
||||
python scripts/update_flow_components.py \\
|
||||
--code-file flows/components/opensearch_multimodel.py \\
|
||||
--display-name \"OpenSearch (Multi-Model)\"
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
|
||||
def load_code(source_path: Path) -> str:
|
||||
try:
|
||||
return source_path.read_text(encoding="utf-8")
|
||||
except FileNotFoundError as exc:
|
||||
raise SystemExit(f"[error] code file not found: {source_path}") from exc
|
||||
|
||||
|
||||
def should_update_component(node: dict, *, display_name: str | None, metadata_module: str | None) -> bool:
|
||||
node_data = node.get("data", {})
|
||||
component = node_data.get("node", {})
|
||||
|
||||
if display_name and component.get("display_name") != display_name:
|
||||
return False
|
||||
|
||||
if metadata_module:
|
||||
metadata = component.get("metadata", {})
|
||||
module_name = metadata.get("module")
|
||||
if module_name != metadata_module:
|
||||
return False
|
||||
|
||||
template = component.get("template", {})
|
||||
code_entry = template.get("code")
|
||||
return isinstance(code_entry, dict) and "value" in code_entry
|
||||
|
||||
|
||||
def update_flow(flow_path: Path, code: str, *, display_name: str | None, metadata_module: str | None, dry_run: bool) -> bool:
|
||||
with flow_path.open(encoding="utf-8") as fh:
|
||||
try:
|
||||
data = json.load(fh)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise SystemExit(f"[error] failed to parse {flow_path}: {exc}") from exc
|
||||
|
||||
changed = False
|
||||
|
||||
for node in data.get("data", {}).get("nodes", []):
|
||||
if not should_update_component(node, display_name=display_name, metadata_module=metadata_module):
|
||||
continue
|
||||
|
||||
template = node["data"]["node"]["template"]
|
||||
if template["code"]["value"] != code:
|
||||
if dry_run:
|
||||
changed = True
|
||||
else:
|
||||
template["code"]["value"] = code
|
||||
changed = True
|
||||
|
||||
if changed and not dry_run:
|
||||
flow_path.write_text(
|
||||
json.dumps(data, indent=2, ensure_ascii=False) + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
return changed
|
||||
|
||||
|
||||
def iter_flow_files(flows_dir: Path) -> Iterable[Path]:
|
||||
for path in sorted(flows_dir.glob("*.json")):
|
||||
if path.is_file():
|
||||
yield path
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Update embedded component code in Langflow JSON files.")
|
||||
parser.add_argument("--code-file", required=True, type=Path, help="Path to the Python file containing the component code.")
|
||||
parser.add_argument("--flows-dir", type=Path, default=Path("flows"), help="Directory containing Langflow JSON files.")
|
||||
parser.add_argument("--display-name", help="Component display_name to match (e.g. 'OpenSearch (Multi-Model)').")
|
||||
parser.add_argument("--metadata-module", help="Component metadata.module value to match.")
|
||||
parser.add_argument("--dry-run", action="store_true", help="Report which files would change without modifying them.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.display_name and not args.metadata_module:
|
||||
parser.error("At least one of --display-name or --metadata-module must be provided.")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
flows_dir: Path = args.flows_dir
|
||||
if not flows_dir.exists():
|
||||
raise SystemExit(f"[error] flows directory not found: {flows_dir}")
|
||||
|
||||
code = load_code(args.code_file)
|
||||
|
||||
updated_files = []
|
||||
for flow_path in iter_flow_files(flows_dir):
|
||||
changed = update_flow(
|
||||
flow_path,
|
||||
code,
|
||||
display_name=args.display_name,
|
||||
metadata_module=args.metadata_module,
|
||||
dry_run=args.dry_run,
|
||||
)
|
||||
if changed:
|
||||
updated_files.append(flow_path)
|
||||
|
||||
if args.dry_run:
|
||||
if updated_files:
|
||||
print("[dry-run] files that would be updated:")
|
||||
for path in updated_files:
|
||||
print(f" - {path}")
|
||||
else:
|
||||
print("[dry-run] no files would change.")
|
||||
else:
|
||||
if updated_files:
|
||||
print("Updated component code in:")
|
||||
for path in updated_files:
|
||||
print(f" - {path}")
|
||||
else:
|
||||
print("No updates were necessary.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -21,10 +21,13 @@ openrag_user_role:
|
|||
allowed_actions:
|
||||
- crud
|
||||
- create_index
|
||||
- indices:admin/mappings/get
|
||||
- indices:admin/mappings/put
|
||||
- indices:admin/exists
|
||||
- indices:admin/get
|
||||
dls: >
|
||||
{"bool":{"should":[
|
||||
{"term":{"owner":"${user.name}"}},
|
||||
{"term":{"allowed_users":"${user.name}"}},
|
||||
{"bool":{"must_not":{"exists":{"field":"owner"}}}}
|
||||
],"minimum_should_match":1}}
|
||||
|
||||
|
|
|
|||
|
|
@ -241,19 +241,49 @@ async def update_settings(request, session_manager):
|
|||
{"error": "embedding_model must be a non-empty string"},
|
||||
status_code=400,
|
||||
)
|
||||
current_config.knowledge.embedding_model = body["embedding_model"].strip()
|
||||
new_embedding_model = body["embedding_model"].strip()
|
||||
current_config.knowledge.embedding_model = new_embedding_model
|
||||
config_updated = True
|
||||
|
||||
# Also update the ingest flow with the new embedding model
|
||||
try:
|
||||
flows_service = _get_flows_service()
|
||||
await flows_service.update_ingest_flow_embedding_model(
|
||||
body["embedding_model"].strip(),
|
||||
new_embedding_model,
|
||||
current_config.provider.model_provider.lower()
|
||||
)
|
||||
logger.info(
|
||||
f"Successfully updated ingest flow embedding model to '{body['embedding_model'].strip()}'"
|
||||
)
|
||||
|
||||
provider = (
|
||||
current_config.provider.model_provider.lower()
|
||||
if current_config.provider.model_provider
|
||||
else "openai"
|
||||
)
|
||||
endpoint = current_config.provider.endpoint or None
|
||||
llm_model = current_config.agent.llm_model
|
||||
|
||||
change_result = await flows_service.change_langflow_model_value(
|
||||
provider=provider,
|
||||
embedding_model=new_embedding_model,
|
||||
llm_model=llm_model,
|
||||
endpoint=endpoint,
|
||||
)
|
||||
|
||||
if not change_result.get("success", False):
|
||||
logger.warning(
|
||||
"Change embedding model across flows completed with issues",
|
||||
provider=provider,
|
||||
embedding_model=new_embedding_model,
|
||||
change_result=change_result,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Successfully updated embedding model across Langflow flows",
|
||||
provider=provider,
|
||||
embedding_model=new_embedding_model,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update ingest flow embedding model: {str(e)}")
|
||||
# Don't fail the entire settings update if flow update fails
|
||||
|
|
|
|||
|
|
@ -107,6 +107,8 @@ INDEX_BODY = {
|
|||
"mimetype": {"type": "keyword"},
|
||||
"page": {"type": "integer"},
|
||||
"text": {"type": "text"},
|
||||
# Legacy field - kept for backward compatibility
|
||||
# New documents will use chunk_embedding_{model_name} fields
|
||||
"chunk_embedding": {
|
||||
"type": "knn_vector",
|
||||
"dimension": VECTOR_DIM,
|
||||
|
|
@ -117,6 +119,8 @@ INDEX_BODY = {
|
|||
"parameters": {"ef_construction": 100, "m": 16},
|
||||
},
|
||||
},
|
||||
# Track which embedding model was used for this chunk
|
||||
"embedding_model": {"type": "keyword"},
|
||||
"source_url": {"type": "keyword"},
|
||||
"connector_type": {"type": "keyword"},
|
||||
"owner": {"type": "keyword"},
|
||||
|
|
@ -322,7 +326,7 @@ class AppClients:
|
|||
|
||||
# Initialize Langflow HTTP client
|
||||
self.langflow_http_client = httpx.AsyncClient(
|
||||
base_url=LANGFLOW_URL, timeout=60.0
|
||||
base_url=LANGFLOW_URL, timeout=300.0
|
||||
)
|
||||
|
||||
return self
|
||||
|
|
@ -591,3 +595,8 @@ def get_knowledge_config():
|
|||
def get_agent_config():
|
||||
"""Get agent configuration."""
|
||||
return get_openrag_config().agent
|
||||
|
||||
|
||||
def get_embedding_model() -> str:
|
||||
"""Return the currently configured embedding model."""
|
||||
return get_openrag_config().knowledge.embedding_model
|
||||
|
|
|
|||
|
|
@ -271,6 +271,7 @@ class ConnectorService:
|
|||
|
||||
# Create custom processor for connector files
|
||||
from models.processors import ConnectorFileProcessor
|
||||
from services.document_service import DocumentService
|
||||
|
||||
processor = ConnectorFileProcessor(
|
||||
self,
|
||||
|
|
@ -280,6 +281,11 @@ class ConnectorService:
|
|||
jwt_token=jwt_token,
|
||||
owner_name=owner_name,
|
||||
owner_email=owner_email,
|
||||
document_service=(
|
||||
self.task_service.document_service
|
||||
if self.task_service and self.task_service.document_service
|
||||
else DocumentService(session_manager=self.session_manager)
|
||||
),
|
||||
)
|
||||
|
||||
# Use file IDs as items (no more fake file paths!)
|
||||
|
|
@ -366,6 +372,7 @@ class ConnectorService:
|
|||
|
||||
# Create custom processor for specific connector files
|
||||
from models.processors import ConnectorFileProcessor
|
||||
from services.document_service import DocumentService
|
||||
|
||||
# Use expanded_file_ids which has folders already expanded
|
||||
processor = ConnectorFileProcessor(
|
||||
|
|
@ -376,6 +383,11 @@ class ConnectorService:
|
|||
jwt_token=jwt_token,
|
||||
owner_name=owner_name,
|
||||
owner_email=owner_email,
|
||||
document_service=(
|
||||
self.task_service.document_service
|
||||
if self.task_service and self.task_service.document_service
|
||||
else DocumentService(session_manager=self.session_manager)
|
||||
),
|
||||
)
|
||||
|
||||
# Create custom task using TaskService
|
||||
|
|
|
|||
18
src/main.py
18
src/main.py
|
|
@ -53,11 +53,11 @@ from auth_middleware import optional_auth, require_auth
|
|||
# Configuration and setup
|
||||
from config.settings import (
|
||||
DISABLE_INGEST_WITH_LANGFLOW,
|
||||
EMBED_MODEL,
|
||||
INDEX_BODY,
|
||||
INDEX_NAME,
|
||||
SESSION_SECRET,
|
||||
clients,
|
||||
get_embedding_model,
|
||||
is_no_auth_mode,
|
||||
get_openrag_config,
|
||||
)
|
||||
|
|
@ -505,7 +505,7 @@ async def initialize_services():
|
|||
openrag_connector_service = ConnectorService(
|
||||
patched_async_client=clients.patched_async_client,
|
||||
process_pool=process_pool,
|
||||
embed_model=EMBED_MODEL,
|
||||
embed_model=get_embedding_model(),
|
||||
index_name=INDEX_NAME,
|
||||
task_service=task_service,
|
||||
session_manager=session_manager,
|
||||
|
|
@ -567,18 +567,6 @@ async def create_app():
|
|||
|
||||
# Create route handlers with service dependencies injected
|
||||
routes = [
|
||||
# Upload endpoints
|
||||
Route(
|
||||
"/upload",
|
||||
require_auth(services["session_manager"])(
|
||||
partial(
|
||||
upload.upload,
|
||||
document_service=services["document_service"],
|
||||
session_manager=services["session_manager"],
|
||||
)
|
||||
),
|
||||
methods=["POST"],
|
||||
),
|
||||
# Langflow Files endpoints
|
||||
Route(
|
||||
"/langflow/files/upload",
|
||||
|
|
@ -1228,4 +1216,4 @@ if __name__ == "__main__":
|
|||
host="0.0.0.0",
|
||||
port=8000,
|
||||
reload=False, # Disable reload since we're running from main
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -156,15 +156,24 @@ class TaskProcessor:
|
|||
owner_email: str = None,
|
||||
file_size: int = None,
|
||||
connector_type: str = "local",
|
||||
embedding_model: str = None,
|
||||
):
|
||||
"""
|
||||
Standard processing pipeline for non-Langflow processors:
|
||||
docling conversion + embeddings + OpenSearch indexing.
|
||||
|
||||
Args:
|
||||
embedding_model: Embedding model to use (defaults to the current
|
||||
embedding model from settings)
|
||||
"""
|
||||
import datetime
|
||||
from config.settings import INDEX_NAME, EMBED_MODEL, clients
|
||||
from config.settings import INDEX_NAME, clients, get_embedding_model
|
||||
from services.document_service import chunk_texts_for_embeddings
|
||||
from utils.document_processing import extract_relevant
|
||||
from utils.embedding_fields import get_embedding_field_name, ensure_embedding_field_exists
|
||||
|
||||
# Use provided embedding model or fall back to default
|
||||
embedding_model = embedding_model or get_embedding_model()
|
||||
|
||||
# Get user's OpenSearch client with JWT for OIDC auth
|
||||
opensearch_client = self.document_service.session_manager.get_user_opensearch_client(
|
||||
|
|
@ -175,6 +184,18 @@ class TaskProcessor:
|
|||
if await self.check_document_exists(file_hash, opensearch_client):
|
||||
return {"status": "unchanged", "id": file_hash}
|
||||
|
||||
# Ensure the embedding field exists for this model
|
||||
embedding_field_name = await ensure_embedding_field_exists(
|
||||
opensearch_client, embedding_model, INDEX_NAME
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Processing document with embedding model",
|
||||
embedding_model=embedding_model,
|
||||
embedding_field=embedding_field_name,
|
||||
file_hash=file_hash,
|
||||
)
|
||||
|
||||
# Convert and extract
|
||||
result = clients.converter.convert(file_path)
|
||||
full_doc = result.document.export_to_dict()
|
||||
|
|
@ -188,7 +209,7 @@ class TaskProcessor:
|
|||
|
||||
for batch in text_batches:
|
||||
resp = await clients.patched_async_client.embeddings.create(
|
||||
model=EMBED_MODEL, input=batch
|
||||
model=embedding_model, input=batch
|
||||
)
|
||||
embeddings.extend([d.embedding for d in resp.data])
|
||||
|
||||
|
|
@ -202,7 +223,11 @@ class TaskProcessor:
|
|||
"mimetype": slim_doc["mimetype"],
|
||||
"page": chunk["page"],
|
||||
"text": chunk["text"],
|
||||
"chunk_embedding": vect,
|
||||
# Store embedding in model-specific field
|
||||
embedding_field_name: vect,
|
||||
# Track which model was used
|
||||
"embedding_model": embedding_model,
|
||||
"embedding_dimensions": len(vect),
|
||||
"file_size": file_size,
|
||||
"connector_type": connector_type,
|
||||
"indexed_time": datetime.datetime.now().isoformat(),
|
||||
|
|
@ -331,8 +356,9 @@ class ConnectorFileProcessor(TaskProcessor):
|
|||
jwt_token: str = None,
|
||||
owner_name: str = None,
|
||||
owner_email: str = None,
|
||||
document_service=None,
|
||||
):
|
||||
super().__init__()
|
||||
super().__init__(document_service=document_service)
|
||||
self.connector_service = connector_service
|
||||
self.connection_id = connection_id
|
||||
self.files_to_process = files_to_process
|
||||
|
|
@ -550,7 +576,7 @@ class S3FileProcessor(TaskProcessor):
|
|||
import time
|
||||
import asyncio
|
||||
import datetime
|
||||
from config.settings import INDEX_NAME, EMBED_MODEL, clients
|
||||
from config.settings import INDEX_NAME, clients, get_embedding_model
|
||||
from services.document_service import chunk_texts_for_embeddings
|
||||
from utils.document_processing import process_document_sync
|
||||
|
||||
|
|
@ -740,4 +766,4 @@ class LangflowFileProcessor(TaskProcessor):
|
|||
file_task.error_message = str(e)
|
||||
file_task.updated_at = time.time()
|
||||
upload_task.failed_files += 1
|
||||
raise
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -12,12 +12,13 @@ from utils.logging_config import get_logger
|
|||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
from config.settings import clients, INDEX_NAME, EMBED_MODEL
|
||||
from config.settings import clients, INDEX_NAME, get_embedding_model
|
||||
from utils.document_processing import extract_relevant, process_document_sync
|
||||
|
||||
|
||||
def get_token_count(text: str, model: str = EMBED_MODEL) -> int:
|
||||
def get_token_count(text: str, model: str = None) -> int:
|
||||
"""Get accurate token count using tiktoken"""
|
||||
model = model or get_embedding_model()
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
return len(encoding.encode(text))
|
||||
|
|
@ -28,12 +29,14 @@ def get_token_count(text: str, model: str = EMBED_MODEL) -> int:
|
|||
|
||||
|
||||
def chunk_texts_for_embeddings(
|
||||
texts: List[str], max_tokens: int = None, model: str = EMBED_MODEL
|
||||
texts: List[str], max_tokens: int = None, model: str = None
|
||||
) -> List[List[str]]:
|
||||
"""
|
||||
Split texts into batches that won't exceed token limits.
|
||||
If max_tokens is None, returns texts as single batch (no splitting).
|
||||
"""
|
||||
model = model or get_embedding_model()
|
||||
|
||||
if max_tokens is None:
|
||||
return [texts]
|
||||
|
||||
|
|
|
|||
|
|
@ -1,27 +1,50 @@
|
|||
import copy
|
||||
from typing import Any, Dict
|
||||
from agentd.tool_decorator import tool
|
||||
from config.settings import clients, INDEX_NAME, EMBED_MODEL
|
||||
from config.settings import clients, INDEX_NAME, get_embedding_model
|
||||
from auth_context import get_auth_context
|
||||
from utils.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
MAX_EMBED_RETRIES = 3
|
||||
EMBED_RETRY_INITIAL_DELAY = 1.0
|
||||
EMBED_RETRY_MAX_DELAY = 8.0
|
||||
|
||||
|
||||
class SearchService:
|
||||
def __init__(self, session_manager=None):
|
||||
self.session_manager = session_manager
|
||||
|
||||
@tool
|
||||
async def search_tool(self, query: str) -> Dict[str, Any]:
|
||||
async def search_tool(self, query: str, embedding_model: str = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Use this tool to search for documents relevant to the query.
|
||||
|
||||
Args:
|
||||
query (str): query string to search the corpus
|
||||
embedding_model (str): Optional override for embedding model.
|
||||
If not provided, uses the current embedding
|
||||
model from configuration.
|
||||
|
||||
Returns:
|
||||
dict (str, Any): {"results": [chunks]} on success
|
||||
"""
|
||||
from utils.embedding_fields import get_embedding_field_name
|
||||
|
||||
# Strategy: Use provided model, or default to the configured embedding
|
||||
# model. This assumes documents are embedded with that model by default.
|
||||
# Future enhancement: Could auto-detect available models in corpus
|
||||
embedding_model = embedding_model or get_embedding_model()
|
||||
embedding_field_name = get_embedding_field_name(embedding_model)
|
||||
|
||||
logger.info(
|
||||
"Search with embedding model",
|
||||
embedding_model=embedding_model,
|
||||
embedding_field=embedding_field_name,
|
||||
query_preview=query[:50] if query else None,
|
||||
)
|
||||
|
||||
# Get authentication context from the current async context
|
||||
user_id, jwt_token = get_auth_context()
|
||||
# Get search filters, limit, and score threshold from context
|
||||
|
|
@ -37,40 +60,176 @@ class SearchService:
|
|||
# Detect wildcard request ("*") to return global facets/stats without semantic search
|
||||
is_wildcard_match_all = isinstance(query, str) and query.strip() == "*"
|
||||
|
||||
# Only embed when not doing match_all
|
||||
# Get available embedding models from corpus
|
||||
query_embeddings = {}
|
||||
available_models = []
|
||||
|
||||
opensearch_client = self.session_manager.get_user_opensearch_client(
|
||||
user_id, jwt_token
|
||||
)
|
||||
|
||||
if not is_wildcard_match_all:
|
||||
resp = await clients.patched_async_client.embeddings.create(
|
||||
model=EMBED_MODEL, input=[query]
|
||||
)
|
||||
query_embedding = resp.data[0].embedding
|
||||
# Build filter clauses first so we can use them in model detection
|
||||
filter_clauses = []
|
||||
if filters:
|
||||
# Map frontend filter names to backend field names
|
||||
field_mapping = {
|
||||
"data_sources": "filename",
|
||||
"document_types": "mimetype",
|
||||
"owners": "owner_name.keyword",
|
||||
"connector_types": "connector_type",
|
||||
}
|
||||
|
||||
# Build filter clauses
|
||||
filter_clauses = []
|
||||
if filters:
|
||||
# Map frontend filter names to backend field names
|
||||
field_mapping = {
|
||||
"data_sources": "filename",
|
||||
"document_types": "mimetype",
|
||||
"owners": "owner_name.keyword",
|
||||
"connector_types": "connector_type",
|
||||
}
|
||||
for filter_key, values in filters.items():
|
||||
if values is not None and isinstance(values, list):
|
||||
# Map frontend key to backend field name
|
||||
field_name = field_mapping.get(filter_key, filter_key)
|
||||
|
||||
for filter_key, values in filters.items():
|
||||
if values is not None and isinstance(values, list):
|
||||
# Map frontend key to backend field name
|
||||
field_name = field_mapping.get(filter_key, filter_key)
|
||||
if len(values) == 0:
|
||||
# Empty array means "match nothing" - use impossible filter
|
||||
filter_clauses.append(
|
||||
{"term": {field_name: "__IMPOSSIBLE_VALUE__"}}
|
||||
)
|
||||
elif len(values) == 1:
|
||||
# Single value filter
|
||||
filter_clauses.append({"term": {field_name: values[0]}})
|
||||
else:
|
||||
# Multiple values filter
|
||||
filter_clauses.append({"terms": {field_name: values}})
|
||||
|
||||
if len(values) == 0:
|
||||
# Empty array means "match nothing" - use impossible filter
|
||||
filter_clauses.append(
|
||||
{"term": {field_name: "__IMPOSSIBLE_VALUE__"}}
|
||||
try:
|
||||
# Build aggregation query with filters applied
|
||||
agg_query = {
|
||||
"size": 0,
|
||||
"aggs": {
|
||||
"embedding_models": {
|
||||
"terms": {
|
||||
"field": "embedding_model",
|
||||
"size": 10
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Apply filters to model detection if any exist
|
||||
if filter_clauses:
|
||||
agg_query["query"] = {
|
||||
"bool": {
|
||||
"filter": filter_clauses
|
||||
}
|
||||
}
|
||||
|
||||
agg_result = await opensearch_client.search(
|
||||
index=INDEX_NAME, body=agg_query, params={"terminate_after": 0}
|
||||
)
|
||||
buckets = agg_result.get("aggregations", {}).get("embedding_models", {}).get("buckets", [])
|
||||
available_models = [b["key"] for b in buckets if b["key"]]
|
||||
|
||||
if not available_models:
|
||||
# Fallback to configured model if no documents indexed yet
|
||||
available_models = [embedding_model]
|
||||
|
||||
logger.info(
|
||||
"Detected embedding models in corpus",
|
||||
available_models=available_models,
|
||||
model_counts={b["key"]: b["doc_count"] for b in buckets},
|
||||
with_filters=len(filter_clauses) > 0
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to detect embedding models, using configured model", error=str(e))
|
||||
available_models = [embedding_model]
|
||||
|
||||
# Parallelize embedding generation for all models
|
||||
import asyncio
|
||||
|
||||
async def embed_with_model(model_name):
|
||||
delay = EMBED_RETRY_INITIAL_DELAY
|
||||
attempts = 0
|
||||
last_exception = None
|
||||
|
||||
while attempts < MAX_EMBED_RETRIES:
|
||||
attempts += 1
|
||||
try:
|
||||
resp = await clients.patched_async_client.embeddings.create(
|
||||
model=model_name, input=[query]
|
||||
)
|
||||
elif len(values) == 1:
|
||||
# Single value filter
|
||||
filter_clauses.append({"term": {field_name: values[0]}})
|
||||
else:
|
||||
# Multiple values filter
|
||||
filter_clauses.append({"terms": {field_name: values}})
|
||||
return model_name, resp.data[0].embedding
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
if attempts >= MAX_EMBED_RETRIES:
|
||||
logger.error(
|
||||
"Failed to embed with model after retries",
|
||||
model=model_name,
|
||||
attempts=attempts,
|
||||
error=str(e),
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Failed to embed with model {model_name}"
|
||||
) from e
|
||||
|
||||
logger.warning(
|
||||
"Retrying embedding generation",
|
||||
model=model_name,
|
||||
attempt=attempts,
|
||||
max_attempts=MAX_EMBED_RETRIES,
|
||||
error=str(e),
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
delay = min(delay * 2, EMBED_RETRY_MAX_DELAY)
|
||||
|
||||
# Should not reach here, but guard in case
|
||||
raise RuntimeError(
|
||||
f"Failed to embed with model {model_name}"
|
||||
) from last_exception
|
||||
|
||||
# Run all embeddings in parallel
|
||||
try:
|
||||
embedding_results = await asyncio.gather(
|
||||
*[embed_with_model(model) for model in available_models]
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Embedding generation failed", error=str(e))
|
||||
raise
|
||||
|
||||
# Collect successful embeddings
|
||||
for result in embedding_results:
|
||||
if isinstance(result, tuple) and result[1] is not None:
|
||||
model_name, embedding = result
|
||||
query_embeddings[model_name] = embedding
|
||||
|
||||
logger.info(
|
||||
"Generated query embeddings",
|
||||
models=list(query_embeddings.keys()),
|
||||
query_preview=query[:50]
|
||||
)
|
||||
else:
|
||||
# Wildcard query - no embedding needed
|
||||
filter_clauses = []
|
||||
if filters:
|
||||
# Map frontend filter names to backend field names
|
||||
field_mapping = {
|
||||
"data_sources": "filename",
|
||||
"document_types": "mimetype",
|
||||
"owners": "owner_name.keyword",
|
||||
"connector_types": "connector_type",
|
||||
}
|
||||
|
||||
for filter_key, values in filters.items():
|
||||
if values is not None and isinstance(values, list):
|
||||
# Map frontend key to backend field name
|
||||
field_name = field_mapping.get(filter_key, filter_key)
|
||||
|
||||
if len(values) == 0:
|
||||
# Empty array means "match nothing" - use impossible filter
|
||||
filter_clauses.append(
|
||||
{"term": {field_name: "__IMPOSSIBLE_VALUE__"}}
|
||||
)
|
||||
elif len(values) == 1:
|
||||
# Single value filter
|
||||
filter_clauses.append({"term": {field_name: values[0]}})
|
||||
else:
|
||||
# Multiple values filter
|
||||
filter_clauses.append({"terms": {field_name: values}})
|
||||
|
||||
# Build query body
|
||||
if is_wildcard_match_all:
|
||||
|
|
@ -80,17 +239,51 @@ class SearchService:
|
|||
else:
|
||||
query_block = {"match_all": {}}
|
||||
else:
|
||||
# Build multi-model KNN queries
|
||||
knn_queries = []
|
||||
embedding_fields_to_check = []
|
||||
|
||||
for model_name, embedding_vector in query_embeddings.items():
|
||||
field_name = get_embedding_field_name(model_name)
|
||||
embedding_fields_to_check.append(field_name)
|
||||
knn_queries.append({
|
||||
"knn": {
|
||||
field_name: {
|
||||
"vector": embedding_vector,
|
||||
"k": 50,
|
||||
"num_candidates": 1000,
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
# Build exists filter - doc must have at least one embedding field
|
||||
exists_any_embedding = {
|
||||
"bool": {
|
||||
"should": [{"exists": {"field": f}} for f in embedding_fields_to_check],
|
||||
"minimum_should_match": 1
|
||||
}
|
||||
}
|
||||
|
||||
# Add exists filter to existing filters
|
||||
all_filters = [*filter_clauses, exists_any_embedding]
|
||||
|
||||
logger.debug(
|
||||
"Building hybrid query with filters",
|
||||
user_filters_count=len(filter_clauses),
|
||||
total_filters_count=len(all_filters),
|
||||
filter_types=[type(f).__name__ for f in all_filters]
|
||||
)
|
||||
|
||||
# Hybrid search query structure (semantic + keyword)
|
||||
# Use dis_max to pick best score across multiple embedding fields
|
||||
query_block = {
|
||||
"bool": {
|
||||
"should": [
|
||||
{
|
||||
"knn": {
|
||||
"chunk_embedding": {
|
||||
"vector": query_embedding,
|
||||
"k": 10,
|
||||
"boost": 0.7,
|
||||
}
|
||||
"dis_max": {
|
||||
"tie_breaker": 0.0, # Take only the best match, no blending
|
||||
"boost": 0.7, # 70% weight for semantic search
|
||||
"queries": knn_queries
|
||||
}
|
||||
},
|
||||
{
|
||||
|
|
@ -99,12 +292,12 @@ class SearchService:
|
|||
"fields": ["text^2", "filename^1.5"],
|
||||
"type": "best_fields",
|
||||
"fuzziness": "AUTO",
|
||||
"boost": 0.3,
|
||||
"boost": 0.3, # 30% weight for keyword search
|
||||
}
|
||||
},
|
||||
],
|
||||
"minimum_should_match": 1,
|
||||
**({"filter": filter_clauses} if filter_clauses else {}),
|
||||
"filter": all_filters,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -115,6 +308,7 @@ class SearchService:
|
|||
"document_types": {"terms": {"field": "mimetype", "size": 10}},
|
||||
"owners": {"terms": {"field": "owner_name.keyword", "size": 10}},
|
||||
"connector_types": {"terms": {"field": "connector_type", "size": 10}},
|
||||
"embedding_models": {"terms": {"field": "embedding_model", "size": 10}},
|
||||
},
|
||||
"_source": [
|
||||
"filename",
|
||||
|
|
@ -127,6 +321,8 @@ class SearchService:
|
|||
"owner_email",
|
||||
"file_size",
|
||||
"connector_type",
|
||||
"embedding_model", # Include embedding model in results
|
||||
"embedding_dimensions",
|
||||
"allowed_users",
|
||||
"allowed_groups",
|
||||
],
|
||||
|
|
@ -137,6 +333,23 @@ class SearchService:
|
|||
if not is_wildcard_match_all and score_threshold > 0:
|
||||
search_body["min_score"] = score_threshold
|
||||
|
||||
# Prepare fallback search body without num_candidates for clusters that don't support it
|
||||
fallback_search_body = None
|
||||
if not is_wildcard_match_all:
|
||||
try:
|
||||
fallback_search_body = copy.deepcopy(search_body)
|
||||
knn_query_blocks = (
|
||||
fallback_search_body["query"]["bool"]["should"][0]["dis_max"]["queries"]
|
||||
)
|
||||
for query_candidate in knn_query_blocks:
|
||||
knn_section = query_candidate.get("knn")
|
||||
if isinstance(knn_section, dict):
|
||||
for params in knn_section.values():
|
||||
if isinstance(params, dict):
|
||||
params.pop("num_candidates", None)
|
||||
except (KeyError, IndexError, AttributeError, TypeError):
|
||||
fallback_search_body = None
|
||||
|
||||
# Authentication required - DLS will handle document filtering automatically
|
||||
logger.debug(
|
||||
"search_service authentication info",
|
||||
|
|
@ -152,8 +365,41 @@ class SearchService:
|
|||
user_id, jwt_token
|
||||
)
|
||||
|
||||
from opensearchpy.exceptions import RequestError
|
||||
|
||||
search_params = {"terminate_after": 0}
|
||||
|
||||
try:
|
||||
results = await opensearch_client.search(index=INDEX_NAME, body=search_body)
|
||||
results = await opensearch_client.search(
|
||||
index=INDEX_NAME, body=search_body, params=search_params
|
||||
)
|
||||
except RequestError as e:
|
||||
error_message = str(e)
|
||||
if (
|
||||
fallback_search_body is not None
|
||||
and "unknown field [num_candidates]" in error_message.lower()
|
||||
):
|
||||
logger.warning(
|
||||
"OpenSearch cluster does not support num_candidates; retrying without it"
|
||||
)
|
||||
try:
|
||||
results = await opensearch_client.search(
|
||||
index=INDEX_NAME,
|
||||
body=fallback_search_body,
|
||||
params=search_params,
|
||||
)
|
||||
except RequestError as retry_error:
|
||||
logger.error(
|
||||
"OpenSearch retry without num_candidates failed",
|
||||
error=str(retry_error),
|
||||
search_body=fallback_search_body,
|
||||
)
|
||||
raise
|
||||
else:
|
||||
logger.error(
|
||||
"OpenSearch query failed", error=error_message, search_body=search_body
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"OpenSearch query failed", error=str(e), search_body=search_body
|
||||
|
|
@ -177,6 +423,8 @@ class SearchService:
|
|||
"owner_email": hit["_source"].get("owner_email"),
|
||||
"file_size": hit["_source"].get("file_size"),
|
||||
"connector_type": hit["_source"].get("connector_type"),
|
||||
"embedding_model": hit["_source"].get("embedding_model"), # Include in results
|
||||
"embedding_dimensions": hit["_source"].get("embedding_dimensions"),
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -199,8 +447,14 @@ class SearchService:
|
|||
filters: Dict[str, Any] = None,
|
||||
limit: int = 10,
|
||||
score_threshold: float = 0,
|
||||
embedding_model: str = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Public search method for API endpoints"""
|
||||
"""Public search method for API endpoints
|
||||
|
||||
Args:
|
||||
embedding_model: Embedding model to use for search (defaults to the
|
||||
currently configured embedding model)
|
||||
"""
|
||||
# Set auth context if provided (for direct API calls)
|
||||
from config.settings import is_no_auth_mode
|
||||
|
||||
|
|
@ -220,4 +474,4 @@ class SearchService:
|
|||
set_search_limit(limit)
|
||||
set_score_threshold(score_threshold)
|
||||
|
||||
return await self.search_tool(query)
|
||||
return await self.search_tool(query, embedding_model=embedding_model)
|
||||
|
|
|
|||
|
|
@ -505,36 +505,116 @@ class ContainerManager:
|
|||
digests[image] = stdout.strip().splitlines()[0]
|
||||
return digests
|
||||
|
||||
def _parse_compose_images(self) -> list[str]:
|
||||
"""Best-effort parse of image names from compose files without YAML dependency."""
|
||||
def _extract_images_from_compose_config(self, text: str, tried_json: bool) -> set[str]:
|
||||
"""
|
||||
Try JSON first (if we asked for it or it looks like JSON), then YAML if available.
|
||||
Returns a set of image names.
|
||||
"""
|
||||
images: set[str] = set()
|
||||
for compose in [self.compose_file, self.cpu_compose_file]:
|
||||
|
||||
# Try JSON parse
|
||||
if tried_json or (text.lstrip().startswith("{") and text.rstrip().endswith("}")):
|
||||
try:
|
||||
if not compose.exists():
|
||||
cfg = json.loads(text)
|
||||
services = cfg.get("services", {})
|
||||
for _, svc in services.items():
|
||||
image = svc.get("image")
|
||||
if image:
|
||||
images.add(str(image))
|
||||
if images:
|
||||
return images
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Try YAML (if available) - import here to avoid hard dependency
|
||||
try:
|
||||
import yaml
|
||||
cfg = yaml.safe_load(text) or {}
|
||||
services = cfg.get("services", {})
|
||||
if isinstance(services, dict):
|
||||
for _, svc in services.items():
|
||||
if isinstance(svc, dict):
|
||||
image = svc.get("image")
|
||||
if image:
|
||||
images.add(str(image))
|
||||
if images:
|
||||
return images
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return images
|
||||
|
||||
async def _parse_compose_images(self) -> list[str]:
|
||||
"""Get resolved image names from compose files using docker/podman compose, with robust fallbacks."""
|
||||
images: set[str] = set()
|
||||
|
||||
compose_files = [self.compose_file, self.cpu_compose_file]
|
||||
for compose_file in compose_files:
|
||||
try:
|
||||
if not compose_file or not compose_file.exists():
|
||||
continue
|
||||
for line in compose.read_text().splitlines():
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
|
||||
cpu_mode = (compose_file == self.cpu_compose_file)
|
||||
|
||||
# Try JSON format first
|
||||
success, stdout, _ = await self._run_compose_command(
|
||||
["config", "--format", "json"],
|
||||
cpu_mode=cpu_mode
|
||||
)
|
||||
|
||||
if success and stdout.strip():
|
||||
from_cfg = self._extract_images_from_compose_config(stdout, tried_json=True)
|
||||
if from_cfg:
|
||||
images.update(from_cfg)
|
||||
continue # this compose file succeeded; move to next file
|
||||
|
||||
# Fallback to YAML output (for older compose versions)
|
||||
success, stdout, _ = await self._run_compose_command(
|
||||
["config"],
|
||||
cpu_mode=cpu_mode
|
||||
)
|
||||
|
||||
if success and stdout.strip():
|
||||
from_cfg = self._extract_images_from_compose_config(stdout, tried_json=False)
|
||||
if from_cfg:
|
||||
images.update(from_cfg)
|
||||
continue
|
||||
if line.startswith("image:"):
|
||||
# image: repo/name:tag
|
||||
val = line.split(":", 1)[1].strip()
|
||||
# Remove quotes if present
|
||||
if (val.startswith('"') and val.endswith('"')) or (
|
||||
val.startswith("'") and val.endswith("'")
|
||||
):
|
||||
val = val[1:-1]
|
||||
images.add(val)
|
||||
|
||||
except Exception:
|
||||
# Keep behavior resilient—just continue to next file
|
||||
continue
|
||||
return list(images)
|
||||
|
||||
# Fallback: manual parsing if compose config didn't work
|
||||
if not images:
|
||||
for compose in compose_files:
|
||||
try:
|
||||
if not compose.exists():
|
||||
continue
|
||||
for line in compose.read_text().splitlines():
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
if line.startswith("image:"):
|
||||
# image: repo/name:tag
|
||||
val = line.split(":", 1)[1].strip()
|
||||
# Remove quotes if present
|
||||
if (val.startswith('"') and val.endswith('"')) or (
|
||||
val.startswith("'") and val.endswith("'")
|
||||
):
|
||||
val = val[1:-1]
|
||||
if val:
|
||||
images.add(val)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return sorted(images)
|
||||
|
||||
async def get_project_images_info(self) -> list[tuple[str, str]]:
|
||||
"""
|
||||
Return list of (image, digest_or_id) for images referenced by compose files.
|
||||
If an image isn't present locally, returns '-' for its digest.
|
||||
"""
|
||||
expected = self._parse_compose_images()
|
||||
expected = await self._parse_compose_images()
|
||||
results: list[tuple[str, str]] = []
|
||||
for image in expected:
|
||||
digest = "-"
|
||||
|
|
|
|||
178
src/utils/embedding_fields.py
Normal file
178
src/utils/embedding_fields.py
Normal file
|
|
@ -0,0 +1,178 @@
|
|||
"""
|
||||
Utility functions for managing dynamic embedding field names in OpenSearch.
|
||||
|
||||
This module provides helpers for:
|
||||
- Normalizing embedding model names to valid OpenSearch field names
|
||||
- Generating dynamic field names based on embedding models
|
||||
- Ensuring embedding fields exist in the OpenSearch index
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
|
||||
from utils.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def normalize_model_name(model_name: str) -> str:
|
||||
"""
|
||||
Convert an embedding model name to a valid OpenSearch field suffix.
|
||||
|
||||
Examples:
|
||||
- "text-embedding-3-small" -> "text_embedding_3_small"
|
||||
- "nomic-embed-text:latest" -> "nomic_embed_text_latest"
|
||||
- "ibm/slate-125m-english-rtrvr" -> "ibm_slate_125m_english_rtrvr"
|
||||
|
||||
Args:
|
||||
model_name: The embedding model name (e.g., from OpenAI, Ollama, Watsonx)
|
||||
|
||||
Returns:
|
||||
Normalized string safe for use as OpenSearch field name suffix
|
||||
"""
|
||||
normalized = model_name.lower()
|
||||
# Replace common separators with underscores
|
||||
normalized = normalized.replace("-", "_")
|
||||
normalized = normalized.replace(":", "_")
|
||||
normalized = normalized.replace("/", "_")
|
||||
normalized = normalized.replace(".", "_")
|
||||
# Remove any other non-alphanumeric characters
|
||||
normalized = "".join(c if c.isalnum() or c == "_" else "_" for c in normalized)
|
||||
# Remove duplicate underscores
|
||||
while "__" in normalized:
|
||||
normalized = normalized.replace("__", "_")
|
||||
# Remove leading/trailing underscores
|
||||
normalized = normalized.strip("_")
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
def get_embedding_field_name(model_name: str) -> str:
|
||||
"""
|
||||
Get the OpenSearch field name for storing embeddings from a specific model.
|
||||
|
||||
Args:
|
||||
model_name: The embedding model name
|
||||
|
||||
Returns:
|
||||
Field name in format: chunk_embedding_{normalized_model_name}
|
||||
|
||||
Examples:
|
||||
>>> get_embedding_field_name("text-embedding-3-small")
|
||||
'chunk_embedding_text_embedding_3_small'
|
||||
>>> get_embedding_field_name("nomic-embed-text")
|
||||
'chunk_embedding_nomic_embed_text'
|
||||
"""
|
||||
normalized = normalize_model_name(model_name)
|
||||
return f"chunk_embedding_{normalized}"
|
||||
|
||||
|
||||
async def ensure_embedding_field_exists(
|
||||
opensearch_client,
|
||||
model_name: str,
|
||||
index_name: str = None,
|
||||
) -> str:
|
||||
"""
|
||||
Ensure that an embedding field for the specified model exists in the OpenSearch index.
|
||||
If the field doesn't exist, it will be added dynamically using PUT mapping API.
|
||||
|
||||
Args:
|
||||
opensearch_client: OpenSearch client instance
|
||||
model_name: The embedding model name
|
||||
index_name: OpenSearch index name (defaults to INDEX_NAME from settings)
|
||||
|
||||
Returns:
|
||||
The field name that was ensured to exist
|
||||
|
||||
Raises:
|
||||
Exception: If unable to add the field mapping
|
||||
"""
|
||||
from config.settings import INDEX_NAME
|
||||
from utils.embeddings import get_embedding_dimensions
|
||||
|
||||
if index_name is None:
|
||||
index_name = INDEX_NAME
|
||||
|
||||
field_name = get_embedding_field_name(model_name)
|
||||
dimensions = await get_embedding_dimensions(model_name)
|
||||
|
||||
logger.info(
|
||||
"Ensuring embedding field exists",
|
||||
field_name=field_name,
|
||||
model_name=model_name,
|
||||
dimensions=dimensions,
|
||||
)
|
||||
|
||||
async def _get_field_definition() -> Dict[str, Any]:
|
||||
try:
|
||||
mapping = await opensearch_client.indices.get_mapping(index=index_name)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"Failed to fetch mapping before ensuring embedding field",
|
||||
index=index_name,
|
||||
error=str(e),
|
||||
)
|
||||
return {}
|
||||
|
||||
properties = mapping.get(index_name, {}).get("mappings", {}).get("properties", {})
|
||||
return properties.get(field_name, {}) if isinstance(properties, dict) else {}
|
||||
|
||||
existing_definition = await _get_field_definition()
|
||||
if existing_definition:
|
||||
if existing_definition.get("type") != "knn_vector":
|
||||
raise RuntimeError(
|
||||
f"Field '{field_name}' already exists with incompatible type '{existing_definition.get('type')}'"
|
||||
)
|
||||
return field_name
|
||||
|
||||
# Define the field mapping for both the vector field and the tracking field
|
||||
mapping = {
|
||||
"properties": {
|
||||
field_name: {
|
||||
"type": "knn_vector",
|
||||
"dimension": dimensions,
|
||||
"method": {
|
||||
"name": "disk_ann",
|
||||
"engine": "jvector",
|
||||
"space_type": "l2",
|
||||
"parameters": {"ef_construction": 100, "m": 16},
|
||||
},
|
||||
},
|
||||
# Also ensure the embedding_model tracking field exists as keyword
|
||||
"embedding_model": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"embedding_dimensions": {
|
||||
"type": "integer"
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
# Try to add the mapping
|
||||
# OpenSearch will ignore if field already exists
|
||||
await opensearch_client.indices.put_mapping(
|
||||
index=index_name,
|
||||
body=mapping
|
||||
)
|
||||
logger.info(
|
||||
"Successfully ensured embedding field exists",
|
||||
field_name=field_name,
|
||||
model_name=model_name,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to add embedding field mapping",
|
||||
field_name=field_name,
|
||||
model_name=model_name,
|
||||
error=str(e),
|
||||
)
|
||||
raise
|
||||
|
||||
# Verify mapping was applied correctly
|
||||
new_definition = await _get_field_definition()
|
||||
if new_definition.get("type") != "knn_vector":
|
||||
raise RuntimeError(
|
||||
f"Failed to ensure '{field_name}' is mapped as knn_vector. Current definition: {new_definition}"
|
||||
)
|
||||
|
||||
return field_name
|
||||
|
|
@ -167,6 +167,8 @@ async def create_dynamic_index_body(
|
|||
"mimetype": {"type": "keyword"},
|
||||
"page": {"type": "integer"},
|
||||
"text": {"type": "text"},
|
||||
# Legacy field - kept for backward compatibility
|
||||
# New documents will use chunk_embedding_{model_name} fields
|
||||
"chunk_embedding": {
|
||||
"type": "knn_vector",
|
||||
"dimension": dimensions,
|
||||
|
|
@ -177,6 +179,9 @@ async def create_dynamic_index_body(
|
|||
"parameters": {"ef_construction": 100, "m": 16},
|
||||
},
|
||||
},
|
||||
# Track which embedding model was used for this chunk
|
||||
"embedding_model": {"type": "keyword"},
|
||||
"embedding_dimensions": {"type": "integer"},
|
||||
"source_url": {"type": "keyword"},
|
||||
"connector_type": {"type": "keyword"},
|
||||
"owner": {"type": "keyword"},
|
||||
|
|
|
|||
|
|
@ -20,6 +20,55 @@ from src.session_manager import SessionManager
|
|||
from src.main import generate_jwt_keys
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session", autouse=True)
|
||||
async def onboard_system():
|
||||
"""Perform initial onboarding once for all tests in the session.
|
||||
|
||||
This ensures the OpenRAG config is marked as edited and properly initialized
|
||||
so that tests can use the /settings endpoint.
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
# Delete any existing config to ensure clean onboarding
|
||||
config_file = Path("config/config.yaml")
|
||||
if config_file.exists():
|
||||
config_file.unlink()
|
||||
|
||||
# Initialize clients
|
||||
await clients.initialize()
|
||||
|
||||
# Create app and perform onboarding via API
|
||||
from src.main import create_app, startup_tasks
|
||||
import httpx
|
||||
|
||||
app = await create_app()
|
||||
await startup_tasks(app.state.services)
|
||||
|
||||
transport = httpx.ASGITransport(app=app)
|
||||
async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client:
|
||||
onboarding_payload = {
|
||||
"model_provider": "openai",
|
||||
"embedding_model": "text-embedding-3-small",
|
||||
"llm_model": "gpt-4o-mini",
|
||||
"endpoint": "https://api.openai.com/v1",
|
||||
"sample_data": False,
|
||||
}
|
||||
resp = await client.post("/onboarding", json=onboarding_payload)
|
||||
if resp.status_code not in (200, 204):
|
||||
# If it fails, it might already be onboarded, which is fine
|
||||
print(f"[DEBUG] Onboarding returned {resp.status_code}: {resp.text}")
|
||||
else:
|
||||
print(f"[DEBUG] Session onboarding completed successfully")
|
||||
|
||||
yield
|
||||
|
||||
# Cleanup after all tests
|
||||
try:
|
||||
await clients.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop():
|
||||
"""Create an instance of the default event loop for the test session."""
|
||||
|
|
|
|||
|
|
@ -1,11 +1,43 @@
|
|||
import asyncio
|
||||
import os
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
|
||||
def dump_docker_logs(container_name_pattern: str = "langflow", tail: int = 100):
|
||||
"""Dump Docker container logs for debugging."""
|
||||
try:
|
||||
# Find container ID by name pattern
|
||||
find_cmd = ["docker", "ps", "-a", "--filter", f"name={container_name_pattern}", "--format", "{{.ID}}"]
|
||||
result = subprocess.run(find_cmd, capture_output=True, text=True, timeout=5)
|
||||
container_ids = result.stdout.strip().split('\n')
|
||||
|
||||
if not container_ids or not container_ids[0]:
|
||||
print(f"[DEBUG] No Docker containers found matching pattern: {container_name_pattern}")
|
||||
return
|
||||
|
||||
for container_id in container_ids:
|
||||
if not container_id:
|
||||
continue
|
||||
print(f"\n{'='*80}")
|
||||
print(f"[DEBUG] Docker logs for container {container_id} (last {tail} lines):")
|
||||
print(f"{'='*80}")
|
||||
|
||||
logs_cmd = ["docker", "logs", "--tail", str(tail), container_id]
|
||||
logs_result = subprocess.run(logs_cmd, capture_output=True, text=True, timeout=10)
|
||||
print(logs_result.stdout)
|
||||
if logs_result.stderr:
|
||||
print("[STDERR]:", logs_result.stderr)
|
||||
print(f"{'='*80}\n")
|
||||
except subprocess.TimeoutExpired:
|
||||
print(f"[DEBUG] Timeout while fetching docker logs for {container_name_pattern}")
|
||||
except Exception as e:
|
||||
print(f"[DEBUG] Failed to fetch docker logs for {container_name_pattern}: {e}")
|
||||
|
||||
|
||||
async def wait_for_service_ready(client: httpx.AsyncClient, timeout_s: float = 30.0):
|
||||
"""Poll existing endpoints until the app and OpenSearch are ready.
|
||||
|
||||
|
|
@ -160,11 +192,22 @@ async def test_upload_and_search_endpoint(tmp_path: Path, disable_langflow_inges
|
|||
"text/markdown",
|
||||
)
|
||||
}
|
||||
upload_resp = await client.post("/upload", files=files)
|
||||
upload_resp = await client.post("/router/upload_ingest", files=files)
|
||||
body = upload_resp.json()
|
||||
assert upload_resp.status_code == 201, upload_resp.text
|
||||
assert body.get("status") in {"indexed", "unchanged"}
|
||||
assert isinstance(body.get("id"), str)
|
||||
assert upload_resp.status_code in (201, 202), upload_resp.text
|
||||
|
||||
# Handle different response formats based on whether Langflow is used
|
||||
if disable_langflow_ingest:
|
||||
# Traditional OpenRAG response (201)
|
||||
assert body.get("status") in {"indexed", "unchanged"}
|
||||
assert isinstance(body.get("id"), str)
|
||||
else:
|
||||
# Langflow task response (202)
|
||||
task_id = body.get("task_id")
|
||||
assert isinstance(task_id, str)
|
||||
assert body.get("file_count") == 1
|
||||
# Wait for task completion before searching
|
||||
await _wait_for_task_completion(client, task_id)
|
||||
|
||||
# Poll search for the specific content until it's indexed
|
||||
async def _wait_for_indexed(timeout_s: float = 30.0):
|
||||
|
|
@ -204,6 +247,353 @@ async def test_upload_and_search_endpoint(tmp_path: Path, disable_langflow_inges
|
|||
pass
|
||||
|
||||
|
||||
async def _wait_for_langflow_chat(
|
||||
client: httpx.AsyncClient, payload: dict, timeout_s: float = 120.0
|
||||
) -> dict:
|
||||
deadline = asyncio.get_event_loop().time() + timeout_s
|
||||
last_payload = None
|
||||
while asyncio.get_event_loop().time() < deadline:
|
||||
resp = await client.post("/langflow", json=payload)
|
||||
if resp.status_code == 200:
|
||||
try:
|
||||
data = resp.json()
|
||||
except Exception:
|
||||
last_payload = resp.text
|
||||
else:
|
||||
response_text = data.get("response")
|
||||
if isinstance(response_text, str) and response_text.strip():
|
||||
return data
|
||||
last_payload = data
|
||||
else:
|
||||
last_payload = resp.text
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
# Dump Langflow logs before raising error
|
||||
print(f"\n[DEBUG] /langflow timed out. Dumping Langflow container logs...")
|
||||
dump_docker_logs(container_name_pattern="langflow", tail=200)
|
||||
raise AssertionError(f"/langflow never returned a usable response. Last payload: {last_payload}")
|
||||
|
||||
|
||||
async def _wait_for_nudges(
|
||||
client: httpx.AsyncClient, chat_id: str | None = None, timeout_s: float = 90.0
|
||||
) -> dict:
|
||||
endpoint = "/nudges" if not chat_id else f"/nudges/{chat_id}"
|
||||
deadline = asyncio.get_event_loop().time() + timeout_s
|
||||
last_payload = None
|
||||
while asyncio.get_event_loop().time() < deadline:
|
||||
resp = await client.get(endpoint)
|
||||
if resp.status_code == 200:
|
||||
try:
|
||||
data = resp.json()
|
||||
except Exception:
|
||||
last_payload = resp.text
|
||||
else:
|
||||
response_text = data.get("response")
|
||||
if isinstance(response_text, str) and response_text.strip():
|
||||
return data
|
||||
last_payload = data
|
||||
else:
|
||||
last_payload = resp.text
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
# Dump Langflow logs before raising error
|
||||
print(f"\n[DEBUG] {endpoint} timed out. Dumping Langflow container logs...")
|
||||
dump_docker_logs(container_name_pattern="langflow", tail=200)
|
||||
raise AssertionError(f"{endpoint} never returned a usable response. Last payload: {last_payload}")
|
||||
|
||||
|
||||
async def _wait_for_task_completion(
|
||||
client: httpx.AsyncClient, task_id: str, timeout_s: float = 180.0
|
||||
) -> dict:
|
||||
deadline = asyncio.get_event_loop().time() + timeout_s
|
||||
last_payload = None
|
||||
while asyncio.get_event_loop().time() < deadline:
|
||||
resp = await client.get(f"/tasks/{task_id}")
|
||||
if resp.status_code == 200:
|
||||
try:
|
||||
data = resp.json()
|
||||
except Exception:
|
||||
last_payload = resp.text
|
||||
else:
|
||||
status = (data.get("status") or "").lower()
|
||||
if status == "completed":
|
||||
return data
|
||||
if status == "failed":
|
||||
raise AssertionError(f"Task {task_id} failed: {data}")
|
||||
last_payload = data
|
||||
elif resp.status_code == 404:
|
||||
last_payload = resp.text
|
||||
else:
|
||||
last_payload = resp.text
|
||||
await asyncio.sleep(1.0)
|
||||
raise AssertionError(
|
||||
f"Task {task_id} did not complete in time. Last payload: {last_payload}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip
|
||||
async def test_langflow_chat_and_nudges_endpoints():
|
||||
"""Exercise /langflow and /nudges endpoints against a live Langflow backend."""
|
||||
required_env = ["LANGFLOW_CHAT_FLOW_ID", "NUDGES_FLOW_ID"]
|
||||
missing = [var for var in required_env if not os.getenv(var)]
|
||||
assert not missing, f"Missing required Langflow configuration: {missing}"
|
||||
|
||||
os.environ["DISABLE_INGEST_WITH_LANGFLOW"] = "true"
|
||||
os.environ["DISABLE_STARTUP_INGEST"] = "true"
|
||||
os.environ["GOOGLE_OAUTH_CLIENT_ID"] = ""
|
||||
os.environ["GOOGLE_OAUTH_CLIENT_SECRET"] = ""
|
||||
|
||||
import sys
|
||||
|
||||
for mod in [
|
||||
"src.api.chat",
|
||||
"api.chat",
|
||||
"src.api.nudges",
|
||||
"api.nudges",
|
||||
"src.api.router",
|
||||
"api.router",
|
||||
"src.api.connector_router",
|
||||
"api.connector_router",
|
||||
"src.config.settings",
|
||||
"config.settings",
|
||||
"src.auth_middleware",
|
||||
"auth_middleware",
|
||||
"src.main",
|
||||
"api",
|
||||
"src.api",
|
||||
"services",
|
||||
"src.services",
|
||||
"services.search_service",
|
||||
"src.services.search_service",
|
||||
"services.chat_service",
|
||||
"src.services.chat_service",
|
||||
]:
|
||||
sys.modules.pop(mod, None)
|
||||
|
||||
from src.main import create_app, startup_tasks
|
||||
from src.config.settings import clients, LANGFLOW_CHAT_FLOW_ID, NUDGES_FLOW_ID
|
||||
|
||||
assert LANGFLOW_CHAT_FLOW_ID, "LANGFLOW_CHAT_FLOW_ID must be configured for integration test"
|
||||
assert NUDGES_FLOW_ID, "NUDGES_FLOW_ID must be configured for integration test"
|
||||
|
||||
await clients.initialize()
|
||||
app = await create_app()
|
||||
await startup_tasks(app.state.services)
|
||||
|
||||
langflow_client = None
|
||||
deadline = asyncio.get_event_loop().time() + 60.0
|
||||
while asyncio.get_event_loop().time() < deadline:
|
||||
langflow_client = await clients.ensure_langflow_client()
|
||||
if langflow_client is not None:
|
||||
break
|
||||
await asyncio.sleep(1.0)
|
||||
assert langflow_client is not None, "Langflow client not initialized. Provide LANGFLOW_KEY or enable superuser auto-login."
|
||||
|
||||
transport = httpx.ASGITransport(app=app)
|
||||
try:
|
||||
async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client:
|
||||
await wait_for_service_ready(client)
|
||||
|
||||
# Ensure embedding model is configured via settings
|
||||
resp = await client.post(
|
||||
"/settings",
|
||||
json={
|
||||
"embedding_model": "text-embedding-3-small",
|
||||
"llm_model": "gpt-4o-mini",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200, resp.text
|
||||
|
||||
warmup_file = Path("./nudges_seed.md")
|
||||
warmup_file.write_text(
|
||||
"The user may care about different fruits including apples, hardy kiwi, and bananas"
|
||||
)
|
||||
files = {
|
||||
"file": (
|
||||
warmup_file.name,
|
||||
warmup_file.read_bytes(),
|
||||
"text/plain",
|
||||
)
|
||||
}
|
||||
upload_resp = await client.post("/router/upload_ingest", files=files)
|
||||
assert upload_resp.status_code in (201, 202), upload_resp.text
|
||||
payload = upload_resp.json()
|
||||
task_id = payload.get("task_id")
|
||||
if task_id:
|
||||
await _wait_for_task_completion(client, task_id)
|
||||
|
||||
prompt = "Respond with a brief acknowledgement for the OpenRAG integration test."
|
||||
langflow_payload = {"prompt": prompt, "limit": 5, "scoreThreshold": 0}
|
||||
langflow_data = await _wait_for_langflow_chat(client, langflow_payload)
|
||||
|
||||
assert isinstance(langflow_data.get("response"), str)
|
||||
assert langflow_data["response"].strip()
|
||||
|
||||
response_id = langflow_data.get("response_id")
|
||||
|
||||
nudges_data = await _wait_for_nudges(client)
|
||||
assert isinstance(nudges_data.get("response"), str)
|
||||
assert nudges_data["response"].strip()
|
||||
|
||||
if response_id:
|
||||
nudges_thread_data = await _wait_for_nudges(client, response_id)
|
||||
assert isinstance(nudges_thread_data.get("response"), str)
|
||||
assert nudges_thread_data["response"].strip()
|
||||
finally:
|
||||
from src.config.settings import clients
|
||||
|
||||
try:
|
||||
await clients.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_multi_embedding_models(
|
||||
tmp_path: Path
|
||||
):
|
||||
"""Ensure /search fans out across multiple embedding models when present."""
|
||||
os.environ["DISABLE_INGEST_WITH_LANGFLOW"] = "true"
|
||||
os.environ["DISABLE_STARTUP_INGEST"] = "true"
|
||||
os.environ["GOOGLE_OAUTH_CLIENT_ID"] = ""
|
||||
os.environ["GOOGLE_OAUTH_CLIENT_SECRET"] = ""
|
||||
|
||||
import sys
|
||||
|
||||
for mod in [
|
||||
"src.api.router",
|
||||
"api.router",
|
||||
"src.api.connector_router",
|
||||
"api.connector_router",
|
||||
"src.config.settings",
|
||||
"config.settings",
|
||||
"src.auth_middleware",
|
||||
"auth_middleware",
|
||||
"src.main",
|
||||
"services.search_service",
|
||||
"src.services.search_service",
|
||||
]:
|
||||
sys.modules.pop(mod, None)
|
||||
|
||||
from src.main import create_app, startup_tasks
|
||||
from src.config.settings import clients, INDEX_NAME
|
||||
|
||||
await clients.initialize()
|
||||
try:
|
||||
await clients.opensearch.indices.delete(index=INDEX_NAME)
|
||||
await asyncio.sleep(1)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
app = await create_app()
|
||||
await startup_tasks(app.state.services)
|
||||
|
||||
from src.main import _ensure_opensearch_index
|
||||
|
||||
await _ensure_opensearch_index()
|
||||
|
||||
transport = httpx.ASGITransport(app=app)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client:
|
||||
await wait_for_service_ready(client)
|
||||
|
||||
async def _upload_doc(name: str, text: str) -> None:
|
||||
file_path = tmp_path / name
|
||||
file_path.write_text(text)
|
||||
files = {
|
||||
"file": (
|
||||
name,
|
||||
file_path.read_bytes(),
|
||||
"text/markdown",
|
||||
)
|
||||
}
|
||||
resp = await client.post("/router/upload_ingest", files=files)
|
||||
assert resp.status_code in (201, 202), resp.text
|
||||
payload = resp.json()
|
||||
task_id = payload.get("task_id")
|
||||
if task_id:
|
||||
await _wait_for_task_completion(client, task_id)
|
||||
|
||||
async def _wait_for_models(expected_models: set[str], query: str = "*"):
|
||||
deadline = asyncio.get_event_loop().time() + 60.0
|
||||
last_payload = None
|
||||
while asyncio.get_event_loop().time() < deadline:
|
||||
resp = await client.post(
|
||||
"/search",
|
||||
json={"query": query, "limit": 0, "scoreThreshold": 0},
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
last_payload = resp.text
|
||||
await asyncio.sleep(0.5)
|
||||
continue
|
||||
payload = resp.json()
|
||||
buckets = (
|
||||
payload.get("aggregations", {})
|
||||
.get("embedding_models", {})
|
||||
.get("buckets", [])
|
||||
)
|
||||
models = {b.get("key") for b in buckets if b.get("key")}
|
||||
if expected_models <= models:
|
||||
return payload
|
||||
last_payload = payload
|
||||
await asyncio.sleep(0.5)
|
||||
raise AssertionError(
|
||||
f"Embedding models not detected. Last payload: {last_payload}"
|
||||
)
|
||||
|
||||
# Start with explicit small embedding model
|
||||
resp = await client.post(
|
||||
"/settings",
|
||||
json={
|
||||
"embedding_model": "text-embedding-3-small",
|
||||
"llm_model": "gpt-4o-mini",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200, resp.text
|
||||
|
||||
# Ingest first document (small model)
|
||||
await _upload_doc("doc-small.md", "Physics basics and fundamental principles.")
|
||||
payload_small = await _wait_for_models({"text-embedding-3-small"})
|
||||
result_models_small = {
|
||||
r.get("embedding_model")
|
||||
for r in (payload_small.get("results") or [])
|
||||
if r.get("embedding_model")
|
||||
}
|
||||
assert "text-embedding-3-small" in result_models_small or not result_models_small
|
||||
|
||||
# Update embedding model via settings
|
||||
resp = await client.post(
|
||||
"/settings",
|
||||
json={"embedding_model": "text-embedding-3-large"},
|
||||
)
|
||||
assert resp.status_code == 200, resp.text
|
||||
|
||||
# Ingest second document which should use the large embedding model
|
||||
await _upload_doc("doc-large.md", "Advanced physics covers quantum topics extensively.")
|
||||
|
||||
payload = await _wait_for_models({"text-embedding-3-small", "text-embedding-3-large"})
|
||||
buckets = payload.get("aggregations", {}).get("embedding_models", {}).get("buckets", [])
|
||||
models = {b.get("key") for b in buckets}
|
||||
assert {"text-embedding-3-small", "text-embedding-3-large"} <= models
|
||||
|
||||
result_models = {
|
||||
r.get("embedding_model")
|
||||
for r in (payload.get("results") or [])
|
||||
if r.get("embedding_model")
|
||||
}
|
||||
assert {"text-embedding-3-small", "text-embedding-3-large"} <= result_models
|
||||
finally:
|
||||
from src.config.settings import clients
|
||||
|
||||
try:
|
||||
await clients.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.parametrize("disable_langflow_ingest", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_upload_ingest_traditional(tmp_path: Path, disable_langflow_ingest: bool):
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue