diff --git a/api/core/tools/tool/tool.py b/api/core/tools/tool/tool.py index f26f9494a1..4203180992 100644 --- a/api/core/tools/tool/tool.py +++ b/api/core/tools/tool/tool.py @@ -243,8 +243,21 @@ class Tool(BaseModel, ABC): tool_parameters[parameter.name] = float(tool_parameters[parameter.name]) elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN: if not isinstance(tool_parameters[parameter.name], bool): - tool_parameters[parameter.name] = bool(tool_parameters[parameter.name]) - + # check if it is a string + if isinstance(tool_parameters[parameter.name], str): + # check true false + if tool_parameters[parameter.name].lower() in ['true', 'false']: + tool_parameters[parameter.name] = tool_parameters[parameter.name].lower() == 'true' + # check 1 0 + elif tool_parameters[parameter.name] in ['1', '0']: + tool_parameters[parameter.name] = tool_parameters[parameter.name] == '1' + else: + tool_parameters[parameter.name] = bool(tool_parameters[parameter.name]) + elif isinstance(tool_parameters[parameter.name], int | float): + tool_parameters[parameter.name] = tool_parameters[parameter.name] != 0 + else: + tool_parameters[parameter.name] = bool(tool_parameters[parameter.name]) + return tool_parameters @abstractmethod diff --git a/api/core/workflow/nodes/tool/entities.py b/api/core/workflow/nodes/tool/entities.py index aec8f34bb9..97fbe8a999 100644 --- a/api/core/workflow/nodes/tool/entities.py +++ b/api/core/workflow/nodes/tool/entities.py @@ -1,10 +1,9 @@ -from typing import Literal, Union +from typing import Any, Literal, Union from pydantic import BaseModel, validator from core.workflow.entities.base_node_data_entities import BaseNodeData -ToolParameterValue = Union[str, int, float, bool] class ToolEntity(BaseModel): provider_id: str @@ -12,11 +11,23 @@ class ToolEntity(BaseModel): provider_name: str # redundancy tool_name: str tool_label: str # redundancy - tool_configurations: dict[str, ToolParameterValue] + tool_configurations: dict[str, Any] + + @validator('tool_configurations', pre=True, always=True) + def validate_tool_configurations(cls, value, values): + if not isinstance(value, dict): + raise ValueError('tool_configurations must be a dictionary') + + for key in values.get('tool_configurations', {}).keys(): + value = values.get('tool_configurations', {}).get(key) + if not isinstance(value, str | int | float | bool): + raise ValueError(f'{key} must be a string') + + return value class ToolNodeData(BaseNodeData, ToolEntity): class ToolInput(BaseModel): - value: Union[ToolParameterValue, list[str]] + value: Union[Any, list[str]] type: Literal['mixed', 'variable', 'constant'] @validator('type', pre=True, always=True) @@ -25,12 +36,16 @@ class ToolNodeData(BaseNodeData, ToolEntity): value = values.get('value') if typ == 'mixed' and not isinstance(value, str): raise ValueError('value must be a string') - elif typ == 'variable' and not isinstance(value, list): - raise ValueError('value must be a list') - elif typ == 'constant' and not isinstance(value, ToolParameterValue): + elif typ == 'variable': + if not isinstance(value, list): + raise ValueError('value must be a list') + for val in value: + if not isinstance(val, str): + raise ValueError('value must be a list of strings') + elif typ == 'constant' and not isinstance(value, str | int | float | bool): raise ValueError('value must be a string, int, float, or bool') return typ - + """ Tool Node Schema """ diff --git a/web/app/components/workflow/nodes/tool/use-config.ts b/web/app/components/workflow/nodes/tool/use-config.ts index 449a39bb9f..5d499e7782 100644 --- a/web/app/components/workflow/nodes/tool/use-config.ts +++ b/web/app/components/workflow/nodes/tool/use-config.ts @@ -1,4 +1,4 @@ -import { useCallback, useEffect, useState } from 'react' +import { useCallback, useEffect, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' import produce from 'immer' import { useBoolean } from 'ahooks' @@ -25,7 +25,7 @@ const useConfig = (id: string, payload: ToolNodeType) => { const { t } = useTranslation() const language = useLanguage() - const { inputs, setInputs } = useNodeCrud(id, payload) + const { inputs, setInputs: doSetInputs } = useNodeCrud(id, payload) /* * tool_configurations: tool setting, not dynamic setting * tool_parameters: tool dynamic setting(by user) @@ -58,10 +58,41 @@ const useConfig = (id: string, payload: ToolNodeType) => { }, [currCollection?.name, hideSetAuthModal, t, handleFetchAllTools, provider_type]) const currTool = currCollection?.tools.find(tool => tool.name === tool_name) - const formSchemas = currTool ? toolParametersToFormSchemas(currTool.parameters) : [] + const formSchemas = useMemo(() => { + return currTool ? toolParametersToFormSchemas(currTool.parameters) : [] + }, [currTool]) const toolInputVarSchema = formSchemas.filter((item: any) => item.form === 'llm') // use setting const toolSettingSchema = formSchemas.filter((item: any) => item.form !== 'llm') + const hasShouldTransferTypeSettingInput = toolSettingSchema.some(item => item.type === 'boolean' || item.type === 'number-input') + + const setInputs = useCallback((value: ToolNodeType) => { + if (!hasShouldTransferTypeSettingInput) { + doSetInputs(value) + return + } + const newInputs = produce(value, (draft) => { + const newConfig = { ...draft.tool_configurations } + Object.keys(draft.tool_configurations).forEach((key) => { + const schema = formSchemas.find(item => item.variable === key) + const value = newConfig[key] + if (schema?.type === 'boolean') { + if (typeof value === 'string') + newConfig[key] = parseInt(value, 10) + + if (typeof value === 'boolean') + newConfig[key] = value ? 1 : 0 + } + + if (schema?.type === 'number-input') { + if (typeof value === 'string' && value !== '') + newConfig[key] = parseFloat(value) + } + }) + draft.tool_configurations = newConfig + }) + doSetInputs(newInputs) + }, [doSetInputs, formSchemas, hasShouldTransferTypeSettingInput]) const [notSetDefaultValue, setNotSetDefaultValue] = useState(false) const toolSettingValue = (() => { if (notSetDefaultValue)