diff --git a/web/src/constants/agent.tsx b/web/src/constants/agent.tsx index 0ba1d927c..92d10e329 100644 --- a/web/src/constants/agent.tsx +++ b/web/src/constants/agent.tsx @@ -118,6 +118,8 @@ export enum Operator { Splitter = 'Splitter', HierarchicalMerger = 'HierarchicalMerger', Extractor = 'Extractor', + Loop = 'Loop', + LoopStart = 'LoopItem', } export enum ComparisonOperator { diff --git a/web/src/pages/agent/canvas/index.tsx b/web/src/pages/agent/canvas/index.tsx index f2fc983e2..3a217e1bc 100644 --- a/web/src/pages/agent/canvas/index.tsx +++ b/web/src/pages/agent/canvas/index.tsx @@ -62,6 +62,7 @@ import { InvokeNode } from './node/invoke-node'; import { IterationNode, IterationStartNode } from './node/iteration-node'; import { KeywordNode } from './node/keyword-node'; import { ListOperationsNode } from './node/list-operations-node'; +import { LoopNode, LoopStartNode } from './node/loop-node'; import { MessageNode } from './node/message-node'; import NoteNode from './node/note-node'; import ParserNode from './node/parser-node'; @@ -105,6 +106,8 @@ export const nodeTypes: NodeTypes = { listOperationsNode: ListOperationsNode, variableAssignerNode: VariableAssignerNode, variableAggregatorNode: VariableAggregatorNode, + loopNode: LoopNode, + loopStartNode: LoopStartNode, }; const edgeTypes = { diff --git a/web/src/pages/agent/canvas/node/dropdown.tsx b/web/src/pages/agent/canvas/node/dropdown.tsx deleted file mode 100644 index dd5263abc..000000000 --- a/web/src/pages/agent/canvas/node/dropdown.tsx +++ /dev/null @@ -1,58 +0,0 @@ -import OperateDropdown from '@/components/operate-dropdown'; -import { CopyOutlined } from '@ant-design/icons'; -import { Flex, MenuProps } from 'antd'; -import { useCallback } from 'react'; -import { useTranslation } from 'react-i18next'; -import { Operator } from '../../constant'; -import { useDuplicateNode } from '../../hooks'; -import useGraphStore from '../../store'; - -interface IProps { - id: string; - iconFontColor?: string; - label: string; -} - -const NodeDropdown = ({ id, iconFontColor, label }: IProps) => { - const { t } = useTranslation(); - const deleteNodeById = useGraphStore((store) => store.deleteNodeById); - const deleteIterationNodeById = useGraphStore( - (store) => store.deleteIterationNodeById, - ); - - const deleteNode = useCallback(() => { - if (label === Operator.Iteration) { - deleteIterationNodeById(id); - } else { - deleteNodeById(id); - } - }, [label, deleteIterationNodeById, id, deleteNodeById]); - - const duplicateNode = useDuplicateNode(); - - const items: MenuProps['items'] = [ - { - key: '2', - onClick: () => duplicateNode(id, label), - label: ( - - {t('common.copy')} - - - ), - }, - ]; - - return ( - - ); -}; - -export default NodeDropdown; diff --git a/web/src/pages/agent/canvas/node/dropdown/accordion-operators.tsx b/web/src/pages/agent/canvas/node/dropdown/accordion-operators.tsx index 6021420c5..184cd1951 100644 --- a/web/src/pages/agent/canvas/node/dropdown/accordion-operators.tsx +++ b/web/src/pages/agent/canvas/node/dropdown/accordion-operators.tsx @@ -62,6 +62,7 @@ export function AccordionOperators({ operators={[ Operator.Switch, Operator.Iteration, + Operator.Loop, Operator.Categorize, ]} isCustomDropdown={isCustomDropdown} diff --git a/web/src/pages/agent/canvas/node/iteration-node.tsx b/web/src/pages/agent/canvas/node/iteration-node.tsx index a11da33dc..c893e7723 100644 --- a/web/src/pages/agent/canvas/node/iteration-node.tsx +++ b/web/src/pages/agent/canvas/node/iteration-node.tsx @@ -56,7 +56,7 @@ export function InnerIterationNode({ ); } -function InnerIterationStartNode({ +export function InnerIterationStartNode({ isConnectable = true, id, selected, diff --git a/web/src/pages/agent/canvas/node/labeled-group-node.tsx b/web/src/pages/agent/canvas/node/labeled-group-node.tsx new file mode 100644 index 000000000..b7f4f1d75 --- /dev/null +++ b/web/src/pages/agent/canvas/node/labeled-group-node.tsx @@ -0,0 +1,75 @@ +import { Panel, type NodeProps, type PanelPosition } from '@xyflow/react'; +import { type ComponentProps, type ReactNode } from 'react'; + +import { BaseNode } from '@/components/xyflow/base-node'; +import { cn } from '@/lib/utils'; + +/* GROUP NODE Label ------------------------------------------------------- */ + +export type GroupNodeLabelProps = ComponentProps<'div'>; + +export function GroupNodeLabel({ + children, + className, + ...props +}: GroupNodeLabelProps) { + return ( +
+
+ {children} +
+
+ ); +} + +export type GroupNodeProps = Partial & { + label?: ReactNode; + position?: PanelPosition; +}; + +/* GROUP NODE -------------------------------------------------------------- */ + +export function LabeledGroupNode({ + label = '', + position, + ...props +}: GroupNodeProps) { + const getLabelClassName = (position?: PanelPosition) => { + switch (position) { + case 'top-left': + return 'rounded-br-sm'; + case 'top-center': + return 'rounded-b-sm'; + case 'top-right': + return 'rounded-bl-sm'; + case 'bottom-left': + return 'rounded-tr-sm'; + case 'bottom-right': + return 'rounded-tl-sm'; + case 'bottom-center': + return 'rounded-t-sm'; + default: + return 'rounded-br-sm'; + } + }; + + return ( + + + {label && ( + + {label} + + )} + + + ); +} diff --git a/web/src/pages/agent/canvas/node/loop-node.tsx b/web/src/pages/agent/canvas/node/loop-node.tsx new file mode 100644 index 000000000..6afbe841a --- /dev/null +++ b/web/src/pages/agent/canvas/node/loop-node.tsx @@ -0,0 +1,16 @@ +import { BaseNode } from '@/interfaces/database/agent'; +import { NodeProps } from '@xyflow/react'; +import { memo } from 'react'; +import { InnerIterationNode, InnerIterationStartNode } from './iteration-node'; + +export function InnerLoopNode({ ...props }: NodeProps>) { + return ; +} + +export const LoopNode = memo(InnerLoopNode); + +export function InnerLoopStartNode({ ...props }: NodeProps>) { + return ; +} + +export const LoopStartNode = memo(InnerLoopStartNode); diff --git a/web/src/pages/agent/canvas/node/toolbar.tsx b/web/src/pages/agent/canvas/node/toolbar.tsx index 74f6ae3db..775ba228d 100644 --- a/web/src/pages/agent/canvas/node/toolbar.tsx +++ b/web/src/pages/agent/canvas/node/toolbar.tsx @@ -58,7 +58,7 @@ export function ToolBar({ const deleteNode: MouseEventHandler = useCallback( (e) => { e.stopPropagation(); - if (label === Operator.Iteration) { + if ([Operator.Iteration, Operator.Loop].includes(label as Operator)) { deleteIterationNodeById(id); } else { deleteNodeById(id); diff --git a/web/src/pages/agent/constant/index.tsx b/web/src/pages/agent/constant/index.tsx index c357e8cb7..16dbc12e3 100644 --- a/web/src/pages/agent/constant/index.tsx +++ b/web/src/pages/agent/constant/index.tsx @@ -625,6 +625,8 @@ export const initialVariableAssignerValues = {}; export const initialVariableAggregatorValues = { outputs: {}, groups: [] }; +export const initialLoopValues = { outputs: {} }; + export const CategorizeAnchorPointPositions = [ { top: 1, right: 34 }, { top: 8, right: 18 }, @@ -707,6 +709,8 @@ export const RestrictedUpstreamMap = { [Operator.Tokenizer]: [Operator.Begin], [Operator.Extractor]: [Operator.Begin], [Operator.File]: [Operator.Begin], + [Operator.Loop]: [Operator.Begin], + [Operator.LoopStart]: [Operator.Begin], }; export const NodeMap = { @@ -759,6 +763,8 @@ export const NodeMap = { [Operator.ListOperations]: 'listOperationsNode', [Operator.VariableAssigner]: 'variableAssignerNode', [Operator.VariableAggregator]: 'variableAggregatorNode', + [Operator.Loop]: 'loopNode', + [Operator.LoopStart]: 'loopStartNode', }; export enum BeginQueryType { diff --git a/web/src/pages/agent/hooks/use-add-node.ts b/web/src/pages/agent/hooks/use-add-node.ts index 44091f1b1..5392cc489 100644 --- a/web/src/pages/agent/hooks/use-add-node.ts +++ b/web/src/pages/agent/hooks/use-add-node.ts @@ -32,6 +32,7 @@ import { initialJin10Values, initialKeywordExtractValues, initialListOperationsValues, + initialLoopValues, initialMessageValues, initialNoteValues, initialParserValues, @@ -68,6 +69,63 @@ function isBottomSubAgent(type: string, position: Position) { type === Operator.Tool ); } + +const GroupStartNodeMap = { + [Operator.Iteration]: { + id: `${Operator.IterationStart}:${humanId()}`, + type: 'iterationStartNode', + position: { x: 50, y: 100 }, + data: { + label: Operator.IterationStart, + name: Operator.IterationStart, + form: initialIterationStartValues, + }, + extent: 'parent' as 'parent', + }, + [Operator.Loop]: { + id: `${Operator.LoopStart}:${humanId()}`, + type: 'loopStartNode', + position: { x: 50, y: 100 }, + data: { + label: Operator.LoopStart, + name: Operator.LoopStart, + form: {}, + }, + extent: 'parent' as 'parent', + }, +}; + +function useAddGroupNode() { + const { addEdge, addNode } = useGraphStore((state) => state); + + const addGroupNode = useCallback( + (operatorType: string, newNode: Node, nodeId?: string) => { + newNode.width = 500; + newNode.height = 250; + + const startNode: Node = + GroupStartNodeMap[operatorType as keyof typeof GroupStartNodeMap]; + + startNode.parentId = newNode.id; + + addNode(newNode); + addNode(startNode); + + if (nodeId) { + addEdge({ + source: nodeId, + target: newNode.id, + sourceHandle: NodeHandleId.Start, + targetHandle: NodeHandleId.End, + }); + } + return newNode.id; + }, + [addEdge, addNode], + ); + + return { addGroupNode }; +} export const useInitializeOperatorParams = () => { const llmId = useFetchModelId(); @@ -133,6 +191,8 @@ export const useInitializeOperatorParams = () => { [Operator.ListOperations]: initialListOperationsValues, [Operator.VariableAssigner]: initialVariableAssignerValues, [Operator.VariableAggregator]: initialVariableAggregatorValues, + [Operator.Loop]: initialLoopValues, + [Operator.LoopStart]: {}, }; }, [llmId]); @@ -311,6 +371,7 @@ export function useAddNode(reactFlowInstance?: ReactFlowInstance) { const { addChildEdge } = useAddChildEdge(); const { addToolNode } = useAddToolNode(); const { resizeIterationNode } = useResizeIterationNode(); + const { addGroupNode } = useAddGroupNode(); // const [reactFlowInstance, setReactFlowInstance] = // useState>(); @@ -376,33 +437,8 @@ export function useAddNode(reactFlowInstance?: ReactFlowInstance) { } } - if (type === Operator.Iteration) { - newNode.width = 500; - newNode.height = 250; - const iterationStartNode: Node = { - id: `${Operator.IterationStart}:${humanId()}`, - type: 'iterationStartNode', - position: { x: 50, y: 100 }, - // draggable: false, - data: { - label: Operator.IterationStart, - name: Operator.IterationStart, - form: initialIterationStartValues, - }, - parentId: newNode.id, - extent: 'parent', - }; - addNode(newNode); - addNode(iterationStartNode); - if (nodeId) { - addEdge({ - source: nodeId, - target: newNode.id, - sourceHandle: NodeHandleId.Start, - targetHandle: NodeHandleId.End, - }); - } - return newNode.id; + if ([Operator.Iteration, Operator.Loop].includes(type as Operator)) { + return addGroupNode(type, newNode, nodeId); } else if ( type === Operator.Agent && params.position === Position.Bottom @@ -456,6 +492,7 @@ export function useAddNode(reactFlowInstance?: ReactFlowInstance) { [ addChildEdge, addEdge, + addGroupNode, addNode, addToolNode, calculateNewlyBackChildPosition, diff --git a/web/src/pages/agent/hooks/use-show-drawer.tsx b/web/src/pages/agent/hooks/use-show-drawer.tsx index a350af074..27717ecb4 100644 --- a/web/src/pages/agent/hooks/use-show-drawer.tsx +++ b/web/src/pages/agent/hooks/use-show-drawer.tsx @@ -14,6 +14,7 @@ export const useShowFormDrawer = () => { setClickedNodeId, getNode, setClickedToolId, + getOperatorTypeFromId, } = useGraphStore((state) => state); const { visible: formDrawerVisible, @@ -25,14 +26,18 @@ export const useShowFormDrawer = () => { (e: React.MouseEvent, nodeId: string) => { const tool = get(e.target, 'dataset.tool'); // TODO: Operator type judgment should be used - if (nodeId.startsWith(Operator.Tool) && !tool) { + const operatorType = getOperatorTypeFromId(nodeId); + if ( + (operatorType === Operator.Tool && !tool) || + [Operator.LoopStart].includes(operatorType as Operator) + ) { return; } setClickedNodeId(nodeId); setClickedToolId(tool); showFormDrawer(); }, - [setClickedNodeId, setClickedToolId, showFormDrawer], + [getOperatorTypeFromId, setClickedNodeId, setClickedToolId, showFormDrawer], ); return { diff --git a/web/src/pages/agent/operator-icon.tsx b/web/src/pages/agent/operator-icon.tsx index bca93d7fa..d2b4233d6 100644 --- a/web/src/pages/agent/operator-icon.tsx +++ b/web/src/pages/agent/operator-icon.tsx @@ -14,7 +14,7 @@ import { ReactComponent as YahooFinanceIcon } from '@/assets/svg/yahoo-finance.s import { IconFontFill } from '@/components/icon-font'; import { cn } from '@/lib/utils'; -import { FileCode, HousePlus } from 'lucide-react'; +import { FileCode, HousePlus, Infinity as InfinityIcon } from 'lucide-react'; import { Operator } from './constant'; interface IProps { @@ -60,6 +60,7 @@ export const SVGIconMap = { }; export const LucideIconMap = { [Operator.DataOperations]: FileCode, + [Operator.Loop]: InfinityIcon, }; const Empty = () => {