diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 0345a9e90b..641997f3f3 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -137,6 +137,71 @@ class AdvancedChatDraftWorkflowRunApi(Resource): logging.exception("internal server error.") raise InternalServerError() +class AdvancedChatDraftRunIterationNodeApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT]) + def post(self, app_model: App, node_id: str): + """ + Run draft workflow iteration node + """ + parser = reqparse.RequestParser() + parser.add_argument('inputs', type=dict, location='json') + args = parser.parse_args() + + try: + response = AppGenerateService.generate_single_iteration( + app_model=app_model, + user=current_user, + node_id=node_id, + args=args, + streaming=True + ) + + return helper.compact_generate_response(response) + except services.errors.conversation.ConversationNotExistsError: + raise NotFound("Conversation Not Exists.") + except services.errors.conversation.ConversationCompletedError: + raise ConversationCompletedError() + except ValueError as e: + raise e + except Exception as e: + logging.exception("internal server error.") + raise InternalServerError() + +class WorkflowDraftRunIterationNodeApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.WORKFLOW]) + def post(self, app_model: App, node_id: str): + """ + Run draft workflow iteration node + """ + parser = reqparse.RequestParser() + parser.add_argument('inputs', type=dict, location='json') + args = parser.parse_args() + + try: + response = AppGenerateService.generate_single_iteration( + app_model=app_model, + user=current_user, + node_id=node_id, + args=args, + streaming=True + ) + + return helper.compact_generate_response(response) + except services.errors.conversation.ConversationNotExistsError: + raise NotFound("Conversation Not Exists.") + except services.errors.conversation.ConversationCompletedError: + raise ConversationCompletedError() + except ValueError as e: + raise e + except Exception as e: + logging.exception("internal server error.") + raise InternalServerError() class DraftWorkflowRunApi(Resource): @setup_required @@ -326,6 +391,8 @@ api.add_resource(AdvancedChatDraftWorkflowRunApi, '/apps//advanced- api.add_resource(DraftWorkflowRunApi, '/apps//workflows/draft/run') api.add_resource(WorkflowTaskStopApi, '/apps//workflow-runs/tasks//stop') api.add_resource(DraftWorkflowNodeRunApi, '/apps//workflows/draft/nodes//run') +api.add_resource(AdvancedChatDraftRunIterationNodeApi, '/apps//advanced-chat/workflows/draft/iteration/nodes//run') +api.add_resource(WorkflowDraftRunIterationNodeApi, '/apps//workflows/draft/iteration/nodes//run') api.add_resource(PublishedWorkflowApi, '/apps//workflows/publish') api.add_resource(DefaultBlockConfigsApi, '/apps//workflows/default-workflow-block-configs') api.add_resource(DefaultBlockConfigApi, '/apps//workflows/default-workflow-block-configs' diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 63f4613e7d..a911e9b2cb 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -9,8 +9,13 @@ from controllers.console import api from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.model_runtime.utils.encoders import jsonable_encoder +from libs.helper import alphanumeric, uuid_value from libs.login import login_required -from services.tools_manage_service import ToolManageService +from services.tools.api_tools_manage_service import ApiToolManageService +from services.tools.builtin_tools_manage_service import BuiltinToolManageService +from services.tools.tool_labels_service import ToolLabelsService +from services.tools.tools_manage_service import ToolCommonService +from services.tools.workflow_tools_manage_service import WorkflowToolManageService class ToolProviderListApi(Resource): @@ -21,7 +26,11 @@ class ToolProviderListApi(Resource): user_id = current_user.id tenant_id = current_user.current_tenant_id - return ToolManageService.list_tool_providers(user_id, tenant_id) + req = reqparse.RequestParser() + req.add_argument('type', type=str, choices=['builtin', 'model', 'api', 'workflow'], required=False, nullable=True, location='args') + args = req.parse_args() + + return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get('type', None)) class ToolBuiltinProviderListToolsApi(Resource): @setup_required @@ -31,7 +40,7 @@ class ToolBuiltinProviderListToolsApi(Resource): user_id = current_user.id tenant_id = current_user.current_tenant_id - return jsonable_encoder(ToolManageService.list_builtin_tool_provider_tools( + return jsonable_encoder(BuiltinToolManageService.list_builtin_tool_provider_tools( user_id, tenant_id, provider, @@ -48,7 +57,7 @@ class ToolBuiltinProviderDeleteApi(Resource): user_id = current_user.id tenant_id = current_user.current_tenant_id - return ToolManageService.delete_builtin_tool_provider( + return BuiltinToolManageService.delete_builtin_tool_provider( user_id, tenant_id, provider, @@ -70,7 +79,7 @@ class ToolBuiltinProviderUpdateApi(Resource): args = parser.parse_args() - return ToolManageService.update_builtin_tool_provider( + return BuiltinToolManageService.update_builtin_tool_provider( user_id, tenant_id, provider, @@ -85,7 +94,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource): user_id = current_user.id tenant_id = current_user.current_tenant_id - return ToolManageService.get_builtin_tool_provider_credentials( + return BuiltinToolManageService.get_builtin_tool_provider_credentials( user_id, tenant_id, provider, @@ -94,7 +103,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource): class ToolBuiltinProviderIconApi(Resource): @setup_required def get(self, provider): - icon_bytes, mimetype = ToolManageService.get_builtin_tool_provider_icon(provider) + icon_bytes, mimetype = BuiltinToolManageService.get_builtin_tool_provider_icon(provider) icon_cache_max_age = int(current_app.config.get('TOOL_ICON_CACHE_MAX_AGE')) return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age) @@ -116,11 +125,12 @@ class ToolApiProviderAddApi(Resource): parser.add_argument('provider', type=str, required=True, nullable=False, location='json') parser.add_argument('icon', type=dict, required=True, nullable=False, location='json') parser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json') + parser.add_argument('labels', type=list[str], required=False, nullable=True, location='json', default=[]) parser.add_argument('custom_disclaimer', type=str, required=False, nullable=True, location='json') args = parser.parse_args() - return ToolManageService.create_api_tool_provider( + return ApiToolManageService.create_api_tool_provider( user_id, tenant_id, args['provider'], @@ -130,6 +140,7 @@ class ToolApiProviderAddApi(Resource): args['schema'], args.get('privacy_policy', ''), args.get('custom_disclaimer', ''), + args.get('labels', []), ) class ToolApiProviderGetRemoteSchemaApi(Resource): @@ -143,7 +154,7 @@ class ToolApiProviderGetRemoteSchemaApi(Resource): args = parser.parse_args() - return ToolManageService.get_api_tool_provider_remote_schema( + return ApiToolManageService.get_api_tool_provider_remote_schema( current_user.id, current_user.current_tenant_id, args['url'], @@ -163,7 +174,7 @@ class ToolApiProviderListToolsApi(Resource): args = parser.parse_args() - return jsonable_encoder(ToolManageService.list_api_tool_provider_tools( + return jsonable_encoder(ApiToolManageService.list_api_tool_provider_tools( user_id, tenant_id, args['provider'], @@ -188,11 +199,12 @@ class ToolApiProviderUpdateApi(Resource): parser.add_argument('original_provider', type=str, required=True, nullable=False, location='json') parser.add_argument('icon', type=dict, required=True, nullable=False, location='json') parser.add_argument('privacy_policy', type=str, required=True, nullable=True, location='json') + parser.add_argument('labels', type=list[str], required=False, nullable=True, location='json') parser.add_argument('custom_disclaimer', type=str, required=True, nullable=True, location='json') args = parser.parse_args() - return ToolManageService.update_api_tool_provider( + return ApiToolManageService.update_api_tool_provider( user_id, tenant_id, args['provider'], @@ -203,6 +215,7 @@ class ToolApiProviderUpdateApi(Resource): args['schema'], args['privacy_policy'], args['custom_disclaimer'], + args.get('labels', []), ) class ToolApiProviderDeleteApi(Resource): @@ -222,7 +235,7 @@ class ToolApiProviderDeleteApi(Resource): args = parser.parse_args() - return ToolManageService.delete_api_tool_provider( + return ApiToolManageService.delete_api_tool_provider( user_id, tenant_id, args['provider'], @@ -242,7 +255,7 @@ class ToolApiProviderGetApi(Resource): args = parser.parse_args() - return ToolManageService.get_api_tool_provider( + return ApiToolManageService.get_api_tool_provider( user_id, tenant_id, args['provider'], @@ -253,7 +266,7 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource): @login_required @account_initialization_required def get(self, provider): - return ToolManageService.list_builtin_provider_credentials_schema(provider) + return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider) class ToolApiProviderSchemaApi(Resource): @setup_required @@ -266,7 +279,7 @@ class ToolApiProviderSchemaApi(Resource): args = parser.parse_args() - return ToolManageService.parser_api_schema( + return ApiToolManageService.parser_api_schema( schema=args['schema'], ) @@ -286,7 +299,7 @@ class ToolApiProviderPreviousTestApi(Resource): args = parser.parse_args() - return ToolManageService.test_api_tool_preview( + return ApiToolManageService.test_api_tool_preview( current_user.current_tenant_id, args['provider_name'] if args['provider_name'] else '', args['tool_name'], @@ -296,6 +309,153 @@ class ToolApiProviderPreviousTestApi(Resource): args['schema'], ) +class ToolWorkflowProviderCreateApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self): + if not current_user.is_admin_or_owner: + raise Forbidden() + + user_id = current_user.id + tenant_id = current_user.current_tenant_id + + reqparser = reqparse.RequestParser() + reqparser.add_argument('workflow_app_id', type=uuid_value, required=True, nullable=False, location='json') + reqparser.add_argument('name', type=alphanumeric, required=True, nullable=False, location='json') + reqparser.add_argument('label', type=str, required=True, nullable=False, location='json') + reqparser.add_argument('description', type=str, required=True, nullable=False, location='json') + reqparser.add_argument('icon', type=dict, required=True, nullable=False, location='json') + reqparser.add_argument('parameters', type=list[dict], required=True, nullable=False, location='json') + reqparser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json', default='') + reqparser.add_argument('labels', type=list[str], required=False, nullable=True, location='json') + + args = reqparser.parse_args() + + return WorkflowToolManageService.create_workflow_tool( + user_id, + tenant_id, + args['workflow_app_id'], + args['name'], + args['label'], + args['icon'], + args['description'], + args['parameters'], + args['privacy_policy'], + args.get('labels', []), + ) + +class ToolWorkflowProviderUpdateApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self): + if not current_user.is_admin_or_owner: + raise Forbidden() + + user_id = current_user.id + tenant_id = current_user.current_tenant_id + + reqparser = reqparse.RequestParser() + reqparser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='json') + reqparser.add_argument('name', type=alphanumeric, required=True, nullable=False, location='json') + reqparser.add_argument('label', type=str, required=True, nullable=False, location='json') + reqparser.add_argument('description', type=str, required=True, nullable=False, location='json') + reqparser.add_argument('icon', type=dict, required=True, nullable=False, location='json') + reqparser.add_argument('parameters', type=list[dict], required=True, nullable=False, location='json') + reqparser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json', default='') + reqparser.add_argument('labels', type=list[str], required=False, nullable=True, location='json') + + args = reqparser.parse_args() + + if not args['workflow_tool_id']: + raise ValueError('incorrect workflow_tool_id') + + return WorkflowToolManageService.update_workflow_tool( + user_id, + tenant_id, + args['workflow_tool_id'], + args['name'], + args['label'], + args['icon'], + args['description'], + args['parameters'], + args['privacy_policy'], + args.get('labels', []), + ) + +class ToolWorkflowProviderDeleteApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self): + if not current_user.is_admin_or_owner: + raise Forbidden() + + user_id = current_user.id + tenant_id = current_user.current_tenant_id + + reqparser = reqparse.RequestParser() + reqparser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='json') + + args = reqparser.parse_args() + + return WorkflowToolManageService.delete_workflow_tool( + user_id, + tenant_id, + args['workflow_tool_id'], + ) + +class ToolWorkflowProviderGetApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self): + user_id = current_user.id + tenant_id = current_user.current_tenant_id + + parser = reqparse.RequestParser() + parser.add_argument('workflow_tool_id', type=uuid_value, required=False, nullable=True, location='args') + parser.add_argument('workflow_app_id', type=uuid_value, required=False, nullable=True, location='args') + + args = parser.parse_args() + + if args.get('workflow_tool_id'): + tool = WorkflowToolManageService.get_workflow_tool_by_tool_id( + user_id, + tenant_id, + args['workflow_tool_id'], + ) + elif args.get('workflow_app_id'): + tool = WorkflowToolManageService.get_workflow_tool_by_app_id( + user_id, + tenant_id, + args['workflow_app_id'], + ) + else: + raise ValueError('incorrect workflow_tool_id or workflow_app_id') + + return jsonable_encoder(tool) + +class ToolWorkflowProviderListToolApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self): + user_id = current_user.id + tenant_id = current_user.current_tenant_id + + parser = reqparse.RequestParser() + parser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='args') + + args = parser.parse_args() + + return jsonable_encoder(WorkflowToolManageService.list_single_workflow_tools( + user_id, + tenant_id, + args['workflow_tool_id'], + )) + class ToolBuiltinListApi(Resource): @setup_required @login_required @@ -304,7 +464,7 @@ class ToolBuiltinListApi(Resource): user_id = current_user.id tenant_id = current_user.current_tenant_id - return jsonable_encoder([provider.to_dict() for provider in ToolManageService.list_builtin_tools( + return jsonable_encoder([provider.to_dict() for provider in BuiltinToolManageService.list_builtin_tools( user_id, tenant_id, )]) @@ -317,18 +477,43 @@ class ToolApiListApi(Resource): user_id = current_user.id tenant_id = current_user.current_tenant_id - return jsonable_encoder([provider.to_dict() for provider in ToolManageService.list_api_tools( + return jsonable_encoder([provider.to_dict() for provider in ApiToolManageService.list_api_tools( user_id, tenant_id, )]) + +class ToolWorkflowListApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self): + user_id = current_user.id + tenant_id = current_user.current_tenant_id + return jsonable_encoder([provider.to_dict() for provider in WorkflowToolManageService.list_tenant_workflow_tools( + user_id, + tenant_id, + )]) + +class ToolLabelsApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self): + return jsonable_encoder(ToolLabelsService.list_tool_labels()) + +# tool provider api.add_resource(ToolProviderListApi, '/workspaces/current/tool-providers') + +# builtin tool provider api.add_resource(ToolBuiltinProviderListToolsApi, '/workspaces/current/tool-provider/builtin//tools') api.add_resource(ToolBuiltinProviderDeleteApi, '/workspaces/current/tool-provider/builtin//delete') api.add_resource(ToolBuiltinProviderUpdateApi, '/workspaces/current/tool-provider/builtin//update') api.add_resource(ToolBuiltinProviderGetCredentialsApi, '/workspaces/current/tool-provider/builtin//credentials') api.add_resource(ToolBuiltinProviderCredentialsSchemaApi, '/workspaces/current/tool-provider/builtin//credentials_schema') api.add_resource(ToolBuiltinProviderIconApi, '/workspaces/current/tool-provider/builtin//icon') + +# api tool provider api.add_resource(ToolApiProviderAddApi, '/workspaces/current/tool-provider/api/add') api.add_resource(ToolApiProviderGetRemoteSchemaApi, '/workspaces/current/tool-provider/api/remote') api.add_resource(ToolApiProviderListToolsApi, '/workspaces/current/tool-provider/api/tools') @@ -338,5 +523,15 @@ api.add_resource(ToolApiProviderGetApi, '/workspaces/current/tool-provider/api/g api.add_resource(ToolApiProviderSchemaApi, '/workspaces/current/tool-provider/api/schema') api.add_resource(ToolApiProviderPreviousTestApi, '/workspaces/current/tool-provider/api/test/pre') +# workflow tool provider +api.add_resource(ToolWorkflowProviderCreateApi, '/workspaces/current/tool-provider/workflow/create') +api.add_resource(ToolWorkflowProviderUpdateApi, '/workspaces/current/tool-provider/workflow/update') +api.add_resource(ToolWorkflowProviderDeleteApi, '/workspaces/current/tool-provider/workflow/delete') +api.add_resource(ToolWorkflowProviderGetApi, '/workspaces/current/tool-provider/workflow/get') +api.add_resource(ToolWorkflowProviderListToolApi, '/workspaces/current/tool-provider/workflow/tools') + api.add_resource(ToolBuiltinListApi, '/workspaces/current/tools/builtin') -api.add_resource(ToolApiListApi, '/workspaces/current/tools/api') \ No newline at end of file +api.add_resource(ToolApiListApi, '/workspaces/current/tools/api') +api.add_resource(ToolWorkflowListApi, '/workspaces/current/tools/workflow') + +api.add_resource(ToolLabelsApi, '/workspaces/current/tool-labels') \ No newline at end of file diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 485633cab1..5eb8cd4997 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -165,6 +165,7 @@ class BaseAgentRunner(AppRunner): tenant_id=self.tenant_id, app_id=self.app_config.app_id, agent_tool=tool, + invoke_from=self.application_generate_entity.invoke_from ) tool_entity.load_variables(self.variables_pool) diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py index 5284faa02e..5274224de5 100644 --- a/api/core/agent/entities.py +++ b/api/core/agent/entities.py @@ -8,7 +8,7 @@ class AgentToolEntity(BaseModel): """ Agent Tool Entity. """ - provider_type: Literal["builtin", "api"] + provider_type: Literal["builtin", "api", "workflow"] provider_id: str tool_name: str tool_parameters: dict[str, Any] = {} diff --git a/api/core/tools/prompt/template.py b/api/core/agent/prompt/template.py similarity index 100% rename from api/core/tools/prompt/template.py rename to api/core/agent/prompt/template.py diff --git a/api/core/app/app_config/easy_ui_based_app/agent/manager.py b/api/core/app/app_config/easy_ui_based_app/agent/manager.py index a48316728b..f271aeed0c 100644 --- a/api/core/app/app_config/easy_ui_based_app/agent/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/agent/manager.py @@ -1,7 +1,7 @@ from typing import Optional from core.agent.entities import AgentEntity, AgentPromptEntity, AgentToolEntity -from core.tools.prompt.template import REACT_PROMPT_TEMPLATES +from core.agent.prompt.template import REACT_PROMPT_TEMPLATES class AgentConfigManager: diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 101e25d582..d6b6d89416 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -239,4 +239,4 @@ class WorkflowUIBasedAppConfig(AppConfig): """ Workflow UI Based App Config Entity. """ - workflow_id: str + workflow_id: str \ No newline at end of file diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 0fa19addaa..3b1ee3578d 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -98,6 +98,90 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): extras=extras ) + return self._generate( + app_model=app_model, + workflow=workflow, + user=user, + invoke_from=invoke_from, + application_generate_entity=application_generate_entity, + conversation=conversation, + stream=stream + ) + + def single_iteration_generate(self, app_model: App, + workflow: Workflow, + node_id: str, + user: Account, + args: dict, + stream: bool = True) \ + -> Union[dict, Generator[dict, None, None]]: + """ + Generate App response. + + :param app_model: App + :param workflow: Workflow + :param user: account or end user + :param args: request args + :param invoke_from: invoke from source + :param stream: is stream + """ + if not node_id: + raise ValueError('node_id is required') + + if args.get('inputs') is None: + raise ValueError('inputs is required') + + extras = { + "auto_generate_conversation_name": False + } + + # get conversation + conversation = None + if args.get('conversation_id'): + conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user) + + # convert to app config + app_config = AdvancedChatAppConfigManager.get_app_config( + app_model=app_model, + workflow=workflow + ) + + # init application generate entity + application_generate_entity = AdvancedChatAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + conversation_id=conversation.id if conversation else None, + inputs={}, + query='', + files=[], + user_id=user.id, + stream=stream, + invoke_from=InvokeFrom.DEBUGGER, + extras=extras, + single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity( + node_id=node_id, + inputs=args['inputs'] + ) + ) + + return self._generate( + app_model=app_model, + workflow=workflow, + user=user, + invoke_from=InvokeFrom.DEBUGGER, + application_generate_entity=application_generate_entity, + conversation=conversation, + stream=stream + ) + + def _generate(self, app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + invoke_from: InvokeFrom, + application_generate_entity: AdvancedChatAppGenerateEntity, + conversation: Conversation = None, + stream: bool = True) \ + -> Union[dict, Generator[dict, None, None]]: is_first_conversation = False if not conversation: is_first_conversation = True @@ -167,18 +251,30 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): """ with flask_app.app_context(): try: - # get conversation and message - conversation = self._get_conversation(conversation_id) - message = self._get_message(message_id) - - # chatbot app runner = AdvancedChatAppRunner() - runner.run( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message - ) + if application_generate_entity.single_iteration_run: + single_iteration_run = application_generate_entity.single_iteration_run + runner.single_iteration_run( + app_id=application_generate_entity.app_config.app_id, + workflow_id=application_generate_entity.app_config.workflow_id, + queue_manager=queue_manager, + inputs=single_iteration_run.inputs, + node_id=single_iteration_run.node_id, + user_id=application_generate_entity.user_id + ) + else: + # get conversation and message + conversation = self._get_conversation(conversation_id) + message = self._get_message(message_id) + + # chatbot app + runner = AdvancedChatAppRunner() + runner.run( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message + ) except GenerateTaskStoppedException: pass except InvokeAuthorizationError: diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index d858dcac12..de3632894d 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -102,6 +102,7 @@ class AdvancedChatAppRunner(AppRunner): user_from=UserFrom.ACCOUNT if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else UserFrom.END_USER, + invoke_from=application_generate_entity.invoke_from, user_inputs=inputs, system_inputs={ SystemVariable.QUERY: query, @@ -109,6 +110,35 @@ class AdvancedChatAppRunner(AppRunner): SystemVariable.CONVERSATION_ID: conversation.id, SystemVariable.USER_ID: user_id }, + callbacks=workflow_callbacks, + call_depth=application_generate_entity.call_depth + ) + + def single_iteration_run(self, app_id: str, workflow_id: str, + queue_manager: AppQueueManager, + inputs: dict, node_id: str, user_id: str) -> None: + """ + Single iteration run + """ + app_record: App = db.session.query(App).filter(App.id == app_id).first() + if not app_record: + raise ValueError("App not found") + + workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id) + if not workflow: + raise ValueError("Workflow not initialized") + + workflow_callbacks = [WorkflowEventTriggerCallback( + queue_manager=queue_manager, + workflow=workflow + )] + + workflow_engine_manager = WorkflowEngineManager() + workflow_engine_manager.single_step_run_iteration_workflow_node( + workflow=workflow, + node_id=node_id, + user_id=user_id, + user_inputs=inputs, callbacks=workflow_callbacks ) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 67d8fe5cb1..4dfc155732 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -12,6 +12,9 @@ from core.app.entities.queue_entities import ( QueueAdvancedChatMessageEndEvent, QueueAnnotationReplyEvent, QueueErrorEvent, + QueueIterationCompletedEvent, + QueueIterationNextEvent, + QueueIterationStartEvent, QueueMessageReplaceEvent, QueueNodeFailedEvent, QueueNodeStartedEvent, @@ -64,6 +67,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc _workflow: Workflow _user: Union[Account, EndUser] _workflow_system_variables: dict[SystemVariable, Any] + _iteration_nested_relations: dict[str, list[str]] def __init__(self, application_generate_entity: AdvancedChatAppGenerateEntity, workflow: Workflow, @@ -104,6 +108,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc ) self._stream_generate_routes = self._get_stream_generate_routes() + self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict) self._conversation_name_generate_thread = None def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: @@ -204,6 +209,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc # search stream_generate_routes if node id is answer start at node if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_routes: self._task_state.current_stream_generate_state = self._stream_generate_routes[event.node_id] + # reset current route position to 0 + self._task_state.current_stream_generate_state.current_route_position = 0 # generate stream outputs when node started yield from self._generate_stream_outputs_when_node_started() @@ -225,6 +232,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution ) + + if isinstance(event, QueueNodeFailedEvent): + yield from self._handle_iteration_exception( + task_id=self._application_generate_entity.task_id, + error=f'Child node failed: {event.error}' + ) + elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent): + if isinstance(event, QueueIterationNextEvent): + # clear ran node execution infos of current iteration + iteration_relations = self._iteration_nested_relations.get(event.node_id) + if iteration_relations: + for node_id in iteration_relations: + self._task_state.ran_node_execution_infos.pop(node_id, None) + + yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event) + self._handle_iteration_operation(event) elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): workflow_run = self._handle_workflow_finished(event) if workflow_run: @@ -263,10 +286,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc self._handle_retriever_resources(event) elif isinstance(event, QueueAnnotationReplyEvent): self._handle_annotation_reply(event) - # elif isinstance(event, QueueMessageFileEvent): - # response = self._message_file_to_stream_response(event) - # if response: - # yield response elif isinstance(event, QueueTextChunkEvent): delta_text = event.text if delta_text is None: @@ -342,7 +361,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc id=self._message.id, **extras ) - + def _get_stream_generate_routes(self) -> dict[str, ChatflowStreamGenerateRoute]: """ Get stream generate routes. @@ -372,7 +391,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc ) return stream_generate_routes - + def _get_answer_start_at_node_ids(self, graph: dict, target_node_id: str) \ -> list[str]: """ @@ -401,14 +420,23 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc continue node_type = source_node.get('data', {}).get('type') + node_iteration_id = source_node.get('data', {}).get('iteration_id') + iteration_start_node_id = None + if node_iteration_id: + iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None) + iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id') + if node_type in [ NodeType.ANSWER.value, NodeType.IF_ELSE.value, - NodeType.QUESTION_CLASSIFIER.value + NodeType.QUESTION_CLASSIFIER.value, + NodeType.ITERATION.value, + NodeType.LOOP.value ]: start_node_id = target_node_id start_node_ids.append(start_node_id) - elif node_type == NodeType.START.value: + elif node_type == NodeType.START.value or \ + node_iteration_id is not None and iteration_start_node_id == source_node.get('id'): start_node_id = source_node_id start_node_ids.append(start_node_id) else: @@ -417,7 +445,27 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc start_node_ids.extend(sub_start_node_ids) return start_node_ids + + def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]: + """ + Get iteration nested relations. + :param graph: graph + :return: + """ + nodes = graph.get('nodes') + iteration_ids = [node.get('id') for node in nodes + if node.get('data', {}).get('type') in [ + NodeType.ITERATION.value, + NodeType.LOOP.value, + ]] + + return { + iteration_id: [ + node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id + ] for iteration_id in iteration_ids + } + def _generate_stream_outputs_when_node_started(self) -> Generator: """ Generate stream outputs. @@ -425,7 +473,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc """ if self._task_state.current_stream_generate_state: route_chunks = self._task_state.current_stream_generate_state.generate_route[ - self._task_state.current_stream_generate_state.current_route_position:] + self._task_state.current_stream_generate_state.current_route_position: + ] for route_chunk in route_chunks: if route_chunk.type == 'text': @@ -458,7 +507,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc route_chunks = self._task_state.current_stream_generate_state.generate_route[ self._task_state.current_stream_generate_state.current_route_position:] - + for route_chunk in route_chunks: if route_chunk.type == 'text': route_chunk = cast(TextGenerateRouteChunk, route_chunk) diff --git a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py index fef719a086..78fe077e6b 100644 --- a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py +++ b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py @@ -1,8 +1,11 @@ -from typing import Optional +from typing import Any, Optional from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.queue_entities import ( AppQueueEvent, + QueueIterationCompletedEvent, + QueueIterationNextEvent, + QueueIterationStartEvent, QueueNodeFailedEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, @@ -130,6 +133,66 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback): ), PublishFrom.APPLICATION_MANAGER ) + def on_workflow_iteration_started(self, + node_id: str, + node_type: NodeType, + node_run_index: int = 1, + node_data: Optional[BaseNodeData] = None, + inputs: dict = None, + predecessor_node_id: Optional[str] = None, + metadata: Optional[dict] = None) -> None: + """ + Publish iteration started + """ + self._queue_manager.publish( + QueueIterationStartEvent( + node_id=node_id, + node_type=node_type, + node_run_index=node_run_index, + node_data=node_data, + inputs=inputs, + predecessor_node_id=predecessor_node_id, + metadata=metadata + ), + PublishFrom.APPLICATION_MANAGER + ) + + def on_workflow_iteration_next(self, node_id: str, + node_type: NodeType, + index: int, + node_run_index: int, + output: Optional[Any]) -> None: + """ + Publish iteration next + """ + self._queue_manager._publish( + QueueIterationNextEvent( + node_id=node_id, + node_type=node_type, + index=index, + node_run_index=node_run_index, + output=output + ), + PublishFrom.APPLICATION_MANAGER + ) + + def on_workflow_iteration_completed(self, node_id: str, + node_type: NodeType, + node_run_index: int, + outputs: dict) -> None: + """ + Publish iteration completed + """ + self._queue_manager._publish( + QueueIterationCompletedEvent( + node_id=node_id, + node_type=node_type, + node_run_index=node_run_index, + outputs=outputs + ), + PublishFrom.APPLICATION_MANAGER + ) + def on_event(self, event: AppQueueEvent) -> None: """ Publish event diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 1835efb40d..fb4c28a855 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -115,7 +115,8 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): user_id=user.id, stream=stream, invoke_from=invoke_from, - extras=extras + extras=extras, + call_depth=0 ) # init generate records diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index bf582801f8..c4324978d8 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -34,7 +34,8 @@ class WorkflowAppGenerator(BaseAppGenerator): user: Union[Account, EndUser], args: dict, invoke_from: InvokeFrom, - stream: bool = True) \ + stream: bool = True, + call_depth: int = 0) \ -> Union[dict, Generator[dict, None, None]]: """ Generate App response. @@ -75,9 +76,38 @@ class WorkflowAppGenerator(BaseAppGenerator): files=file_objs, user_id=user.id, stream=stream, - invoke_from=invoke_from + invoke_from=invoke_from, + call_depth=call_depth ) + return self._generate( + app_model=app_model, + workflow=workflow, + user=user, + application_generate_entity=application_generate_entity, + invoke_from=invoke_from, + stream=stream, + call_depth=call_depth + ) + + def _generate(self, app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + application_generate_entity: WorkflowAppGenerateEntity, + invoke_from: InvokeFrom, + stream: bool = True, + call_depth: int = 0) \ + -> Union[dict, Generator[dict, None, None]]: + """ + Generate App response. + + :param app_model: App + :param workflow: Workflow + :param user: account or end user + :param application_generate_entity: application generate entity + :param invoke_from: invoke from source + :param stream: is stream + """ # init queue manager queue_manager = WorkflowAppQueueManager( task_id=application_generate_entity.task_id, @@ -109,6 +139,64 @@ class WorkflowAppGenerator(BaseAppGenerator): invoke_from=invoke_from ) + def single_iteration_generate(self, app_model: App, + workflow: Workflow, + node_id: str, + user: Account, + args: dict, + stream: bool = True) \ + -> Union[dict, Generator[dict, None, None]]: + """ + Generate App response. + + :param app_model: App + :param workflow: Workflow + :param user: account or end user + :param args: request args + :param invoke_from: invoke from source + :param stream: is stream + """ + if not node_id: + raise ValueError('node_id is required') + + if args.get('inputs') is None: + raise ValueError('inputs is required') + + extras = { + "auto_generate_conversation_name": False + } + + # convert to app config + app_config = WorkflowAppConfigManager.get_app_config( + app_model=app_model, + workflow=workflow + ) + + # init application generate entity + application_generate_entity = WorkflowAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + inputs={}, + files=[], + user_id=user.id, + stream=stream, + invoke_from=InvokeFrom.DEBUGGER, + extras=extras, + single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity( + node_id=node_id, + inputs=args['inputs'] + ) + ) + + return self._generate( + app_model=app_model, + workflow=workflow, + user=user, + invoke_from=InvokeFrom.DEBUGGER, + application_generate_entity=application_generate_entity, + stream=stream + ) + def _generate_worker(self, flask_app: Flask, application_generate_entity: WorkflowAppGenerateEntity, queue_manager: AppQueueManager) -> None: @@ -123,10 +211,21 @@ class WorkflowAppGenerator(BaseAppGenerator): try: # workflow app runner = WorkflowAppRunner() - runner.run( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager - ) + if application_generate_entity.single_iteration_run: + single_iteration_run = application_generate_entity.single_iteration_run + runner.single_iteration_run( + app_id=application_generate_entity.app_config.app_id, + workflow_id=application_generate_entity.app_config.workflow_id, + queue_manager=queue_manager, + inputs=single_iteration_run.inputs, + node_id=single_iteration_run.node_id, + user_id=application_generate_entity.user_id + ) + else: + runner.run( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager + ) except GenerateTaskStoppedException: pass except InvokeAuthorizationError: diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 9d854afe35..050319e552 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -73,11 +73,44 @@ class WorkflowAppRunner: user_from=UserFrom.ACCOUNT if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else UserFrom.END_USER, + invoke_from=application_generate_entity.invoke_from, user_inputs=inputs, system_inputs={ SystemVariable.FILES: files, SystemVariable.USER_ID: user_id }, + callbacks=workflow_callbacks, + call_depth=application_generate_entity.call_depth + ) + + def single_iteration_run(self, app_id: str, workflow_id: str, + queue_manager: AppQueueManager, + inputs: dict, node_id: str, user_id: str) -> None: + """ + Single iteration run + """ + app_record: App = db.session.query(App).filter(App.id == app_id).first() + if not app_record: + raise ValueError("App not found") + + if not app_record.workflow_id: + raise ValueError("Workflow not initialized") + + workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id) + if not workflow: + raise ValueError("Workflow not initialized") + + workflow_callbacks = [WorkflowEventTriggerCallback( + queue_manager=queue_manager, + workflow=workflow + )] + + workflow_engine_manager = WorkflowEngineManager() + workflow_engine_manager.single_step_run_iteration_workflow_node( + workflow=workflow, + node_id=node_id, + user_id=user_id, + user_inputs=inputs, callbacks=workflow_callbacks ) diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index a7061a77bb..8d961e0993 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -9,6 +9,9 @@ from core.app.entities.app_invoke_entities import ( ) from core.app.entities.queue_entities import ( QueueErrorEvent, + QueueIterationCompletedEvent, + QueueIterationNextEvent, + QueueIterationStartEvent, QueueMessageReplaceEvent, QueueNodeFailedEvent, QueueNodeStartedEvent, @@ -58,6 +61,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa _task_state: WorkflowTaskState _application_generate_entity: WorkflowAppGenerateEntity _workflow_system_variables: dict[SystemVariable, Any] + _iteration_nested_relations: dict[str, list[str]] def __init__(self, application_generate_entity: WorkflowAppGenerateEntity, workflow: Workflow, @@ -85,8 +89,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa SystemVariable.USER_ID: user_id } - self._task_state = WorkflowTaskState() + self._task_state = WorkflowTaskState( + iteration_nested_node_ids=[] + ) self._stream_generate_nodes = self._get_stream_generate_nodes() + self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict) def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: """ @@ -191,6 +198,22 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution ) + + if isinstance(event, QueueNodeFailedEvent): + yield from self._handle_iteration_exception( + task_id=self._application_generate_entity.task_id, + error=f'Child node failed: {event.error}' + ) + elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent): + if isinstance(event, QueueIterationNextEvent): + # clear ran node execution infos of current iteration + iteration_relations = self._iteration_nested_relations.get(event.node_id) + if iteration_relations: + for node_id in iteration_relations: + self._task_state.ran_node_execution_infos.pop(node_id, None) + + yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event) + self._handle_iteration_operation(event) elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): workflow_run = self._handle_workflow_finished(event) @@ -331,13 +354,20 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa continue node_type = source_node.get('data', {}).get('type') + node_iteration_id = source_node.get('data', {}).get('iteration_id') + iteration_start_node_id = None + if node_iteration_id: + iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None) + iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id') + if node_type in [ NodeType.IF_ELSE.value, NodeType.QUESTION_CLASSIFIER.value ]: start_node_id = target_node_id start_node_ids.append(start_node_id) - elif node_type == NodeType.START.value: + elif node_type == NodeType.START.value or \ + node_iteration_id is not None and iteration_start_node_id == source_node.get('id'): start_node_id = source_node_id start_node_ids.append(start_node_id) else: @@ -411,3 +441,24 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa return False return True + + def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]: + """ + Get iteration nested relations. + :param graph: graph + :return: + """ + nodes = graph.get('nodes') + + iteration_ids = [node.get('id') for node in nodes + if node.get('data', {}).get('type') in [ + NodeType.ITERATION.value, + NodeType.LOOP.value, + ]] + + return { + iteration_id: [ + node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id + ] for iteration_id in iteration_ids + } + \ No newline at end of file diff --git a/api/core/app/apps/workflow/workflow_event_trigger_callback.py b/api/core/app/apps/workflow/workflow_event_trigger_callback.py index 2048abe464..e423a40bcb 100644 --- a/api/core/app/apps/workflow/workflow_event_trigger_callback.py +++ b/api/core/app/apps/workflow/workflow_event_trigger_callback.py @@ -1,8 +1,11 @@ -from typing import Optional +from typing import Any, Optional from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.queue_entities import ( AppQueueEvent, + QueueIterationCompletedEvent, + QueueIterationNextEvent, + QueueIterationStartEvent, QueueNodeFailedEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, @@ -130,6 +133,66 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback): ), PublishFrom.APPLICATION_MANAGER ) + def on_workflow_iteration_started(self, + node_id: str, + node_type: NodeType, + node_run_index: int = 1, + node_data: Optional[BaseNodeData] = None, + inputs: dict = None, + predecessor_node_id: Optional[str] = None, + metadata: Optional[dict] = None) -> None: + """ + Publish iteration started + """ + self._queue_manager.publish( + QueueIterationStartEvent( + node_id=node_id, + node_type=node_type, + node_run_index=node_run_index, + node_data=node_data, + inputs=inputs, + predecessor_node_id=predecessor_node_id, + metadata=metadata + ), + PublishFrom.APPLICATION_MANAGER + ) + + def on_workflow_iteration_next(self, node_id: str, + node_type: NodeType, + index: int, + node_run_index: int, + output: Optional[Any]) -> None: + """ + Publish iteration next + """ + self._queue_manager.publish( + QueueIterationNextEvent( + node_id=node_id, + node_type=node_type, + index=index, + node_run_index=node_run_index, + output=output + ), + PublishFrom.APPLICATION_MANAGER + ) + + def on_workflow_iteration_completed(self, node_id: str, + node_type: NodeType, + node_run_index: int, + outputs: dict) -> None: + """ + Publish iteration completed + """ + self._queue_manager.publish( + QueueIterationCompletedEvent( + node_id=node_id, + node_type=node_type, + node_run_index=node_run_index, + outputs=outputs + ), + PublishFrom.APPLICATION_MANAGER + ) + def on_event(self, event: AppQueueEvent) -> None: """ Publish event diff --git a/api/core/app/apps/workflow_logging_callback.py b/api/core/app/apps/workflow_logging_callback.py index 4627c21c7a..f617c671e9 100644 --- a/api/core/app/apps/workflow_logging_callback.py +++ b/api/core/app/apps/workflow_logging_callback.py @@ -102,6 +102,39 @@ class WorkflowLoggingCallback(BaseWorkflowCallback): self.print_text(text, color="pink", end="") + def on_workflow_iteration_started(self, + node_id: str, + node_type: NodeType, + node_run_index: int = 1, + node_data: Optional[BaseNodeData] = None, + inputs: dict = None, + predecessor_node_id: Optional[str] = None, + metadata: Optional[dict] = None) -> None: + """ + Publish iteration started + """ + self.print_text("\n[on_workflow_iteration_started]", color='blue') + self.print_text(f"Node ID: {node_id}", color='blue') + + def on_workflow_iteration_next(self, node_id: str, + node_type: NodeType, + index: int, + node_run_index: int, + output: Optional[dict]) -> None: + """ + Publish iteration next + """ + self.print_text("\n[on_workflow_iteration_next]", color='blue') + + def on_workflow_iteration_completed(self, node_id: str, + node_type: NodeType, + node_run_index: int, + outputs: dict) -> None: + """ + Publish iteration completed + """ + self.print_text("\n[on_workflow_iteration_completed]", color='blue') + def on_event(self, event: AppQueueEvent) -> None: """ Publish event diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 09c62c802c..cc63fa4684 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -80,6 +80,9 @@ class AppGenerateEntity(BaseModel): stream: bool invoke_from: InvokeFrom + # invoke call depth + call_depth: int = 0 + # extra parameters, like: auto_generate_conversation_name extras: dict[str, Any] = {} @@ -126,6 +129,14 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity): conversation_id: Optional[str] = None query: Optional[str] = None + class SingleIterationRunEntity(BaseModel): + """ + Single Iteration Run Entity. + """ + node_id: str + inputs: dict + + single_iteration_run: Optional[SingleIterationRunEntity] = None class WorkflowAppGenerateEntity(AppGenerateEntity): """ @@ -133,3 +144,12 @@ class WorkflowAppGenerateEntity(AppGenerateEntity): """ # app config app_config: WorkflowUIBasedAppConfig + + class SingleIterationRunEntity(BaseModel): + """ + Single Iteration Run Entity. + """ + node_id: str + inputs: dict + + single_iteration_run: Optional[SingleIterationRunEntity] = None \ No newline at end of file diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index bf174e30e1..47fa2ac19d 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -1,7 +1,7 @@ from enum import Enum from typing import Any, Optional -from pydantic import BaseModel +from pydantic import BaseModel, validator from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from core.workflow.entities.base_node_data_entities import BaseNodeData @@ -21,6 +21,9 @@ class QueueEvent(Enum): WORKFLOW_STARTED = "workflow_started" WORKFLOW_SUCCEEDED = "workflow_succeeded" WORKFLOW_FAILED = "workflow_failed" + ITERATION_START = "iteration_start" + ITERATION_NEXT = "iteration_next" + ITERATION_COMPLETED = "iteration_completed" NODE_STARTED = "node_started" NODE_SUCCEEDED = "node_succeeded" NODE_FAILED = "node_failed" @@ -47,6 +50,55 @@ class QueueLLMChunkEvent(AppQueueEvent): event = QueueEvent.LLM_CHUNK chunk: LLMResultChunk +class QueueIterationStartEvent(AppQueueEvent): + """ + QueueIterationStartEvent entity + """ + event = QueueEvent.ITERATION_START + node_id: str + node_type: NodeType + node_data: BaseNodeData + + node_run_index: int + inputs: dict = None + predecessor_node_id: Optional[str] = None + metadata: Optional[dict] = None + +class QueueIterationNextEvent(AppQueueEvent): + """ + QueueIterationNextEvent entity + """ + event = QueueEvent.ITERATION_NEXT + + index: int + node_id: str + node_type: NodeType + + node_run_index: int + output: Optional[Any] # output for the current iteration + + @validator('output', pre=True, always=True) + def set_output(cls, v): + """ + Set output + """ + if v is None: + return None + if isinstance(v, int | float | str | bool | dict | list): + return v + raise ValueError('output must be a valid type') + +class QueueIterationCompletedEvent(AppQueueEvent): + """ + QueueIterationCompletedEvent entity + """ + event = QueueEvent.ITERATION_COMPLETED + + node_id: str + node_type: NodeType + + node_run_index: int + outputs: dict class QueueTextChunkEvent(AppQueueEvent): """ diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 1a11ac9aa3..5956bc35fa 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -1,12 +1,14 @@ from enum import Enum -from typing import Optional +from typing import Any, Optional from pydantic import BaseModel from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.utils.encoders import jsonable_encoder +from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeType from core.workflow.nodes.answer.entities import GenerateRouteChunk +from models.workflow import WorkflowNodeExecutionStatus class WorkflowStreamGenerateNodes(BaseModel): @@ -65,6 +67,7 @@ class WorkflowTaskState(TaskState): current_stream_generate_state: Optional[WorkflowStreamGenerateNodes] = None + iteration_nested_node_ids: list[str] = None class AdvancedChatTaskState(WorkflowTaskState): """ @@ -91,6 +94,9 @@ class StreamEvent(Enum): WORKFLOW_FINISHED = "workflow_finished" NODE_STARTED = "node_started" NODE_FINISHED = "node_finished" + ITERATION_STARTED = "iteration_started" + ITERATION_NEXT = "iteration_next" + ITERATION_COMPLETED = "iteration_completed" TEXT_CHUNK = "text_chunk" TEXT_REPLACE = "text_replace" @@ -319,6 +325,74 @@ class NodeFinishStreamResponse(StreamResponse): } } +class IterationNodeStartStreamResponse(StreamResponse): + """ + NodeStartStreamResponse entity + """ + class Data(BaseModel): + """ + Data entity + """ + id: str + node_id: str + node_type: str + title: str + created_at: int + extras: dict = {} + metadata: dict = {} + inputs: dict = {} + + event: StreamEvent = StreamEvent.ITERATION_STARTED + workflow_run_id: str + data: Data + +class IterationNodeNextStreamResponse(StreamResponse): + """ + NodeStartStreamResponse entity + """ + class Data(BaseModel): + """ + Data entity + """ + id: str + node_id: str + node_type: str + title: str + index: int + created_at: int + pre_iteration_output: Optional[Any] + extras: dict = {} + + event: StreamEvent = StreamEvent.ITERATION_NEXT + workflow_run_id: str + data: Data + +class IterationNodeCompletedStreamResponse(StreamResponse): + """ + NodeStartStreamResponse entity + """ + class Data(BaseModel): + """ + Data entity + """ + id: str + node_id: str + node_type: str + title: str + outputs: Optional[dict] + created_at: int + extras: dict = None + inputs: dict = None + status: WorkflowNodeExecutionStatus + error: Optional[str] + elapsed_time: float + total_tokens: int + finished_at: int + steps: int + + event: StreamEvent = StreamEvent.ITERATION_COMPLETED + workflow_run_id: str + data: Data class TextChunkStreamResponse(StreamResponse): """ @@ -454,3 +528,23 @@ class WorkflowAppBlockingResponse(AppBlockingResponse): workflow_run_id: str data: Data + +class WorkflowIterationState(BaseModel): + """ + WorkflowIterationState entity + """ + class Data(BaseModel): + """ + Data entity + """ + parent_iteration_id: Optional[str] = None + iteration_id: str + current_index: int + iteration_steps_boundary: list[int] = None + node_execution_id: str + started_at: float + inputs: dict = None + total_tokens: int = 0 + node_data: BaseNodeData + + current_iterations: dict[str, Data] = None \ No newline at end of file diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 48ff34fef9..978a318279 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -1,9 +1,9 @@ import json import time from datetime import datetime, timezone -from typing import Any, Optional, Union, cast +from typing import Optional, Union, cast -from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity +from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import ( QueueNodeFailedEvent, QueueNodeStartedEvent, @@ -13,18 +13,17 @@ from core.app.entities.queue_entities import ( QueueWorkflowSucceededEvent, ) from core.app.entities.task_entities import ( - AdvancedChatTaskState, NodeExecutionInfo, NodeFinishStreamResponse, NodeStartStreamResponse, WorkflowFinishStreamResponse, WorkflowStartStreamResponse, - WorkflowTaskState, ) +from core.app.task_pipeline.workflow_iteration_cycle_manage import WorkflowIterationCycleManage from core.file.file_obj import FileVar from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.tool_manager import ToolManager -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType, SystemVariable +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db @@ -42,13 +41,7 @@ from models.workflow import ( ) -class WorkflowCycleManage: - _application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity] - _workflow: Workflow - _user: Union[Account, EndUser] - _task_state: Union[AdvancedChatTaskState, WorkflowTaskState] - _workflow_system_variables: dict[SystemVariable, Any] - +class WorkflowCycleManage(WorkflowIterationCycleManage): def _init_workflow_run(self, workflow: Workflow, triggered_from: WorkflowRunTriggeredFrom, user: Union[Account, EndUser], @@ -237,6 +230,7 @@ class WorkflowCycleManage: inputs: Optional[dict] = None, process_data: Optional[dict] = None, outputs: Optional[dict] = None, + execution_metadata: Optional[dict] = None ) -> WorkflowNodeExecution: """ Workflow node execution failed @@ -255,6 +249,8 @@ class WorkflowCycleManage: workflow_node_execution.inputs = json.dumps(inputs) if inputs else None workflow_node_execution.process_data = json.dumps(process_data) if process_data else None workflow_node_execution.outputs = json.dumps(outputs) if outputs else None + workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \ + if execution_metadata else None db.session.commit() db.session.refresh(workflow_node_execution) @@ -444,6 +440,23 @@ class WorkflowCycleManage: current_node_execution = self._task_state.ran_node_execution_infos[event.node_id] workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first() + + execution_metadata = event.execution_metadata if isinstance(event, QueueNodeSucceededEvent) else None + + if self._iteration_state and self._iteration_state.current_iterations: + if not execution_metadata: + execution_metadata = {} + current_iteration_data = None + for iteration_node_id in self._iteration_state.current_iterations: + data = self._iteration_state.current_iterations[iteration_node_id] + if data.parent_iteration_id == None: + current_iteration_data = data + break + + if current_iteration_data: + execution_metadata[NodeRunMetadataKey.ITERATION_ID] = current_iteration_data.iteration_id + execution_metadata[NodeRunMetadataKey.ITERATION_INDEX] = current_iteration_data.current_index + if isinstance(event, QueueNodeSucceededEvent): workflow_node_execution = self._workflow_node_execution_success( workflow_node_execution=workflow_node_execution, @@ -451,12 +464,18 @@ class WorkflowCycleManage: inputs=event.inputs, process_data=event.process_data, outputs=event.outputs, - execution_metadata=event.execution_metadata + execution_metadata=execution_metadata ) - if event.execution_metadata and event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): + if execution_metadata and execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): self._task_state.total_tokens += ( - int(event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))) + int(execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))) + + if self._iteration_state: + for iteration_node_id in self._iteration_state.current_iterations: + data = self._iteration_state.current_iterations[iteration_node_id] + if execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): + data.total_tokens += int(execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)) if workflow_node_execution.node_type == NodeType.LLM.value: outputs = workflow_node_execution.outputs_dict @@ -469,7 +488,8 @@ class WorkflowCycleManage: error=event.error, inputs=event.inputs, process_data=event.process_data, - outputs=event.outputs + outputs=event.outputs, + execution_metadata=execution_metadata ) db.session.close() diff --git a/api/core/app/task_pipeline/workflow_cycle_state_manager.py b/api/core/app/task_pipeline/workflow_cycle_state_manager.py new file mode 100644 index 0000000000..545f31fddf --- /dev/null +++ b/api/core/app/task_pipeline/workflow_cycle_state_manager.py @@ -0,0 +1,16 @@ +from typing import Any, Union + +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity +from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState +from core.workflow.entities.node_entities import SystemVariable +from models.account import Account +from models.model import EndUser +from models.workflow import Workflow + + +class WorkflowCycleStateManager: + _application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity] + _workflow: Workflow + _user: Union[Account, EndUser] + _task_state: Union[AdvancedChatTaskState, WorkflowTaskState] + _workflow_system_variables: dict[SystemVariable, Any] \ No newline at end of file diff --git a/api/core/app/task_pipeline/workflow_iteration_cycle_manage.py b/api/core/app/task_pipeline/workflow_iteration_cycle_manage.py new file mode 100644 index 0000000000..55e3e03173 --- /dev/null +++ b/api/core/app/task_pipeline/workflow_iteration_cycle_manage.py @@ -0,0 +1,281 @@ +import json +import time +from collections.abc import Generator +from typing import Optional, Union + +from core.app.entities.queue_entities import ( + QueueIterationCompletedEvent, + QueueIterationNextEvent, + QueueIterationStartEvent, +) +from core.app.entities.task_entities import ( + IterationNodeCompletedStreamResponse, + IterationNodeNextStreamResponse, + IterationNodeStartStreamResponse, + NodeExecutionInfo, + WorkflowIterationState, +) +from core.app.task_pipeline.workflow_cycle_state_manager import WorkflowCycleStateManager +from core.workflow.entities.node_entities import NodeType +from extensions.ext_database import db +from models.workflow import ( + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, + WorkflowNodeExecutionTriggeredFrom, + WorkflowRun, +) + + +class WorkflowIterationCycleManage(WorkflowCycleStateManager): + _iteration_state: WorkflowIterationState = None + + def _init_iteration_state(self) -> WorkflowIterationState: + if not self._iteration_state: + self._iteration_state = WorkflowIterationState( + current_iterations={} + ) + + def _handle_iteration_to_stream_response(self, task_id: str, event: QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent) \ + -> Union[IterationNodeStartStreamResponse, IterationNodeNextStreamResponse, IterationNodeCompletedStreamResponse]: + """ + Handle iteration to stream response + :param task_id: task id + :param event: iteration event + :return: + """ + if isinstance(event, QueueIterationStartEvent): + return IterationNodeStartStreamResponse( + task_id=task_id, + workflow_run_id=self._task_state.workflow_run_id, + data=IterationNodeStartStreamResponse.Data( + id=event.node_id, + node_id=event.node_id, + node_type=event.node_type.value, + title=event.node_data.title, + created_at=int(time.time()), + extras={}, + inputs=event.inputs, + metadata=event.metadata + ) + ) + elif isinstance(event, QueueIterationNextEvent): + current_iteration = self._iteration_state.current_iterations[event.node_id] + + return IterationNodeNextStreamResponse( + task_id=task_id, + workflow_run_id=self._task_state.workflow_run_id, + data=IterationNodeNextStreamResponse.Data( + id=event.node_id, + node_id=event.node_id, + node_type=event.node_type.value, + title=current_iteration.node_data.title, + index=event.index, + pre_iteration_output=event.output, + created_at=int(time.time()), + extras={} + ) + ) + elif isinstance(event, QueueIterationCompletedEvent): + current_iteration = self._iteration_state.current_iterations[event.node_id] + + return IterationNodeCompletedStreamResponse( + task_id=task_id, + workflow_run_id=self._task_state.workflow_run_id, + data=IterationNodeCompletedStreamResponse.Data( + id=event.node_id, + node_id=event.node_id, + node_type=event.node_type.value, + title=current_iteration.node_data.title, + outputs=event.outputs, + created_at=int(time.time()), + extras={}, + inputs=current_iteration.inputs, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + error=None, + elapsed_time=time.perf_counter() - current_iteration.started_at, + total_tokens=current_iteration.total_tokens, + finished_at=int(time.time()), + steps=current_iteration.current_index + ) + ) + + def _init_iteration_execution_from_workflow_run(self, + workflow_run: WorkflowRun, + node_id: str, + node_type: NodeType, + node_title: str, + node_run_index: int = 1, + inputs: Optional[dict] = None, + predecessor_node_id: Optional[str] = None + ) -> WorkflowNodeExecution: + workflow_node_execution = WorkflowNodeExecution( + tenant_id=workflow_run.tenant_id, + app_id=workflow_run.app_id, + workflow_id=workflow_run.workflow_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + workflow_run_id=workflow_run.id, + predecessor_node_id=predecessor_node_id, + index=node_run_index, + node_id=node_id, + node_type=node_type.value, + inputs=json.dumps(inputs) if inputs else None, + title=node_title, + status=WorkflowNodeExecutionStatus.RUNNING.value, + created_by_role=workflow_run.created_by_role, + created_by=workflow_run.created_by, + execution_metadata=json.dumps({ + 'started_run_index': node_run_index + 1, + 'current_index': 0, + 'steps_boundary': [], + }) + ) + + db.session.add(workflow_node_execution) + db.session.commit() + db.session.refresh(workflow_node_execution) + db.session.close() + + return workflow_node_execution + + def _handle_iteration_operation(self, event: QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent) -> WorkflowNodeExecution: + if isinstance(event, QueueIterationStartEvent): + return self._handle_iteration_started(event) + elif isinstance(event, QueueIterationNextEvent): + return self._handle_iteration_next(event) + elif isinstance(event, QueueIterationCompletedEvent): + return self._handle_iteration_completed(event) + + def _handle_iteration_started(self, event: QueueIterationStartEvent) -> WorkflowNodeExecution: + self._init_iteration_state() + + workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first() + workflow_node_execution = self._init_iteration_execution_from_workflow_run( + workflow_run=workflow_run, + node_id=event.node_id, + node_type=NodeType.ITERATION, + node_title=event.node_data.title, + node_run_index=event.node_run_index, + inputs=event.inputs, + predecessor_node_id=event.predecessor_node_id + ) + + latest_node_execution_info = NodeExecutionInfo( + workflow_node_execution_id=workflow_node_execution.id, + node_type=NodeType.ITERATION, + start_at=time.perf_counter() + ) + + self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info + self._task_state.latest_node_execution_info = latest_node_execution_info + + self._iteration_state.current_iterations[event.node_id] = WorkflowIterationState.Data( + parent_iteration_id=None, + iteration_id=event.node_id, + current_index=0, + iteration_steps_boundary=[], + node_execution_id=workflow_node_execution.id, + started_at=time.perf_counter(), + inputs=event.inputs, + total_tokens=0, + node_data=event.node_data + ) + + db.session.close() + + return workflow_node_execution + + def _handle_iteration_next(self, event: QueueIterationNextEvent) -> WorkflowNodeExecution: + if event.node_id not in self._iteration_state.current_iterations: + return + current_iteration = self._iteration_state.current_iterations[event.node_id] + current_iteration.current_index = event.index + current_iteration.iteration_steps_boundary.append(event.node_run_index) + workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter( + WorkflowNodeExecution.id == current_iteration.node_execution_id + ).first() + + original_node_execution_metadata = workflow_node_execution.execution_metadata_dict + if original_node_execution_metadata: + original_node_execution_metadata['current_index'] = event.index + original_node_execution_metadata['steps_boundary'] = current_iteration.iteration_steps_boundary + original_node_execution_metadata['total_tokens'] = current_iteration.total_tokens + workflow_node_execution.execution_metadata = json.dumps(original_node_execution_metadata) + + db.session.commit() + + db.session.close() + + def _handle_iteration_completed(self, event: QueueIterationCompletedEvent) -> WorkflowNodeExecution: + if event.node_id not in self._iteration_state.current_iterations: + return + + current_iteration = self._iteration_state.current_iterations[event.node_id] + workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter( + WorkflowNodeExecution.id == current_iteration.node_execution_id + ).first() + + workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value + workflow_node_execution.outputs = json.dumps(event.outputs) if event.outputs else None + workflow_node_execution.elapsed_time = time.perf_counter() - current_iteration.started_at + + original_node_execution_metadata = workflow_node_execution.execution_metadata_dict + if original_node_execution_metadata: + original_node_execution_metadata['steps_boundary'] = current_iteration.iteration_steps_boundary + original_node_execution_metadata['total_tokens'] = current_iteration.total_tokens + workflow_node_execution.execution_metadata = json.dumps(original_node_execution_metadata) + + db.session.commit() + + # remove current iteration + self._iteration_state.current_iterations.pop(event.node_id, None) + + # set latest node execution info + latest_node_execution_info = NodeExecutionInfo( + workflow_node_execution_id=workflow_node_execution.id, + node_type=NodeType.ITERATION, + start_at=time.perf_counter() + ) + + self._task_state.latest_node_execution_info = latest_node_execution_info + + db.session.close() + + def _handle_iteration_exception(self, task_id: str, error: str) -> Generator[IterationNodeCompletedStreamResponse, None, None]: + """ + Handle iteration exception + """ + if not self._iteration_state or not self._iteration_state.current_iterations: + return + + for node_id, current_iteration in self._iteration_state.current_iterations.items(): + workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter( + WorkflowNodeExecution.id == current_iteration.node_execution_id + ).first() + + workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value + workflow_node_execution.error = error + workflow_node_execution.elapsed_time = time.perf_counter() - current_iteration.started_at + + db.session.commit() + db.session.close() + + yield IterationNodeCompletedStreamResponse( + task_id=task_id, + workflow_run_id=self._task_state.workflow_run_id, + data=IterationNodeCompletedStreamResponse.Data( + id=node_id, + node_id=node_id, + node_type=NodeType.ITERATION.value, + title=current_iteration.node_data.title, + outputs={}, + created_at=int(time.time()), + extras={}, + inputs=current_iteration.inputs, + status=WorkflowNodeExecutionStatus.FAILED, + error=error, + elapsed_time=time.perf_counter() - current_iteration.started_at, + total_tokens=current_iteration.total_tokens, + finished_at=int(time.time()), + steps=current_iteration.current_index + ) + ) \ No newline at end of file diff --git a/api/core/file/message_file_parser.py b/api/core/file/message_file_parser.py index 06f21c880a..52498eb871 100644 --- a/api/core/file/message_file_parser.py +++ b/api/core/file/message_file_parser.py @@ -42,6 +42,8 @@ class MessageFileParser: raise ValueError('Invalid file url') if file.get('transfer_method') == FileTransferMethod.LOCAL_FILE.value and not file.get('upload_file_id'): raise ValueError('Missing file upload_file_id') + if file.get('transform_method') == FileTransferMethod.TOOL_FILE.value and not file.get('tool_file_id'): + raise ValueError('Missing file tool_file_id') # transform files to file objs type_file_objs = self._to_file_objs(files, file_extra_config) @@ -149,12 +151,21 @@ class MessageFileParser: """ if isinstance(file, dict): transfer_method = FileTransferMethod.value_of(file.get('transfer_method')) + if transfer_method != FileTransferMethod.TOOL_FILE: + return FileVar( + tenant_id=self.tenant_id, + type=FileType.value_of(file.get('type')), + transfer_method=transfer_method, + url=file.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None, + related_id=file.get('upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None, + extra_config=file_extra_config + ) return FileVar( tenant_id=self.tenant_id, type=FileType.value_of(file.get('type')), transfer_method=transfer_method, - url=file.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None, - related_id=file.get('upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None, + url=None, + related_id=file.get('tool_file_id'), extra_config=file_extra_config ) else: diff --git a/api/core/prompt/utils/prompt_message_util.py b/api/core/prompt/utils/prompt_message_util.py index 5fceeb3595..befdceeda5 100644 --- a/api/core/prompt/utils/prompt_message_util.py +++ b/api/core/prompt/utils/prompt_message_util.py @@ -1,6 +1,7 @@ from typing import cast from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, ImagePromptMessageContent, PromptMessage, PromptMessageContentType, @@ -21,13 +22,25 @@ class PromptMessageUtil: """ prompts = [] if model_mode == ModelMode.CHAT.value: + tool_calls = [] for prompt_message in prompt_messages: if prompt_message.role == PromptMessageRole.USER: role = 'user' elif prompt_message.role == PromptMessageRole.ASSISTANT: role = 'assistant' + if isinstance(prompt_message, AssistantPromptMessage): + tool_calls = [{ + 'id': tool_call.id, + 'type': 'function', + 'function': { + 'name': tool_call.function.name, + 'arguments': tool_call.function.arguments, + } + } for tool_call in prompt_message.tool_calls] elif prompt_message.role == PromptMessageRole.SYSTEM: role = 'system' + elif prompt_message.role == PromptMessageRole.TOOL: + role = 'tool' else: continue @@ -48,11 +61,16 @@ class PromptMessageUtil: else: text = prompt_message.content - prompts.append({ + prompt = { "role": role, "text": text, "files": files - }) + } + + if tool_calls: + prompt['tool_calls'] = tool_calls + + prompts.append(prompt) else: prompt_message = prompt_messages[0] text = '' diff --git a/api/core/tools/entities/user_entities.py b/api/core/tools/entities/api_entities.py similarity index 62% rename from api/core/tools/entities/user_entities.py rename to api/core/tools/entities/api_entities.py index 48fe5b0ed5..b6c553f3f1 100644 --- a/api/core/tools/entities/user_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -1,10 +1,10 @@ -from enum import Enum -from typing import Optional +from typing import Literal, Optional from pydantic import BaseModel +from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolProviderCredentials +from core.tools.entities.tool_entities import ToolProviderCredentials, ToolProviderType from core.tools.tool.tool import ToolParameter @@ -14,27 +14,38 @@ class UserTool(BaseModel): label: I18nObject # label description: I18nObject parameters: Optional[list[ToolParameter]] + labels: list[str] = None + +UserToolProviderTypeLiteral = Optional[Literal[ + 'builtin', 'api', 'workflow' +]] class UserToolProvider(BaseModel): - class ProviderType(Enum): - BUILTIN = "builtin" - APP = "app" - API = "api" - id: str author: str name: str # identifier description: I18nObject icon: str label: I18nObject # label - type: ProviderType + type: ToolProviderType masked_credentials: dict = None original_credentials: dict = None is_team_authorization: bool = False allow_delete: bool = True tools: list[UserTool] = None + labels: list[str] = None def to_dict(self) -> dict: + # ------------- + # overwrite tool parameter types for temp fix + tools = jsonable_encoder(self.tools) + for tool in tools: + if tool.get('parameters'): + for parameter in tool.get('parameters'): + if parameter.get('type') == ToolParameter.ToolParameterType.FILE.value: + parameter['type'] = 'files' + # ------------- + return { 'id': self.id, 'author': self.author, @@ -46,7 +57,8 @@ class UserToolProvider(BaseModel): 'team_credentials': self.masked_credentials, 'is_team_authorization': self.is_team_authorization, 'allow_delete': self.allow_delete, - 'tools': self.tools + 'tools': tools, + 'labels': self.labels, } class UserToolProviderCredentials(BaseModel): diff --git a/api/core/tools/entities/constant.py b/api/core/tools/entities/constant.py deleted file mode 100644 index 2e75fedf99..0000000000 --- a/api/core/tools/entities/constant.py +++ /dev/null @@ -1,3 +0,0 @@ -class DEFAULT_PROVIDERS: - API_BASED = '__api_based' - APP_BASED = '__app_based' \ No newline at end of file diff --git a/api/core/tools/entities/tool_bundle.py b/api/core/tools/entities/tool_bundle.py index efa10e792c..d18d27fb02 100644 --- a/api/core/tools/entities/tool_bundle.py +++ b/api/core/tools/entities/tool_bundle.py @@ -1,11 +1,11 @@ -from typing import Any, Optional +from typing import Optional from pydantic import BaseModel -from core.tools.entities.tool_entities import ToolParameter, ToolProviderType +from core.tools.entities.tool_entities import ToolParameter -class ApiBasedToolBundle(BaseModel): +class ApiToolBundle(BaseModel): """ This class is used to store the schema information of an api based tool. such as the url, the method, the parameters, etc. """ @@ -25,12 +25,3 @@ class ApiBasedToolBundle(BaseModel): icon: Optional[str] = None # openapi operation openapi: dict - -class AppToolBundle(BaseModel): - """ - This class is used to store the schema information of an tool for an app. - """ - type: ToolProviderType - credential: Optional[dict[str, Any]] = None - provider_id: str - tool_name: str \ No newline at end of file diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index fad91baf83..03ce44b219 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -10,10 +10,11 @@ class ToolProviderType(Enum): """ Enum class for tool provider """ - BUILT_IN = "built-in" + BUILT_IN = "builtin" + WORKFLOW = "workflow" + API = "api" + APP = "app" DATASET_RETRIEVAL = "dataset-retrieval" - APP_BASED = "app-based" - API_BASED = "api-based" @classmethod def value_of(cls, value: str) -> 'ToolProviderType': @@ -77,6 +78,7 @@ class ToolInvokeMessage(BaseModel): LINK = "link" BLOB = "blob" IMAGE_LINK = "image_link" + FILE_VAR = "file_var" type: MessageType = MessageType.TEXT """ @@ -90,6 +92,7 @@ class ToolInvokeMessageBinary(BaseModel): mimetype: str = Field(..., description="The mimetype of the binary") url: str = Field(..., description="The url of the binary") save_as: str = '' + file_var: Optional[dict[str, Any]] = None class ToolParameterOption(BaseModel): value: str = Field(..., description="The value of the option") @@ -102,6 +105,7 @@ class ToolParameter(BaseModel): BOOLEAN = "boolean" SELECT = "select" SECRET_INPUT = "secret-input" + FILE = "file" class ToolParameterForm(Enum): SCHEMA = "schema" # should be set while adding tool @@ -331,6 +335,15 @@ class ModelToolProviderConfiguration(BaseModel): models: list[ModelToolConfiguration] = Field(..., description="The models of the model tool") label: I18nObject = Field(..., description="The label of the model tool") + +class WorkflowToolParameterConfiguration(BaseModel): + """ + Workflow tool configuration + """ + name: str = Field(..., description="The name of the parameter") + description: str = Field(..., description="The description of the parameter") + form: ToolParameter.ToolParameterForm = Field(..., description="The form of the parameter") + class ToolInvokeMeta(BaseModel): """ Tool invoke meta @@ -358,4 +371,19 @@ class ToolInvokeMeta(BaseModel): 'time_cost': self.time_cost, 'error': self.error, 'tool_config': self.tool_config, - } \ No newline at end of file + } + +class ToolLabel(BaseModel): + """ + Tool label + """ + name: str = Field(..., description="The name of the tool") + label: I18nObject = Field(..., description="The label of the tool") + icon: str = Field(..., description="The icon of the tool") + +class ToolInvokeFrom(Enum): + """ + Enum class for tool invoke + """ + WORKFLOW = "workflow" + AGENT = "agent" \ No newline at end of file diff --git a/api/core/tools/entities/values.py b/api/core/tools/entities/values.py new file mode 100644 index 0000000000..5ca50a3ed3 --- /dev/null +++ b/api/core/tools/entities/values.py @@ -0,0 +1,96 @@ +from enum import Enum + +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolLabel + + +class ToolLabelEnum(Enum): + SEARCH = 'search' + IMAGE = 'image' + VIDEOS = 'videos' + WEATHER = 'weather' + FINANCE = 'finance' + DESIGN = 'design' + TRAVEL = 'travel' + SOCIAL = 'social' + NEWS = 'news' + MEDICAL = 'medical' + PRODUCTIVITY = 'productivity' + EDUCATION = 'education' + BUSINESS = 'business' + ENTERTAINMENT = 'entertainment' + UTILITIES = 'utilities' + OTHER = 'other' + +ICONS = { + ToolLabelEnum.SEARCH: ''' + +''', + ToolLabelEnum.IMAGE: ''' + +''', + ToolLabelEnum.VIDEOS: ''' + +''', + ToolLabelEnum.WEATHER: ''' + +''', + ToolLabelEnum.FINANCE: ''' + +''', + ToolLabelEnum.DESIGN: ''' + +''', + ToolLabelEnum.TRAVEL: ''' + +''', + ToolLabelEnum.SOCIAL: ''' + +''', + ToolLabelEnum.NEWS: ''' + +''', + ToolLabelEnum.MEDICAL: ''' + +''', + ToolLabelEnum.PRODUCTIVITY: ''' + +''', + ToolLabelEnum.EDUCATION: ''' + +''', + ToolLabelEnum.BUSINESS: ''' + +''', + ToolLabelEnum.ENTERTAINMENT: ''' + +''', + ToolLabelEnum.UTILITIES: ''' + +''', + ToolLabelEnum.OTHER: ''' + +''' +} + +default_tool_label_dict = { + ToolLabelEnum.SEARCH: ToolLabel(name='search', label=I18nObject(en_US='Search', zh_Hans='搜索'), icon=ICONS[ToolLabelEnum.SEARCH]), + ToolLabelEnum.IMAGE: ToolLabel(name='image', label=I18nObject(en_US='Image', zh_Hans='图片'), icon=ICONS[ToolLabelEnum.IMAGE]), + ToolLabelEnum.VIDEOS: ToolLabel(name='videos', label=I18nObject(en_US='Videos', zh_Hans='视频'), icon=ICONS[ToolLabelEnum.VIDEOS]), + ToolLabelEnum.WEATHER: ToolLabel(name='weather', label=I18nObject(en_US='Weather', zh_Hans='天气'), icon=ICONS[ToolLabelEnum.WEATHER]), + ToolLabelEnum.FINANCE: ToolLabel(name='finance', label=I18nObject(en_US='Finance', zh_Hans='金融'), icon=ICONS[ToolLabelEnum.FINANCE]), + ToolLabelEnum.DESIGN: ToolLabel(name='design', label=I18nObject(en_US='Design', zh_Hans='设计'), icon=ICONS[ToolLabelEnum.DESIGN]), + ToolLabelEnum.TRAVEL: ToolLabel(name='travel', label=I18nObject(en_US='Travel', zh_Hans='旅行'), icon=ICONS[ToolLabelEnum.TRAVEL]), + ToolLabelEnum.SOCIAL: ToolLabel(name='social', label=I18nObject(en_US='Social', zh_Hans='社交'), icon=ICONS[ToolLabelEnum.SOCIAL]), + ToolLabelEnum.NEWS: ToolLabel(name='news', label=I18nObject(en_US='News', zh_Hans='新闻'), icon=ICONS[ToolLabelEnum.NEWS]), + ToolLabelEnum.MEDICAL: ToolLabel(name='medical', label=I18nObject(en_US='Medical', zh_Hans='医疗'), icon=ICONS[ToolLabelEnum.MEDICAL]), + ToolLabelEnum.PRODUCTIVITY: ToolLabel(name='productivity', label=I18nObject(en_US='Productivity', zh_Hans='生产力'), icon=ICONS[ToolLabelEnum.PRODUCTIVITY]), + ToolLabelEnum.EDUCATION: ToolLabel(name='education', label=I18nObject(en_US='Education', zh_Hans='教育'), icon=ICONS[ToolLabelEnum.EDUCATION]), + ToolLabelEnum.BUSINESS: ToolLabel(name='business', label=I18nObject(en_US='Business', zh_Hans='商业'), icon=ICONS[ToolLabelEnum.BUSINESS]), + ToolLabelEnum.ENTERTAINMENT: ToolLabel(name='entertainment', label=I18nObject(en_US='Entertainment', zh_Hans='娱乐'), icon=ICONS[ToolLabelEnum.ENTERTAINMENT]), + ToolLabelEnum.UTILITIES: ToolLabel(name='utilities', label=I18nObject(en_US='Utilities', zh_Hans='工具'), icon=ICONS[ToolLabelEnum.UTILITIES]), + ToolLabelEnum.OTHER: ToolLabel(name='other', label=I18nObject(en_US='Other', zh_Hans='其他'), icon=ICONS[ToolLabelEnum.OTHER]), +} + +default_tool_labels = [v for k, v in default_tool_label_dict.items()] +default_tool_label_name_list = [label.name for label in default_tool_labels] diff --git a/api/core/tools/model/errors.py b/api/core/tools/model/errors.py deleted file mode 100644 index 6e242b349a..0000000000 --- a/api/core/tools/model/errors.py +++ /dev/null @@ -1,2 +0,0 @@ -class InvokeModelError(Exception): - pass \ No newline at end of file diff --git a/api/core/tools/provider/api_tool_provider.py b/api/core/tools/provider/api_tool_provider.py index 11e6e892c9..ae80ad2114 100644 --- a/api/core/tools/provider/api_tool_provider.py +++ b/api/core/tools/provider/api_tool_provider.py @@ -1,7 +1,6 @@ -from typing import Any from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_bundle import ApiBasedToolBundle +from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ( ApiProviderAuthType, ToolCredentialsOption, @@ -15,11 +14,11 @@ from extensions.ext_database import db from models.tools import ApiToolProvider -class ApiBasedToolProviderController(ToolProviderController): +class ApiToolProviderController(ToolProviderController): provider_id: str @staticmethod - def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiBasedToolProviderController': + def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiToolProviderController': credentials_schema = { 'auth_type': ToolProviderCredentials( name='auth_type', @@ -79,9 +78,11 @@ class ApiBasedToolProviderController(ToolProviderController): else: raise ValueError(f'invalid auth type {auth_type}') - return ApiBasedToolProviderController(**{ + user_name = db_provider.user.name if db_provider.user_id else '' + + return ApiToolProviderController(**{ 'identity': { - 'author': db_provider.user.name if db_provider.user_id and db_provider.user else '', + 'author': user_name, 'name': db_provider.name, 'label': { 'en_US': db_provider.name, @@ -98,16 +99,10 @@ class ApiBasedToolProviderController(ToolProviderController): }) @property - def app_type(self) -> ToolProviderType: - return ToolProviderType.API_BASED - - def _validate_credentials(self, tool_name: str, credentials: dict[str, Any]) -> None: - pass + def provider_type(self) -> ToolProviderType: + return ToolProviderType.API - def validate_parameters(self, tool_name: str, tool_parameters: dict[str, Any]) -> None: - pass - - def _parse_tool_bundle(self, tool_bundle: ApiBasedToolBundle) -> ApiTool: + def _parse_tool_bundle(self, tool_bundle: ApiToolBundle) -> ApiTool: """ parse tool bundle to tool @@ -136,7 +131,7 @@ class ApiBasedToolProviderController(ToolProviderController): 'parameters' : tool_bundle.parameters if tool_bundle.parameters else [], }) - def load_bundled_tools(self, tools: list[ApiBasedToolBundle]) -> list[ApiTool]: + def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[ApiTool]: """ load bundled tools diff --git a/api/core/tools/provider/app_tool_provider.py b/api/core/tools/provider/app_tool_provider.py index 159c94bbf3..2d472e0a93 100644 --- a/api/core/tools/provider/app_tool_provider.py +++ b/api/core/tools/provider/app_tool_provider.py @@ -11,10 +11,10 @@ from models.tools import PublishedAppTool logger = logging.getLogger(__name__) -class AppBasedToolProviderEntity(ToolProviderController): +class AppToolProviderEntity(ToolProviderController): @property - def app_type(self) -> ToolProviderType: - return ToolProviderType.APP_BASED + def provider_type(self) -> ToolProviderType: + return ToolProviderType.APP def _validate_credentials(self, tool_name: str, credentials: dict[str, Any]) -> None: pass diff --git a/api/core/tools/provider/builtin/_positions.py b/api/core/tools/provider/builtin/_positions.py index d0826ddcf0..851736dc7a 100644 --- a/api/core/tools/provider/builtin/_positions.py +++ b/api/core/tools/provider/builtin/_positions.py @@ -1,6 +1,6 @@ import os.path -from core.tools.entities.user_entities import UserToolProvider +from core.tools.entities.api_entities import UserToolProvider from core.utils.position_helper import get_position_map, sort_by_position_map diff --git a/api/core/tools/provider/builtin/aippt/aippt.py b/api/core/tools/provider/builtin/aippt/aippt.py index 25133c51df..b156520e49 100644 --- a/api/core/tools/provider/builtin/aippt/aippt.py +++ b/api/core/tools/provider/builtin/aippt/aippt.py @@ -1,3 +1,4 @@ +from core.tools.entities.values import ToolLabelEnum from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin.aippt.tools.aippt import AIPPTGenerateTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController @@ -9,3 +10,9 @@ class AIPPTProvider(BuiltinToolProviderController): AIPPTGenerateTool._get_api_token(credentials, user_id='__dify_system__') except Exception as e: raise ToolProviderCredentialValidationError(str(e)) + + def _get_tool_labels(self) -> list[ToolLabelEnum]: + return [ + ToolLabelEnum.PRODUCTIVITY, + ToolLabelEnum.DESIGN, + ] \ No newline at end of file diff --git a/api/core/tools/provider/builtin/arxiv/arxiv.py b/api/core/tools/provider/builtin/arxiv/arxiv.py index 998128522e..f8613bbc78 100644 --- a/api/core/tools/provider/builtin/arxiv/arxiv.py +++ b/api/core/tools/provider/builtin/arxiv/arxiv.py @@ -1,3 +1,4 @@ +from core.tools.entities.values import ToolLabelEnum from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin.arxiv.tools.arxiv_search import ArxivSearchTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController @@ -7,7 +8,7 @@ class ArxivProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: ArxivSearchTool().fork_tool_runtime( - meta={ + runtime={ "credentials": credentials, } ).invoke( @@ -17,4 +18,9 @@ class ArxivProvider(BuiltinToolProviderController): }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + + def _get_tool_labels(self) -> list[ToolLabelEnum]: + return [ + ToolLabelEnum.SEARCH, + ] \ No newline at end of file diff --git a/api/core/tools/provider/builtin/azuredalle/azuredalle.py b/api/core/tools/provider/builtin/azuredalle/azuredalle.py index 4278da54ba..ea4789d479 100644 --- a/api/core/tools/provider/builtin/azuredalle/azuredalle.py +++ b/api/core/tools/provider/builtin/azuredalle/azuredalle.py @@ -1,5 +1,6 @@ from typing import Any +from core.tools.entities.values import ToolLabelEnum from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin.azuredalle.tools.dalle3 import DallE3Tool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController @@ -9,7 +10,7 @@ class AzureDALLEProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: DallE3Tool().fork_tool_runtime( - meta={ + runtime={ "credentials": credentials, } ).invoke( @@ -22,3 +23,8 @@ class AzureDALLEProvider(BuiltinToolProviderController): ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) + + def _get_tool_labels(self) -> list[ToolLabelEnum]: + return [ + ToolLabelEnum.IMAGE + ] \ No newline at end of file diff --git a/api/core/tools/provider/builtin/bing/bing.py b/api/core/tools/provider/builtin/bing/bing.py index 6e62abfc10..f64460c3f2 100644 --- a/api/core/tools/provider/builtin/bing/bing.py +++ b/api/core/tools/provider/builtin/bing/bing.py @@ -1,5 +1,6 @@ from typing import Any +from core.tools.entities.values import ToolLabelEnum from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin.bing.tools.bing_web_search import BingSearchTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController @@ -9,7 +10,7 @@ class BingProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: BingSearchTool().fork_tool_runtime( - meta={ + runtime={ "credentials": credentials, } ).validate_credentials( @@ -21,3 +22,8 @@ class BingProvider(BuiltinToolProviderController): ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) + + def _get_tool_labels(self) -> list[ToolLabelEnum]: + return [ + ToolLabelEnum.SEARCH + ] \ No newline at end of file diff --git a/api/core/tools/provider/builtin/brave/brave.py b/api/core/tools/provider/builtin/brave/brave.py index e26b28b46a..a1c3101418 100644 --- a/api/core/tools/provider/builtin/brave/brave.py +++ b/api/core/tools/provider/builtin/brave/brave.py @@ -1,5 +1,6 @@ from typing import Any +from core.tools.entities.values import ToolLabelEnum from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin.brave.tools.brave_search import BraveSearchTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController @@ -9,7 +10,7 @@ class BraveProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: BraveSearchTool().fork_tool_runtime( - meta={ + runtime={ "credentials": credentials, } ).invoke( @@ -19,4 +20,9 @@ class BraveProvider(BuiltinToolProviderController): }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + + def _get_tool_labels(self) -> list[ToolLabelEnum]: + return [ + ToolLabelEnum.SEARCH, + ] \ No newline at end of file diff --git a/api/core/tools/provider/builtin/chart/chart.py b/api/core/tools/provider/builtin/chart/chart.py index f5e42e766d..aed475155c 100644 --- a/api/core/tools/provider/builtin/chart/chart.py +++ b/api/core/tools/provider/builtin/chart/chart.py @@ -2,6 +2,7 @@ import matplotlib.pyplot as plt from fontTools.ttLib import TTFont from matplotlib.font_manager import findSystemFonts +from core.tools.entities.values import ToolLabelEnum from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin.chart.tools.line import LinearChartTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController @@ -44,7 +45,7 @@ class ChartProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: LinearChartTool().fork_tool_runtime( - meta={ + runtime={ "credentials": credentials, } ).invoke( @@ -54,4 +55,9 @@ class ChartProvider(BuiltinToolProviderController): }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + + def _get_tool_labels(self) -> list[ToolLabelEnum]: + return [ + ToolLabelEnum.DESIGN, ToolLabelEnum.PRODUCTIVITY, ToolLabelEnum.UTILITIES + ] \ No newline at end of file diff --git a/api/core/tools/provider/builtin/code/code.py b/api/core/tools/provider/builtin/code/code.py index fae5ecf769..93a4f2b6d2 100644 --- a/api/core/tools/provider/builtin/code/code.py +++ b/api/core/tools/provider/builtin/code/code.py @@ -1,8 +1,14 @@ from typing import Any +from core.tools.entities.values import ToolLabelEnum from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController class CodeToolProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: - pass \ No newline at end of file + pass + + def _get_tool_labels(self) -> list[ToolLabelEnum]: + return [ + ToolLabelEnum.PRODUCTIVITY + ] \ No newline at end of file diff --git a/api/core/tools/provider/builtin/dalle/dalle.py b/api/core/tools/provider/builtin/dalle/dalle.py index 34a24a7425..1c249cf440 100644 --- a/api/core/tools/provider/builtin/dalle/dalle.py +++ b/api/core/tools/provider/builtin/dalle/dalle.py @@ -1,5 +1,6 @@ from typing import Any +from core.tools.entities.values import ToolLabelEnum from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin.dalle.tools.dalle2 import DallE2Tool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController @@ -9,7 +10,7 @@ class DALLEProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: DallE2Tool().fork_tool_runtime( - meta={ + runtime={ "credentials": credentials, } ).invoke( @@ -21,4 +22,9 @@ class DALLEProvider(BuiltinToolProviderController): }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + + def _get_tool_labels(self) -> list[ToolLabelEnum]: + return [ + ToolLabelEnum.IMAGE, ToolLabelEnum.PRODUCTIVITY + ] \ No newline at end of file diff --git a/api/core/tools/provider/builtin/devdocs/devdocs.py b/api/core/tools/provider/builtin/devdocs/devdocs.py index 25cbe4d053..eda05626d5 100644 --- a/api/core/tools/provider/builtin/devdocs/devdocs.py +++ b/api/core/tools/provider/builtin/devdocs/devdocs.py @@ -1,3 +1,4 @@ +from core.tools.entities.values import ToolLabelEnum from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin.devdocs.tools.searchDevDocs import SearchDevDocsTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController @@ -7,7 +8,7 @@ class DevDocsProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: SearchDevDocsTool().fork_tool_runtime( - meta={ + runtime={ "credentials": credentials, } ).invoke( @@ -18,4 +19,9 @@ class DevDocsProvider(BuiltinToolProviderController): }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + + def _get_tool_labels(self) -> list[ToolLabelEnum]: + return [ + ToolLabelEnum.SEARCH, ToolLabelEnum.PRODUCTIVITY + ] \ No newline at end of file diff --git a/api/core/tools/provider/builtin/dingtalk/dingtalk.py b/api/core/tools/provider/builtin/dingtalk/dingtalk.py index be1d5e099c..d42f8c5bbf 100644 --- a/api/core/tools/provider/builtin/dingtalk/dingtalk.py +++ b/api/core/tools/provider/builtin/dingtalk/dingtalk.py @@ -1,3 +1,4 @@ +from core.tools.entities.values import ToolLabelEnum from core.tools.provider.builtin.dingtalk.tools.dingtalk_group_bot import DingTalkGroupBotTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController @@ -6,3 +7,8 @@ class DingTalkProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: DingTalkGroupBotTool() pass + + def _get_tool_labels(self) -> list[ToolLabelEnum]: + return [ + ToolLabelEnum.SOCIAL + ] \ No newline at end of file diff --git a/api/core/tools/provider/builtin/duckduckgo/duckduckgo.py b/api/core/tools/provider/builtin/duckduckgo/duckduckgo.py index 3e9b57ece7..53d0e90d42 100644 --- a/api/core/tools/provider/builtin/duckduckgo/duckduckgo.py +++ b/api/core/tools/provider/builtin/duckduckgo/duckduckgo.py @@ -1,3 +1,4 @@ +from core.tools.entities.values import ToolLabelEnum from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin.duckduckgo.tools.duckduckgo_search import DuckDuckGoSearchTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController @@ -7,7 +8,7 @@ class DuckDuckGoProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: DuckDuckGoSearchTool().fork_tool_runtime( - meta={ + runtime={ "credentials": credentials, } ).invoke( @@ -17,4 +18,9 @@ class DuckDuckGoProvider(BuiltinToolProviderController): }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + + def _get_tool_labels(self) -> list[ToolLabelEnum]: + return [ + ToolLabelEnum.SEARCH + ] \ No newline at end of file diff --git a/api/core/tools/provider/builtin/feishu/feishu.py b/api/core/tools/provider/builtin/feishu/feishu.py index 13303dbe64..4f455b093b 100644 --- a/api/core/tools/provider/builtin/feishu/feishu.py +++ b/api/core/tools/provider/builtin/feishu/feishu.py @@ -1,3 +1,4 @@ +from core.tools.entities.values import ToolLabelEnum from core.tools.provider.builtin.feishu.tools.feishu_group_bot import FeishuGroupBotTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController @@ -6,3 +7,8 @@ class FeishuProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: FeishuGroupBotTool() pass + + def _get_tool_labels(self) -> list[ToolLabelEnum]: + return [ + ToolLabelEnum.SOCIAL + ] \ No newline at end of file diff --git a/api/core/tools/provider/builtin/firecrawl/firecrawl.py b/api/core/tools/provider/builtin/firecrawl/firecrawl.py index 20ab978b8d..cc8fcb1006 100644 --- a/api/core/tools/provider/builtin/firecrawl/firecrawl.py +++ b/api/core/tools/provider/builtin/firecrawl/firecrawl.py @@ -1,3 +1,4 @@ +from core.tools.entities.values import ToolLabelEnum from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin.firecrawl.tools.crawl import CrawlTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController @@ -8,7 +9,7 @@ class FirecrawlProvider(BuiltinToolProviderController): try: # Example validation using the Crawl tool CrawlTool().fork_tool_runtime( - meta={"credentials": credentials} + runtime={"credentials": credentials} ).invoke( user_id='', tool_parameters={ @@ -20,4 +21,9 @@ class FirecrawlProvider(BuiltinToolProviderController): } ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + + def _get_tool_labels(self) -> list[ToolLabelEnum]: + return [ + ToolLabelEnum.SEARCH, ToolLabelEnum.UTILITIES + ] \ No newline at end of file diff --git a/api/core/tools/provider/builtin/gaode/gaode.py b/api/core/tools/provider/builtin/gaode/gaode.py index b55d93e07b..e6a992ff87 100644 --- a/api/core/tools/provider/builtin/gaode/gaode.py +++ b/api/core/tools/provider/builtin/gaode/gaode.py @@ -2,6 +2,7 @@ import urllib.parse import requests +from core.tools.entities.values import ToolLabelEnum from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController @@ -24,3 +25,9 @@ class GaodeProvider(BuiltinToolProviderController): raise ToolProviderCredentialValidationError("Gaode API Key is invalid. {}".format(e)) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) + + def _get_tool_labels(self) -> list[ToolLabelEnum]: + return [ + ToolLabelEnum.UTILITIES, ToolLabelEnum.PRODUCTIVITY, + ToolLabelEnum.WEATHER, ToolLabelEnum.TRAVEL + ] \ No newline at end of file diff --git a/api/core/tools/provider/builtin/github/github.py b/api/core/tools/provider/builtin/github/github.py index 9275504208..fa22761206 100644 --- a/api/core/tools/provider/builtin/github/github.py +++ b/api/core/tools/provider/builtin/github/github.py @@ -1,5 +1,6 @@ import requests +from core.tools.entities.values import ToolLabelEnum from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController @@ -30,3 +31,8 @@ class GihubProvider(BuiltinToolProviderController): raise ToolProviderCredentialValidationError("Github API Key and Api Version is invalid. {}".format(e)) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) + + def _get_tool_labels(self) -> list[ToolLabelEnum]: + return [ + ToolLabelEnum.UTILITIES + ] \ No newline at end of file diff --git a/api/core/tools/provider/builtin/google/google.py b/api/core/tools/provider/builtin/google/google.py index 3900804b45..0ec5baef26 100644 --- a/api/core/tools/provider/builtin/google/google.py +++ b/api/core/tools/provider/builtin/google/google.py @@ -1,5 +1,6 @@ from typing import Any +from core.tools.entities.values import ToolLabelEnum from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin.google.tools.google_search import GoogleSearchTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController @@ -9,7 +10,7 @@ class GoogleProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: GoogleSearchTool().fork_tool_runtime( - meta={ + runtime={ "credentials": credentials, } ).invoke( @@ -20,4 +21,9 @@ class GoogleProvider(BuiltinToolProviderController): }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + + def _get_tool_labels(self) -> list[ToolLabelEnum]: + return [ + ToolLabelEnum.SEARCH + ] \ No newline at end of file diff --git a/api/core/tools/provider/builtin/jina/jina.py b/api/core/tools/provider/builtin/jina/jina.py index ed1de6f6c1..b1a8d62138 100644 --- a/api/core/tools/provider/builtin/jina/jina.py +++ b/api/core/tools/provider/builtin/jina/jina.py @@ -1,5 +1,6 @@ from typing import Any +from core.tools.entities.values import ToolLabelEnum from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController @@ -9,4 +10,9 @@ class GoogleProvider(BuiltinToolProviderController): try: pass except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + + def _get_tool_labels(self) -> list[ToolLabelEnum]: + return [ + ToolLabelEnum.SEARCH, ToolLabelEnum.PRODUCTIVITY + ] \ No newline at end of file diff --git a/api/core/tools/provider/builtin/judge0ce/judge0ce.py b/api/core/tools/provider/builtin/judge0ce/judge0ce.py index c00747868b..c378b64759 100644 --- a/api/core/tools/provider/builtin/judge0ce/judge0ce.py +++ b/api/core/tools/provider/builtin/judge0ce/judge0ce.py @@ -1,5 +1,6 @@ from typing import Any +from core.tools.entities.values import ToolLabelEnum from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin.judge0ce.tools.executeCode import ExecuteCodeTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController @@ -9,7 +10,7 @@ class Judge0CEProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: ExecuteCodeTool().fork_tool_runtime( - meta={ + runtime={ "credentials": credentials, } ).invoke( @@ -20,4 +21,9 @@ class Judge0CEProvider(BuiltinToolProviderController): }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + + def _get_tool_labels(self) -> list[ToolLabelEnum]: + return [ + ToolLabelEnum.OTHER, ToolLabelEnum.UTILITIES + ] \ No newline at end of file diff --git a/api/core/tools/provider/builtin/maths/maths.py b/api/core/tools/provider/builtin/maths/maths.py index 7226a5c168..1471de503e 100644 --- a/api/core/tools/provider/builtin/maths/maths.py +++ b/api/core/tools/provider/builtin/maths/maths.py @@ -1,5 +1,6 @@ from typing import Any +from core.tools.entities.values import ToolLabelEnum from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin.maths.tools.eval_expression import EvaluateExpressionTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController @@ -16,3 +17,8 @@ class MathsProvider(BuiltinToolProviderController): ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) + + def _get_tool_labels(self) -> list[ToolLabelEnum]: + return [ + ToolLabelEnum.UTILITIES, ToolLabelEnum.PRODUCTIVITY + ] \ No newline at end of file diff --git a/api/core/tools/provider/builtin/openweather/openweather.py b/api/core/tools/provider/builtin/openweather/openweather.py index a2827177a3..dd0649f1aa 100644 --- a/api/core/tools/provider/builtin/openweather/openweather.py +++ b/api/core/tools/provider/builtin/openweather/openweather.py @@ -1,5 +1,6 @@ import requests +from core.tools.entities.values import ToolLabelEnum from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController @@ -34,3 +35,8 @@ class OpenweatherProvider(BuiltinToolProviderController): ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) + + def _get_tool_labels(self) -> list[ToolLabelEnum]: + return [ + ToolLabelEnum.WEATHER + ] \ No newline at end of file diff --git a/api/core/tools/provider/builtin/pubmed/pubmed.py b/api/core/tools/provider/builtin/pubmed/pubmed.py index 663617c0c1..9f4087abd7 100644 --- a/api/core/tools/provider/builtin/pubmed/pubmed.py +++ b/api/core/tools/provider/builtin/pubmed/pubmed.py @@ -1,3 +1,4 @@ +from core.tools.entities.values import ToolLabelEnum from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin.pubmed.tools.pubmed_search import PubMedSearchTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController @@ -7,7 +8,7 @@ class PubMedProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: PubMedSearchTool().fork_tool_runtime( - meta={ + runtime={ "credentials": credentials, } ).invoke( @@ -17,4 +18,9 @@ class PubMedProvider(BuiltinToolProviderController): }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + + def _get_tool_labels(self) -> list[ToolLabelEnum]: + return [ + ToolLabelEnum.MEDICAL, ToolLabelEnum.SEARCH + ] \ No newline at end of file diff --git a/api/core/tools/provider/builtin/qrcode/qrcode.py b/api/core/tools/provider/builtin/qrcode/qrcode.py index 9fa7d01265..615e4ff6e6 100644 --- a/api/core/tools/provider/builtin/qrcode/qrcode.py +++ b/api/core/tools/provider/builtin/qrcode/qrcode.py @@ -1,5 +1,6 @@ from typing import Any +from core.tools.entities.values import ToolLabelEnum from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin.qrcode.tools.qrcode_generator import QRCodeGeneratorTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController @@ -14,3 +15,8 @@ class QRCodeProvider(BuiltinToolProviderController): }) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) + + def _get_tool_labels(self) -> list[ToolLabelEnum]: + return [ + ToolLabelEnum.UTILITIES + ] \ No newline at end of file diff --git a/api/core/tools/provider/builtin/searxng/searxng.py b/api/core/tools/provider/builtin/searxng/searxng.py index 8046056093..5f1b05135d 100644 --- a/api/core/tools/provider/builtin/searxng/searxng.py +++ b/api/core/tools/provider/builtin/searxng/searxng.py @@ -1,5 +1,6 @@ from typing import Any +from core.tools.entities.values import ToolLabelEnum from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin.searxng.tools.searxng_search import SearXNGSearchTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController @@ -9,7 +10,7 @@ class SearXNGProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: SearXNGSearchTool().fork_tool_runtime( - meta={ + runtime={ "credentials": credentials, } ).invoke( @@ -22,4 +23,9 @@ class SearXNGProvider(BuiltinToolProviderController): }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + + def _get_tool_labels(self) -> list[ToolLabelEnum]: + return [ + ToolLabelEnum.SEARCH, ToolLabelEnum.PRODUCTIVITY + ] \ No newline at end of file diff --git a/api/core/tools/provider/builtin/slack/slack.py b/api/core/tools/provider/builtin/slack/slack.py index 2de7911f63..b8843392c7 100644 --- a/api/core/tools/provider/builtin/slack/slack.py +++ b/api/core/tools/provider/builtin/slack/slack.py @@ -1,3 +1,4 @@ +from core.tools.entities.values import ToolLabelEnum from core.tools.provider.builtin.slack.tools.slack_webhook import SlackWebhookTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController @@ -6,3 +7,8 @@ class SlackProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: SlackWebhookTool() pass + + def _get_tool_labels(self) -> list[ToolLabelEnum]: + return [ + ToolLabelEnum.SOCIAL + ] \ No newline at end of file diff --git a/api/core/tools/provider/builtin/spark/spark.py b/api/core/tools/provider/builtin/spark/spark.py index cb8e69a59f..b4727912c3 100644 --- a/api/core/tools/provider/builtin/spark/spark.py +++ b/api/core/tools/provider/builtin/spark/spark.py @@ -1,5 +1,6 @@ import json +from core.tools.entities.values import ToolLabelEnum from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin.spark.tools.spark_img_generation import spark_response from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController @@ -38,3 +39,8 @@ class SparkProvider(BuiltinToolProviderController): ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) + + def _get_tool_labels(self) -> list[ToolLabelEnum]: + return [ + ToolLabelEnum.IMAGE + ] \ No newline at end of file diff --git a/api/core/tools/provider/builtin/stability/stability.py b/api/core/tools/provider/builtin/stability/stability.py index d00c3ecf00..9706f17468 100644 --- a/api/core/tools/provider/builtin/stability/stability.py +++ b/api/core/tools/provider/builtin/stability/stability.py @@ -1,5 +1,6 @@ from typing import Any +from core.tools.entities.values import ToolLabelEnum from core.tools.provider.builtin.stability.tools.base import BaseStabilityAuthorization from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController @@ -12,4 +13,9 @@ class StabilityToolProvider(BuiltinToolProviderController, BaseStabilityAuthoriz """ This method is responsible for validating the credentials. """ - self.sd_validate_credentials(credentials) \ No newline at end of file + self.sd_validate_credentials(credentials) + + def _get_tool_labels(self) -> list[ToolLabelEnum]: + return [ + ToolLabelEnum.IMAGE + ] \ No newline at end of file diff --git a/api/core/tools/provider/builtin/stablediffusion/stablediffusion.py b/api/core/tools/provider/builtin/stablediffusion/stablediffusion.py index 5748e8d4e2..2389e9c141 100644 --- a/api/core/tools/provider/builtin/stablediffusion/stablediffusion.py +++ b/api/core/tools/provider/builtin/stablediffusion/stablediffusion.py @@ -1,5 +1,6 @@ from typing import Any +from core.tools.entities.values import ToolLabelEnum from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin.stablediffusion.tools.stable_diffusion import StableDiffusionTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController @@ -9,9 +10,14 @@ class StableDiffusionProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: StableDiffusionTool().fork_tool_runtime( - meta={ + runtime={ "credentials": credentials, } ).validate_models() except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + + def _get_tool_labels(self) -> list[ToolLabelEnum]: + return [ + ToolLabelEnum.IMAGE + ] \ No newline at end of file diff --git a/api/core/tools/provider/builtin/stackexchange/stackexchange.py b/api/core/tools/provider/builtin/stackexchange/stackexchange.py index fab543c580..34e2a47b86 100644 --- a/api/core/tools/provider/builtin/stackexchange/stackexchange.py +++ b/api/core/tools/provider/builtin/stackexchange/stackexchange.py @@ -1,3 +1,4 @@ +from core.tools.entities.values import ToolLabelEnum from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin.stackexchange.tools.searchStackExQuestions import SearchStackExQuestionsTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController @@ -7,7 +8,7 @@ class StackExchangeProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: SearchStackExQuestionsTool().fork_tool_runtime( - meta={ + runtime={ "credentials": credentials, } ).invoke( @@ -22,4 +23,9 @@ class StackExchangeProvider(BuiltinToolProviderController): }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + + def _get_tool_labels(self) -> list[ToolLabelEnum]: + return [ + ToolLabelEnum.SEARCH, ToolLabelEnum.UTILITIES + ] \ No newline at end of file diff --git a/api/core/tools/provider/builtin/tavily/tavily.py b/api/core/tools/provider/builtin/tavily/tavily.py index 575d9268b9..b4e7e78ed3 100644 --- a/api/core/tools/provider/builtin/tavily/tavily.py +++ b/api/core/tools/provider/builtin/tavily/tavily.py @@ -1,5 +1,6 @@ from typing import Any +from core.tools.entities.values import ToolLabelEnum from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin.tavily.tools.tavily_search import TavilySearchTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController @@ -9,7 +10,7 @@ class TavilyProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: TavilySearchTool().fork_tool_runtime( - meta={ + runtime={ "credentials": credentials, } ).invoke( @@ -26,4 +27,9 @@ class TavilyProvider(BuiltinToolProviderController): }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + + def _get_tool_labels(self) -> list[ToolLabelEnum]: + return [ + ToolLabelEnum.SEARCH + ] \ No newline at end of file diff --git a/api/core/tools/provider/builtin/time/time.py b/api/core/tools/provider/builtin/time/time.py index 0d3285f495..ce6e67a759 100644 --- a/api/core/tools/provider/builtin/time/time.py +++ b/api/core/tools/provider/builtin/time/time.py @@ -1,5 +1,6 @@ from typing import Any +from core.tools.entities.values import ToolLabelEnum from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin.time.tools.current_time import CurrentTimeTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController @@ -13,4 +14,9 @@ class WikiPediaProvider(BuiltinToolProviderController): tool_parameters={}, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + + def _get_tool_labels(self) -> list[ToolLabelEnum]: + return [ + ToolLabelEnum.UTILITIES + ] \ No newline at end of file diff --git a/api/core/tools/provider/builtin/trello/trello.py b/api/core/tools/provider/builtin/trello/trello.py index d27115d246..eb64714d14 100644 --- a/api/core/tools/provider/builtin/trello/trello.py +++ b/api/core/tools/provider/builtin/trello/trello.py @@ -2,6 +2,7 @@ from typing import Any import requests +from core.tools.entities.values import ToolLabelEnum from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController @@ -31,4 +32,9 @@ class TrelloProvider(BuiltinToolProviderController): raise ToolProviderCredentialValidationError("Error validating Trello credentials") except requests.exceptions.RequestException as e: # Handle other exceptions, such as connection errors - raise ToolProviderCredentialValidationError("Error validating Trello credentials") \ No newline at end of file + raise ToolProviderCredentialValidationError("Error validating Trello credentials") + + def _get_tool_labels(self) -> list[ToolLabelEnum]: + return [ + ToolLabelEnum.PRODUCTIVITY + ] \ No newline at end of file diff --git a/api/core/tools/provider/builtin/twilio/twilio.py b/api/core/tools/provider/builtin/twilio/twilio.py index 7984d7b3b1..a881efb940 100644 --- a/api/core/tools/provider/builtin/twilio/twilio.py +++ b/api/core/tools/provider/builtin/twilio/twilio.py @@ -3,6 +3,7 @@ from typing import Any from twilio.base.exceptions import TwilioRestException from twilio.rest import Client +from core.tools.entities.values import ToolLabelEnum from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController @@ -26,4 +27,9 @@ class TwilioProvider(BuiltinToolProviderController): except KeyError as e: raise ToolProviderCredentialValidationError(f"Missing required credential: {e}") from e except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + + def _get_tool_labels(self) -> list[ToolLabelEnum]: + return [ + ToolLabelEnum.SOCIAL + ] \ No newline at end of file diff --git a/api/core/tools/provider/builtin/vectorizer/vectorizer.py b/api/core/tools/provider/builtin/vectorizer/vectorizer.py index 2b4d71e058..73ddc14bff 100644 --- a/api/core/tools/provider/builtin/vectorizer/vectorizer.py +++ b/api/core/tools/provider/builtin/vectorizer/vectorizer.py @@ -1,5 +1,6 @@ from typing import Any +from core.tools.entities.values import ToolLabelEnum from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin.vectorizer.tools.vectorizer import VectorizerTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController @@ -9,7 +10,7 @@ class VectorizerProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: VectorizerTool().fork_tool_runtime( - meta={ + runtime={ "credentials": credentials, } ).invoke( @@ -20,4 +21,9 @@ class VectorizerProvider(BuiltinToolProviderController): }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + + def _get_tool_labels(self) -> list[ToolLabelEnum]: + return [ + ToolLabelEnum.PRODUCTIVITY, ToolLabelEnum.IMAGE + ] \ No newline at end of file diff --git a/api/core/tools/provider/builtin/webscraper/webscraper.py b/api/core/tools/provider/builtin/webscraper/webscraper.py index 8761493e3b..df7d0fd549 100644 --- a/api/core/tools/provider/builtin/webscraper/webscraper.py +++ b/api/core/tools/provider/builtin/webscraper/webscraper.py @@ -1,5 +1,6 @@ from typing import Any +from core.tools.entities.values import ToolLabelEnum from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin.webscraper.tools.webscraper import WebscraperTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController @@ -9,7 +10,7 @@ class WebscraperProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: WebscraperTool().fork_tool_runtime( - meta={ + runtime={ "credentials": credentials, } ).invoke( @@ -20,4 +21,9 @@ class WebscraperProvider(BuiltinToolProviderController): }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + + def _get_tool_labels(self) -> list[ToolLabelEnum]: + return [ + ToolLabelEnum.PRODUCTIVITY + ] \ No newline at end of file diff --git a/api/core/tools/provider/builtin/wecom/wecom.py b/api/core/tools/provider/builtin/wecom/wecom.py index 7a2576b668..e60b64f21c 100644 --- a/api/core/tools/provider/builtin/wecom/wecom.py +++ b/api/core/tools/provider/builtin/wecom/wecom.py @@ -1,3 +1,4 @@ +from core.tools.entities.values import ToolLabelEnum from core.tools.provider.builtin.wecom.tools.wecom_group_bot import WecomGroupBotTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController @@ -6,3 +7,8 @@ class WecomProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: WecomGroupBotTool() pass + + def _get_tool_labels(self) -> list[ToolLabelEnum]: + return [ + ToolLabelEnum.SOCIAL + ] \ No newline at end of file diff --git a/api/core/tools/provider/builtin/wikipedia/wikipedia.py b/api/core/tools/provider/builtin/wikipedia/wikipedia.py index 8d53852255..4153a6871c 100644 --- a/api/core/tools/provider/builtin/wikipedia/wikipedia.py +++ b/api/core/tools/provider/builtin/wikipedia/wikipedia.py @@ -1,3 +1,4 @@ +from core.tools.entities.values import ToolLabelEnum from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin.wikipedia.tools.wikipedia_search import WikiPediaSearchTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController @@ -7,7 +8,7 @@ class WikiPediaProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: WikiPediaSearchTool().fork_tool_runtime( - meta={ + runtime={ "credentials": credentials, } ).invoke( @@ -17,4 +18,9 @@ class WikiPediaProvider(BuiltinToolProviderController): }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + + def _get_tool_labels(self) -> list[ToolLabelEnum]: + return [ + ToolLabelEnum.SEARCH + ] \ No newline at end of file diff --git a/api/core/tools/provider/builtin/wolframalpha/wolframalpha.py b/api/core/tools/provider/builtin/wolframalpha/wolframalpha.py index 4e8213d90c..86b0b65d27 100644 --- a/api/core/tools/provider/builtin/wolframalpha/wolframalpha.py +++ b/api/core/tools/provider/builtin/wolframalpha/wolframalpha.py @@ -1,5 +1,6 @@ from typing import Any +from core.tools.entities.values import ToolLabelEnum from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin.wolframalpha.tools.wolframalpha import WolframAlphaTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController @@ -9,7 +10,7 @@ class GoogleProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: WolframAlphaTool().fork_tool_runtime( - meta={ + runtime={ "credentials": credentials, } ).invoke( @@ -19,4 +20,9 @@ class GoogleProvider(BuiltinToolProviderController): }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + + def _get_tool_labels(self) -> list[ToolLabelEnum]: + return [ + ToolLabelEnum.PRODUCTIVITY, ToolLabelEnum.UTILITIES + ] \ No newline at end of file diff --git a/api/core/tools/provider/builtin/yahoo/yahoo.py b/api/core/tools/provider/builtin/yahoo/yahoo.py index ade33ffb63..7b5bf7f307 100644 --- a/api/core/tools/provider/builtin/yahoo/yahoo.py +++ b/api/core/tools/provider/builtin/yahoo/yahoo.py @@ -1,3 +1,4 @@ +from core.tools.entities.values import ToolLabelEnum from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin.yahoo.tools.ticker import YahooFinanceSearchTickerTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController @@ -7,7 +8,7 @@ class YahooFinanceProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: YahooFinanceSearchTickerTool().fork_tool_runtime( - meta={ + runtime={ "credentials": credentials, } ).invoke( @@ -17,4 +18,9 @@ class YahooFinanceProvider(BuiltinToolProviderController): }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + + def _get_tool_labels(self) -> list[ToolLabelEnum]: + return [ + ToolLabelEnum.BUSINESS, ToolLabelEnum.FINANCE + ] \ No newline at end of file diff --git a/api/core/tools/provider/builtin/youtube/youtube.py b/api/core/tools/provider/builtin/youtube/youtube.py index 8cca578c46..37d24b87eb 100644 --- a/api/core/tools/provider/builtin/youtube/youtube.py +++ b/api/core/tools/provider/builtin/youtube/youtube.py @@ -1,3 +1,4 @@ +from core.tools.entities.values import ToolLabelEnum from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin.youtube.tools.videos import YoutubeVideosAnalyticsTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController @@ -7,7 +8,7 @@ class YahooFinanceProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: YoutubeVideosAnalyticsTool().fork_tool_runtime( - meta={ + runtime={ "credentials": credentials, } ).invoke( @@ -19,4 +20,9 @@ class YahooFinanceProvider(BuiltinToolProviderController): }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) + + def _get_tool_labels(self) -> list[ToolLabelEnum]: + return [ + ToolLabelEnum.VIDEOS + ] \ No newline at end of file diff --git a/api/core/tools/provider/builtin_tool_provider.py b/api/core/tools/provider/builtin_tool_provider.py index 76ee473beb..3dc2a65de3 100644 --- a/api/core/tools/provider/builtin_tool_provider.py +++ b/api/core/tools/provider/builtin_tool_provider.py @@ -2,8 +2,9 @@ from abc import abstractmethod from os import listdir, path from typing import Any +from core.tools.entities.api_entities import UserToolProviderCredentials from core.tools.entities.tool_entities import ToolParameter, ToolProviderCredentials, ToolProviderType -from core.tools.entities.user_entities import UserToolProviderCredentials +from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict from core.tools.errors import ( ToolNotFoundError, ToolParameterValidationError, @@ -19,7 +20,7 @@ from core.utils.module_import_helper import load_single_subclass_from_source class BuiltinToolProviderController(ToolProviderController): def __init__(self, **data: Any) -> None: - if self.app_type == ToolProviderType.API_BASED or self.app_type == ToolProviderType.APP_BASED: + if self.provider_type == ToolProviderType.API or self.provider_type == ToolProviderType.APP: super().__init__(**data) return @@ -129,13 +130,29 @@ class BuiltinToolProviderController(ToolProviderController): len(self.credentials_schema) != 0 @property - def app_type(self) -> ToolProviderType: + def provider_type(self) -> ToolProviderType: """ returns the type of the provider :return: type of the provider """ return ToolProviderType.BUILT_IN + + @property + def tool_labels(self) -> list[str]: + """ + returns the labels of the provider + + :return: labels of the provider + """ + label_enums = self._get_tool_labels() + return [default_tool_label_dict[label].name for label in label_enums] + + def _get_tool_labels(self) -> list[ToolLabelEnum]: + """ + returns the labels of the provider + """ + return [] def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None: """ diff --git a/api/core/tools/provider/tool_provider.py b/api/core/tools/provider/tool_provider.py index b527f2b274..0dcb8ca58c 100644 --- a/api/core/tools/provider/tool_provider.py +++ b/api/core/tools/provider/tool_provider.py @@ -3,13 +3,13 @@ from typing import Any, Optional from pydantic import BaseModel +from core.tools.entities.api_entities import UserToolProviderCredentials from core.tools.entities.tool_entities import ( ToolParameter, ToolProviderCredentials, ToolProviderIdentity, ToolProviderType, ) -from core.tools.entities.user_entities import UserToolProviderCredentials from core.tools.errors import ToolNotFoundError, ToolParameterValidationError, ToolProviderCredentialValidationError from core.tools.tool.tool import Tool @@ -67,7 +67,7 @@ class ToolProviderController(BaseModel, ABC): return tool.parameters @property - def app_type(self) -> ToolProviderType: + def provider_type(self) -> ToolProviderType: """ returns the type of the provider @@ -197,26 +197,4 @@ class ToolProviderController(BaseModel, ABC): default_value = str(default_value) credentials[credential_name] = default_value - - def validate_credentials(self, credentials: dict[str, Any]) -> None: - """ - validate the credentials of the provider - - :param tool_name: the name of the tool, defined in `get_tools` - :param credentials: the credentials of the tool - """ - # validate credentials format - self.validate_credentials_format(credentials) - - # validate credentials - self._validate_credentials(credentials) - - @abstractmethod - def _validate_credentials(self, credentials: dict[str, Any]) -> None: - """ - validate the credentials of the provider - - :param tool_name: the name of the tool, defined in `get_tools` - :param credentials: the credentials of the tool - """ - pass \ No newline at end of file + \ No newline at end of file diff --git a/api/core/tools/provider/workflow_tool_provider.py b/api/core/tools/provider/workflow_tool_provider.py new file mode 100644 index 0000000000..f98ad0f26a --- /dev/null +++ b/api/core/tools/provider/workflow_tool_provider.py @@ -0,0 +1,230 @@ +from typing import Optional + +from core.app.app_config.entities import VariableEntity +from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager +from core.model_runtime.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ( + ToolDescription, + ToolIdentity, + ToolParameter, + ToolParameterOption, + ToolProviderType, +) +from core.tools.provider.tool_provider import ToolProviderController +from core.tools.tool.workflow_tool import WorkflowTool +from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils +from extensions.ext_database import db +from models.model import App, AppMode +from models.tools import WorkflowToolProvider +from models.workflow import Workflow + + +class WorkflowToolProviderController(ToolProviderController): + provider_id: str + + @classmethod + def from_db(cls, db_provider: WorkflowToolProvider) -> 'WorkflowToolProviderController': + app = db_provider.app + + if not app: + raise ValueError('app not found') + + controller = WorkflowToolProviderController(**{ + 'identity': { + 'author': db_provider.user.name if db_provider.user_id and db_provider.user else '', + 'name': db_provider.label, + 'label': { + 'en_US': db_provider.label, + 'zh_Hans': db_provider.label + }, + 'description': { + 'en_US': db_provider.description, + 'zh_Hans': db_provider.description + }, + 'icon': db_provider.icon, + }, + 'credentials_schema': {}, + 'provider_id': db_provider.id or '', + }) + + # init tools + + controller.tools = [controller._get_db_provider_tool(db_provider, app)] + + return controller + + @property + def provider_type(self) -> ToolProviderType: + return ToolProviderType.WORKFLOW + + def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool: + """ + get db provider tool + :param db_provider: the db provider + :param app: the app + :return: the tool + """ + workflow: Workflow = db.session.query(Workflow).filter( + Workflow.app_id == db_provider.app_id, + Workflow.version == db_provider.version + ).first() + if not workflow: + raise ValueError('workflow not found') + + # fetch start node + graph: dict = workflow.graph_dict + features_dict: dict = workflow.features_dict + features = WorkflowAppConfigManager.convert_features( + config_dict=features_dict, + app_mode=AppMode.WORKFLOW + ) + + parameters = db_provider.parameter_configurations + variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph) + + def fetch_workflow_variable(variable_name: str) -> VariableEntity: + return next(filter(lambda x: x.variable == variable_name, variables), None) + + user = db_provider.user + + workflow_tool_parameters = [] + for parameter in parameters: + variable = fetch_workflow_variable(parameter.name) + if variable: + parameter_type = None + options = None + if variable.type in [ + VariableEntity.Type.TEXT_INPUT, + VariableEntity.Type.PARAGRAPH, + ]: + parameter_type = ToolParameter.ToolParameterType.STRING + elif variable.type in [ + VariableEntity.Type.SELECT + ]: + parameter_type = ToolParameter.ToolParameterType.SELECT + elif variable.type in [ + VariableEntity.Type.NUMBER + ]: + parameter_type = ToolParameter.ToolParameterType.NUMBER + else: + raise ValueError(f'unsupported variable type {variable.type}') + + if variable.type == VariableEntity.Type.SELECT and variable.options: + options = [ + ToolParameterOption( + value=option, + label=I18nObject( + en_US=option, + zh_Hans=option + ) + ) for option in variable.options + ] + + workflow_tool_parameters.append( + ToolParameter( + name=parameter.name, + label=I18nObject( + en_US=variable.label, + zh_Hans=variable.label + ), + human_description=I18nObject( + en_US=parameter.description, + zh_Hans=parameter.description + ), + type=parameter_type, + form=parameter.form, + llm_description=parameter.description, + required=variable.required, + options=options, + default=variable.default + ) + ) + elif features.file_upload: + workflow_tool_parameters.append( + ToolParameter( + name=parameter.name, + label=I18nObject( + en_US=parameter.name, + zh_Hans=parameter.name + ), + human_description=I18nObject( + en_US=parameter.description, + zh_Hans=parameter.description + ), + type=ToolParameter.ToolParameterType.FILE, + llm_description=parameter.description, + required=False, + form=parameter.form, + ) + ) + else: + raise ValueError('variable not found') + + return WorkflowTool( + identity=ToolIdentity( + author=user.name if user else '', + name=db_provider.name, + label=I18nObject( + en_US=db_provider.label, + zh_Hans=db_provider.label + ), + provider=self.provider_id, + icon=db_provider.icon, + ), + description=ToolDescription( + human=I18nObject( + en_US=db_provider.description, + zh_Hans=db_provider.description + ), + llm=db_provider.description, + ), + parameters=workflow_tool_parameters, + is_team_authorization=True, + workflow_app_id=app.id, + workflow_entities={ + 'app': app, + 'workflow': workflow, + }, + version=db_provider.version, + workflow_call_depth=0, + label=db_provider.label + ) + + def get_tools(self, user_id: str, tenant_id: str) -> list[WorkflowTool]: + """ + fetch tools from database + + :param user_id: the user id + :param tenant_id: the tenant id + :return: the tools + """ + if self.tools is not None: + return self.tools + + db_providers: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( + WorkflowToolProvider.tenant_id == tenant_id, + WorkflowToolProvider.app_id == self.provider_id, + ).first() + + if not db_providers: + return [] + + self.tools = [self._get_db_provider_tool(db_providers, db_providers.app)] + + return self.tools + + def get_tool(self, tool_name: str) -> Optional[WorkflowTool]: + """ + get tool by name + + :param tool_name: the name of the tool + :return: the tool + """ + if self.tools is None: + return None + + for tool in self.tools: + if tool.identity.name == tool_name: + return tool + + return None diff --git a/api/core/tools/tool/api_tool.py b/api/core/tools/tool/api_tool.py index f7b963a92e..ff7d4015ab 100644 --- a/api/core/tools/tool/api_tool.py +++ b/api/core/tools/tool/api_tool.py @@ -8,9 +8,8 @@ import httpx import requests import core.helper.ssrf_proxy as ssrf_proxy -from core.tools.entities.tool_bundle import ApiBasedToolBundle +from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType -from core.tools.entities.user_entities import UserToolProvider from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError from core.tools.tool.tool import Tool @@ -20,12 +19,12 @@ API_TOOL_DEFAULT_TIMEOUT = ( ) class ApiTool(Tool): - api_bundle: ApiBasedToolBundle + api_bundle: ApiToolBundle """ Api tool """ - def fork_tool_runtime(self, meta: dict[str, Any]) -> 'Tool': + def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'Tool': """ fork a new tool with meta data @@ -37,7 +36,7 @@ class ApiTool(Tool): parameters=self.parameters.copy() if self.parameters else None, description=self.description.copy() if self.description else None, api_bundle=self.api_bundle.copy() if self.api_bundle else None, - runtime=Tool.Runtime(**meta) + runtime=Tool.Runtime(**runtime) ) def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False) -> str: @@ -55,7 +54,7 @@ class ApiTool(Tool): return self.validate_and_parse_response(response) def tool_provider_type(self) -> ToolProviderType: - return UserToolProvider.ProviderType.API + return ToolProviderType.API def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]: headers = {} diff --git a/api/core/tools/tool/builtin_tool.py b/api/core/tools/tool/builtin_tool.py index 68193e5f69..1ed05c112f 100644 --- a/api/core/tools/tool/builtin_tool.py +++ b/api/core/tools/tool/builtin_tool.py @@ -2,9 +2,8 @@ from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage from core.tools.entities.tool_entities import ToolProviderType -from core.tools.entities.user_entities import UserToolProvider -from core.tools.model.tool_model_manager import ToolModelManager from core.tools.tool.tool import Tool +from core.tools.utils.model_invocation_utils import ModelInvocationUtils from core.tools.utils.web_reader_tool import get_url _SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language @@ -34,7 +33,7 @@ class BuiltinTool(Tool): :return: the model result """ # invoke model - return ToolModelManager.invoke( + return ModelInvocationUtils.invoke( user_id=user_id, tenant_id=self.runtime.tenant_id, tool_type='builtin', @@ -43,7 +42,7 @@ class BuiltinTool(Tool): ) def tool_provider_type(self) -> ToolProviderType: - return UserToolProvider.ProviderType.BUILTIN + return ToolProviderType.BUILT_IN def get_max_tokens(self) -> int: """ @@ -52,7 +51,7 @@ class BuiltinTool(Tool): :param model_config: the model config :return: the max tokens """ - return ToolModelManager.get_max_llm_context_tokens( + return ModelInvocationUtils.get_max_llm_context_tokens( tenant_id=self.runtime.tenant_id, ) @@ -63,7 +62,7 @@ class BuiltinTool(Tool): :param prompt_messages: the prompt messages :return: the tokens """ - return ToolModelManager.calculate_tokens( + return ModelInvocationUtils.calculate_tokens( tenant_id=self.runtime.tenant_id, prompt_messages=prompt_messages ) diff --git a/api/core/tools/tool/tool.py b/api/core/tools/tool/tool.py index 03aa0623fe..9de72dfe3a 100644 --- a/api/core/tools/tool/tool.py +++ b/api/core/tools/tool/tool.py @@ -4,9 +4,12 @@ from typing import Any, Optional, Union from pydantic import BaseModel, validator +from core.app.entities.app_invoke_entities import InvokeFrom +from core.file.file_obj import FileVar from core.tools.entities.tool_entities import ( ToolDescription, ToolIdentity, + ToolInvokeFrom, ToolInvokeMessage, ToolParameter, ToolProviderType, @@ -25,10 +28,7 @@ class Tool(BaseModel, ABC): @validator('parameters', pre=True, always=True) def set_parameters(cls, v, values): - if not v: - return [] - - return v + return v or [] class Runtime(BaseModel): """ @@ -41,6 +41,8 @@ class Tool(BaseModel, ABC): tenant_id: str = None tool_id: str = None + invoke_from: InvokeFrom = None + tool_invoke_from: ToolInvokeFrom = None credentials: dict[str, Any] = None runtime_parameters: dict[str, Any] = None @@ -53,7 +55,7 @@ class Tool(BaseModel, ABC): class VARIABLE_KEY(Enum): IMAGE = 'image' - def fork_tool_runtime(self, meta: dict[str, Any]) -> 'Tool': + def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'Tool': """ fork a new tool with meta data @@ -64,7 +66,7 @@ class Tool(BaseModel, ABC): identity=self.identity.copy() if self.identity else None, parameters=self.parameters.copy() if self.parameters else None, description=self.description.copy() if self.description else None, - runtime=Tool.Runtime(**meta), + runtime=Tool.Runtime(**runtime), ) @abstractmethod @@ -208,17 +210,17 @@ class Tool(BaseModel, ABC): if response.type == ToolInvokeMessage.MessageType.TEXT: result += response.message elif response.type == ToolInvokeMessage.MessageType.LINK: - result += f"result link: {response.message}. please tell user to check it." + result += f"result link: {response.message}. please tell user to check it. \n" elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ response.type == ToolInvokeMessage.MessageType.IMAGE: - result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now." + result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now. \n" elif response.type == ToolInvokeMessage.MessageType.BLOB: if len(response.message) > 114: result += str(response.message[:114]) + '...' else: result += str(response.message) else: - result += f"tool response: {response.message}." + result += f"tool response: {response.message}. \n" return result @@ -343,6 +345,14 @@ class Tool(BaseModel, ABC): message=image, save_as=save_as) + def create_file_var_message(self, file_var: FileVar) -> ToolInvokeMessage: + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE_VAR, + message='', + meta={ + 'file_var': file_var + }, + save_as='') + def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage: """ create a link message diff --git a/api/core/tools/tool/workflow_tool.py b/api/core/tools/tool/workflow_tool.py new file mode 100644 index 0000000000..7da3d09d4c --- /dev/null +++ b/api/core/tools/tool/workflow_tool.py @@ -0,0 +1,200 @@ +import json +import logging +from copy import deepcopy +from typing import Any, Union + +from core.file.file_obj import FileTransferMethod, FileVar +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType +from core.tools.tool.tool import Tool +from extensions.ext_database import db +from models.account import Account +from models.model import App, EndUser +from models.workflow import Workflow + +logger = logging.getLogger(__name__) + +class WorkflowTool(Tool): + workflow_app_id: str + version: str + workflow_entities: dict[str, Any] + workflow_call_depth: int + + label: str + + """ + Workflow tool. + """ + def tool_provider_type(self) -> ToolProviderType: + """ + get the tool provider type + + :return: the tool provider type + """ + return ToolProviderType.WORKFLOW + + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \ + -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke the tool + """ + app = self._get_app(app_id=self.workflow_app_id) + workflow = self._get_workflow(app_id=self.workflow_app_id, version=self.version) + + # transform the tool parameters + tool_parameters, files = self._transform_args(tool_parameters) + + from core.app.apps.workflow.app_generator import WorkflowAppGenerator + generator = WorkflowAppGenerator() + result = generator.generate( + app_model=app, + workflow=workflow, + user=self._get_user(user_id), + args={ + 'inputs': tool_parameters, + 'files': files + }, + invoke_from=self.runtime.invoke_from, + stream=False, + call_depth=self.workflow_call_depth + 1, + ) + + data = result.get('data', {}) + + if data.get('error'): + raise Exception(data.get('error')) + + result = [] + + outputs = data.get('outputs', {}) + outputs, files = self._extract_files(outputs) + for file in files: + result.append(self.create_file_var_message(file)) + + result.append(self.create_text_message(json.dumps(outputs))) + + return result + + def _get_user(self, user_id: str) -> Union[EndUser, Account]: + """ + get the user by user id + """ + + user = db.session.query(EndUser).filter(EndUser.id == user_id).first() + if not user: + user = db.session.query(Account).filter(Account.id == user_id).first() + + if not user: + raise ValueError('user not found') + + return user + + def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'WorkflowTool': + """ + fork a new tool with meta data + + :param meta: the meta data of a tool call processing, tenant_id is required + :return: the new tool + """ + return self.__class__( + identity=deepcopy(self.identity), + parameters=deepcopy(self.parameters), + description=deepcopy(self.description), + runtime=Tool.Runtime(**runtime), + workflow_app_id=self.workflow_app_id, + workflow_entities=self.workflow_entities, + workflow_call_depth=self.workflow_call_depth, + version=self.version, + label=self.label + ) + + def _get_workflow(self, app_id: str, version: str) -> Workflow: + """ + get the workflow by app id and version + """ + if not version: + workflow = db.session.query(Workflow).filter( + Workflow.app_id == app_id, + Workflow.version != 'draft' + ).order_by(Workflow.created_at.desc()).first() + else: + workflow = db.session.query(Workflow).filter( + Workflow.app_id == app_id, + Workflow.version == version + ).first() + + if not workflow: + raise ValueError('workflow not found or not published') + + return workflow + + def _get_app(self, app_id: str) -> App: + """ + get the app by app id + """ + app = db.session.query(App).filter(App.id == app_id).first() + if not app: + raise ValueError('app not found') + + return app + + def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]: + """ + transform the tool parameters + + :param tool_parameters: the tool parameters + :return: tool_parameters, files + """ + parameter_rules = self.get_all_runtime_parameters() + parameters_result = {} + files = [] + for parameter in parameter_rules: + if parameter.type == ToolParameter.ToolParameterType.FILE: + file = tool_parameters.get(parameter.name) + if file: + try: + file_var_list = [FileVar(**f) for f in file] + for file_var in file_var_list: + file_dict = { + 'transfer_method': file_var.transfer_method.value, + 'type': file_var.type.value, + } + if file_var.transfer_method == FileTransferMethod.TOOL_FILE: + file_dict['tool_file_id'] = file_var.related_id + elif file_var.transfer_method == FileTransferMethod.LOCAL_FILE: + file_dict['upload_file_id'] = file_var.related_id + elif file_var.transfer_method == FileTransferMethod.REMOTE_URL: + file_dict['url'] = file_var.preview_url + + files.append(file_dict) + except Exception as e: + logger.exception(e) + else: + parameters_result[parameter.name] = tool_parameters.get(parameter.name) + + return parameters_result, files + + def _extract_files(self, outputs: dict) -> tuple[dict, list[FileVar]]: + """ + extract files from the result + + :param result: the result + :return: the result, files + """ + files = [] + result = {} + for key, value in outputs.items(): + if isinstance(value, list): + has_file = False + for item in value: + if isinstance(item, dict) and item.get('__variant') == 'FileVar': + try: + files.append(FileVar(**item)) + has_file = True + except Exception as e: + pass + if has_file: + continue + + result[key] = value + + return result, files \ No newline at end of file diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index f96d7940bd..16fe9051e3 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -1,7 +1,10 @@ from copy import deepcopy from datetime import datetime, timezone +from mimetypes import guess_type from typing import Union +from yarl import URL + from core.app.entities.app_invoke_entities import InvokeFrom from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler @@ -17,6 +20,7 @@ from core.tools.errors import ( ToolProviderNotFoundError, ) from core.tools.tool.tool import Tool +from core.tools.tool.workflow_tool import WorkflowTool from core.tools.utils.message_transformer import ToolFileMessageTransformer from extensions.ext_database import db from models.model import Message, MessageFile @@ -115,7 +119,8 @@ class ToolEngine: @staticmethod def workflow_invoke(tool: Tool, tool_parameters: dict, user_id: str, workflow_id: str, - workflow_tool_callback: DifyWorkflowCallbackHandler) \ + workflow_tool_callback: DifyWorkflowCallbackHandler, + workflow_call_depth: int) \ -> list[ToolInvokeMessage]: """ Workflow invokes the tool with the given arguments. @@ -127,6 +132,9 @@ class ToolEngine: tool_inputs=tool_parameters ) + if isinstance(tool, WorkflowTool): + tool.workflow_call_depth = workflow_call_depth + 1 + response = tool.invoke(user_id, tool_parameters) # hit the callback handler @@ -195,8 +203,24 @@ class ToolEngine: for response in tool_response: if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ response.type == ToolInvokeMessage.MessageType.IMAGE: + mimetype = None + if response.meta.get('mime_type'): + mimetype = response.meta.get('mime_type') + else: + try: + url = URL(response.message) + extension = url.suffix + guess_type_result, _ = guess_type(f'a{extension}') + if guess_type_result: + mimetype = guess_type_result + except Exception: + pass + + if not mimetype: + mimetype = 'image/jpeg' + result.append(ToolInvokeMessageBinary( - mimetype=response.meta.get('mime_type', 'octet/stream'), + mimetype=response.meta.get('mime_type', 'image/jpeg'), url=response.message, save_as=response.save_as, )) diff --git a/api/core/tools/tool_label_manager.py b/api/core/tools/tool_label_manager.py new file mode 100644 index 0000000000..97788a7a07 --- /dev/null +++ b/api/core/tools/tool_label_manager.py @@ -0,0 +1,96 @@ +from core.tools.entities.values import default_tool_label_name_list +from core.tools.provider.api_tool_provider import ApiToolProviderController +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from core.tools.provider.tool_provider import ToolProviderController +from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController +from extensions.ext_database import db +from models.tools import ToolLabelBinding + + +class ToolLabelManager: + @classmethod + def filter_tool_labels(cls, tool_labels: list[str]) -> list[str]: + """ + Filter tool labels + """ + tool_labels = [label for label in tool_labels if label in default_tool_label_name_list] + return list(set(tool_labels)) + + @classmethod + def update_tool_labels(cls, controller: ToolProviderController, labels: list[str]): + """ + Update tool labels + """ + labels = cls.filter_tool_labels(labels) + + if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): + provider_id = controller.provider_id + else: + raise ValueError('Unsupported tool type') + + # delete old labels + db.session.query(ToolLabelBinding).filter( + ToolLabelBinding.tool_id == provider_id + ).delete() + + # insert new labels + for label in labels: + db.session.add(ToolLabelBinding( + tool_id=provider_id, + tool_type=controller.provider_type.value, + label_name=label, + )) + + db.session.commit() + + @classmethod + def get_tool_labels(cls, controller: ToolProviderController) -> list[str]: + """ + Get tool labels + """ + if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): + provider_id = controller.provider_id + elif isinstance(controller, BuiltinToolProviderController): + return controller.tool_labels + else: + raise ValueError('Unsupported tool type') + + labels: list[ToolLabelBinding] = db.session.query(ToolLabelBinding.label_name).filter( + ToolLabelBinding.tool_id == provider_id, + ToolLabelBinding.tool_type == controller.provider_type.value, + ).all() + + return [label.label_name for label in labels] + + @classmethod + def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]: + """ + Get tools labels + + :param tool_providers: list of tool providers + + :return: dict of tool labels + :key: tool id + :value: list of tool labels + """ + if not tool_providers: + return {} + + for controller in tool_providers: + if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): + raise ValueError('Unsupported tool type') + + provider_ids = [controller.provider_id for controller in tool_providers] + + labels: list[ToolLabelBinding] = db.session.query(ToolLabelBinding).filter( + ToolLabelBinding.tool_id.in_(provider_ids) + ).all() + + tool_labels = { + label.tool_id: [] for label in labels + } + + for label in labels: + tool_labels[label.tool_id].append(label.label_name) + + return tool_labels \ No newline at end of file diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index d544087594..5ae5b6622d 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -9,21 +9,24 @@ from typing import Any, Union from flask import current_app from core.agent.entities import AgentToolEntity +from core.app.entities.app_invoke_entities import InvokeFrom from core.model_runtime.utils.encoders import jsonable_encoder from core.tools import * +from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ( ApiProviderAuthType, + ToolInvokeFrom, ToolParameter, ) -from core.tools.entities.user_entities import UserToolProvider from core.tools.errors import ToolProviderNotFoundError -from core.tools.provider.api_tool_provider import ApiBasedToolProviderController +from core.tools.provider.api_tool_provider import ApiToolProviderController from core.tools.provider.builtin._positions import BuiltinToolProviderSort from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController from core.tools.tool.api_tool import ApiTool from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.tool import Tool +from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.configuration import ( ToolConfigurationManager, ToolParameterConfigurationManager, @@ -31,8 +34,8 @@ from core.tools.utils.configuration import ( from core.utils.module_import_helper import load_single_subclass_from_source from core.workflow.nodes.tool.entities import ToolEntity from extensions.ext_database import db -from models.tools import ApiToolProvider, BuiltinToolProvider -from services.tools_transform_service import ToolTransformService +from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider +from services.tools.tools_transform_service import ToolTransformService logger = logging.getLogger(__name__) @@ -99,7 +102,12 @@ class ToolManager: raise ToolProviderNotFoundError(f'provider type {provider_type} not found') @classmethod - def get_tool_runtime(cls, provider_type: str, provider_name: str, tool_name: str, tenant_id: str) \ + def get_tool_runtime(cls, provider_type: str, + provider_id: str, + tool_name: str, + tenant_id: str, + invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, + tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \ -> Union[BuiltinTool, ApiTool]: """ get the tool runtime @@ -111,51 +119,76 @@ class ToolManager: :return: the tool """ if provider_type == 'builtin': - builtin_tool = cls.get_builtin_tool(provider_name, tool_name) + builtin_tool = cls.get_builtin_tool(provider_id, tool_name) # check if the builtin tool need credentials - provider_controller = cls.get_builtin_provider(provider_name) + provider_controller = cls.get_builtin_provider(provider_id) if not provider_controller.need_credentials: - return builtin_tool.fork_tool_runtime(meta={ + return builtin_tool.fork_tool_runtime(runtime={ 'tenant_id': tenant_id, 'credentials': {}, + 'invoke_from': invoke_from, + 'tool_invoke_from': tool_invoke_from, }) # get credentials builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider_name, + BuiltinToolProvider.provider == provider_id, ).first() if builtin_provider is None: - raise ToolProviderNotFoundError(f'builtin provider {provider_name} not found') + raise ToolProviderNotFoundError(f'builtin provider {provider_id} not found') # decrypt the credentials credentials = builtin_provider.credentials - controller = cls.get_builtin_provider(provider_name) + controller = cls.get_builtin_provider(provider_id) tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller) decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) - return builtin_tool.fork_tool_runtime(meta={ + return builtin_tool.fork_tool_runtime(runtime={ 'tenant_id': tenant_id, 'credentials': decrypted_credentials, - 'runtime_parameters': {} + 'runtime_parameters': {}, + 'invoke_from': invoke_from, + 'tool_invoke_from': tool_invoke_from, }) elif provider_type == 'api': if tenant_id is None: raise ValueError('tenant id is required for api provider') - api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_name) + api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id) # decrypt the credentials tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=api_provider) decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) - return api_provider.get_tool(tool_name).fork_tool_runtime(meta={ + return api_provider.get_tool(tool_name).fork_tool_runtime(runtime={ 'tenant_id': tenant_id, 'credentials': decrypted_credentials, + 'invoke_from': invoke_from, + 'tool_invoke_from': tool_invoke_from, + }) + elif provider_type == 'workflow': + workflow_provider = db.session.query(WorkflowToolProvider).filter( + WorkflowToolProvider.tenant_id == tenant_id, + WorkflowToolProvider.id == provider_id + ).first() + + if workflow_provider is None: + raise ToolProviderNotFoundError(f'workflow provider {provider_id} not found') + + controller = ToolTransformService.workflow_provider_to_controller( + db_provider=workflow_provider + ) + + return controller.get_tools(user_id=None, tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(runtime={ + 'tenant_id': tenant_id, + 'credentials': {}, + 'invoke_from': invoke_from, + 'tool_invoke_from': tool_invoke_from, }) elif provider_type == 'app': raise NotImplementedError('app provider not implemented') @@ -207,18 +240,25 @@ class ToolManager: return parameter_value @classmethod - def get_agent_tool_runtime(cls, tenant_id: str, app_id: str, agent_tool: AgentToolEntity) -> Tool: + def get_agent_tool_runtime(cls, tenant_id: str, app_id: str, agent_tool: AgentToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER) -> Tool: """ get the agent tool runtime """ tool_entity = cls.get_tool_runtime( - provider_type=agent_tool.provider_type, provider_name=agent_tool.provider_id, + provider_type=agent_tool.provider_type, + provider_id=agent_tool.provider_id, tool_name=agent_tool.tool_name, tenant_id=tenant_id, + invoke_from=invoke_from, + tool_invoke_from=ToolInvokeFrom.AGENT ) runtime_parameters = {} parameters = tool_entity.get_all_runtime_parameters() for parameter in parameters: + # check file types + if parameter.type == ToolParameter.ToolParameterType.FILE: + raise ValueError(f"file type parameter {parameter.name} not supported in agent") + if parameter.form == ToolParameter.ToolParameterForm.FORM: # save tool parameter to tool entity memory value = cls._init_runtime_parameter(parameter, agent_tool.tool_parameters) @@ -238,15 +278,17 @@ class ToolManager: return tool_entity @classmethod - def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, workflow_tool: ToolEntity): + def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, workflow_tool: ToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER) -> Tool: """ get the workflow tool runtime """ tool_entity = cls.get_tool_runtime( provider_type=workflow_tool.provider_type, - provider_name=workflow_tool.provider_id, + provider_id=workflow_tool.provider_id, tool_name=workflow_tool.tool_name, tenant_id=tenant_id, + invoke_from=invoke_from, + tool_invoke_from=ToolInvokeFrom.WORKFLOW ) runtime_parameters = {} parameters = tool_entity.get_all_runtime_parameters() @@ -371,51 +413,91 @@ class ToolManager: return cls._builtin_tools_labels[tool_name] @classmethod - def user_list_providers(cls, user_id: str, tenant_id: str) -> list[UserToolProvider]: + def user_list_providers(cls, user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral) -> list[UserToolProvider]: result_providers: dict[str, UserToolProvider] = {} - # get builtin providers - builtin_providers = cls.list_builtin_providers() - - # get db builtin providers - db_builtin_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider). \ - filter(BuiltinToolProvider.tenant_id == tenant_id).all() + filters = [] + if not typ: + filters.extend(['builtin', 'api', 'workflow']) + else: + filters.append(typ) - find_db_builtin_provider = lambda provider: next( - (x for x in db_builtin_providers if x.provider == provider), - None - ) + if 'builtin' in filters: - # append builtin providers - for provider in builtin_providers: - user_provider = ToolTransformService.builtin_provider_to_user_provider( - provider_controller=provider, - db_provider=find_db_builtin_provider(provider.identity.name), - decrypt_credentials=False + # get builtin providers + builtin_providers = cls.list_builtin_providers() + + # get db builtin providers + db_builtin_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider). \ + filter(BuiltinToolProvider.tenant_id == tenant_id).all() + + find_db_builtin_provider = lambda provider: next( + (x for x in db_builtin_providers if x.provider == provider), + None ) - result_providers[provider.identity.name] = user_provider + # append builtin providers + for provider in builtin_providers: + user_provider = ToolTransformService.builtin_provider_to_user_provider( + provider_controller=provider, + db_provider=find_db_builtin_provider(provider.identity.name), + decrypt_credentials=False + ) + + result_providers[provider.identity.name] = user_provider # get db api providers - db_api_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider). \ - filter(ApiToolProvider.tenant_id == tenant_id).all() - for db_api_provider in db_api_providers: - provider_controller = ToolTransformService.api_provider_to_controller( - db_provider=db_api_provider, - ) - user_provider = ToolTransformService.api_provider_to_user_provider( - provider_controller=provider_controller, - db_provider=db_api_provider, - decrypt_credentials=False - ) - result_providers[db_api_provider.name] = user_provider + if 'api' in filters: + db_api_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider). \ + filter(ApiToolProvider.tenant_id == tenant_id).all() + + api_provider_controllers = [{ + 'provider': provider, + 'controller': ToolTransformService.api_provider_to_controller(provider) + } for provider in db_api_providers] + + # get labels + labels = ToolLabelManager.get_tools_labels([x['controller'] for x in api_provider_controllers]) + + for api_provider_controller in api_provider_controllers: + user_provider = ToolTransformService.api_provider_to_user_provider( + provider_controller=api_provider_controller['controller'], + db_provider=api_provider_controller['provider'], + decrypt_credentials=False, + labels=labels.get(api_provider_controller['controller'].provider_id, []) + ) + result_providers[f'api_provider.{user_provider.name}'] = user_provider + + if 'workflow' in filters: + # get workflow providers + workflow_providers: list[WorkflowToolProvider] = db.session.query(WorkflowToolProvider). \ + filter(WorkflowToolProvider.tenant_id == tenant_id).all() + + workflow_provider_controllers = [] + for provider in workflow_providers: + try: + workflow_provider_controllers.append( + ToolTransformService.workflow_provider_to_controller(db_provider=provider) + ) + except Exception as e: + # app has been deleted + pass + + labels = ToolLabelManager.get_tools_labels(workflow_provider_controllers) + + for provider_controller in workflow_provider_controllers: + user_provider = ToolTransformService.workflow_provider_to_user_provider( + provider_controller=provider_controller, + labels=labels.get(provider_controller.provider_id, []), + ) + result_providers[f'workflow_provider.{user_provider.name}'] = user_provider return BuiltinToolProviderSort.sort(list(result_providers.values())) @classmethod def get_api_provider_controller(cls, tenant_id: str, provider_id: str) -> tuple[ - ApiBasedToolProviderController, dict[str, Any]]: + ApiToolProviderController, dict[str, Any]]: """ get the api provider @@ -431,7 +513,7 @@ class ToolManager: if provider is None: raise ToolProviderNotFoundError(f'api provider {provider_id} not found') - controller = ApiBasedToolProviderController.from_db( + controller = ApiToolProviderController.from_db( provider, ApiProviderAuthType.API_KEY if provider.credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE @@ -462,7 +544,7 @@ class ToolManager: credentials = {} # package tool provider controller - controller = ApiBasedToolProviderController.from_db( + controller = ApiToolProviderController.from_db( provider, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE ) # init tool configuration @@ -479,6 +561,9 @@ class ToolManager: "content": "\ud83d\ude01" } + # add tool labels + labels = ToolLabelManager.get_tool_labels(controller) + return jsonable_encoder({ 'schema_type': provider.schema_type, 'schema': provider.schema, @@ -487,7 +572,8 @@ class ToolManager: 'description': provider.description, 'credentials': masked_credentials, 'privacy_policy': provider.privacy_policy, - 'custom_disclaimer': provider.custom_disclaimer + 'custom_disclaimer': provider.custom_disclaimer, + 'labels': labels, }) @classmethod @@ -519,6 +605,15 @@ class ToolManager: "background": "#252525", "content": "\ud83d\ude01" } + elif provider_type == 'workflow': + provider: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( + WorkflowToolProvider.tenant_id == tenant_id, + WorkflowToolProvider.id == provider_id + ).first() + if provider is None: + raise ToolProviderNotFoundError(f'workflow provider {provider_id} not found') + + return json.loads(provider.icon) else: raise ValueError(f"provider type {provider_type} not found") diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py index 90b39a7fc9..b213879e96 100644 --- a/api/core/tools/utils/configuration.py +++ b/api/core/tools/utils/configuration.py @@ -72,8 +72,8 @@ class ToolConfigurationManager(BaseModel): return a deep copy of credentials with decrypted values """ cache = ToolProviderCredentialsCache( - tenant_id=self.tenant_id, - identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}', + tenant_id=self.tenant_id, + identity_id=f'{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}', cache_type=ToolProviderCredentialsCacheType.PROVIDER ) cached_credentials = cache.get() @@ -95,8 +95,8 @@ class ToolConfigurationManager(BaseModel): def delete_tool_credentials_cache(self): cache = ToolProviderCredentialsCache( - tenant_id=self.tenant_id, - identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}', + tenant_id=self.tenant_id, + identity_id=f'{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}', cache_type=ToolProviderCredentialsCacheType.PROVIDER ) cache.delete() diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index 3f456b4eb6..ef9e5b67ae 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -1,14 +1,15 @@ import logging from mimetypes import guess_extension +from core.file.file_obj import FileTransferMethod, FileType, FileVar from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool_file_manager import ToolFileManager logger = logging.getLogger(__name__) class ToolFileMessageTransformer: - @staticmethod - def transform_tool_invoke_messages(messages: list[ToolInvokeMessage], + @classmethod + def transform_tool_invoke_messages(cls, messages: list[ToolInvokeMessage], user_id: str, tenant_id: str, conversation_id: str) -> list[ToolInvokeMessage]: @@ -62,7 +63,7 @@ class ToolFileMessageTransformer: mimetype=mimetype ) - url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".bin"}' + url = cls.get_tool_file_url(file.id, guess_extension(file.mimetype)) # check if file is image if 'image' in mimetype: @@ -79,7 +80,30 @@ class ToolFileMessageTransformer: save_as=message.save_as, meta=message.meta.copy() if message.meta is not None else {}, )) + elif message.type == ToolInvokeMessage.MessageType.FILE_VAR: + file_var: FileVar = message.meta.get('file_var') + if file_var: + if file_var.transfer_method == FileTransferMethod.TOOL_FILE: + url = cls.get_tool_file_url(file_var.related_id, file_var.extension) + if file_var.type == FileType.IMAGE: + result.append(ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE_LINK, + message=url, + save_as=message.save_as, + meta=message.meta.copy() if message.meta is not None else {}, + )) + else: + result.append(ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.LINK, + message=url, + save_as=message.save_as, + meta=message.meta.copy() if message.meta is not None else {}, + )) else: result.append(message) - return result \ No newline at end of file + return result + + @classmethod + def get_tool_file_url(cls, tool_file_id: str, extension: str) -> str: + return f'/files/tools/{tool_file_id}{extension or ".bin"}' \ No newline at end of file diff --git a/api/core/tools/model/tool_model_manager.py b/api/core/tools/utils/model_invocation_utils.py similarity index 98% rename from api/core/tools/model/tool_model_manager.py rename to api/core/tools/utils/model_invocation_utils.py index e97d78d699..6526df6aa5 100644 --- a/api/core/tools/model/tool_model_manager.py +++ b/api/core/tools/utils/model_invocation_utils.py @@ -20,12 +20,14 @@ from core.model_runtime.errors.invoke import ( ) from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel, ModelPropertyKey from core.model_runtime.utils.encoders import jsonable_encoder -from core.tools.model.errors import InvokeModelError from extensions.ext_database import db from models.tools import ToolModelInvoke -class ToolModelManager: +class InvokeModelError(Exception): + pass + +class ModelInvocationUtils: @staticmethod def get_max_llm_context_tokens( tenant_id: str, diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index a96d8a6b7c..40ae6c66d5 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -9,14 +9,14 @@ from requests import get from yaml import YAMLError, safe_load from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_bundle import ApiBasedToolBundle +from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError class ApiBasedToolSchemaParser: @staticmethod - def parse_openapi_to_tool_bundle(openapi: dict, extra_info: dict = None, warning: dict = None) -> list[ApiBasedToolBundle]: + def parse_openapi_to_tool_bundle(openapi: dict, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]: warning = warning if warning is not None else {} extra_info = extra_info if extra_info is not None else {} @@ -145,7 +145,7 @@ class ApiBasedToolSchemaParser: interface['operation']['operationId'] = f'{path}_{interface["method"]}' - bundles.append(ApiBasedToolBundle( + bundles.append(ApiToolBundle( server_url=server_url + interface['path'], method=interface['method'], summary=interface['operation']['description'] if 'description' in interface['operation'] else @@ -176,7 +176,7 @@ class ApiBasedToolSchemaParser: return ToolParameter.ToolParameterType.STRING @staticmethod - def parse_openapi_yaml_to_tool_bundle(yaml: str, extra_info: dict = None, warning: dict = None) -> list[ApiBasedToolBundle]: + def parse_openapi_yaml_to_tool_bundle(yaml: str, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]: """ parse openapi yaml to tool bundle @@ -258,7 +258,7 @@ class ApiBasedToolSchemaParser: return openapi @staticmethod - def parse_openai_plugin_json_to_tool_bundle(json: str, extra_info: dict = None, warning: dict = None) -> list[ApiBasedToolBundle]: + def parse_openai_plugin_json_to_tool_bundle(json: str, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]: """ parse openapi plugin yaml to tool bundle @@ -290,7 +290,7 @@ class ApiBasedToolSchemaParser: return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(response.text, extra_info=extra_info, warning=warning) @staticmethod - def auto_parse_to_tool_bundle(content: str, extra_info: dict = None, warning: dict = None) -> tuple[list[ApiBasedToolBundle], str]: + def auto_parse_to_tool_bundle(content: str, extra_info: dict = None, warning: dict = None) -> tuple[list[ApiToolBundle], str]: """ auto parse to tool bundle diff --git a/api/core/tools/utils/workflow_configuration_sync.py b/api/core/tools/utils/workflow_configuration_sync.py new file mode 100644 index 0000000000..ff5505bbbf --- /dev/null +++ b/api/core/tools/utils/workflow_configuration_sync.py @@ -0,0 +1,48 @@ +from core.app.app_config.entities import VariableEntity +from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration + + +class WorkflowToolConfigurationUtils: + @classmethod + def check_parameter_configurations(cls, configurations: list[dict]): + """ + check parameter configurations + """ + for configuration in configurations: + if not WorkflowToolParameterConfiguration(**configuration): + raise ValueError('invalid parameter configuration') + + @classmethod + def get_workflow_graph_variables(cls, graph: dict) -> list[VariableEntity]: + """ + get workflow graph variables + """ + nodes = graph.get('nodes', []) + start_node = next(filter(lambda x: x.get('data', {}).get('type') == 'start', nodes), None) + + if not start_node: + return [] + + return [ + VariableEntity(**variable) for variable in start_node.get('data', {}).get('variables', []) + ] + + @classmethod + def check_is_synced(cls, + variables: list[VariableEntity], + tool_configurations: list[WorkflowToolParameterConfiguration]) -> None: + """ + check is synced + + raise ValueError if not synced + """ + variable_names = [variable.variable for variable in variables] + + if len(tool_configurations) != len(variables): + raise ValueError('parameter configuration mismatch, please republish the tool to update') + + for parameter in tool_configurations: + if parameter.name not in variable_names: + raise ValueError('parameter configuration mismatch, please republish the tool to update') + + return True \ No newline at end of file diff --git a/api/core/workflow/callbacks/base_workflow_callback.py b/api/core/workflow/callbacks/base_workflow_callback.py index dd5a30f611..3b0d51d868 100644 --- a/api/core/workflow/callbacks/base_workflow_callback.py +++ b/api/core/workflow/callbacks/base_workflow_callback.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Optional +from typing import Any, Optional from core.app.entities.queue_entities import AppQueueEvent from core.workflow.entities.base_node_data_entities import BaseNodeData @@ -71,6 +71,42 @@ class BaseWorkflowCallback(ABC): Publish text chunk """ raise NotImplementedError + + @abstractmethod + def on_workflow_iteration_started(self, + node_id: str, + node_type: NodeType, + node_run_index: int = 1, + node_data: Optional[BaseNodeData] = None, + inputs: dict = None, + predecessor_node_id: Optional[str] = None, + metadata: Optional[dict] = None) -> None: + """ + Publish iteration started + """ + raise NotImplementedError + + @abstractmethod + def on_workflow_iteration_next(self, node_id: str, + node_type: NodeType, + index: int, + node_run_index: int, + output: Optional[Any], + ) -> None: + """ + Publish iteration next + """ + raise NotImplementedError + + @abstractmethod + def on_workflow_iteration_completed(self, node_id: str, + node_type: NodeType, + node_run_index: int, + outputs: dict) -> None: + """ + Publish iteration completed + """ + raise NotImplementedError @abstractmethod def on_event(self, event: AppQueueEvent) -> None: diff --git a/api/core/workflow/entities/base_node_data_entities.py b/api/core/workflow/entities/base_node_data_entities.py index fc6ee231ff..6bf0c11c7d 100644 --- a/api/core/workflow/entities/base_node_data_entities.py +++ b/api/core/workflow/entities/base_node_data_entities.py @@ -7,3 +7,16 @@ from pydantic import BaseModel class BaseNodeData(ABC, BaseModel): title: str desc: Optional[str] = None + +class BaseIterationNodeData(BaseNodeData): + start_node_id: str + +class BaseIterationState(BaseModel): + iteration_node_id: str + index: int + inputs: dict + + class MetaData(BaseModel): + pass + + metadata: MetaData \ No newline at end of file diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index 7eb9488792..ae86463407 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -21,7 +21,11 @@ class NodeType(Enum): QUESTION_CLASSIFIER = 'question-classifier' HTTP_REQUEST = 'http-request' TOOL = 'tool' + VARIABLE_AGGREGATOR = 'variable-aggregator' VARIABLE_ASSIGNER = 'variable-assigner' + LOOP = 'loop' + ITERATION = 'iteration' + PARAMETER_EXTRACTOR = 'parameter-extractor' @classmethod def value_of(cls, value: str) -> 'NodeType': @@ -68,6 +72,8 @@ class NodeRunMetadataKey(Enum): TOTAL_PRICE = 'total_price' CURRENCY = 'currency' TOOL_INFO = 'tool_info' + ITERATION_ID = 'iteration_id' + ITERATION_INDEX = 'iteration_index' class NodeRunResult(BaseModel): diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index 690bdddaf6..c04770616c 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -90,3 +90,12 @@ class VariablePool: raise ValueError(f'Invalid value type: {target_value_type.value}') return value + + def clear_node_variables(self, node_id: str) -> None: + """ + Clear node variables + :param node_id: node id + :return: + """ + if node_id in self.variables_mapping: + self.variables_mapping.pop(node_id) \ No newline at end of file diff --git a/api/core/workflow/entities/workflow_entities.py b/api/core/workflow/entities/workflow_entities.py index e1c5eb6752..9b35b8df8a 100644 --- a/api/core/workflow/entities/workflow_entities.py +++ b/api/core/workflow/entities/workflow_entities.py @@ -1,5 +1,9 @@ from typing import Optional +from pydantic import BaseModel + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.base_node_data_entities import BaseIterationState from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode, UserFrom @@ -22,6 +26,9 @@ class WorkflowRunState: workflow_type: WorkflowType user_id: str user_from: UserFrom + invoke_from: InvokeFrom + + workflow_call_depth: int start_at: float variable_pool: VariablePool @@ -30,20 +37,37 @@ class WorkflowRunState: workflow_nodes_and_results: list[WorkflowNodeAndResult] + class NodeRun(BaseModel): + node_id: str + iteration_node_id: str + + workflow_node_runs: list[NodeRun] + workflow_node_steps: int + + current_iteration_state: Optional[BaseIterationState] + def __init__(self, workflow: Workflow, start_at: float, variable_pool: VariablePool, user_id: str, - user_from: UserFrom): + user_from: UserFrom, + invoke_from: InvokeFrom, + workflow_call_depth: int): self.workflow_id = workflow.id self.tenant_id = workflow.tenant_id self.app_id = workflow.app_id self.workflow_type = WorkflowType.value_of(workflow.type) self.user_id = user_id self.user_from = user_from + self.invoke_from = invoke_from + self.workflow_call_depth = workflow_call_depth self.start_at = start_at self.variable_pool = variable_pool self.total_tokens = 0 self.workflow_nodes_and_results = [] + + self.current_iteration_state = None + self.workflow_node_steps = 1 + self.workflow_node_runs = [] \ No newline at end of file diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index 7cc9c6ee3d..fa7d6424f1 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -2,8 +2,9 @@ from abc import ABC, abstractmethod from enum import Enum from typing import Optional +from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback -from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool @@ -37,6 +38,9 @@ class BaseNode(ABC): workflow_id: str user_id: str user_from: UserFrom + invoke_from: InvokeFrom + + workflow_call_depth: int node_id: str node_data: BaseNodeData @@ -49,13 +53,17 @@ class BaseNode(ABC): workflow_id: str, user_id: str, user_from: UserFrom, + invoke_from: InvokeFrom, config: dict, - callbacks: list[BaseWorkflowCallback] = None) -> None: + callbacks: list[BaseWorkflowCallback] = None, + workflow_call_depth: int = 0) -> None: self.tenant_id = tenant_id self.app_id = app_id self.workflow_id = workflow_id self.user_id = user_id self.user_from = user_from + self.invoke_from = invoke_from + self.workflow_call_depth = workflow_call_depth self.node_id = config.get("id") if not self.node_id: @@ -140,3 +148,38 @@ class BaseNode(ABC): :return: """ return self._node_type + +class BaseIterationNode(BaseNode): + @abstractmethod + def _run(self, variable_pool: VariablePool) -> BaseIterationState: + """ + Run node + :param variable_pool: variable pool + :return: + """ + raise NotImplementedError + + def run(self, variable_pool: VariablePool) -> BaseIterationState: + """ + Run node entry + :param variable_pool: variable pool + :return: + """ + return self._run(variable_pool=variable_pool) + + def get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str: + """ + Get next iteration start node id based on the graph. + :param graph: graph + :return: next node id + """ + return self._get_next_iteration(variable_pool, state) + + @abstractmethod + def _get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str: + """ + Get next iteration start node id based on the graph. + :param graph: graph + :return: next node id + """ + raise NotImplementedError diff --git a/api/core/workflow/nodes/variable_assigner/__init__.py b/api/core/workflow/nodes/iteration/__init__.py similarity index 100% rename from api/core/workflow/nodes/variable_assigner/__init__.py rename to api/core/workflow/nodes/iteration/__init__.py diff --git a/api/core/workflow/nodes/iteration/entities.py b/api/core/workflow/nodes/iteration/entities.py new file mode 100644 index 0000000000..c85aa66c7b --- /dev/null +++ b/api/core/workflow/nodes/iteration/entities.py @@ -0,0 +1,39 @@ +from typing import Any, Optional + +from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState + + +class IterationNodeData(BaseIterationNodeData): + """ + Iteration Node Data. + """ + parent_loop_id: Optional[str] # redundant field, not used currently + iterator_selector: list[str] # variable selector + output_selector: list[str] # output selector + +class IterationState(BaseIterationState): + """ + Iteration State. + """ + outputs: list[Any] = None + current_output: Optional[Any] = None + + class MetaData(BaseIterationState.MetaData): + """ + Data. + """ + iterator_length: int + + def get_last_output(self) -> Optional[Any]: + """ + Get last output. + """ + if self.outputs: + return self.outputs[-1] + return None + + def get_current_output(self) -> Optional[Any]: + """ + Get current output. + """ + return self.current_output \ No newline at end of file diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py new file mode 100644 index 0000000000..12d792f297 --- /dev/null +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -0,0 +1,119 @@ +from typing import cast + +from core.model_runtime.utils.encoders import jsonable_encoder +from core.workflow.entities.base_node_data_entities import BaseIterationState +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.base_node import BaseIterationNode +from core.workflow.nodes.iteration.entities import IterationNodeData, IterationState +from models.workflow import WorkflowNodeExecutionStatus + + +class IterationNode(BaseIterationNode): + """ + Iteration Node. + """ + _node_data_cls = IterationNodeData + _node_type = NodeType.ITERATION + + def _run(self, variable_pool: VariablePool) -> BaseIterationState: + """ + Run the node. + """ + iterator = variable_pool.get_variable_value(cast(IterationNodeData, self.node_data).iterator_selector) + + if not isinstance(iterator, list): + raise ValueError(f"Invalid iterator value: {iterator}, please provide a list.") + + state = IterationState(iteration_node_id=self.node_id, index=-1, inputs={ + 'iterator_selector': iterator + }, outputs=[], metadata=IterationState.MetaData( + iterator_length=len(iterator) if iterator is not None else 0 + )) + + self._set_current_iteration_variable(variable_pool, state) + return state + + def _get_next_iteration(self, variable_pool: VariablePool, state: IterationState) -> NodeRunResult | str: + """ + Get next iteration start node id based on the graph. + :param graph: graph + :return: next node id + """ + # resolve current output + self._resolve_current_output(variable_pool, state) + # move to next iteration + self._next_iteration(variable_pool, state) + + node_data = cast(IterationNodeData, self.node_data) + if self._reached_iteration_limit(variable_pool, state): + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={ + 'output': jsonable_encoder(state.outputs) + } + ) + + return node_data.start_node_id + + def _set_current_iteration_variable(self, variable_pool: VariablePool, state: IterationState): + """ + Set current iteration variable. + :variable_pool: variable pool + """ + node_data = cast(IterationNodeData, self.node_data) + + variable_pool.append_variable(self.node_id, ['index'], state.index) + # get the iterator value + iterator = variable_pool.get_variable_value(node_data.iterator_selector) + + if iterator is None or not isinstance(iterator, list): + return + + if state.index < len(iterator): + variable_pool.append_variable(self.node_id, ['item'], iterator[state.index]) + + def _next_iteration(self, variable_pool: VariablePool, state: IterationState): + """ + Move to next iteration. + :param variable_pool: variable pool + """ + state.index += 1 + self._set_current_iteration_variable(variable_pool, state) + + def _reached_iteration_limit(self, variable_pool: VariablePool, state: IterationState): + """ + Check if iteration limit is reached. + :return: True if iteration limit is reached, False otherwise + """ + node_data = cast(IterationNodeData, self.node_data) + iterator = variable_pool.get_variable_value(node_data.iterator_selector) + + if iterator is None or not isinstance(iterator, list): + return True + + return state.index >= len(iterator) + + def _resolve_current_output(self, variable_pool: VariablePool, state: IterationState): + """ + Resolve current output. + :param variable_pool: variable pool + """ + output_selector = cast(IterationNodeData, self.node_data).output_selector + output = variable_pool.get_variable_value(output_selector) + # clear the output for this iteration + variable_pool.append_variable(self.node_id, output_selector[1:], None) + state.current_output = output + if output is not None: + state.outputs.append(output) + + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: IterationNodeData) -> dict[str, list[str]]: + """ + Extract variable selector to variable mapping + :param node_data: node data + :return: + """ + return { + 'input_selector': node_data.iterator_selector, + } \ No newline at end of file diff --git a/api/core/workflow/nodes/loop/__init__.py b/api/core/workflow/nodes/loop/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/nodes/loop/entities.py b/api/core/workflow/nodes/loop/entities.py new file mode 100644 index 0000000000..8a5684551e --- /dev/null +++ b/api/core/workflow/nodes/loop/entities.py @@ -0,0 +1,13 @@ + +from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState + + +class LoopNodeData(BaseIterationNodeData): + """ + Loop Node Data. + """ + +class LoopState(BaseIterationState): + """ + Loop State. + """ \ No newline at end of file diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py new file mode 100644 index 0000000000..7d53c6f5f2 --- /dev/null +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -0,0 +1,20 @@ +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.base_node import BaseIterationNode +from core.workflow.nodes.loop.entities import LoopNodeData, LoopState + + +class LoopNode(BaseIterationNode): + """ + Loop Node. + """ + _node_data_cls = LoopNodeData + _node_type = NodeType.LOOP + + def _run(self, variable_pool: VariablePool) -> LoopState: + return super()._run(variable_pool) + + def _get_next_iteration(self, variable_loop: VariablePool) -> NodeRunResult | str: + """ + Get next iteration start node id based on the graph. + """ diff --git a/api/core/workflow/nodes/parameter_extractor/__init__.py b/api/core/workflow/nodes/parameter_extractor/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/nodes/parameter_extractor/entities.py b/api/core/workflow/nodes/parameter_extractor/entities.py new file mode 100644 index 0000000000..a89a6903ef --- /dev/null +++ b/api/core/workflow/nodes/parameter_extractor/entities.py @@ -0,0 +1,85 @@ +from typing import Any, Literal, Optional + +from pydantic import BaseModel, validator + +from core.prompt.entities.advanced_prompt_entities import MemoryConfig +from core.workflow.entities.base_node_data_entities import BaseNodeData + + +class ModelConfig(BaseModel): + """ + Model Config. + """ + provider: str + name: str + mode: str + completion_params: dict[str, Any] = {} + +class ParameterConfig(BaseModel): + """ + Parameter Config. + """ + name: str + type: Literal['string', 'number', 'bool', 'select', 'array[string]', 'array[number]', 'array[object]'] + options: Optional[list[str]] + description: str + required: bool + + @validator('name', pre=True, always=True) + def validate_name(cls, value): + if not value: + raise ValueError('Parameter name is required') + if value in ['__reason', '__is_success']: + raise ValueError('Invalid parameter name, __reason and __is_success are reserved') + return value + +class ParameterExtractorNodeData(BaseNodeData): + """ + Parameter Extractor Node Data. + """ + model: ModelConfig + query: list[str] + parameters: list[ParameterConfig] + instruction: Optional[str] + memory: Optional[MemoryConfig] + reasoning_mode: Literal['function_call', 'prompt'] + + @validator('reasoning_mode', pre=True, always=True) + def set_reasoning_mode(cls, v): + return v or 'function_call' + + def get_parameter_json_schema(self) -> dict: + """ + Get parameter json schema. + + :return: parameter json schema + """ + parameters = { + 'type': 'object', + 'properties': {}, + 'required': [] + } + + for parameter in self.parameters: + parameter_schema = { + 'description': parameter.description + } + + if parameter.type in ['string', 'select']: + parameter_schema['type'] = 'string' + elif parameter.type.startswith('array'): + parameter_schema['type'] = 'array' + nested_type = parameter.type[6:-1] + parameter_schema['items'] = {'type': nested_type} + else: + parameter_schema['type'] = parameter.type + + if parameter.type == 'select': + parameter_schema['enum'] = parameter.options + + parameters['properties'][parameter.name] = parameter_schema + + if parameter.required: + parameters['required'].append(parameter.name) + + return parameters \ No newline at end of file diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py new file mode 100644 index 0000000000..6e7dbd2702 --- /dev/null +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -0,0 +1,711 @@ +import json +import uuid +from typing import Optional, cast + +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_manager import ModelInstance +from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageRole, + PromptMessageTool, + ToolPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_runtime.utils.encoders import jsonable_encoder +from core.prompt.advanced_prompt_transform import AdvancedPromptTransform +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate +from core.prompt.simple_prompt_transform import ModelMode +from core.prompt.utils.prompt_message_util import PromptMessageUtil +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.llm.entities import ModelConfig +from core.workflow.nodes.llm.llm_node import LLMNode +from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from core.workflow.nodes.parameter_extractor.prompts import ( + CHAT_EXAMPLE, + CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE, + COMPLETION_GENERATE_JSON_PROMPT, + FUNCTION_CALLING_EXTRACTOR_EXAMPLE, + FUNCTION_CALLING_EXTRACTOR_NAME, + FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT, + FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE, +) +from core.workflow.utils.variable_template_parser import VariableTemplateParser +from extensions.ext_database import db +from models.workflow import WorkflowNodeExecutionStatus + + +class ParameterExtractorNode(LLMNode): + """ + Parameter Extractor Node. + """ + _node_data_cls = ParameterExtractorNodeData + _node_type = NodeType.PARAMETER_EXTRACTOR + + _model_instance: Optional[ModelInstance] = None + _model_config: Optional[ModelConfigWithCredentialsEntity] = None + + @classmethod + def get_default_config(cls, filters: Optional[dict] = None) -> dict: + return { + "model": { + "prompt_templates": { + "completion_model": { + "conversation_histories_role": { + "user_prefix": "Human", + "assistant_prefix": "Assistant" + }, + "stop": ["Human:"] + } + } + } + } + + + def _run(self, variable_pool: VariablePool) -> NodeRunResult: + """ + Run the node. + """ + + node_data = cast(ParameterExtractorNodeData, self.node_data) + query = variable_pool.get_variable_value(node_data.query) + if not query: + raise ValueError("Query not found") + + inputs={ + 'query': query, + 'parameters': jsonable_encoder(node_data.parameters), + 'instruction': jsonable_encoder(node_data.instruction), + } + + model_instance, model_config = self._fetch_model_config(node_data.model) + if not isinstance(model_instance.model_type_instance, LargeLanguageModel): + raise ValueError("Model is not a Large Language Model") + + llm_model = model_instance.model_type_instance + model_schema = llm_model.get_model_schema(model_config.model, model_config.credentials) + if not model_schema: + raise ValueError("Model schema not found") + + # fetch memory + memory = self._fetch_memory(node_data.memory, variable_pool, model_instance) + + if set(model_schema.features or []) & set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL]) \ + and node_data.reasoning_mode == 'function_call': + # use function call + prompt_messages, prompt_message_tools = self._generate_function_call_prompt( + node_data, query, variable_pool, model_config, memory + ) + else: + # use prompt engineering + prompt_messages = self._generate_prompt_engineering_prompt(node_data, query, variable_pool, model_config, memory) + prompt_message_tools = [] + + process_data = { + 'model_mode': model_config.mode, + 'prompts': PromptMessageUtil.prompt_messages_to_prompt_for_saving( + model_mode=model_config.mode, + prompt_messages=prompt_messages + ), + 'usage': None, + 'function': {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]), + 'tool_call': None, + } + + try: + text, usage, tool_call = self._invoke_llm( + node_data_model=node_data.model, + model_instance=model_instance, + prompt_messages=prompt_messages, + tools=prompt_message_tools, + stop=model_config.stop, + ) + process_data['usage'] = jsonable_encoder(usage) + process_data['tool_call'] = jsonable_encoder(tool_call) + process_data['llm_text'] = text + except Exception as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=inputs, + process_data={}, + outputs={ + '__is_success': 0, + '__reason': str(e) + }, + error=str(e), + metadata={} + ) + + error = None + + if tool_call: + result = self._extract_json_from_tool_call(tool_call) + else: + result = self._extract_complete_json_response(text) + if not result: + result = self._generate_default_result(node_data) + error = "Failed to extract result from function call or text response, using empty result." + + try: + result = self._validate_result(node_data, result) + except Exception as e: + error = str(e) + + # transform result into standard format + result = self._transform_result(node_data, result) + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=inputs, + process_data=process_data, + outputs={ + '__is_success': 1 if not error else 0, + '__reason': error, + **result + }, + metadata={ + NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, + NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, + NodeRunMetadataKey.CURRENCY: usage.currency + } + ) + + def _invoke_llm(self, node_data_model: ModelConfig, + model_instance: ModelInstance, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + stop: list[str]) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]: + """ + Invoke large language model + :param node_data_model: node data model + :param model_instance: model instance + :param prompt_messages: prompt messages + :param stop: stop + :return: + """ + db.session.close() + + invoke_result = model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters=node_data_model.completion_params, + tools=tools, + stop=stop, + stream=False, + user=self.user_id, + ) + + # handle invoke result + if not isinstance(invoke_result, LLMResult): + raise ValueError(f"Invalid invoke result: {invoke_result}") + + text = invoke_result.message.content + usage = invoke_result.usage + tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None + + # deduct quota + self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) + + return text, usage, tool_call + + def _generate_function_call_prompt(self, + node_data: ParameterExtractorNodeData, + query: str, + variable_pool: VariablePool, + model_config: ModelConfigWithCredentialsEntity, + memory: Optional[TokenBufferMemory], + ) -> tuple[list[PromptMessage], list[PromptMessageTool]]: + """ + Generate function call prompt. + """ + query = FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE.format(content=query, structure=json.dumps(node_data.get_parameter_json_schema())) + + prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) + rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '') + prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, memory, rest_token) + prompt_messages = prompt_transform.get_prompt( + prompt_template=prompt_template, + inputs={}, + query='', + files=[], + context='', + memory_config=node_data.memory, + memory=None, + model_config=model_config + ) + + # find last user message + last_user_message_idx = -1 + for i, prompt_message in enumerate(prompt_messages): + if prompt_message.role == PromptMessageRole.USER: + last_user_message_idx = i + + # add function call messages before last user message + example_messages = [] + for example in FUNCTION_CALLING_EXTRACTOR_EXAMPLE: + id = uuid.uuid4().hex + example_messages.extend([ + UserPromptMessage(content=example['user']['query']), + AssistantPromptMessage( + content=example['assistant']['text'], + tool_calls=[ + AssistantPromptMessage.ToolCall( + id=id, + type='function', + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=example['assistant']['function_call']['name'], + arguments=json.dumps(example['assistant']['function_call']['parameters'] + ) + )) + ] + ), + ToolPromptMessage( + content='Great! You have called the function with the correct parameters.', + tool_call_id=id + ), + AssistantPromptMessage( + content='I have extracted the parameters, let\'s move on.', + ) + ]) + + prompt_messages = prompt_messages[:last_user_message_idx] + \ + example_messages + prompt_messages[last_user_message_idx:] + + # generate tool + tool = PromptMessageTool( + name=FUNCTION_CALLING_EXTRACTOR_NAME, + description='Extract parameters from the natural language text', + parameters=node_data.get_parameter_json_schema(), + ) + + return prompt_messages, [tool] + + def _generate_prompt_engineering_prompt(self, + data: ParameterExtractorNodeData, + query: str, + variable_pool: VariablePool, + model_config: ModelConfigWithCredentialsEntity, + memory: Optional[TokenBufferMemory], + ) -> list[PromptMessage]: + """ + Generate prompt engineering prompt. + """ + model_mode = ModelMode.value_of(data.model.mode) + + if model_mode == ModelMode.COMPLETION: + return self._generate_prompt_engineering_completion_prompt( + data, query, variable_pool, model_config, memory + ) + elif model_mode == ModelMode.CHAT: + return self._generate_prompt_engineering_chat_prompt( + data, query, variable_pool, model_config, memory + ) + else: + raise ValueError(f"Invalid model mode: {model_mode}") + + def _generate_prompt_engineering_completion_prompt(self, + node_data: ParameterExtractorNodeData, + query: str, + variable_pool: VariablePool, + model_config: ModelConfigWithCredentialsEntity, + memory: Optional[TokenBufferMemory], + ) -> list[PromptMessage]: + """ + Generate completion prompt. + """ + prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) + rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '') + prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, memory, rest_token) + prompt_messages = prompt_transform.get_prompt( + prompt_template=prompt_template, + inputs={ + 'structure': json.dumps(node_data.get_parameter_json_schema()) + }, + query='', + files=[], + context='', + memory_config=node_data.memory, + memory=memory, + model_config=model_config + ) + + return prompt_messages + + def _generate_prompt_engineering_chat_prompt(self, + node_data: ParameterExtractorNodeData, + query: str, + variable_pool: VariablePool, + model_config: ModelConfigWithCredentialsEntity, + memory: Optional[TokenBufferMemory], + ) -> list[PromptMessage]: + """ + Generate chat prompt. + """ + prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) + rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '') + prompt_template = self._get_prompt_engineering_prompt_template( + node_data, + CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format( + structure=json.dumps(node_data.get_parameter_json_schema()), + text=query + ), + variable_pool, memory, rest_token + ) + + prompt_messages = prompt_transform.get_prompt( + prompt_template=prompt_template, + inputs={}, + query='', + files=[], + context='', + memory_config=node_data.memory, + memory=memory, + model_config=model_config + ) + + # find last user message + last_user_message_idx = -1 + for i, prompt_message in enumerate(prompt_messages): + if prompt_message.role == PromptMessageRole.USER: + last_user_message_idx = i + + # add example messages before last user message + example_messages = [] + for example in CHAT_EXAMPLE: + example_messages.extend([ + UserPromptMessage(content=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format( + structure=json.dumps(example['user']['json']), + text=example['user']['query'], + )), + AssistantPromptMessage( + content=json.dumps(example['assistant']['json']), + ) + ]) + + prompt_messages = prompt_messages[:last_user_message_idx] + \ + example_messages + prompt_messages[last_user_message_idx:] + + return prompt_messages + + def _validate_result(self, data: ParameterExtractorNodeData, result: dict) -> dict: + """ + Validate result. + """ + if len(data.parameters) != len(result): + raise ValueError("Invalid number of parameters") + + for parameter in data.parameters: + if parameter.required and parameter.name not in result: + raise ValueError(f"Parameter {parameter.name} is required") + + if parameter.type == 'select' and parameter.options and result.get(parameter.name) not in parameter.options: + raise ValueError(f"Invalid `select` value for parameter {parameter.name}") + + if parameter.type == 'number' and not isinstance(result.get(parameter.name), int | float): + raise ValueError(f"Invalid `number` value for parameter {parameter.name}") + + if parameter.type == 'bool' and not isinstance(result.get(parameter.name), bool): + raise ValueError(f"Invalid `bool` value for parameter {parameter.name}") + + if parameter.type == 'string' and not isinstance(result.get(parameter.name), str): + raise ValueError(f"Invalid `string` value for parameter {parameter.name}") + + if parameter.type.startswith('array'): + if not isinstance(result.get(parameter.name), list): + raise ValueError(f"Invalid `array` value for parameter {parameter.name}") + nested_type = parameter.type[6:-1] + for item in result.get(parameter.name): + if nested_type == 'number' and not isinstance(item, int | float): + raise ValueError(f"Invalid `array[number]` value for parameter {parameter.name}") + if nested_type == 'string' and not isinstance(item, str): + raise ValueError(f"Invalid `array[string]` value for parameter {parameter.name}") + if nested_type == 'object' and not isinstance(item, dict): + raise ValueError(f"Invalid `array[object]` value for parameter {parameter.name}") + return result + + def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> dict: + """ + Transform result into standard format. + """ + transformed_result = {} + for parameter in data.parameters: + if parameter.name in result: + # transform value + if parameter.type == 'number': + if isinstance(result[parameter.name], int | float): + transformed_result[parameter.name] = result[parameter.name] + elif isinstance(result[parameter.name], str): + try: + if '.' in result[parameter.name]: + result[parameter.name] = float(result[parameter.name]) + else: + result[parameter.name] = int(result[parameter.name]) + except ValueError: + pass + else: + pass + # TODO: bool is not supported in the current version + # elif parameter.type == 'bool': + # if isinstance(result[parameter.name], bool): + # transformed_result[parameter.name] = bool(result[parameter.name]) + # elif isinstance(result[parameter.name], str): + # if result[parameter.name].lower() in ['true', 'false']: + # transformed_result[parameter.name] = bool(result[parameter.name].lower() == 'true') + # elif isinstance(result[parameter.name], int): + # transformed_result[parameter.name] = bool(result[parameter.name]) + elif parameter.type in ['string', 'select']: + if isinstance(result[parameter.name], str): + transformed_result[parameter.name] = result[parameter.name] + elif parameter.type.startswith('array'): + if isinstance(result[parameter.name], list): + nested_type = parameter.type[6:-1] + transformed_result[parameter.name] = [] + for item in result[parameter.name]: + if nested_type == 'number': + if isinstance(item, int | float): + transformed_result[parameter.name].append(item) + elif isinstance(item, str): + try: + if '.' in item: + transformed_result[parameter.name].append(float(item)) + else: + transformed_result[parameter.name].append(int(item)) + except ValueError: + pass + elif nested_type == 'string': + if isinstance(item, str): + transformed_result[parameter.name].append(item) + elif nested_type == 'object': + if isinstance(item, dict): + transformed_result[parameter.name].append(item) + + if parameter.name not in transformed_result: + if parameter.type == 'number': + transformed_result[parameter.name] = 0 + elif parameter.type == 'bool': + transformed_result[parameter.name] = False + elif parameter.type in ['string', 'select']: + transformed_result[parameter.name] = '' + elif parameter.type.startswith('array'): + transformed_result[parameter.name] = [] + + return transformed_result + + def _extract_complete_json_response(self, result: str) -> Optional[dict]: + """ + Extract complete json response. + """ + def extract_json(text): + """ + From a given JSON started from '{' or '[' extract the complete JSON object. + """ + stack = [] + for i, c in enumerate(text): + if c == '{' or c == '[': + stack.append(c) + elif c == '}' or c == ']': + # check if stack is empty + if not stack: + return text[:i] + # check if the last element in stack is matching + if (c == '}' and stack[-1] == '{') or (c == ']' and stack[-1] == '['): + stack.pop() + if not stack: + return text[:i+1] + else: + return text[:i] + return None + + # extract json from the text + for idx in range(len(result)): + if result[idx] == '{' or result[idx] == '[': + json_str = extract_json(result[idx:]) + if json_str: + try: + return json.loads(json_str) + except Exception: + pass + + def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> Optional[dict]: + """ + Extract json from tool call. + """ + if not tool_call or not tool_call.function.arguments: + return None + + return json.loads(tool_call.function.arguments) + + def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict: + """ + Generate default result. + """ + result = {} + for parameter in data.parameters: + if parameter.type == 'number': + result[parameter.name] = 0 + elif parameter.type == 'bool': + result[parameter.name] = False + elif parameter.type in ['string', 'select']: + result[parameter.name] = '' + + return result + + def _render_instruction(self, instruction: str, variable_pool: VariablePool) -> str: + """ + Render instruction. + """ + variable_template_parser = VariableTemplateParser(instruction) + inputs = {} + for selector in variable_template_parser.extract_variable_selectors(): + inputs[selector.variable] = variable_pool.get_variable_value(selector.value_selector) + + return variable_template_parser.format(inputs) + + def _get_function_calling_prompt_template(self, node_data: ParameterExtractorNodeData, query: str, + variable_pool: VariablePool, + memory: Optional[TokenBufferMemory], + max_token_limit: int = 2000) \ + -> list[ChatModelMessage]: + model_mode = ModelMode.value_of(node_data.model.mode) + input_text = query + memory_str = '' + instruction = self._render_instruction(node_data.instruction or '', variable_pool) + + if memory: + memory_str = memory.get_history_prompt_text(max_token_limit=max_token_limit, + message_limit=node_data.memory.window.size) + if model_mode == ModelMode.CHAT: + system_prompt_messages = ChatModelMessage( + role=PromptMessageRole.SYSTEM, + text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction) + ) + user_prompt_message = ChatModelMessage( + role=PromptMessageRole.USER, + text=input_text + ) + return [system_prompt_messages, user_prompt_message] + else: + raise ValueError(f"Model mode {model_mode} not support.") + + def _get_prompt_engineering_prompt_template(self, node_data: ParameterExtractorNodeData, query: str, + variable_pool: VariablePool, + memory: Optional[TokenBufferMemory], + max_token_limit: int = 2000) \ + -> list[ChatModelMessage]: + + model_mode = ModelMode.value_of(node_data.model.mode) + input_text = query + memory_str = '' + instruction = self._render_instruction(node_data.instruction or '', variable_pool) + + if memory: + memory_str = memory.get_history_prompt_text(max_token_limit=max_token_limit, + message_limit=node_data.memory.window.size) + if model_mode == ModelMode.CHAT: + system_prompt_messages = ChatModelMessage( + role=PromptMessageRole.SYSTEM, + text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction) + ) + user_prompt_message = ChatModelMessage( + role=PromptMessageRole.USER, + text=input_text + ) + return [system_prompt_messages, user_prompt_message] + elif model_mode == ModelMode.COMPLETION: + return CompletionModelPromptTemplate( + text=COMPLETION_GENERATE_JSON_PROMPT.format(histories=memory_str, + text=input_text, + instruction=instruction) + .replace('{γγγ', '') + .replace('}γγγ', '') + ) + else: + raise ValueError(f"Model mode {model_mode} not support.") + + def _calculate_rest_token(self, node_data: ParameterExtractorNodeData, query: str, + variable_pool: VariablePool, + model_config: ModelConfigWithCredentialsEntity, + context: Optional[str]) -> int: + prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) + + model_instance, model_config = self._fetch_model_config(node_data.model) + if not isinstance(model_instance.model_type_instance, LargeLanguageModel): + raise ValueError("Model is not a Large Language Model") + + llm_model = model_instance.model_type_instance + model_schema = llm_model.get_model_schema(model_config.model, model_config.credentials) + if not model_schema: + raise ValueError("Model schema not found") + + if set(model_schema.features or []) & set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL]): + prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000) + else: + prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, None, 2000) + + prompt_messages = prompt_transform.get_prompt( + prompt_template=prompt_template, + inputs={}, + query='', + files=[], + context=context, + memory_config=node_data.memory, + memory=None, + model_config=model_config + ) + rest_tokens = 2000 + + model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) + if model_context_tokens: + model_type_instance = model_config.provider_model_bundle.model_type_instance + model_type_instance = cast(LargeLanguageModel, model_type_instance) + + curr_message_tokens = model_type_instance.get_num_tokens( + model_config.model, + model_config.credentials, + prompt_messages + ) + 1000 # add 1000 to ensure tool call messages + + max_tokens = 0 + for parameter_rule in model_config.model_schema.parameter_rules: + if (parameter_rule.name == 'max_tokens' + or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): + max_tokens = (model_config.parameters.get(parameter_rule.name) + or model_config.parameters.get(parameter_rule.use_template)) or 0 + + rest_tokens = model_context_tokens - max_tokens - curr_message_tokens + rest_tokens = max(rest_tokens, 0) + + return rest_tokens + + def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: + """ + Fetch model config. + """ + if not self._model_instance or not self._model_config: + self._model_instance, self._model_config = super()._fetch_model_config(node_data_model) + + return self._model_instance, self._model_config + + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: ParameterExtractorNodeData) -> dict[str, list[str]]: + """ + Extract variable selector to variable mapping + :param node_data: node data + :return: + """ + node_data = node_data + + variable_mapping = { + 'query': node_data.query + } + + if node_data.instruction: + variable_template_parser = VariableTemplateParser(template=node_data.instruction) + for selector in variable_template_parser.extract_variable_selectors(): + variable_mapping[selector.variable] = selector.value_selector + + return variable_mapping \ No newline at end of file diff --git a/api/core/workflow/nodes/parameter_extractor/prompts.py b/api/core/workflow/nodes/parameter_extractor/prompts.py new file mode 100644 index 0000000000..499c58d505 --- /dev/null +++ b/api/core/workflow/nodes/parameter_extractor/prompts.py @@ -0,0 +1,206 @@ +FUNCTION_CALLING_EXTRACTOR_NAME = 'extract_parameters' + +FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT = f"""You are a helpful assistant tasked with extracting structured information based on specific criteria provided. Follow the guidelines below to ensure consistency and accuracy. +### Task +Always call the `{FUNCTION_CALLING_EXTRACTOR_NAME}` function with the correct parameters. Ensure that the information extraction is contextual and aligns with the provided criteria. +### Memory +Here is the chat history between the human and assistant, provided within tags: + +\x7bhistories\x7d + +### Instructions: +Some additional information is provided below. Always adhere to these instructions as closely as possible: + +\x7binstruction\x7d + +Steps: +1. Review the chat history provided within the tags. +2. Extract the relevant information based on the criteria given, output multiple values if there is multiple relevant information that match the criteria in the given text. +3. Generate a well-formatted output using the defined functions and arguments. +4. Use the `extract_parameter` function to create structured outputs with appropriate parameters. +5. Do not include any XML tags in your output. +### Example +To illustrate, if the task involves extracting a user's name and their request, your function call might look like this: Ensure your output follows a similar structure to examples. +### Final Output +Produce well-formatted function calls in json without XML tags, as shown in the example. +""" + +FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE = f"""extract structured information from context inside XML tags by calling the function {FUNCTION_CALLING_EXTRACTOR_NAME} with the correct parameters with structure inside XML tags. + +\x7bcontent\x7d + + + +\x7bstructure\x7d + +""" + +FUNCTION_CALLING_EXTRACTOR_EXAMPLE = [{ + 'user': { + 'query': 'What is the weather today in SF?', + 'function': { + 'name': FUNCTION_CALLING_EXTRACTOR_NAME, + 'parameters': { + 'type': 'object', + 'properties': { + 'location': { + 'type': 'string', + 'description': 'The location to get the weather information', + 'required': True + }, + }, + 'required': ['location'] + } + } + }, + 'assistant': { + 'text': 'I need always call the function with the correct parameters. in this case, I need to call the function with the location parameter.', + 'function_call' : { + 'name': FUNCTION_CALLING_EXTRACTOR_NAME, + 'parameters': { + 'location': 'San Francisco' + } + } + } +}, { + 'user': { + 'query': 'I want to eat some apple pie.', + 'function': { + 'name': FUNCTION_CALLING_EXTRACTOR_NAME, + 'parameters': { + 'type': 'object', + 'properties': { + 'food': { + 'type': 'string', + 'description': 'The food to eat', + 'required': True + } + }, + 'required': ['food'] + } + } + }, + 'assistant': { + 'text': 'I need always call the function with the correct parameters. in this case, I need to call the function with the food parameter.', + 'function_call' : { + 'name': FUNCTION_CALLING_EXTRACTOR_NAME, + 'parameters': { + 'food': 'apple pie' + } + } + } +}] + +COMPLETION_GENERATE_JSON_PROMPT = """### Instructions: +Some extra information are provided below, I should always follow the instructions as possible as I can. + +{instruction} + + +### Extract parameter Workflow +I need to extract the following information from the input text. The tag specifies the 'type', 'description' and 'required' of the information to be extracted. + +{{ structure }} + + +Step 1: Carefully read the input and understand the structure of the expected output. +Step 2: Extract relevant parameters from the provided text based on the name and description of object. +Step 3: Structure the extracted parameters to JSON object as specified in . +Step 4: Ensure that the JSON object is properly formatted and valid. The output should not contain any XML tags. Only the JSON object should be outputted. + +### Memory +Here is the chat histories between human and assistant, inside XML tags. + +{histories} + + +### Structure +Here is the structure of the expected output, I should always follow the output structure. +{{γγγ + 'properties1': 'relevant text extracted from input', + 'properties2': 'relevant text extracted from input', +}}γγγ + +### Input Text +Inside XML tags, there is a text that I should extract parameters and convert to a JSON object. + +{text} + + +### Answer +I should always output a valid JSON object. Output nothing other than the JSON object. +```JSON +""" + +CHAT_GENERATE_JSON_PROMPT = """You should always follow the instructions and output a valid JSON object. +The structure of the JSON object you can found in the instructions. + +### Memory +Here is the chat histories between human and assistant, inside XML tags. + +{histories} + + +### Instructions: +Some extra information are provided below, you should always follow the instructions as possible as you can. + +{{instructions}} + +""" + +CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE = """### Structure +Here is the structure of the JSON object, you should always follow the structure. + +{structure} + + +### Text to be converted to JSON +Inside XML tags, there is a text that you should convert to a JSON object. + +{text} + +""" + +CHAT_EXAMPLE = [{ + 'user': { + 'query': 'What is the weather today in SF?', + 'json': { + 'type': 'object', + 'properties': { + 'location': { + 'type': 'string', + 'description': 'The location to get the weather information', + 'required': True + } + }, + 'required': ['location'] + } + }, + 'assistant': { + 'text': 'I need to output a valid JSON object.', + 'json': { + 'location': 'San Francisco' + } + } +}, { + 'user': { + 'query': 'I want to eat some apple pie.', + 'json': { + 'type': 'object', + 'properties': { + 'food': { + 'type': 'string', + 'description': 'The food to eat', + 'required': True + } + }, + 'required': ['food'] + } + }, + 'assistant': { + 'text': 'I need to output a valid JSON object.', + 'json': { + 'result': 'apple pie' + } + } +}] \ No newline at end of file diff --git a/api/core/workflow/nodes/tool/entities.py b/api/core/workflow/nodes/tool/entities.py index 97fbe8a999..98b28ac4f1 100644 --- a/api/core/workflow/nodes/tool/entities.py +++ b/api/core/workflow/nodes/tool/entities.py @@ -7,7 +7,7 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData class ToolEntity(BaseModel): provider_id: str - provider_type: Literal['builtin', 'api'] + provider_type: Literal['builtin', 'api', 'workflow'] provider_name: str # redundancy tool_name: str tool_label: str # redundancy diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index edfff593dc..6194dd1f5d 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -1,13 +1,14 @@ from os import path -from typing import cast +from typing import Optional, cast from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.file.file_obj import FileTransferMethod, FileType, FileVar -from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter +from core.tools.tool.tool import Tool from core.tools.tool_engine import ToolEngine from core.tools.tool_manager import ToolManager from core.tools.utils.message_transformer import ToolFileMessageTransformer -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.tool.entities import ToolNodeData @@ -35,20 +36,23 @@ class ToolNode(BaseNode): 'provider_id': node_data.provider_id } - # get parameters - parameters = self._generate_parameters(variable_pool, node_data) # get tool runtime try: - tool_runtime = ToolManager.get_workflow_tool_runtime(self.tenant_id, self.app_id, self.node_id, node_data) + tool_runtime = ToolManager.get_workflow_tool_runtime( + self.tenant_id, self.app_id, self.node_id, node_data, self.invoke_from + ) except Exception as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, - inputs=parameters, + inputs={}, metadata={ NodeRunMetadataKey.TOOL_INFO: tool_info }, error=f'Failed to get tool runtime: {str(e)}' ) + + # get parameters + parameters = self._generate_parameters(variable_pool, node_data, tool_runtime) try: messages = ToolEngine.workflow_invoke( @@ -56,7 +60,8 @@ class ToolNode(BaseNode): tool_parameters=parameters, user_id=self.user_id, workflow_id=self.workflow_id, - workflow_tool_callback=DifyWorkflowCallbackHandler() + workflow_tool_callback=DifyWorkflowCallbackHandler(), + workflow_call_depth=self.workflow_call_depth + 1 ) except Exception as e: return NodeRunResult( @@ -83,19 +88,32 @@ class ToolNode(BaseNode): inputs=parameters ) - def _generate_parameters(self, variable_pool: VariablePool, node_data: ToolNodeData) -> dict: + def _generate_parameters(self, variable_pool: VariablePool, node_data: ToolNodeData, tool_runtime: Tool) -> dict: """ Generate parameters """ + tool_parameters = tool_runtime.get_all_runtime_parameters() + + def fetch_parameter(name: str) -> Optional[ToolParameter]: + return next((parameter for parameter in tool_parameters if parameter.name == name), None) + result = {} for parameter_name in node_data.tool_parameters: - input = node_data.tool_parameters[parameter_name] - if input.type == 'mixed': - result[parameter_name] = self._format_variable_template(input.value, variable_pool) - elif input.type == 'variable': - result[parameter_name] = variable_pool.get_variable_value(input.value) - elif input.type == 'constant': - result[parameter_name] = input.value + parameter = fetch_parameter(parameter_name) + if not parameter: + continue + if parameter.type == ToolParameter.ToolParameterType.FILE: + result[parameter_name] = [ + v.to_dict() for v in self._fetch_files(variable_pool) + ] + else: + input = node_data.tool_parameters[parameter_name] + if input.type == 'mixed': + result[parameter_name] = self._format_variable_template(input.value, variable_pool) + elif input.type == 'variable': + result[parameter_name] = variable_pool.get_variable_value(input.value) + elif input.type == 'constant': + result[parameter_name] = input.value return result @@ -109,6 +127,13 @@ class ToolNode(BaseNode): inputs[selector.variable] = variable_pool.get_variable_value(selector.value_selector) return template_parser.format(inputs) + + def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]: + files = variable_pool.get_variable_value(['sys', SystemVariable.FILES.value]) + if not files: + return [] + + return files def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[FileVar]]: """ diff --git a/api/core/workflow/nodes/variable_aggregator/__init__.py b/api/core/workflow/nodes/variable_aggregator/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/nodes/variable_aggregator/entities.py b/api/core/workflow/nodes/variable_aggregator/entities.py new file mode 100644 index 0000000000..21ab76670f --- /dev/null +++ b/api/core/workflow/nodes/variable_aggregator/entities.py @@ -0,0 +1,33 @@ + + +from typing import Literal, Optional + +from pydantic import BaseModel + +from core.workflow.entities.base_node_data_entities import BaseNodeData + + +class AdvancedSetting(BaseModel): + """ + Advanced setting. + """ + group_enabled: bool + + class Group(BaseModel): + """ + Group. + """ + output_type: Literal['string', 'number', 'array', 'object'] + variables: list[list[str]] + group_name: str + + groups: list[Group] + +class VariableAssignerNodeData(BaseNodeData): + """ + Knowledge retrieval Node Data. + """ + type: str = 'variable-assigner' + output_type: str + variables: list[list[str]] + advanced_setting: Optional[AdvancedSetting] \ No newline at end of file diff --git a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py new file mode 100644 index 0000000000..b18a4d7fbf --- /dev/null +++ b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py @@ -0,0 +1,52 @@ +from typing import cast + +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData +from models.workflow import WorkflowNodeExecutionStatus + + +class VariableAggregatorNode(BaseNode): + _node_data_cls = VariableAssignerNodeData + _node_type = NodeType.VARIABLE_AGGREGATOR + + def _run(self, variable_pool: VariablePool) -> NodeRunResult: + node_data = cast(VariableAssignerNodeData, self.node_data) + # Get variables + outputs = {} + inputs = {} + + if not node_data.advanced_setting or node_data.advanced_setting.group_enabled: + for variable in node_data.variables: + value = variable_pool.get_variable_value(variable) + + if value is not None: + outputs = { + "output": value + } + + inputs = { + '.'.join(variable[1:]): value + } + break + else: + for group in node_data.advanced_setting.groups: + for variable in group.variables: + value = variable_pool.get_variable_value(variable) + + if value is not None: + outputs[f'{group.group_name}_output'] = value + inputs['.'.join(variable[1:])] = value + break + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs=outputs, + inputs=inputs + ) + + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + return {} diff --git a/api/core/workflow/nodes/variable_assigner/entities.py b/api/core/workflow/nodes/variable_assigner/entities.py deleted file mode 100644 index 035618bd66..0000000000 --- a/api/core/workflow/nodes/variable_assigner/entities.py +++ /dev/null @@ -1,12 +0,0 @@ - - -from core.workflow.entities.base_node_data_entities import BaseNodeData - - -class VariableAssignerNodeData(BaseNodeData): - """ - Knowledge retrieval Node Data. - """ - type: str = 'variable-assigner' - output_type: str - variables: list[list[str]] diff --git a/api/core/workflow/nodes/variable_assigner/variable_assigner_node.py b/api/core/workflow/nodes/variable_assigner/variable_assigner_node.py deleted file mode 100644 index d0a1c9789c..0000000000 --- a/api/core/workflow/nodes/variable_assigner/variable_assigner_node.py +++ /dev/null @@ -1,41 +0,0 @@ -from typing import cast - -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import BaseNode -from core.workflow.nodes.variable_assigner.entities import VariableAssignerNodeData -from models.workflow import WorkflowNodeExecutionStatus - - -class VariableAssignerNode(BaseNode): - _node_data_cls = VariableAssignerNodeData - _node_type = NodeType.VARIABLE_ASSIGNER - - def _run(self, variable_pool: VariablePool) -> NodeRunResult: - node_data: VariableAssignerNodeData = cast(self._node_data_cls, self.node_data) - # Get variables - outputs = {} - inputs = {} - for variable in node_data.variables: - value = variable_pool.get_variable_value(variable) - - if value is not None: - outputs = { - "output": value - } - - inputs = { - '.'.join(variable[1:]): value - } - break - - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs=outputs, - inputs=inputs - ) - - @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: - return {} diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 07b0a2741a..86aa650056 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -6,6 +6,7 @@ from flask import current_app from core.app.app_config.entities import FileExtraConfig from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException +from core.app.entities.app_invoke_entities import InvokeFrom from core.file.file_obj import FileTransferMethod, FileType, FileVar from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType @@ -13,19 +14,22 @@ from core.workflow.entities.variable_pool import VariablePool, VariableValue from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.nodes.answer.answer_node import AnswerNode -from core.workflow.nodes.base_node import BaseNode, UserFrom +from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.end.end_node import EndNode from core.workflow.nodes.http_request.http_request_node import HttpRequestNode from core.workflow.nodes.if_else.if_else_node import IfElseNode +from core.workflow.nodes.iteration.entities import IterationState +from core.workflow.nodes.iteration.iteration_node import IterationNode from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode from core.workflow.nodes.llm.entities import LLMNodeData from core.workflow.nodes.llm.llm_node import LLMNode +from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode from core.workflow.nodes.start.start_node import StartNode from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode from core.workflow.nodes.tool.tool_node import ToolNode -from core.workflow.nodes.variable_assigner.variable_assigner_node import VariableAssignerNode +from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode from extensions.ext_database import db from models.workflow import ( Workflow, @@ -44,9 +48,14 @@ node_classes = { NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode, NodeType.HTTP_REQUEST: HttpRequestNode, NodeType.TOOL: ToolNode, - NodeType.VARIABLE_ASSIGNER: VariableAssignerNode, + NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode, + NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, + NodeType.ITERATION: IterationNode, + NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode } +WORKFLOW_CALL_MAX_DEPTH = 5 + logger = logging.getLogger(__name__) @@ -83,18 +92,20 @@ class WorkflowEngineManager: def run_workflow(self, workflow: Workflow, user_id: str, user_from: UserFrom, + invoke_from: InvokeFrom, user_inputs: dict, system_inputs: Optional[dict] = None, - callbacks: list[BaseWorkflowCallback] = None) -> None: + callbacks: list[BaseWorkflowCallback] = None, + call_depth: Optional[int] = 0, + variable_pool: Optional[VariablePool] = None) -> None: """ - Run workflow :param workflow: Workflow instance :param user_id: user id :param user_from: user from :param user_inputs: user variables inputs :param system_inputs: system inputs, like: query, files :param callbacks: workflow callbacks - :return: + :param call_depth: call depth """ # fetch workflow graph graph = workflow.graph_dict @@ -109,57 +120,185 @@ class WorkflowEngineManager: if not isinstance(graph.get('edges'), list): raise ValueError('edges in workflow graph must be a list') + + # init variable pool + if not variable_pool: + variable_pool = VariablePool( + system_variables=system_inputs, + user_inputs=user_inputs + ) + + if call_depth > WORKFLOW_CALL_MAX_DEPTH: + raise ValueError('Max workflow call depth reached.') + + # init workflow run state + workflow_run_state = WorkflowRunState( + workflow=workflow, + start_at=time.perf_counter(), + variable_pool=variable_pool, + user_id=user_id, + user_from=user_from, + invoke_from=invoke_from, + workflow_call_depth=call_depth + ) # init workflow run if callbacks: for callback in callbacks: callback.on_workflow_run_started() - # init workflow run state - workflow_run_state = WorkflowRunState( + # run workflow + self._run_workflow( workflow=workflow, - start_at=time.perf_counter(), - variable_pool=VariablePool( - system_variables=system_inputs, - user_inputs=user_inputs - ), - user_id=user_id, - user_from=user_from + workflow_run_state=workflow_run_state, + callbacks=callbacks, ) + def _run_workflow(self, workflow: Workflow, + workflow_run_state: WorkflowRunState, + callbacks: list[BaseWorkflowCallback] = None, + start_at: Optional[str] = None, + end_at: Optional[str] = None) -> None: + """ + Run workflow + :param workflow: Workflow instance + :param user_id: user id + :param user_from: user from + :param user_inputs: user variables inputs + :param system_inputs: system inputs, like: query, files + :param callbacks: workflow callbacks + :param call_depth: call depth + :param start_at: force specific start node + :param end_at: force specific end node + :return: + """ + graph = workflow.graph_dict + try: - predecessor_node = None + predecessor_node: BaseNode = None + current_iteration_node: BaseIterationNode = None has_entry_node = False max_execution_steps = current_app.config.get("WORKFLOW_MAX_EXECUTION_STEPS") max_execution_time = current_app.config.get("WORKFLOW_MAX_EXECUTION_TIME") while True: # get next node, multiple target nodes in the future - next_node = self._get_next_node( + next_node = self._get_next_overall_node( workflow_run_state=workflow_run_state, graph=graph, predecessor_node=predecessor_node, - callbacks=callbacks + callbacks=callbacks, + start_at=start_at, + end_at=end_at ) + if not next_node: + # reached loop/iteration end or overall end + if current_iteration_node and workflow_run_state.current_iteration_state: + # reached loop/iteration end + # get next iteration + next_iteration = current_iteration_node.get_next_iteration( + variable_pool=workflow_run_state.variable_pool, + state=workflow_run_state.current_iteration_state + ) + self._workflow_iteration_next( + graph=graph, + current_iteration_node=current_iteration_node, + workflow_run_state=workflow_run_state, + callbacks=callbacks + ) + if isinstance(next_iteration, NodeRunResult): + if next_iteration.outputs: + for variable_key, variable_value in next_iteration.outputs.items(): + # append variables to variable pool recursively + self._append_variables_recursively( + variable_pool=workflow_run_state.variable_pool, + node_id=current_iteration_node.node_id, + variable_key_list=[variable_key], + variable_value=variable_value + ) + self._workflow_iteration_completed( + current_iteration_node=current_iteration_node, + workflow_run_state=workflow_run_state, + callbacks=callbacks + ) + # iteration has ended + next_node = self._get_next_overall_node( + workflow_run_state=workflow_run_state, + graph=graph, + predecessor_node=current_iteration_node, + callbacks=callbacks, + start_at=start_at, + end_at=end_at + ) + current_iteration_node = None + workflow_run_state.current_iteration_state = None + # continue overall process + elif isinstance(next_iteration, str): + # move to next iteration + next_node_id = next_iteration + # get next id + next_node = self._get_node(workflow_run_state, graph, next_node_id, callbacks) + if not next_node: break # check is already ran - if next_node.node_id in [node_and_result.node.node_id - for node_and_result in workflow_run_state.workflow_nodes_and_results]: + if self._check_node_has_ran(workflow_run_state, next_node.node_id): predecessor_node = next_node continue has_entry_node = True # max steps reached - if len(workflow_run_state.workflow_nodes_and_results) > max_execution_steps: + if workflow_run_state.workflow_node_steps > max_execution_steps: raise ValueError('Max steps {} reached.'.format(max_execution_steps)) # or max execution time reached if self._is_timed_out(start_at=workflow_run_state.start_at, max_execution_time=max_execution_time): raise ValueError('Max execution time {}s reached.'.format(max_execution_time)) + # handle iteration nodes + if isinstance(next_node, BaseIterationNode): + current_iteration_node = next_node + workflow_run_state.current_iteration_state = next_node.run( + variable_pool=workflow_run_state.variable_pool + ) + self._workflow_iteration_started( + graph=graph, + current_iteration_node=current_iteration_node, + workflow_run_state=workflow_run_state, + predecessor_node_id=predecessor_node.node_id if predecessor_node else None, + callbacks=callbacks + ) + predecessor_node = next_node + # move to start node of iteration + next_node_id = next_node.get_next_iteration( + variable_pool=workflow_run_state.variable_pool, + state=workflow_run_state.current_iteration_state + ) + self._workflow_iteration_next( + graph=graph, + current_iteration_node=current_iteration_node, + workflow_run_state=workflow_run_state, + callbacks=callbacks + ) + if isinstance(next_node_id, NodeRunResult): + # iteration has ended + current_iteration_node.set_output( + variable_pool=workflow_run_state.variable_pool, + state=workflow_run_state.current_iteration_state + ) + self._workflow_iteration_completed( + current_iteration_node=current_iteration_node, + workflow_run_state=workflow_run_state, + callbacks=callbacks + ) + current_iteration_node = None + workflow_run_state.current_iteration_state = None + continue + else: + next_node = self._get_node(workflow_run_state, graph, next_node_id, callbacks) + # run workflow, run multiple target nodes in the future self._run_workflow_node( workflow_run_state=workflow_run_state, @@ -235,7 +374,9 @@ class WorkflowEngineManager: workflow_id=workflow.id, user_id=user_id, user_from=UserFrom.ACCOUNT, - config=node_config + invoke_from=InvokeFrom.DEBUGGER, + config=node_config, + workflow_call_depth=0 ) try: @@ -251,49 +392,14 @@ class WorkflowEngineManager: except NotImplementedError: variable_mapping = {} - for variable_key, variable_selector in variable_mapping.items(): - if variable_key not in user_inputs: - raise ValueError(f'Variable key {variable_key} not found in user inputs.') - - # fetch variable node id from variable selector - variable_node_id = variable_selector[0] - variable_key_list = variable_selector[1:] - - # get value - value = user_inputs.get(variable_key) - - # temp fix for image type - if node_type == NodeType.LLM: - new_value = [] - if isinstance(value, list): - node_data = node_instance.node_data - node_data = cast(LLMNodeData, node_data) - - detail = node_data.vision.configs.detail if node_data.vision.configs else None - - for item in value: - if isinstance(item, dict) and 'type' in item and item['type'] == 'image': - transfer_method = FileTransferMethod.value_of(item.get('transfer_method')) - file = FileVar( - tenant_id=workflow.tenant_id, - type=FileType.IMAGE, - transfer_method=transfer_method, - url=item.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None, - related_id=item.get( - 'upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None, - extra_config=FileExtraConfig(image_config={'detail': detail} if detail else None), - ) - new_value.append(file) - - if new_value: - value = new_value - - # append variable and value to variable pool - variable_pool.append_variable( - node_id=variable_node_id, - variable_key_list=variable_key_list, - value=value - ) + self._mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id=workflow.tenant_id, + node_instance=node_instance + ) + # run node node_run_result = node_instance.run( variable_pool=variable_pool @@ -311,6 +417,126 @@ class WorkflowEngineManager: return node_instance, node_run_result + def single_step_run_iteration_workflow_node(self, workflow: Workflow, + node_id: str, + user_id: str, + user_inputs: dict, + callbacks: list[BaseWorkflowCallback] = None, + ) -> None: + """ + Single iteration run workflow node + """ + # fetch node info from workflow graph + graph = workflow.graph_dict + if not graph: + raise ValueError('workflow graph not found') + + nodes = graph.get('nodes') + if not nodes: + raise ValueError('nodes not found in workflow graph') + + for node in nodes: + if node.get('id') == node_id: + if node.get('data', {}).get('type') in [ + NodeType.ITERATION.value, + NodeType.LOOP.value, + ]: + node_config = node + else: + raise ValueError('node id is not an iteration node') + + # init variable pool + variable_pool = VariablePool( + system_variables={}, + user_inputs={} + ) + + # variable selector to variable mapping + iteration_nested_nodes = [ + node for node in nodes + if node.get('data', {}).get('iteration_id') == node_id or node.get('id') == node_id + ] + iteration_nested_node_ids = [node.get('id') for node in iteration_nested_nodes] + + if not iteration_nested_nodes: + raise ValueError('iteration has no nested nodes') + + # init workflow run + if callbacks: + for callback in callbacks: + callback.on_workflow_run_started() + + for node_config in iteration_nested_nodes: + # mapping user inputs to variable pool + node_cls = node_classes.get(NodeType.value_of(node_config.get('data', {}).get('type'))) + try: + variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(node_config) + except NotImplementedError: + variable_mapping = {} + + # remove iteration variables + variable_mapping = { + f'{node_config.get("id")}.{key}': value for key, value in variable_mapping.items() + if value[0] != node_id + } + + # remove variable out from iteration + variable_mapping = { + key: value for key, value in variable_mapping.items() + if value[0] not in iteration_nested_node_ids + } + + # append variables to variable pool + node_instance = node_cls( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + workflow_id=workflow.id, + user_id=user_id, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + config=node_config, + callbacks=callbacks, + workflow_call_depth=0 + ) + + self._mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id=workflow.tenant_id, + node_instance=node_instance + ) + + # fetch end node of iteration + end_node_id = None + for edge in graph.get('edges'): + if edge.get('source') == node_id: + end_node_id = edge.get('target') + break + + if not end_node_id: + raise ValueError('end node of iteration not found') + + # init workflow run state + workflow_run_state = WorkflowRunState( + workflow=workflow, + start_at=time.perf_counter(), + variable_pool=variable_pool, + user_id=user_id, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + workflow_call_depth=0 + ) + + # run workflow + self._run_workflow( + workflow=workflow, + workflow_run_state=workflow_run_state, + callbacks=callbacks, + start_at=node_id, + end_at=end_node_id + ) + def _workflow_run_success(self, callbacks: list[BaseWorkflowCallback] = None) -> None: """ Workflow run success @@ -336,10 +562,96 @@ class WorkflowEngineManager: error=error ) - def _get_next_node(self, workflow_run_state: WorkflowRunState, + def _workflow_iteration_started(self, graph: dict, + current_iteration_node: BaseIterationNode, + workflow_run_state: WorkflowRunState, + predecessor_node_id: Optional[str] = None, + callbacks: list[BaseWorkflowCallback] = None) -> None: + """ + Workflow iteration started + :param current_iteration_node: current iteration node + :param workflow_run_state: workflow run state + :param callbacks: workflow callbacks + :return: + """ + # get nested nodes + iteration_nested_nodes = [ + node for node in graph.get('nodes') + if node.get('data', {}).get('iteration_id') == current_iteration_node.node_id + ] + + if not iteration_nested_nodes: + raise ValueError('iteration has no nested nodes') + + if callbacks: + if isinstance(workflow_run_state.current_iteration_state, IterationState): + for callback in callbacks: + callback.on_workflow_iteration_started( + node_id=current_iteration_node.node_id, + node_type=NodeType.ITERATION, + node_run_index=workflow_run_state.workflow_node_steps, + node_data=current_iteration_node.node_data, + inputs=workflow_run_state.current_iteration_state.inputs, + predecessor_node_id=predecessor_node_id, + metadata=workflow_run_state.current_iteration_state.metadata.dict() + ) + + # add steps + workflow_run_state.workflow_node_steps += 1 + + def _workflow_iteration_next(self, graph: dict, + current_iteration_node: BaseIterationNode, + workflow_run_state: WorkflowRunState, + callbacks: list[BaseWorkflowCallback] = None) -> None: + """ + Workflow iteration next + :param workflow_run_state: workflow run state + :return: + """ + if callbacks: + if isinstance(workflow_run_state.current_iteration_state, IterationState): + for callback in callbacks: + callback.on_workflow_iteration_next( + node_id=current_iteration_node.node_id, + node_type=NodeType.ITERATION, + index=workflow_run_state.current_iteration_state.index, + node_run_index=workflow_run_state.workflow_node_steps, + output=workflow_run_state.current_iteration_state.get_current_output() + ) + # clear ran nodes + workflow_run_state.workflow_node_runs = [ + node_run for node_run in workflow_run_state.workflow_node_runs + if node_run.iteration_node_id != current_iteration_node.node_id + ] + + # clear variables in current iteration + nodes = graph.get('nodes') + nodes = [node for node in nodes if node.get('data', {}).get('iteration_id') == current_iteration_node.node_id] + + for node in nodes: + workflow_run_state.variable_pool.clear_node_variables(node_id=node.get('id')) + + def _workflow_iteration_completed(self, current_iteration_node: BaseIterationNode, + workflow_run_state: WorkflowRunState, + callbacks: list[BaseWorkflowCallback] = None) -> None: + if callbacks: + if isinstance(workflow_run_state.current_iteration_state, IterationState): + for callback in callbacks: + callback.on_workflow_iteration_completed( + node_id=current_iteration_node.node_id, + node_type=NodeType.ITERATION, + node_run_index=workflow_run_state.workflow_node_steps, + outputs={ + 'output': workflow_run_state.current_iteration_state.outputs + } + ) + + def _get_next_overall_node(self, workflow_run_state: WorkflowRunState, graph: dict, predecessor_node: Optional[BaseNode] = None, - callbacks: list[BaseWorkflowCallback] = None) -> Optional[BaseNode]: + callbacks: list[BaseWorkflowCallback] = None, + start_at: Optional[str] = None, + end_at: Optional[str] = None) -> Optional[BaseNode]: """ Get next node multiple target nodes in the future. @@ -354,16 +666,26 @@ class WorkflowEngineManager: if not predecessor_node: for node_config in nodes: - if node_config.get('data', {}).get('type', '') == NodeType.START.value: - return StartNode( + node_cls = None + if start_at: + if node_config.get('id') == start_at: + node_cls = node_classes.get(NodeType.value_of(node_config.get('data', {}).get('type'))) + else: + if node_config.get('data', {}).get('type', '') == NodeType.START.value: + node_cls = StartNode + if node_cls: + return node_cls( tenant_id=workflow_run_state.tenant_id, app_id=workflow_run_state.app_id, workflow_id=workflow_run_state.workflow_id, user_id=workflow_run_state.user_id, user_from=workflow_run_state.user_from, + invoke_from=workflow_run_state.invoke_from, config=node_config, - callbacks=callbacks + callbacks=callbacks, + workflow_call_depth=workflow_run_state.workflow_call_depth ) + else: edges = graph.get('edges') source_node_id = predecessor_node.node_id @@ -390,6 +712,9 @@ class WorkflowEngineManager: target_node_id = outgoing_edge.get('target') + if end_at and target_node_id == end_at: + return None + # fetch target node from target node id target_node_config = None for node in nodes: @@ -409,9 +734,40 @@ class WorkflowEngineManager: workflow_id=workflow_run_state.workflow_id, user_id=workflow_run_state.user_id, user_from=workflow_run_state.user_from, + invoke_from=workflow_run_state.invoke_from, config=target_node_config, - callbacks=callbacks + callbacks=callbacks, + workflow_call_depth=workflow_run_state.workflow_call_depth ) + + def _get_node(self, workflow_run_state: WorkflowRunState, + graph: dict, + node_id: str, + callbacks: list[BaseWorkflowCallback]) -> Optional[BaseNode]: + """ + Get node from graph by node id + """ + nodes = graph.get('nodes') + if not nodes: + return None + + for node_config in nodes: + if node_config.get('id') == node_id: + node_type = NodeType.value_of(node_config.get('data', {}).get('type')) + node_cls = node_classes.get(node_type) + return node_cls( + tenant_id=workflow_run_state.tenant_id, + app_id=workflow_run_state.app_id, + workflow_id=workflow_run_state.workflow_id, + user_id=workflow_run_state.user_id, + user_from=workflow_run_state.user_from, + invoke_from=workflow_run_state.invoke_from, + config=node_config, + callbacks=callbacks, + workflow_call_depth=workflow_run_state.workflow_call_depth + ) + + return None def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool: """ @@ -422,6 +778,15 @@ class WorkflowEngineManager: """ return time.perf_counter() - start_at > max_execution_time + def _check_node_has_ran(self, workflow_run_state: WorkflowRunState, node_id: str) -> bool: + """ + Check node has ran + """ + return bool([ + node_and_result for node_and_result in workflow_run_state.workflow_node_runs + if node_and_result.node_id == node_id + ]) + def _run_workflow_node(self, workflow_run_state: WorkflowRunState, node: BaseNode, predecessor_node: Optional[BaseNode] = None, @@ -432,7 +797,7 @@ class WorkflowEngineManager: node_id=node.node_id, node_type=node.node_type, node_data=node.node_data, - node_run_index=len(workflow_run_state.workflow_nodes_and_results) + 1, + node_run_index=workflow_run_state.workflow_node_steps, predecessor_node_id=predecessor_node.node_id if predecessor_node else None ) @@ -446,6 +811,16 @@ class WorkflowEngineManager: # add to workflow_nodes_and_results workflow_run_state.workflow_nodes_and_results.append(workflow_nodes_and_result) + # add steps + workflow_run_state.workflow_node_steps += 1 + + # mark node as running + if workflow_run_state.current_iteration_state: + workflow_run_state.workflow_node_runs.append(WorkflowRunState.NodeRun( + node_id=node.node_id, + iteration_node_id=workflow_run_state.current_iteration_state.iteration_node_id + )) + try: # run node, result must have inputs, process_data, outputs, execution_metadata node_run_result = node.run( @@ -565,3 +940,53 @@ class WorkflowEngineManager: new_value[key] = new_val return new_value + + def _mapping_user_inputs_to_variable_pool(self, + variable_mapping: dict, + user_inputs: dict, + variable_pool: VariablePool, + tenant_id: str, + node_instance: BaseNode): + for variable_key, variable_selector in variable_mapping.items(): + if variable_key not in user_inputs: + raise ValueError(f'Variable key {variable_key} not found in user inputs.') + + # fetch variable node id from variable selector + variable_node_id = variable_selector[0] + variable_key_list = variable_selector[1:] + + # get value + value = user_inputs.get(variable_key) + + # temp fix for image type + if node_instance.node_type == NodeType.LLM: + new_value = [] + if isinstance(value, list): + node_data = node_instance.node_data + node_data = cast(LLMNodeData, node_data) + + detail = node_data.vision.configs.detail if node_data.vision.configs else None + + for item in value: + if isinstance(item, dict) and 'type' in item and item['type'] == 'image': + transfer_method = FileTransferMethod.value_of(item.get('transfer_method')) + file = FileVar( + tenant_id=tenant_id, + type=FileType.IMAGE, + transfer_method=transfer_method, + url=item.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None, + related_id=item.get( + 'upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None, + extra_config=FileExtraConfig(image_config={'detail': detail} if detail else None), + ) + new_value.append(file) + + if new_value: + value = new_value + + # append variable and value to variable pool + variable_pool.append_variable( + node_id=variable_node_id, + variable_key_list=variable_key_list, + value=value + ) \ No newline at end of file diff --git a/api/events/event_handlers/__init__.py b/api/events/event_handlers/__init__.py index 688b80aa8c..ceb50a252b 100644 --- a/api/events/event_handlers/__init__.py +++ b/api/events/event_handlers/__init__.py @@ -6,6 +6,7 @@ from .create_site_record_when_app_created import handle from .deduct_quota_when_messaeg_created import handle from .delete_installed_app_when_app_deleted import handle from .delete_tool_parameters_cache_when_sync_draft_workflow import handle +from .delete_workflow_as_tool_when_app_deleted import handle from .update_app_dataset_join_when_app_model_config_updated import handle from .update_app_dataset_join_when_app_published_workflow_updated import handle from .update_provider_last_used_at_when_messaeg_created import handle diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py index 2a127d903e..1f6da34ee2 100644 --- a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py +++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py @@ -10,18 +10,22 @@ def handle(sender, **kwargs): app = sender for node_data in kwargs.get('synced_draft_workflow').graph_dict.get('nodes', []): if node_data.get('data', {}).get('type') == NodeType.TOOL.value: - tool_entity = ToolEntity(**node_data["data"]) - tool_runtime = ToolManager.get_tool_runtime( - provider_type=tool_entity.provider_type, - provider_name=tool_entity.provider_id, - tool_name=tool_entity.tool_name, - tenant_id=app.tenant_id, - ) - manager = ToolParameterConfigurationManager( - tenant_id=app.tenant_id, - tool_runtime=tool_runtime, - provider_name=tool_entity.provider_name, - provider_type=tool_entity.provider_type, - identity_id=f'WORKFLOW.{app.id}.{node_data.get("id")}' - ) - manager.delete_tool_parameters_cache() + try: + tool_entity = ToolEntity(**node_data["data"]) + tool_runtime = ToolManager.get_tool_runtime( + provider_type=tool_entity.provider_type, + provider_id=tool_entity.provider_id, + tool_name=tool_entity.tool_name, + tenant_id=app.tenant_id, + ) + manager = ToolParameterConfigurationManager( + tenant_id=app.tenant_id, + tool_runtime=tool_runtime, + provider_name=tool_entity.provider_name, + provider_type=tool_entity.provider_type, + identity_id=f'WORKFLOW.{app.id}.{node_data.get("id")}' + ) + manager.delete_tool_parameters_cache() + except: + # tool dose not exist + pass diff --git a/api/events/event_handlers/delete_workflow_as_tool_when_app_deleted.py b/api/events/event_handlers/delete_workflow_as_tool_when_app_deleted.py new file mode 100644 index 0000000000..0c56688ff6 --- /dev/null +++ b/api/events/event_handlers/delete_workflow_as_tool_when_app_deleted.py @@ -0,0 +1,14 @@ +from events.app_event import app_was_deleted +from extensions.ext_database import db +from models.tools import WorkflowToolProvider + + +@app_was_deleted.connect +def handle(sender, **kwargs): + app = sender + workflow_tools = db.session.query(WorkflowToolProvider).filter( + WorkflowToolProvider.app_id == app.id + ).all() + for workflow_tool in workflow_tools: + db.session.delete(workflow_tool) + db.session.commit() diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index 94d16be8db..54d7ed55f8 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -11,5 +11,6 @@ workflow_fields = { 'created_by': fields.Nested(simple_account_fields, attribute='created_by_account'), 'created_at': TimestampField, 'updated_by': fields.Nested(simple_account_fields, attribute='updated_by_account', allow_null=True), - 'updated_at': TimestampField + 'updated_at': TimestampField, + 'tool_published': fields.Boolean, } diff --git a/api/libs/helper.py b/api/libs/helper.py index f9cf590b7a..f4be9c5531 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -46,7 +46,13 @@ def uuid_value(value): error = ('{value} is not a valid uuid.' .format(value=value)) raise ValueError(error) - + +def alphanumeric(value: str): + # check if the value is alphanumeric and underlined + if re.match(r'^[a-zA-Z0-9_]+$', value): + return value + + raise ValueError(f'{value} is not a valid alphanumeric value') def timestamp_value(timestamp): try: diff --git a/api/migrations/versions/03f98355ba0e_add_workflow_tool_label_and_tool_.py b/api/migrations/versions/03f98355ba0e_add_workflow_tool_label_and_tool_.py new file mode 100644 index 0000000000..0fba6a87eb --- /dev/null +++ b/api/migrations/versions/03f98355ba0e_add_workflow_tool_label_and_tool_.py @@ -0,0 +1,32 @@ +"""add workflow tool label and tool bindings idx + +Revision ID: 03f98355ba0e +Revises: 9e98fbaffb88 +Create Date: 2024-05-25 07:17:00.539125 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '03f98355ba0e' +down_revision = '9e98fbaffb88' +branch_labels = None +depends_on = None + + +def upgrade(): + with op.batch_alter_table('tool_label_bindings', schema=None) as batch_op: + batch_op.create_unique_constraint('unique_tool_label_bind', ['tool_id', 'label_name']) + + with op.batch_alter_table('tool_workflow_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('label', sa.String(length=255), server_default='', nullable=False)) + +def downgrade(): + with op.batch_alter_table('tool_workflow_providers', schema=None) as batch_op: + batch_op.drop_column('label') + + with op.batch_alter_table('tool_label_bindings', schema=None) as batch_op: + batch_op.drop_constraint('unique_tool_label_bind', type_='unique') diff --git a/api/migrations/versions/3b18fea55204_add_tool_label_bings.py b/api/migrations/versions/3b18fea55204_add_tool_label_bings.py new file mode 100644 index 0000000000..db3119badf --- /dev/null +++ b/api/migrations/versions/3b18fea55204_add_tool_label_bings.py @@ -0,0 +1,42 @@ +"""add tool label bings + +Revision ID: 3b18fea55204 +Revises: 7bdef072e63a +Create Date: 2024-05-14 09:27:18.857890 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '3b18fea55204' +down_revision = '7bdef072e63a' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('tool_label_bindings', + sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tool_id', sa.String(length=64), nullable=False), + sa.Column('tool_type', sa.String(length=40), nullable=False), + sa.Column('label_name', sa.String(length=40), nullable=False), + sa.PrimaryKeyConstraint('id', name='tool_label_bind_pkey') + ) + + with op.batch_alter_table('tool_workflow_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('privacy_policy', sa.String(length=255), server_default='', nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_workflow_providers', schema=None) as batch_op: + batch_op.drop_column('privacy_policy') + + op.drop_table('tool_label_bindings') + # ### end Alembic commands ### diff --git a/api/migrations/versions/7bdef072e63a_add_workflow_tool.py b/api/migrations/versions/7bdef072e63a_add_workflow_tool.py new file mode 100644 index 0000000000..67b61e5c76 --- /dev/null +++ b/api/migrations/versions/7bdef072e63a_add_workflow_tool.py @@ -0,0 +1,42 @@ +"""add workflow tool + +Revision ID: 7bdef072e63a +Revises: 5fda94355fce +Create Date: 2024-05-04 09:47:19.366961 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '7bdef072e63a' +down_revision = '5fda94355fce' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('tool_workflow_providers', + sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('name', sa.String(length=40), nullable=False), + sa.Column('icon', sa.String(length=255), nullable=False), + sa.Column('app_id', models.StringUUID(), nullable=False), + sa.Column('user_id', models.StringUUID(), nullable=False), + sa.Column('tenant_id', models.StringUUID(), nullable=False), + sa.Column('description', sa.Text(), nullable=False), + sa.Column('parameter_configuration', sa.Text(), server_default='[]', nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='tool_workflow_provider_pkey'), + sa.UniqueConstraint('name', 'tenant_id', name='unique_workflow_tool_provider'), + sa.UniqueConstraint('tenant_id', 'app_id', name='unique_workflow_tool_provider_app_id') + ) + # ### end Alembic commands ### + + +def downgrade(): + op.drop_table('tool_workflow_providers') + # ### end Alembic commands ### diff --git a/api/migrations/versions/9e98fbaffb88_add_workflow_tool_version.py b/api/migrations/versions/9e98fbaffb88_add_workflow_tool_version.py new file mode 100644 index 0000000000..bfda7d619c --- /dev/null +++ b/api/migrations/versions/9e98fbaffb88_add_workflow_tool_version.py @@ -0,0 +1,26 @@ +"""add workflow tool version + +Revision ID: 9e98fbaffb88 +Revises: 3b18fea55204 +Create Date: 2024-05-21 10:25:40.434162 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '9e98fbaffb88' +down_revision = '3b18fea55204' +branch_labels = None +depends_on = None + + +def upgrade(): + with op.batch_alter_table('tool_workflow_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('version', sa.String(length=255), server_default='', nullable=False)) + +def downgrade(): + with op.batch_alter_table('tool_workflow_providers', schema=None) as batch_op: + batch_op.drop_column('version') diff --git a/api/models/model.py b/api/models/model.py index 372c5e8550..657db5a5c2 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -100,7 +100,7 @@ class App(db.Model): return None @property - def workflow(self): + def workflow(self) -> Optional['Workflow']: if self.workflow_id: from .workflow import Workflow return db.session.query(Workflow).filter(Workflow.id == self.workflow_id).first() diff --git a/api/models/tools.py b/api/models/tools.py index 64fc334549..49212916ec 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -3,8 +3,8 @@ import json from sqlalchemy import ForeignKey from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_bundle import ApiBasedToolBundle -from core.tools.entities.tool_entities import ApiProviderSchemaType +from core.tools.entities.tool_bundle import ApiToolBundle +from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration from extensions.ext_database import db from models import StringUUID from models.model import Account, App, Tenant @@ -118,8 +118,8 @@ class ApiToolProvider(db.Model): return ApiProviderSchemaType.value_of(self.schema_type_str) @property - def tools(self) -> list[ApiBasedToolBundle]: - return [ApiBasedToolBundle(**tool) for tool in json.loads(self.tools_str)] + def tools(self) -> list[ApiToolBundle]: + return [ApiToolBundle(**tool) for tool in json.loads(self.tools_str)] @property def credentials(self) -> dict: @@ -132,7 +132,84 @@ class ApiToolProvider(db.Model): @property def tenant(self) -> Tenant: return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() + +class ToolLabelBinding(db.Model): + """ + The table stores the labels for tools. + """ + __tablename__ = 'tool_label_bindings' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='tool_label_bind_pkey'), + db.UniqueConstraint('tool_id', 'label_name', name='unique_tool_label_bind'), + ) + + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + # tool id + tool_id = db.Column(db.String(64), nullable=False) + # tool type + tool_type = db.Column(db.String(40), nullable=False) + # label name + label_name = db.Column(db.String(40), nullable=False) + +class WorkflowToolProvider(db.Model): + """ + The table stores the workflow providers. + """ + __tablename__ = 'tool_workflow_providers' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='tool_workflow_provider_pkey'), + db.UniqueConstraint('name', 'tenant_id', name='unique_workflow_tool_provider'), + db.UniqueConstraint('tenant_id', 'app_id', name='unique_workflow_tool_provider_app_id'), + ) + + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + # name of the workflow provider + name = db.Column(db.String(40), nullable=False) + # label of the workflow provider + label = db.Column(db.String(255), nullable=False, server_default='') + # icon + icon = db.Column(db.String(255), nullable=False) + # app id of the workflow provider + app_id = db.Column(StringUUID, nullable=False) + # version of the workflow provider + version = db.Column(db.String(255), nullable=False, server_default='') + # who created this tool + user_id = db.Column(StringUUID, nullable=False) + # tenant id + tenant_id = db.Column(StringUUID, nullable=False) + # description of the provider + description = db.Column(db.Text, nullable=False) + # parameter configuration + parameter_configuration = db.Column(db.Text, nullable=False, server_default='[]') + # privacy policy + privacy_policy = db.Column(db.String(255), nullable=True, server_default='') + + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + + @property + def schema_type(self) -> ApiProviderSchemaType: + return ApiProviderSchemaType.value_of(self.schema_type_str) + @property + def user(self) -> Account: + return db.session.query(Account).filter(Account.id == self.user_id).first() + + @property + def tenant(self) -> Tenant: + return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() + + @property + def parameter_configurations(self) -> list[WorkflowToolParameterConfiguration]: + return [ + WorkflowToolParameterConfiguration(**config) + for config in json.loads(self.parameter_configuration) + ] + + @property + def app(self) -> App: + return db.session.query(App).filter(App.id == self.app_id).first() + class ToolModelInvoke(db.Model): """ store the invoke logs from tool invoke diff --git a/api/models/workflow.py b/api/models/workflow.py index 3f44641032..264fea4ecb 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -2,7 +2,6 @@ import json from enum import Enum from typing import Optional, Union -from core.tools.tool_manager import ToolManager from extensions.ext_database import db from libs import helper from models import StringUUID @@ -171,6 +170,12 @@ class Workflow(db.Model): return helper.generate_text_hash(json.dumps(entity, sort_keys=True)) + @property + def tool_published(self) -> bool: + from models.tools import WorkflowToolProvider + return db.session.query(WorkflowToolProvider).filter( + WorkflowToolProvider.app_id == self.app_id + ).first() is not None class WorkflowRunTriggeredFrom(Enum): """ @@ -473,6 +478,7 @@ class WorkflowNodeExecution(db.Model): @property def extras(self): + from core.tools.tool_manager import ToolManager extras = {} if self.execution_metadata_dict: from core.workflow.entities.node_entities import NodeType diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index 185d9ba89f..f73a6dcbb6 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -75,6 +75,35 @@ class AppGenerateService: else: raise ValueError(f'Invalid app mode {app_model.mode}') + @classmethod + def generate_single_iteration(cls, app_model: App, + user: Union[Account, EndUser], + node_id: str, + args: Any, + streaming: bool = True): + if app_model.mode == AppMode.ADVANCED_CHAT.value: + workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) + return AdvancedChatAppGenerator().single_iteration_generate( + app_model=app_model, + workflow=workflow, + node_id=node_id, + user=user, + args=args, + stream=streaming + ) + elif app_model.mode == AppMode.WORKFLOW.value: + workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) + return WorkflowAppGenerator().single_iteration_generate( + app_model=app_model, + workflow=workflow, + node_id=node_id, + user=user, + args=args, + stream=streaming + ) + else: + raise ValueError(f'Invalid app mode {app_model.mode}') + @classmethod def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser], message_id: str, invoke_from: InvokeFrom, streaming: bool = True) \ diff --git a/api/services/tools_manage_service.py b/api/services/tools/api_tools_manage_service.py similarity index 58% rename from api/services/tools_manage_service.py rename to api/services/tools/api_tools_manage_service.py index 7100d79ee0..9a0d6ca8d9 100644 --- a/api/services/tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -4,97 +4,30 @@ import logging from httpx import get from core.model_runtime.utils.encoders import jsonable_encoder +from core.tools.entities.api_entities import UserTool, UserToolProvider from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_bundle import ApiBasedToolBundle +from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ( ApiProviderAuthType, ApiProviderSchemaType, ToolCredentialsOption, ToolProviderCredentials, ) -from core.tools.entities.user_entities import UserTool, UserToolProvider -from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError -from core.tools.provider.api_tool_provider import ApiBasedToolProviderController -from core.tools.provider.builtin._positions import BuiltinToolProviderSort -from core.tools.provider.tool_provider import ToolProviderController +from core.tools.provider.api_tool_provider import ApiToolProviderController +from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolConfigurationManager from core.tools.utils.parser import ApiBasedToolSchemaParser from extensions.ext_database import db -from models.tools import ApiToolProvider, BuiltinToolProvider -from services.model_provider_service import ModelProviderService -from services.tools_transform_service import ToolTransformService +from models.tools import ApiToolProvider +from services.tools.tools_transform_service import ToolTransformService logger = logging.getLogger(__name__) -class ToolManageService: +class ApiToolManageService: @staticmethod - def list_tool_providers(user_id: str, tenant_id: str): - """ - list tool providers - - :return: the list of tool providers - """ - providers = ToolManager.user_list_providers( - user_id, tenant_id - ) - - # add icon - for provider in providers: - ToolTransformService.repack_provider(provider) - - result = [provider.to_dict() for provider in providers] - - return result - - @staticmethod - def list_builtin_tool_provider_tools( - user_id: str, tenant_id: str, provider: str - ) -> list[UserTool]: - """ - list builtin tool provider tools - """ - provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider) - tools = provider_controller.get_tools() - - tool_provider_configurations = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) - # check if user has added the provider - builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider, - ).first() - - credentials = {} - if builtin_provider is not None: - # get credentials - credentials = builtin_provider.credentials - credentials = tool_provider_configurations.decrypt_tool_credentials(credentials) - - result = [] - for tool in tools: - result.append(ToolTransformService.tool_to_user_tool( - tool=tool, credentials=credentials, tenant_id=tenant_id - )) - - return result - - @staticmethod - def list_builtin_provider_credentials_schema( - provider_name - ): - """ - list builtin provider credentials schema - - :return: the list of tool providers - """ - provider = ToolManager.get_builtin_provider(provider_name) - return jsonable_encoder([ - v for _, v in (provider.credentials_schema or {}).items() - ]) - - @staticmethod - def parser_api_schema(schema: str) -> list[ApiBasedToolBundle]: + def parser_api_schema(schema: str) -> list[ApiToolBundle]: """ parse api schema to tool bundle """ @@ -162,7 +95,7 @@ class ToolManageService: raise ValueError(f'invalid schema: {str(e)}') @staticmethod - def convert_schema_to_tool_bundles(schema: str, extra_info: dict = None) -> list[ApiBasedToolBundle]: + def convert_schema_to_tool_bundles(schema: str, extra_info: dict = None) -> list[ApiToolBundle]: """ convert schema to tool bundles @@ -177,7 +110,7 @@ class ToolManageService: @staticmethod def create_api_tool_provider( user_id: str, tenant_id: str, provider_name: str, icon: dict, credentials: dict, - schema_type: str, schema: str, privacy_policy: str, custom_disclaimer: str + schema_type: str, schema: str, privacy_policy: str, custom_disclaimer: str, labels: list[str] ): """ create api tool provider @@ -197,7 +130,7 @@ class ToolManageService: # parse openapi to tool bundle extra_info = {} # extra info like description will be set here - tool_bundles, schema_type = ToolManageService.convert_schema_to_tool_bundles(schema, extra_info) + tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) if len(tool_bundles) > 100: raise ValueError('the number of apis should be less than 100') @@ -224,7 +157,7 @@ class ToolManageService: auth_type = ApiProviderAuthType.value_of(credentials['auth_type']) # create provider entity - provider_controller = ApiBasedToolProviderController.from_db(db_provider, auth_type) + provider_controller = ApiToolProviderController.from_db(db_provider, auth_type) # load tools into provider entity provider_controller.load_bundled_tools(tool_bundles) @@ -236,6 +169,9 @@ class ToolManageService: db.session.add(db_provider) db.session.commit() + # update labels + ToolLabelManager.update_tool_labels(provider_controller, labels) + return { 'result': 'success' } @staticmethod @@ -257,7 +193,7 @@ class ToolManageService: schema = response.text # try to parse schema, avoid SSRF attack - ToolManageService.parser_api_schema(schema) + ApiToolManageService.parser_api_schema(schema) except Exception as e: logger.error(f"parse api schema error: {str(e)}") raise ValueError('invalid schema, please check the url you provided') @@ -281,91 +217,20 @@ class ToolManageService: if provider is None: raise ValueError(f'you have not added provider {provider}') - return [ - ToolTransformService.tool_to_user_tool(tool_bundle) for tool_bundle in provider.tools - ] - - @staticmethod - def update_builtin_tool_provider( - user_id: str, tenant_id: str, provider_name: str, credentials: dict - ): - """ - update builtin tool provider - """ - # get if the provider exists - provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider_name, - ).first() - - try: - # get provider - provider_controller = ToolManager.get_builtin_provider(provider_name) - if not provider_controller.need_credentials: - raise ValueError(f'provider {provider_name} does not need credentials') - tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) - # get original credentials if exists - if provider is not None: - original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials) - masked_credentials = tool_configuration.mask_tool_credentials(original_credentials) - # check if the credential has changed, save the original credential - for name, value in credentials.items(): - if name in masked_credentials and value == masked_credentials[name]: - credentials[name] = original_credentials[name] - # validate credentials - provider_controller.validate_credentials(credentials) - # encrypt credentials - credentials = tool_configuration.encrypt_tool_credentials(credentials) - except (ToolProviderNotFoundError, ToolNotFoundError, ToolProviderCredentialValidationError) as e: - raise ValueError(str(e)) - - if provider is None: - # create provider - provider = BuiltinToolProvider( - tenant_id=tenant_id, - user_id=user_id, - provider=provider_name, - encrypted_credentials=json.dumps(credentials), - ) - - db.session.add(provider) - db.session.commit() - - else: - provider.encrypted_credentials = json.dumps(credentials) - db.session.add(provider) - db.session.commit() - - # delete cache - tool_configuration.delete_tool_credentials_cache() - - return { 'result': 'success' } - - @staticmethod - def get_builtin_tool_provider_credentials( - user_id: str, tenant_id: str, provider: str - ): - """ - get builtin tool provider credentials - """ - provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider, - ).first() - - if provider is None: - return {} + controller = ToolTransformService.api_provider_to_controller(db_provider=provider) + labels = ToolLabelManager.get_tool_labels(controller) - provider_controller = ToolManager.get_builtin_provider(provider.provider) - tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) - credentials = tool_configuration.decrypt_tool_credentials(provider.credentials) - credentials = tool_configuration.mask_tool_credentials(credentials) - return credentials + return [ + ToolTransformService.tool_to_user_tool( + tool_bundle, + labels=labels, + ) for tool_bundle in provider.tools + ] @staticmethod def update_api_tool_provider( user_id: str, tenant_id: str, provider_name: str, original_provider: str, icon: dict, credentials: dict, - schema_type: str, schema: str, privacy_policy: str, custom_disclaimer: str + schema_type: str, schema: str, privacy_policy: str, custom_disclaimer: str, labels: list[str] ): """ update api tool provider @@ -385,7 +250,7 @@ class ToolManageService: # parse openapi to tool bundle extra_info = {} # extra info like description will be set here - tool_bundles, schema_type = ToolManageService.convert_schema_to_tool_bundles(schema, extra_info) + tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) # update db provider provider.name = provider_name @@ -404,7 +269,7 @@ class ToolManageService: auth_type = ApiProviderAuthType.value_of(credentials['auth_type']) # create provider entity - provider_controller = ApiBasedToolProviderController.from_db(provider, auth_type) + provider_controller = ApiToolProviderController.from_db(provider, auth_type) # load tools into provider entity provider_controller.load_bundled_tools(tool_bundles) @@ -427,84 +292,11 @@ class ToolManageService: # delete cache tool_configuration.delete_tool_credentials_cache() - return { 'result': 'success' } - - @staticmethod - def delete_builtin_tool_provider( - user_id: str, tenant_id: str, provider_name: str - ): - """ - delete tool provider - """ - provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider_name, - ).first() - - if provider is None: - raise ValueError(f'you have not added provider {provider_name}') - - db.session.delete(provider) - db.session.commit() - - # delete cache - provider_controller = ToolManager.get_builtin_provider(provider_name) - tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) - tool_configuration.delete_tool_credentials_cache() + # update labels + ToolLabelManager.update_tool_labels(provider_controller, labels) return { 'result': 'success' } - @staticmethod - def get_builtin_tool_provider_icon( - provider: str - ): - """ - get tool provider icon and it's mimetype - """ - icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider) - with open(icon_path, 'rb') as f: - icon_bytes = f.read() - - return icon_bytes, mime_type - - @staticmethod - def get_model_tool_provider_icon( - provider: str - ): - """ - get tool provider icon and it's mimetype - """ - - service = ModelProviderService() - icon_bytes, mime_type = service.get_model_provider_icon(provider=provider, icon_type='icon_small', lang='en_US') - - if icon_bytes is None: - raise ValueError(f'provider {provider} does not exists') - - return icon_bytes, mime_type - - @staticmethod - def list_model_tool_provider_tools( - user_id: str, tenant_id: str, provider: str - ) -> list[UserTool]: - """ - list model tool provider tools - """ - provider_controller = ToolManager.get_model_provider(tenant_id=tenant_id, provider_name=provider) - tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id) - - result = [ - UserTool( - author=tool.identity.author, - name=tool.identity.name, - label=tool.identity.label, - description=tool.description.human, - parameters=tool.parameters or [] - ) for tool in tools - ] - - return jsonable_encoder(result) - @staticmethod def delete_api_tool_provider( user_id: str, tenant_id: str, provider_name: str @@ -583,7 +375,7 @@ class ToolManageService: auth_type = ApiProviderAuthType.value_of(credentials['auth_type']) # create provider entity - provider_controller = ApiBasedToolProviderController.from_db(db_provider, auth_type) + provider_controller = ApiToolProviderController.from_db(db_provider, auth_type) # load tools into provider entity provider_controller.load_bundled_tools(tool_bundles) @@ -604,7 +396,7 @@ class ToolManageService: provider_controller.validate_credentials_format(credentials) # get tool tool = provider_controller.get_tool(tool_name) - tool = tool.fork_tool_runtime(meta={ + tool = tool.fork_tool_runtime(runtime={ 'credentials': credentials, 'tenant_id': tenant_id, }) @@ -614,49 +406,6 @@ class ToolManageService: return { 'result': result or 'empty response' } - @staticmethod - def list_builtin_tools( - user_id: str, tenant_id: str - ) -> list[UserToolProvider]: - """ - list builtin tools - """ - # get all builtin providers - provider_controllers = ToolManager.list_builtin_providers() - - # get all user added providers - db_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider).filter( - BuiltinToolProvider.tenant_id == tenant_id - ).all() or [] - - # find provider - find_provider = lambda provider: next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None) - - result: list[UserToolProvider] = [] - - for provider_controller in provider_controllers: - # convert provider controller to user provider - user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider( - provider_controller=provider_controller, - db_provider=find_provider(provider_controller.identity.name), - decrypt_credentials=True - ) - - # add icon - ToolTransformService.repack_provider(user_builtin_provider) - - tools = provider_controller.get_tools() - for tool in tools: - user_builtin_provider.tools.append(ToolTransformService.tool_to_user_tool( - tenant_id=tenant_id, - tool=tool, - credentials=user_builtin_provider.original_credentials, - )) - - result.append(user_builtin_provider) - - return BuiltinToolProviderSort.sort(result) - @staticmethod def list_api_tools( user_id: str, tenant_id: str @@ -674,6 +423,7 @@ class ToolManageService: for provider in db_providers: # convert provider controller to user provider provider_controller = ToolTransformService.api_provider_to_controller(db_provider=provider) + labels = ToolLabelManager.get_tool_labels(provider_controller) user_provider = ToolTransformService.api_provider_to_user_provider( provider_controller, db_provider=provider, @@ -692,6 +442,7 @@ class ToolManageService: tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, + labels=labels )) result.append(user_provider) diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py new file mode 100644 index 0000000000..2503191b63 --- /dev/null +++ b/api/services/tools/builtin_tools_manage_service.py @@ -0,0 +1,226 @@ +import json +import logging + +from core.model_runtime.utils.encoders import jsonable_encoder +from core.tools.entities.api_entities import UserTool, UserToolProvider +from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError +from core.tools.provider.builtin._positions import BuiltinToolProviderSort +from core.tools.provider.tool_provider import ToolProviderController +from core.tools.tool_label_manager import ToolLabelManager +from core.tools.tool_manager import ToolManager +from core.tools.utils.configuration import ToolConfigurationManager +from extensions.ext_database import db +from models.tools import BuiltinToolProvider +from services.tools.tools_transform_service import ToolTransformService + +logger = logging.getLogger(__name__) + + +class BuiltinToolManageService: + @staticmethod + def list_builtin_tool_provider_tools( + user_id: str, tenant_id: str, provider: str + ) -> list[UserTool]: + """ + list builtin tool provider tools + """ + provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider) + tools = provider_controller.get_tools() + + tool_provider_configurations = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) + # check if user has added the provider + builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider, + ).first() + + credentials = {} + if builtin_provider is not None: + # get credentials + credentials = builtin_provider.credentials + credentials = tool_provider_configurations.decrypt_tool_credentials(credentials) + + result = [] + for tool in tools: + result.append(ToolTransformService.tool_to_user_tool( + tool=tool, + credentials=credentials, + tenant_id=tenant_id, + labels=ToolLabelManager.get_tool_labels(provider_controller) + )) + + return result + + @staticmethod + def list_builtin_provider_credentials_schema( + provider_name + ): + """ + list builtin provider credentials schema + + :return: the list of tool providers + """ + provider = ToolManager.get_builtin_provider(provider_name) + return jsonable_encoder([ + v for _, v in (provider.credentials_schema or {}).items() + ]) + + @staticmethod + def update_builtin_tool_provider( + user_id: str, tenant_id: str, provider_name: str, credentials: dict + ): + """ + update builtin tool provider + """ + # get if the provider exists + provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider_name, + ).first() + + try: + # get provider + provider_controller = ToolManager.get_builtin_provider(provider_name) + if not provider_controller.need_credentials: + raise ValueError(f'provider {provider_name} does not need credentials') + tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) + # get original credentials if exists + if provider is not None: + original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials) + masked_credentials = tool_configuration.mask_tool_credentials(original_credentials) + # check if the credential has changed, save the original credential + for name, value in credentials.items(): + if name in masked_credentials and value == masked_credentials[name]: + credentials[name] = original_credentials[name] + # validate credentials + provider_controller.validate_credentials(credentials) + # encrypt credentials + credentials = tool_configuration.encrypt_tool_credentials(credentials) + except (ToolProviderNotFoundError, ToolNotFoundError, ToolProviderCredentialValidationError) as e: + raise ValueError(str(e)) + + if provider is None: + # create provider + provider = BuiltinToolProvider( + tenant_id=tenant_id, + user_id=user_id, + provider=provider_name, + encrypted_credentials=json.dumps(credentials), + ) + + db.session.add(provider) + db.session.commit() + + else: + provider.encrypted_credentials = json.dumps(credentials) + db.session.add(provider) + db.session.commit() + + # delete cache + tool_configuration.delete_tool_credentials_cache() + + return { 'result': 'success' } + + @staticmethod + def get_builtin_tool_provider_credentials( + user_id: str, tenant_id: str, provider: str + ): + """ + get builtin tool provider credentials + """ + provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider, + ).first() + + if provider is None: + return {} + + provider_controller = ToolManager.get_builtin_provider(provider.provider) + tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) + credentials = tool_configuration.decrypt_tool_credentials(provider.credentials) + credentials = tool_configuration.mask_tool_credentials(credentials) + return credentials + + @staticmethod + def delete_builtin_tool_provider( + user_id: str, tenant_id: str, provider_name: str + ): + """ + delete tool provider + """ + provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider_name, + ).first() + + if provider is None: + raise ValueError(f'you have not added provider {provider_name}') + + db.session.delete(provider) + db.session.commit() + + # delete cache + provider_controller = ToolManager.get_builtin_provider(provider_name) + tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) + tool_configuration.delete_tool_credentials_cache() + + return { 'result': 'success' } + + @staticmethod + def get_builtin_tool_provider_icon( + provider: str + ): + """ + get tool provider icon and it's mimetype + """ + icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider) + with open(icon_path, 'rb') as f: + icon_bytes = f.read() + + return icon_bytes, mime_type + + @staticmethod + def list_builtin_tools( + user_id: str, tenant_id: str + ) -> list[UserToolProvider]: + """ + list builtin tools + """ + # get all builtin providers + provider_controllers = ToolManager.list_builtin_providers() + + # get all user added providers + db_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider).filter( + BuiltinToolProvider.tenant_id == tenant_id + ).all() or [] + + # find provider + find_provider = lambda provider: next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None) + + result: list[UserToolProvider] = [] + + for provider_controller in provider_controllers: + # convert provider controller to user provider + user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider( + provider_controller=provider_controller, + db_provider=find_provider(provider_controller.identity.name), + decrypt_credentials=True + ) + + # add icon + ToolTransformService.repack_provider(user_builtin_provider) + + tools = provider_controller.get_tools() + for tool in tools: + user_builtin_provider.tools.append(ToolTransformService.tool_to_user_tool( + tenant_id=tenant_id, + tool=tool, + credentials=user_builtin_provider.original_credentials, + labels=ToolLabelManager.get_tool_labels(provider_controller) + )) + + result.append(user_builtin_provider) + + return BuiltinToolProviderSort.sort(result) + \ No newline at end of file diff --git a/api/services/tools/tool_labels_service.py b/api/services/tools/tool_labels_service.py new file mode 100644 index 0000000000..8a6aa025f2 --- /dev/null +++ b/api/services/tools/tool_labels_service.py @@ -0,0 +1,8 @@ +from core.tools.entities.tool_entities import ToolLabel +from core.tools.entities.values import default_tool_labels + + +class ToolLabelsService: + @classmethod + def list_tool_labels(cls) -> list[ToolLabel]: + return default_tool_labels \ No newline at end of file diff --git a/api/services/tools/tools_manage_service.py b/api/services/tools/tools_manage_service.py new file mode 100644 index 0000000000..76d2f53ae8 --- /dev/null +++ b/api/services/tools/tools_manage_service.py @@ -0,0 +1,29 @@ +import logging + +from core.tools.entities.api_entities import UserToolProviderTypeLiteral +from core.tools.tool_manager import ToolManager +from services.tools.tools_transform_service import ToolTransformService + +logger = logging.getLogger(__name__) + + +class ToolCommonService: + @staticmethod + def list_tool_providers(user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral = None): + """ + list tool providers + + :return: the list of tool providers + """ + providers = ToolManager.user_list_providers( + user_id, tenant_id, typ + ) + + # add icon + for provider in providers: + ToolTransformService.repack_provider(provider) + + result = [provider.to_dict() for provider in providers] + + return result + \ No newline at end of file diff --git a/api/services/tools_transform_service.py b/api/services/tools/tools_transform_service.py similarity index 72% rename from api/services/tools_transform_service.py rename to api/services/tools/tools_transform_service.py index ed7fd589b8..ba8c20d79b 100644 --- a/api/services/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -5,14 +5,21 @@ from typing import Optional, Union from flask import current_app from core.model_runtime.entities.common_entities import I18nObject -from core.tools.entities.tool_bundle import ApiBasedToolBundle -from core.tools.entities.tool_entities import ApiProviderAuthType, ToolParameter, ToolProviderCredentials -from core.tools.entities.user_entities import UserTool, UserToolProvider -from core.tools.provider.api_tool_provider import ApiBasedToolProviderController +from core.tools.entities.api_entities import UserTool, UserToolProvider +from core.tools.entities.tool_bundle import ApiToolBundle +from core.tools.entities.tool_entities import ( + ApiProviderAuthType, + ToolParameter, + ToolProviderCredentials, + ToolProviderType, +) +from core.tools.provider.api_tool_provider import ApiToolProviderController from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController from core.tools.tool.tool import Tool +from core.tools.tool.workflow_tool import WorkflowTool from core.tools.utils.configuration import ToolConfigurationManager -from models.tools import ApiToolProvider, BuiltinToolProvider +from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider logger = logging.getLogger(__name__) @@ -25,9 +32,9 @@ class ToolTransformService: url_prefix = (current_app.config.get("CONSOLE_API_URL") + "/console/api/workspaces/current/tool-provider/") - if provider_type == UserToolProvider.ProviderType.BUILTIN.value: + if provider_type == ToolProviderType.BUILT_IN.value: return url_prefix + 'builtin/' + provider_name + '/icon' - elif provider_type == UserToolProvider.ProviderType.API.value: + elif provider_type in [ToolProviderType.API.value, ToolProviderType.WORKFLOW.value]: try: return json.loads(icon) except: @@ -62,7 +69,7 @@ class ToolTransformService: def builtin_provider_to_user_provider( provider_controller: BuiltinToolProviderController, db_provider: Optional[BuiltinToolProvider], - decrypt_credentials: bool = True + decrypt_credentials: bool = True, ) -> UserToolProvider: """ convert provider controller to user provider @@ -80,10 +87,11 @@ class ToolTransformService: en_US=provider_controller.identity.label.en_US, zh_Hans=provider_controller.identity.label.zh_Hans, ), - type=UserToolProvider.ProviderType.BUILTIN, + type=ToolProviderType.BUILT_IN, masked_credentials={}, is_team_authorization=False, - tools=[] + tools=[], + labels=provider_controller.tool_labels ) # get credentials schema @@ -119,24 +127,62 @@ class ToolTransformService: @staticmethod def api_provider_to_controller( db_provider: ApiToolProvider, - ) -> ApiBasedToolProviderController: + ) -> ApiToolProviderController: """ convert provider controller to user provider """ # package tool provider controller - controller = ApiBasedToolProviderController.from_db( + controller = ApiToolProviderController.from_db( db_provider=db_provider, auth_type=ApiProviderAuthType.API_KEY if db_provider.credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE ) return controller + + @staticmethod + def workflow_provider_to_controller( + db_provider: WorkflowToolProvider + ) -> WorkflowToolProviderController: + """ + convert provider controller to provider + """ + return WorkflowToolProviderController.from_db(db_provider) + + @staticmethod + def workflow_provider_to_user_provider( + provider_controller: WorkflowToolProviderController, + labels: list[str] = None + ): + """ + convert provider controller to user provider + """ + return UserToolProvider( + id=provider_controller.provider_id, + author=provider_controller.identity.author, + name=provider_controller.identity.name, + description=I18nObject( + en_US=provider_controller.identity.description.en_US, + zh_Hans=provider_controller.identity.description.zh_Hans, + ), + icon=provider_controller.identity.icon, + label=I18nObject( + en_US=provider_controller.identity.label.en_US, + zh_Hans=provider_controller.identity.label.zh_Hans, + ), + type=ToolProviderType.WORKFLOW, + masked_credentials={}, + is_team_authorization=True, + tools=[], + labels=labels or [] + ) @staticmethod def api_provider_to_user_provider( - provider_controller: ApiBasedToolProviderController, + provider_controller: ApiToolProviderController, db_provider: ApiToolProvider, - decrypt_credentials: bool = True + decrypt_credentials: bool = True, + labels: list[str] = None ) -> UserToolProvider: """ convert provider controller to user provider @@ -161,10 +207,11 @@ class ToolTransformService: en_US=db_provider.name, zh_Hans=db_provider.name, ), - type=UserToolProvider.ProviderType.API, + type=ToolProviderType.API, masked_credentials={}, is_team_authorization=True, - tools=[] + tools=[], + labels=labels or [] ) if decrypt_credentials: @@ -184,14 +231,17 @@ class ToolTransformService: @staticmethod def tool_to_user_tool( - tool: Union[ApiBasedToolBundle, Tool], credentials: dict = None, tenant_id: str = None + tool: Union[ApiToolBundle, WorkflowTool, Tool], + credentials: dict = None, + tenant_id: str = None, + labels: list[str] = None ) -> UserTool: """ convert tool to user tool """ if isinstance(tool, Tool): # fork tool runtime - tool = tool.fork_tool_runtime(meta={ + tool = tool.fork_tool_runtime(runtime={ 'credentials': credentials, 'tenant_id': tenant_id, }) @@ -213,17 +263,15 @@ class ToolTransformService: if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM: current_parameters.append(runtime_parameter) - user_tool = UserTool( + return UserTool( author=tool.identity.author, name=tool.identity.name, label=tool.identity.label, description=tool.description.human, - parameters=current_parameters + parameters=current_parameters, + labels=labels ) - - return user_tool - - if isinstance(tool, ApiBasedToolBundle): + if isinstance(tool, ApiToolBundle): return UserTool( author=tool.author, name=tool.operation_id, @@ -235,5 +283,6 @@ class ToolTransformService: en_US=tool.summary or '', zh_Hans=tool.summary or '' ), - parameters=tool.parameters + parameters=tool.parameters, + labels=labels ) \ No newline at end of file diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py new file mode 100644 index 0000000000..e89d94160c --- /dev/null +++ b/api/services/tools/workflow_tools_manage_service.py @@ -0,0 +1,326 @@ +import json +from datetime import datetime + +from sqlalchemy import or_ + +from core.model_runtime.utils.encoders import jsonable_encoder +from core.tools.entities.api_entities import UserToolProvider +from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController +from core.tools.tool_label_manager import ToolLabelManager +from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils +from extensions.ext_database import db +from models.model import App +from models.tools import WorkflowToolProvider +from models.workflow import Workflow +from services.tools.tools_transform_service import ToolTransformService + + +class WorkflowToolManageService: + """ + Service class for managing workflow tools. + """ + @classmethod + def create_workflow_tool(cls, user_id: str, tenant_id: str, workflow_app_id: str, name: str, + label: str, icon: dict, description: str, + parameters: list[dict], privacy_policy: str = '', labels: list[str] = None) -> dict: + """ + Create a workflow tool. + :param user_id: the user id + :param tenant_id: the tenant id + :param name: the name + :param icon: the icon + :param description: the description + :param parameters: the parameters + :param privacy_policy: the privacy policy + :return: the created tool + """ + WorkflowToolConfigurationUtils.check_parameter_configurations(parameters) + + # check if the name is unique + existing_workflow_tool_provider = db.session.query(WorkflowToolProvider).filter( + WorkflowToolProvider.tenant_id == tenant_id, + # name or app_id + or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id) + ).first() + + if existing_workflow_tool_provider is not None: + raise ValueError(f'Tool with name {name} or app_id {workflow_app_id} already exists') + + app: App = db.session.query(App).filter( + App.id == workflow_app_id, + App.tenant_id == tenant_id + ).first() + + if app is None: + raise ValueError(f'App {workflow_app_id} not found') + + workflow: Workflow = app.workflow + if workflow is None: + raise ValueError(f'Workflow not found for app {workflow_app_id}') + + workflow_tool_provider = WorkflowToolProvider( + tenant_id=tenant_id, + user_id=user_id, + app_id=workflow_app_id, + name=name, + label=label, + icon=json.dumps(icon), + description=description, + parameter_configuration=json.dumps(parameters), + privacy_policy=privacy_policy, + version=workflow.version, + ) + + try: + WorkflowToolProviderController.from_db(workflow_tool_provider) + except Exception as e: + raise ValueError(str(e)) + + db.session.add(workflow_tool_provider) + db.session.commit() + + return { + 'result': 'success' + } + + + @classmethod + def update_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str, + name: str, label: str, icon: dict, description: str, + parameters: list[dict], privacy_policy: str = '', labels: list[str] = None) -> dict: + """ + Update a workflow tool. + :param user_id: the user id + :param tenant_id: the tenant id + :param tool: the tool + :return: the updated tool + """ + WorkflowToolConfigurationUtils.check_parameter_configurations(parameters) + + # check if the name is unique + existing_workflow_tool_provider = db.session.query(WorkflowToolProvider).filter( + WorkflowToolProvider.tenant_id == tenant_id, + WorkflowToolProvider.name == name, + WorkflowToolProvider.id != workflow_tool_id + ).first() + + if existing_workflow_tool_provider is not None: + raise ValueError(f'Tool with name {name} already exists') + + workflow_tool_provider: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( + WorkflowToolProvider.tenant_id == tenant_id, + WorkflowToolProvider.id == workflow_tool_id + ).first() + + if workflow_tool_provider is None: + raise ValueError(f'Tool {workflow_tool_id} not found') + + app: App = db.session.query(App).filter( + App.id == workflow_tool_provider.app_id, + App.tenant_id == tenant_id + ).first() + + if app is None: + raise ValueError(f'App {workflow_tool_provider.app_id} not found') + + workflow: Workflow = app.workflow + if workflow is None: + raise ValueError(f'Workflow not found for app {workflow_tool_provider.app_id}') + + workflow_tool_provider.name = name + workflow_tool_provider.label = label + workflow_tool_provider.icon = json.dumps(icon) + workflow_tool_provider.description = description + workflow_tool_provider.parameter_configuration = json.dumps(parameters) + workflow_tool_provider.privacy_policy = privacy_policy + workflow_tool_provider.version = workflow.version + workflow_tool_provider.updated_at = datetime.now() + + try: + WorkflowToolProviderController.from_db(workflow_tool_provider) + except Exception as e: + raise ValueError(str(e)) + + db.session.add(workflow_tool_provider) + db.session.commit() + + if labels is not None: + ToolLabelManager.update_tool_labels( + ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), + labels + ) + + return { + 'result': 'success' + } + + @classmethod + def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[UserToolProvider]: + """ + List workflow tools. + :param user_id: the user id + :param tenant_id: the tenant id + :return: the list of tools + """ + db_tools = db.session.query(WorkflowToolProvider).filter( + WorkflowToolProvider.tenant_id == tenant_id + ).all() + + tools = [] + for provider in db_tools: + try: + tools.append(ToolTransformService.workflow_provider_to_controller(provider)) + except: + # skip deleted tools + pass + + labels = ToolLabelManager.get_tools_labels(tools) + + result = [] + + for tool in tools: + user_tool_provider = ToolTransformService.workflow_provider_to_user_provider( + provider_controller=tool, + labels=labels.get(tool.provider_id, []) + ) + ToolTransformService.repack_provider(user_tool_provider) + user_tool_provider.tools = [ + ToolTransformService.tool_to_user_tool( + tool.get_tools(user_id, tenant_id)[0], + labels=labels.get(tool.provider_id, []) + ) + ] + result.append(user_tool_provider) + + return result + + @classmethod + def delete_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict: + """ + Delete a workflow tool. + :param user_id: the user id + :param tenant_id: the tenant id + :param workflow_app_id: the workflow app id + """ + db.session.query(WorkflowToolProvider).filter( + WorkflowToolProvider.tenant_id == tenant_id, + WorkflowToolProvider.id == workflow_tool_id + ).delete() + + db.session.commit() + + return { + 'result': 'success' + } + + @classmethod + def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict: + """ + Get a workflow tool. + :param user_id: the user id + :param tenant_id: the tenant id + :param workflow_app_id: the workflow app id + :return: the tool + """ + db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( + WorkflowToolProvider.tenant_id == tenant_id, + WorkflowToolProvider.id == workflow_tool_id + ).first() + + if db_tool is None: + raise ValueError(f'Tool {workflow_tool_id} not found') + + workflow_app: App = db.session.query(App).filter( + App.id == db_tool.app_id, + App.tenant_id == tenant_id + ).first() + + if workflow_app is None: + raise ValueError(f'App {db_tool.app_id} not found') + + tool = ToolTransformService.workflow_provider_to_controller(db_tool) + + return { + 'name': db_tool.name, + 'label': db_tool.label, + 'workflow_tool_id': db_tool.id, + 'workflow_app_id': db_tool.app_id, + 'icon': json.loads(db_tool.icon), + 'description': db_tool.description, + 'parameters': jsonable_encoder(db_tool.parameter_configurations), + 'tool': ToolTransformService.tool_to_user_tool( + tool.get_tools(user_id, tenant_id)[0], + labels=ToolLabelManager.get_tool_labels(tool) + ), + 'synced': workflow_app.workflow.version == db_tool.version, + 'privacy_policy': db_tool.privacy_policy, + } + + @classmethod + def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str) -> dict: + """ + Get a workflow tool. + :param user_id: the user id + :param tenant_id: the tenant id + :param workflow_app_id: the workflow app id + :return: the tool + """ + db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( + WorkflowToolProvider.tenant_id == tenant_id, + WorkflowToolProvider.app_id == workflow_app_id + ).first() + + if db_tool is None: + raise ValueError(f'Tool {workflow_app_id} not found') + + workflow_app: App = db.session.query(App).filter( + App.id == db_tool.app_id, + App.tenant_id == tenant_id + ).first() + + if workflow_app is None: + raise ValueError(f'App {db_tool.app_id} not found') + + tool = ToolTransformService.workflow_provider_to_controller(db_tool) + + return { + 'name': db_tool.name, + 'label': db_tool.label, + 'workflow_tool_id': db_tool.id, + 'workflow_app_id': db_tool.app_id, + 'icon': json.loads(db_tool.icon), + 'description': db_tool.description, + 'parameters': jsonable_encoder(db_tool.parameter_configurations), + 'tool': ToolTransformService.tool_to_user_tool( + tool.get_tools(user_id, tenant_id)[0], + labels=ToolLabelManager.get_tool_labels(tool) + ), + 'synced': workflow_app.workflow.version == db_tool.version, + 'privacy_policy': db_tool.privacy_policy + } + + @classmethod + def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[dict]: + """ + List workflow tool provider tools. + :param user_id: the user id + :param tenant_id: the tenant id + :param workflow_app_id: the workflow app id + :return: the list of tools + """ + db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( + WorkflowToolProvider.tenant_id == tenant_id, + WorkflowToolProvider.id == workflow_tool_id + ).first() + + if db_tool is None: + raise ValueError(f'Tool {workflow_tool_id} not found') + + tool = ToolTransformService.workflow_provider_to_controller(db_tool) + + return [ + ToolTransformService.tool_to_user_tool( + tool.get_tools(user_id, tenant_id)[0], + labels=ToolLabelManager.get_tool_labels(tool) + ) + ] \ No newline at end of file diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index c41d51caf7..15cf5367d3 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -4,6 +4,7 @@ import pytest from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.code.code_node import CodeNode from models.workflow import WorkflowNodeExecutionStatus from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock @@ -25,7 +26,8 @@ def test_execute_code(setup_code_executor_mock): app_id='1', workflow_id='1', user_id='1', - user_from=InvokeFrom.WEB_APP, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.WEB_APP, config={ 'id': '1', 'data': { @@ -78,7 +80,8 @@ def test_execute_code_output_validator(setup_code_executor_mock): app_id='1', workflow_id='1', user_id='1', - user_from=InvokeFrom.WEB_APP, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.WEB_APP, config={ 'id': '1', 'data': { @@ -132,7 +135,8 @@ def test_execute_code_output_validator_depth(): app_id='1', workflow_id='1', user_id='1', - user_from=InvokeFrom.WEB_APP, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.WEB_APP, config={ 'id': '1', 'data': { @@ -285,7 +289,8 @@ def test_execute_code_output_object_list(): app_id='1', workflow_id='1', user_id='1', - user_from=InvokeFrom.WEB_APP, + invoke_from=InvokeFrom.WEB_APP, + user_from=UserFrom.ACCOUNT, config={ 'id': '1', 'data': { diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index 10e3d53608..ffa2741e55 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -2,6 +2,7 @@ import pytest from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.http_request.http_request_node import HttpRequestNode from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock @@ -10,7 +11,8 @@ BASIC_NODE_DATA = { 'app_id': '1', 'workflow_id': '1', 'user_id': '1', - 'user_from': InvokeFrom.WEB_APP, + 'user_from': UserFrom.ACCOUNT, + 'invoke_from': InvokeFrom.WEB_APP, } # construct variable pool diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index d04497a187..394a3dcbd7 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -4,7 +4,7 @@ from unittest.mock import MagicMock import pytest -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration from core.model_manager import ModelInstance @@ -30,6 +30,7 @@ def test_execute_llm(setup_openai_mock): app_id='1', workflow_id='1', user_id='1', + invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ 'id': 'llm', @@ -130,6 +131,7 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock): app_id='1', workflow_id='1', user_id='1', + invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ 'id': 'llm', diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py new file mode 100644 index 0000000000..342f371eea --- /dev/null +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -0,0 +1,356 @@ +import json +import os +from unittest.mock import MagicMock + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity +from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle +from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration +from core.model_manager import ModelInstance +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from core.workflow.entities.node_entities import SystemVariable +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.base_node import UserFrom +from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode +from extensions.ext_database import db +from models.provider import ProviderType + +"""FOR MOCK FIXTURES, DO NOT REMOVE""" +from models.workflow import WorkflowNodeExecutionStatus +from tests.integration_tests.model_runtime.__mock.anthropic import setup_anthropic_mock +from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock + + +def get_mocked_fetch_model_config( + provider: str, model: str, mode: str, + credentials: dict, +): + provider_instance = ModelProviderFactory().get_provider_instance(provider) + model_type_instance = provider_instance.get_model_instance(ModelType.LLM) + provider_model_bundle = ProviderModelBundle( + configuration=ProviderConfiguration( + tenant_id='1', + provider=provider_instance.get_provider_schema(), + preferred_provider_type=ProviderType.CUSTOM, + using_provider_type=ProviderType.CUSTOM, + system_configuration=SystemConfiguration( + enabled=False + ), + custom_configuration=CustomConfiguration( + provider=CustomProviderConfiguration( + credentials=credentials + ) + ) + ), + provider_instance=provider_instance, + model_type_instance=model_type_instance + ) + model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model=model) + model_config = ModelConfigWithCredentialsEntity( + model=model, + provider=provider, + mode=mode, + credentials=credentials, + parameters={}, + model_schema=model_type_instance.get_model_schema(model), + provider_model_bundle=provider_model_bundle + ) + + return MagicMock(return_value=tuple([model_instance, model_config])) + +@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) +def test_function_calling_parameter_extractor(setup_openai_mock): + """ + Test function calling for parameter extractor. + """ + node = ParameterExtractorNode( + tenant_id='1', + app_id='1', + workflow_id='1', + user_id='1', + invoke_from=InvokeFrom.WEB_APP, + user_from=UserFrom.ACCOUNT, + config={ + 'id': 'llm', + 'data': { + 'title': '123', + 'type': 'parameter-extractor', + 'model': { + 'provider': 'openai', + 'name': 'gpt-3.5-turbo', + 'mode': 'chat', + 'completion_params': {} + }, + 'query': ['sys', 'query'], + 'parameters': [{ + 'name': 'location', + 'type': 'string', + 'description': 'location', + 'required': True + }], + 'instruction': '', + 'reasoning_mode': 'function_call', + 'memory': None, + } + } + ) + + node._fetch_model_config = get_mocked_fetch_model_config( + provider='openai', model='gpt-3.5-turbo', mode='chat', credentials={ + 'openai_api_key': os.environ.get('OPENAI_API_KEY') + } + ) + db.session.close = MagicMock() + + # construct variable pool + pool = VariablePool(system_variables={ + SystemVariable.QUERY: 'what\'s the weather in SF', + SystemVariable.FILES: [], + SystemVariable.CONVERSATION_ID: 'abababa', + SystemVariable.USER_ID: 'aaa' + }, user_inputs={}) + + result = node.run(pool) + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs.get('location') == 'kawaii' + assert result.outputs.get('__reason') == None + +@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) +def test_instructions(setup_openai_mock): + """ + Test chat parameter extractor. + """ + node = ParameterExtractorNode( + tenant_id='1', + app_id='1', + workflow_id='1', + user_id='1', + invoke_from=InvokeFrom.WEB_APP, + user_from=UserFrom.ACCOUNT, + config={ + 'id': 'llm', + 'data': { + 'title': '123', + 'type': 'parameter-extractor', + 'model': { + 'provider': 'openai', + 'name': 'gpt-3.5-turbo', + 'mode': 'chat', + 'completion_params': {} + }, + 'query': ['sys', 'query'], + 'parameters': [{ + 'name': 'location', + 'type': 'string', + 'description': 'location', + 'required': True + }], + 'reasoning_mode': 'function_call', + 'instruction': '{{#sys.query#}}', + 'memory': None, + } + } + ) + + node._fetch_model_config = get_mocked_fetch_model_config( + provider='openai', model='gpt-3.5-turbo', mode='chat', credentials={ + 'openai_api_key': os.environ.get('OPENAI_API_KEY') + } + ) + db.session.close = MagicMock() + + # construct variable pool + pool = VariablePool(system_variables={ + SystemVariable.QUERY: 'what\'s the weather in SF', + SystemVariable.FILES: [], + SystemVariable.CONVERSATION_ID: 'abababa', + SystemVariable.USER_ID: 'aaa' + }, user_inputs={}) + + result = node.run(pool) + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs.get('location') == 'kawaii' + assert result.outputs.get('__reason') == None + + process_data = result.process_data + + process_data.get('prompts') + + for prompt in process_data.get('prompts'): + if prompt.get('role') == 'system': + assert 'what\'s the weather in SF' in prompt.get('text') + +@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True) +def test_chat_parameter_extractor(setup_anthropic_mock): + """ + Test chat parameter extractor. + """ + node = ParameterExtractorNode( + tenant_id='1', + app_id='1', + workflow_id='1', + user_id='1', + invoke_from=InvokeFrom.WEB_APP, + user_from=UserFrom.ACCOUNT, + config={ + 'id': 'llm', + 'data': { + 'title': '123', + 'type': 'parameter-extractor', + 'model': { + 'provider': 'anthropic', + 'name': 'claude-2', + 'mode': 'chat', + 'completion_params': {} + }, + 'query': ['sys', 'query'], + 'parameters': [{ + 'name': 'location', + 'type': 'string', + 'description': 'location', + 'required': True + }], + 'reasoning_mode': 'prompt', + 'instruction': '', + 'memory': None, + } + } + ) + + node._fetch_model_config = get_mocked_fetch_model_config( + provider='anthropic', model='claude-2', mode='chat', credentials={ + 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY') + } + ) + db.session.close = MagicMock() + + # construct variable pool + pool = VariablePool(system_variables={ + SystemVariable.QUERY: 'what\'s the weather in SF', + SystemVariable.FILES: [], + SystemVariable.CONVERSATION_ID: 'abababa', + SystemVariable.USER_ID: 'aaa' + }, user_inputs={}) + + result = node.run(pool) + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs.get('location') == '' + assert result.outputs.get('__reason') == 'Failed to extract result from function call or text response, using empty result.' + prompts = result.process_data.get('prompts') + + for prompt in prompts: + if prompt.get('role') == 'user': + if '' in prompt.get('text'): + assert '\n{"type": "object"' in prompt.get('text') + +@pytest.mark.parametrize('setup_openai_mock', [['completion']], indirect=True) +def test_completion_parameter_extractor(setup_openai_mock): + """ + Test completion parameter extractor. + """ + node = ParameterExtractorNode( + tenant_id='1', + app_id='1', + workflow_id='1', + user_id='1', + invoke_from=InvokeFrom.WEB_APP, + user_from=UserFrom.ACCOUNT, + config={ + 'id': 'llm', + 'data': { + 'title': '123', + 'type': 'parameter-extractor', + 'model': { + 'provider': 'openai', + 'name': 'gpt-3.5-turbo-instruct', + 'mode': 'completion', + 'completion_params': {} + }, + 'query': ['sys', 'query'], + 'parameters': [{ + 'name': 'location', + 'type': 'string', + 'description': 'location', + 'required': True + }], + 'reasoning_mode': 'prompt', + 'instruction': '{{#sys.query#}}', + 'memory': None, + } + } + ) + + node._fetch_model_config = get_mocked_fetch_model_config( + provider='openai', model='gpt-3.5-turbo-instruct', mode='completion', credentials={ + 'openai_api_key': os.environ.get('OPENAI_API_KEY') + } + ) + db.session.close = MagicMock() + + # construct variable pool + pool = VariablePool(system_variables={ + SystemVariable.QUERY: 'what\'s the weather in SF', + SystemVariable.FILES: [], + SystemVariable.CONVERSATION_ID: 'abababa', + SystemVariable.USER_ID: 'aaa' + }, user_inputs={}) + + result = node.run(pool) + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs.get('location') == '' + assert result.outputs.get('__reason') == 'Failed to extract result from function call or text response, using empty result.' + assert len(result.process_data.get('prompts')) == 1 + assert 'SF' in result.process_data.get('prompts')[0].get('text') + +def test_extract_json_response(): + """ + Test extract json response. + """ + + node = ParameterExtractorNode( + tenant_id='1', + app_id='1', + workflow_id='1', + user_id='1', + invoke_from=InvokeFrom.WEB_APP, + user_from=UserFrom.ACCOUNT, + config={ + 'id': 'llm', + 'data': { + 'title': '123', + 'type': 'parameter-extractor', + 'model': { + 'provider': 'openai', + 'name': 'gpt-3.5-turbo-instruct', + 'mode': 'completion', + 'completion_params': {} + }, + 'query': ['sys', 'query'], + 'parameters': [{ + 'name': 'location', + 'type': 'string', + 'description': 'location', + 'required': True + }], + 'reasoning_mode': 'prompt', + 'instruction': '{{#sys.query#}}', + 'memory': None, + } + } + ) + + result = node._extract_complete_json_response(""" + uwu{ovo} + { + "location": "kawaii" + } + hello world. + """) + + assert result['location'] == 'kawaii' \ No newline at end of file diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 4a31334056..02999bf0a2 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -1,5 +1,6 @@ import pytest +from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode @@ -15,6 +16,7 @@ def test_execute_code(setup_code_executor_mock): app_id='1', workflow_id='1', user_id='1', + invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.END_USER, config={ 'id': '1', diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index 4bbd4ccee7..fffd074457 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -1,5 +1,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.tool.tool_node import ToolNode from models.workflow import WorkflowNodeExecutionStatus @@ -13,7 +14,8 @@ def test_tool_variable_invoke(): app_id='1', workflow_id='1', user_id='1', - user_from=InvokeFrom.WEB_APP, + invoke_from=InvokeFrom.WEB_APP, + user_from=UserFrom.ACCOUNT, config={ 'id': '1', 'data': { @@ -51,7 +53,8 @@ def test_tool_mixed_invoke(): app_id='1', workflow_id='1', user_id='1', - user_from=InvokeFrom.WEB_APP, + invoke_from=InvokeFrom.WEB_APP, + user_from=UserFrom.ACCOUNT, config={ 'id': '1', 'data': { diff --git a/api/tests/unit_tests/core/workflow/nodes/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/test_answer.py index cf21401eb2..102711b4b6 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_answer.py @@ -1,5 +1,6 @@ from unittest.mock import MagicMock +from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.node_entities import SystemVariable from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.answer.answer_node import AnswerNode @@ -15,6 +16,7 @@ def test_execute_answer(): workflow_id='1', user_id='1', user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, config={ 'id': 'answer', 'data': { diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index 99413540c5..6860b2fd97 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -1,5 +1,6 @@ from unittest.mock import MagicMock +from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.node_entities import SystemVariable from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import UserFrom @@ -15,6 +16,7 @@ def test_execute_if_else_result_true(): workflow_id='1', user_id='1', user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, config={ 'id': 'if-else', 'data': { @@ -155,6 +157,7 @@ def test_execute_if_else_result_false(): workflow_id='1', user_id='1', user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, config={ 'id': 'if-else', 'data': {