feat: add CategorizeHandle #918 (#1282)

### What problem does this PR solve?

feat: add CategorizeHandle #918

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
balibabu 2024-06-27 09:20:19 +08:00 committed by GitHub
parent e43208a1ca
commit fa5695c250
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 198 additions and 65 deletions

View File

@ -0,0 +1,39 @@
import { Handle, Position } from 'reactflow';
// import { v4 as uuid } from 'uuid';
import styles from './index.less';
const DEFAULT_HANDLE_STYLE = {
width: 6,
height: 6,
bottom: -5,
fontSize: 8,
};
interface IProps {
top: number;
right: number;
text: string;
idx: number;
}
const CategorizeHandle = ({ top, right, text, idx }: IProps) => {
return (
<Handle
type="source"
position={Position.Right}
id={`CategorizeHandle${idx}`}
isConnectable
style={{
...DEFAULT_HANDLE_STYLE,
top: `${top}%`,
right: `${right}%`,
background: 'red',
}}
>
<span className={styles.categorizeAnchorPointText}>{text}</span>
</Handle>
);
};
export default CategorizeHandle;

View File

@ -37,6 +37,12 @@
-3px 0 6px -4px rgba(0, 0, 0, 0.12), -3px 0 6px -4px rgba(0, 0, 0, 0.12),
-6px 0 16px 6px rgba(0, 0, 0, 0.05); -6px 0 16px 6px rgba(0, 0, 0, 0.05);
} }
.categorizeAnchorPointText {
position: absolute;
top: -4px;
left: 8px;
white-space: nowrap;
}
} }
.selectedNode { .selectedNode {
border: 1px solid rgb(59, 118, 244); border: 1px solid rgb(59, 118, 244);

View File

@ -4,12 +4,14 @@ import { Handle, NodeProps, Position } from 'reactflow';
import OperateDropdown from '@/components/operate-dropdown'; import OperateDropdown from '@/components/operate-dropdown';
import { CopyOutlined } from '@ant-design/icons'; import { CopyOutlined } from '@ant-design/icons';
import { Flex, MenuProps, Space } from 'antd'; import { Flex, MenuProps, Space } from 'antd';
import get from 'lodash/get';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { Operator, operatorMap } from '../../constant'; import { CategorizeAnchorPointPositions, Operator } from '../../constant';
import { NodeData } from '../../interface'; import { NodeData } from '../../interface';
import OperatorIcon from '../../operator-icon'; import OperatorIcon from '../../operator-icon';
import useGraphStore from '../../store'; import useGraphStore from '../../store';
import CategorizeHandle from './categorize-handle';
import styles from './index.less'; import styles from './index.less';
export function RagNode({ export function RagNode({
@ -30,7 +32,8 @@ export function RagNode({
duplicateNodeById(id); duplicateNodeById(id);
}, [id, duplicateNodeById]); }, [id, duplicateNodeById]);
const description = operatorMap[data.label as Operator].description; const isCategorize = data.label === Operator.Categorize;
const categoryData = get(data, 'form.category_description') ?? {};
const items: MenuProps['items'] = [ const items: MenuProps['items'] = [
{ {
@ -57,9 +60,7 @@ export function RagNode({
position={Position.Left} position={Position.Left}
isConnectable={isConnectable} isConnectable={isConnectable}
className={styles.handle} className={styles.handle}
> ></Handle>
{/* <PlusCircleOutlined style={{ fontSize: 10 }} /> */}
</Handle>
<Handle type="source" position={Position.Top} id="d" isConnectable /> <Handle type="source" position={Position.Top} id="d" isConnectable />
<Handle <Handle
type="source" type="source"
@ -67,34 +68,32 @@ export function RagNode({
isConnectable={isConnectable} isConnectable={isConnectable}
className={styles.handle} className={styles.handle}
id="b" id="b"
> ></Handle>
{/* <PlusCircleOutlined style={{ fontSize: 10 }} /> */}
</Handle>
<Handle type="source" position={Position.Bottom} id="a" isConnectable /> <Handle type="source" position={Position.Bottom} id="a" isConnectable />
{isCategorize &&
Object.keys(categoryData).map((x, idx) => (
<CategorizeHandle
top={CategorizeAnchorPointPositions[idx].top}
right={CategorizeAnchorPointPositions[idx].right}
key={idx}
text={x}
idx={idx}
></CategorizeHandle>
))}
<Flex vertical align="center" justify="center"> <Flex vertical align="center" justify="center">
<Space size={6}> <Space size={6}>
<OperatorIcon <OperatorIcon
name={data.label as Operator} name={data.label as Operator}
fontSize={16} fontSize={16}
></OperatorIcon> ></OperatorIcon>
{/* {data.label} */}
<OperateDropdown <OperateDropdown
iconFontSize={14} iconFontSize={14}
deleteItem={deleteNode} deleteItem={deleteNode}
items={items} items={items}
></OperateDropdown> ></OperateDropdown>
</Space> </Space>
{/* <div className={styles.nodeName}>{id}</div> */}
</Flex> </Flex>
{/* <div>
<Text
ellipsis={{ tooltip: description }}
style={{ width: 130 }}
className={styles.description}
>
{description}
</Text>
</div> */}
<section className={styles.bottomBox}> <section className={styles.bottomBox}>
<div className={styles.nodeName}>{id}</div> <div className={styles.nodeName}>{id}</div>
</section> </section>

View File

@ -1,58 +1,78 @@
import { CloseOutlined } from '@ant-design/icons'; import { CloseOutlined } from '@ant-design/icons';
import { Button, Card, Form, Input, Select, Typography } from 'antd'; import { Button, Card, Form, Input, Select, Typography } from 'antd';
import { useBuildCategorizeToOptions } from './hooks'; import { useBuildCategorizeToOptions, useHandleToSelectChange } from './hooks';
const DynamicCategorize = () => { interface IProps {
nodeId?: string;
}
const DynamicCategorize = ({ nodeId }: IProps) => {
const form = Form.useFormInstance(); const form = Form.useFormInstance();
const options = useBuildCategorizeToOptions(); const options = useBuildCategorizeToOptions();
const { handleSelectChange } = useHandleToSelectChange(
options.map((x) => x.value),
nodeId,
);
return ( return (
<> <>
<Form.List name="items"> <Form.List name="items">
{(fields, { add, remove }) => ( {(fields, { add, remove }) => {
<div style={{ display: 'flex', rowGap: 16, flexDirection: 'column' }}> const handleAdd = () => {
{fields.map((field) => ( const idx = fields.length;
<Card add({ name: `Categorize ${idx + 1}` });
size="small" };
key={field.key} return (
extra={ <div
<CloseOutlined style={{ display: 'flex', rowGap: 10, flexDirection: 'column' }}
onClick={() => { >
remove(field.name); {fields.map((field) => (
}} <Card
/> size="small"
} key={field.key}
> extra={
<Form.Item <CloseOutlined
label="name" onClick={() => {
name={[field.name, 'name']} remove(field.name);
initialValue={`Categorize ${field.name + 1}`} }}
rules={[ />
{ required: true, message: 'Please input your name!' }, }
]}
> >
<Input /> <Form.Item
</Form.Item> label="name"
<Form.Item name={[field.name, 'name']}
label="description" // initialValue={`Categorize ${field.name + 1}`}
name={[field.name, 'description']} rules={[
> { required: true, message: 'Please input your name!' },
<Input.TextArea rows={3} /> ]}
</Form.Item> >
<Form.Item label="examples" name={[field.name, 'examples']}> <Input />
<Input.TextArea rows={3} /> </Form.Item>
</Form.Item> <Form.Item
<Form.Item label="to" name={[field.name, 'to']}> label="description"
<Select options={options} /> name={[field.name, 'description']}
</Form.Item> >
</Card> <Input.TextArea rows={3} />
))} </Form.Item>
<Form.Item label="examples" name={[field.name, 'examples']}>
<Input.TextArea rows={3} />
</Form.Item>
<Form.Item label="to" name={[field.name, 'to']}>
<Select
allowClear
options={options}
onChange={handleSelectChange}
/>
</Form.Item>
</Card>
))}
<Button type="dashed" onClick={() => add()} block> <Button type="dashed" onClick={handleAdd} block>
+ Add Item + Add Item
</Button> </Button>
</div> </div>
)} );
}}
</Form.List> </Form.List>
<Form.Item noStyle shouldUpdate> <Form.Item noStyle shouldUpdate>

View File

@ -1,6 +1,6 @@
import get from 'lodash/get'; import get from 'lodash/get';
import omit from 'lodash/omit'; import omit from 'lodash/omit';
import { useCallback, useEffect } from 'react'; import { useCallback, useEffect, useRef } from 'react';
import { Operator } from '../constant'; import { Operator } from '../constant';
import { import {
ICategorizeItem, ICategorizeItem,
@ -72,6 +72,7 @@ export const useHandleFormValuesChange = ({
}: IOperatorForm) => { }: IOperatorForm) => {
const handleValuesChange = useCallback( const handleValuesChange = useCallback(
(changedValues: any, values: any) => { (changedValues: any, values: any) => {
console.info(changedValues, values);
onValuesChange?.(changedValues, { onValuesChange?.(changedValues, {
...omit(values, 'items'), ...omit(values, 'items'),
category_description: buildCategorizeObjectFromList(values.items), category_description: buildCategorizeObjectFromList(values.items),
@ -90,3 +91,38 @@ export const useHandleFormValuesChange = ({
return { handleValuesChange }; return { handleValuesChange };
}; };
export const useHandleToSelectChange = (
opstionIds: string[],
nodeId?: string,
) => {
// const [previousTarget, setPreviousTarget] = useState('');
const previousTarget = useRef('');
const { addEdge, deleteEdgeBySourceAndTarget } = useGraphStore(
(state) => state,
);
const handleSelectChange = useCallback(
(value?: string) => {
if (nodeId) {
if (previousTarget.current) {
// delete previous edge
deleteEdgeBySourceAndTarget(nodeId, previousTarget.current);
}
if (value) {
addEdge({
source: nodeId,
target: value,
sourceHandle: 'b',
targetHandle: 'd',
});
} else {
// if the value is empty, delete the edges between the current node and all nodes in the drop-down box.
}
previousTarget.current = value;
}
},
[addEdge, nodeId, deleteEdgeBySourceAndTarget],
);
return { handleSelectChange };
};

View File

@ -32,7 +32,7 @@ const CategorizeForm = ({ form, onValuesChange, node }: IOperatorForm) => {
> >
<LLMSelect></LLMSelect> <LLMSelect></LLMSelect>
</Form.Item> </Form.Item>
<DynamicCategorize></DynamicCategorize> <DynamicCategorize nodeId={node?.id}></DynamicCategorize>
</Form> </Form>
); );
}; };

View File

@ -82,3 +82,18 @@ export const initialFormValuesMap = {
[Operator.Answer]: {}, [Operator.Answer]: {},
[Operator.Categorize]: {}, [Operator.Categorize]: {},
}; };
export const CategorizeAnchorPointPositions = [
{ top: 1, right: 34 },
{ top: 8, right: 18 },
{ top: 15, right: 10 },
{ top: 24, right: 4 },
{ top: 31, right: 1 },
{ top: 38, right: -2 },
{ top: 62, right: -2 }, //bottom
{ top: 71, right: 1 },
{ top: 79, right: 6 },
{ top: 86, right: 12 },
{ top: 91, right: 20 },
{ top: 98, right: 34 },
];

View File

@ -34,10 +34,12 @@ export type RFState = {
onSelectionChange: OnSelectionChangeFunc; onSelectionChange: OnSelectionChangeFunc;
addNode: (nodes: Node) => void; addNode: (nodes: Node) => void;
getNode: (id: string) => Node | undefined; getNode: (id: string) => Node | undefined;
addEdge: (connection: Connection) => void;
duplicateNode: (id: string) => void; duplicateNode: (id: string) => void;
deleteEdge: () => void; deleteEdge: () => void;
deleteEdgeById: (id: string) => void; deleteEdgeById: (id: string) => void;
deleteNodeById: (id: string) => void; deleteNodeById: (id: string) => void;
deleteEdgeBySourceAndTarget: (source: string, target: string) => void;
findNodeByName: (operatorName: Operator) => Node | undefined; findNodeByName: (operatorName: Operator) => Node | undefined;
findNodeById: (id: string) => Node | undefined; findNodeById: (id: string) => Node | undefined;
}; };
@ -83,6 +85,14 @@ const useGraphStore = create<RFState>()(
getNode: (id: string) => { getNode: (id: string) => {
return get().nodes.find((x) => x.id === id); return get().nodes.find((x) => x.id === id);
}, },
addEdge: (connection: Connection) => {
set({
edges: addEdge(connection, get().edges),
});
},
// addOnlyOneEdgeBetweenTwoNodes: (connection: Connection) => {
// },
duplicateNode: (id: string) => { duplicateNode: (id: string) => {
const { getNode, addNode } = get(); const { getNode, addNode } = get();
const node = getNode(id); const node = getNode(id);
@ -114,6 +124,14 @@ const useGraphStore = create<RFState>()(
edges: edges.filter((edge) => edge.id !== id), edges: edges.filter((edge) => edge.id !== id),
}); });
}, },
deleteEdgeBySourceAndTarget: (source: string, target: string) => {
const { edges } = get();
set({
edges: edges.filter(
(edge) => edge.target !== target && edge.source !== source,
),
});
},
deleteNodeById: (id: string) => { deleteNodeById: (id: string) => {
const { nodes, edges } = get(); const { nodes, edges } = get();
set({ set({