mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-13 11:39:01 +08:00
Feat/workflow phase2 (#4687)
This commit is contained in:
parent
45deaee762
commit
e852a21634
@ -137,6 +137,71 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
class AdvancedChatDraftRunIterationNodeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
Run draft workflow iteration node
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('inputs', type=dict, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate_single_iteration(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
node_id=node_id,
|
||||
args=args,
|
||||
streaming=True
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
class WorkflowDraftRunIterationNodeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
Run draft workflow iteration node
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('inputs', type=dict, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate_single_iteration(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
node_id=node_id,
|
||||
args=args,
|
||||
streaming=True
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
class DraftWorkflowRunApi(Resource):
|
||||
@setup_required
|
||||
@ -326,6 +391,8 @@ api.add_resource(AdvancedChatDraftWorkflowRunApi, '/apps/<uuid:app_id>/advanced-
|
||||
api.add_resource(DraftWorkflowRunApi, '/apps/<uuid:app_id>/workflows/draft/run')
|
||||
api.add_resource(WorkflowTaskStopApi, '/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop')
|
||||
api.add_resource(DraftWorkflowNodeRunApi, '/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run')
|
||||
api.add_resource(AdvancedChatDraftRunIterationNodeApi, '/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run')
|
||||
api.add_resource(WorkflowDraftRunIterationNodeApi, '/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run')
|
||||
api.add_resource(PublishedWorkflowApi, '/apps/<uuid:app_id>/workflows/publish')
|
||||
api.add_resource(DefaultBlockConfigsApi, '/apps/<uuid:app_id>/workflows/default-workflow-block-configs')
|
||||
api.add_resource(DefaultBlockConfigApi, '/apps/<uuid:app_id>/workflows/default-workflow-block-configs'
|
||||
|
@ -9,8 +9,13 @@ from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.helper import alphanumeric, uuid_value
|
||||
from libs.login import login_required
|
||||
from services.tools_manage_service import ToolManageService
|
||||
from services.tools.api_tools_manage_service import ApiToolManageService
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
from services.tools.tool_labels_service import ToolLabelsService
|
||||
from services.tools.tools_manage_service import ToolCommonService
|
||||
from services.tools.workflow_tools_manage_service import WorkflowToolManageService
|
||||
|
||||
|
||||
class ToolProviderListApi(Resource):
|
||||
@ -21,7 +26,11 @@ class ToolProviderListApi(Resource):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
return ToolManageService.list_tool_providers(user_id, tenant_id)
|
||||
req = reqparse.RequestParser()
|
||||
req.add_argument('type', type=str, choices=['builtin', 'model', 'api', 'workflow'], required=False, nullable=True, location='args')
|
||||
args = req.parse_args()
|
||||
|
||||
return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get('type', None))
|
||||
|
||||
class ToolBuiltinProviderListToolsApi(Resource):
|
||||
@setup_required
|
||||
@ -31,7 +40,7 @@ class ToolBuiltinProviderListToolsApi(Resource):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
return jsonable_encoder(ToolManageService.list_builtin_tool_provider_tools(
|
||||
return jsonable_encoder(BuiltinToolManageService.list_builtin_tool_provider_tools(
|
||||
user_id,
|
||||
tenant_id,
|
||||
provider,
|
||||
@ -48,7 +57,7 @@ class ToolBuiltinProviderDeleteApi(Resource):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
return ToolManageService.delete_builtin_tool_provider(
|
||||
return BuiltinToolManageService.delete_builtin_tool_provider(
|
||||
user_id,
|
||||
tenant_id,
|
||||
provider,
|
||||
@ -70,7 +79,7 @@ class ToolBuiltinProviderUpdateApi(Resource):
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return ToolManageService.update_builtin_tool_provider(
|
||||
return BuiltinToolManageService.update_builtin_tool_provider(
|
||||
user_id,
|
||||
tenant_id,
|
||||
provider,
|
||||
@ -85,7 +94,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
return ToolManageService.get_builtin_tool_provider_credentials(
|
||||
return BuiltinToolManageService.get_builtin_tool_provider_credentials(
|
||||
user_id,
|
||||
tenant_id,
|
||||
provider,
|
||||
@ -94,7 +103,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
|
||||
class ToolBuiltinProviderIconApi(Resource):
|
||||
@setup_required
|
||||
def get(self, provider):
|
||||
icon_bytes, mimetype = ToolManageService.get_builtin_tool_provider_icon(provider)
|
||||
icon_bytes, mimetype = BuiltinToolManageService.get_builtin_tool_provider_icon(provider)
|
||||
icon_cache_max_age = int(current_app.config.get('TOOL_ICON_CACHE_MAX_AGE'))
|
||||
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
|
||||
|
||||
@ -116,11 +125,12 @@ class ToolApiProviderAddApi(Resource):
|
||||
parser.add_argument('provider', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('icon', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json')
|
||||
parser.add_argument('labels', type=list[str], required=False, nullable=True, location='json', default=[])
|
||||
parser.add_argument('custom_disclaimer', type=str, required=False, nullable=True, location='json')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return ToolManageService.create_api_tool_provider(
|
||||
return ApiToolManageService.create_api_tool_provider(
|
||||
user_id,
|
||||
tenant_id,
|
||||
args['provider'],
|
||||
@ -130,6 +140,7 @@ class ToolApiProviderAddApi(Resource):
|
||||
args['schema'],
|
||||
args.get('privacy_policy', ''),
|
||||
args.get('custom_disclaimer', ''),
|
||||
args.get('labels', []),
|
||||
)
|
||||
|
||||
class ToolApiProviderGetRemoteSchemaApi(Resource):
|
||||
@ -143,7 +154,7 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return ToolManageService.get_api_tool_provider_remote_schema(
|
||||
return ApiToolManageService.get_api_tool_provider_remote_schema(
|
||||
current_user.id,
|
||||
current_user.current_tenant_id,
|
||||
args['url'],
|
||||
@ -163,7 +174,7 @@ class ToolApiProviderListToolsApi(Resource):
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return jsonable_encoder(ToolManageService.list_api_tool_provider_tools(
|
||||
return jsonable_encoder(ApiToolManageService.list_api_tool_provider_tools(
|
||||
user_id,
|
||||
tenant_id,
|
||||
args['provider'],
|
||||
@ -188,11 +199,12 @@ class ToolApiProviderUpdateApi(Resource):
|
||||
parser.add_argument('original_provider', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('icon', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('privacy_policy', type=str, required=True, nullable=True, location='json')
|
||||
parser.add_argument('labels', type=list[str], required=False, nullable=True, location='json')
|
||||
parser.add_argument('custom_disclaimer', type=str, required=True, nullable=True, location='json')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return ToolManageService.update_api_tool_provider(
|
||||
return ApiToolManageService.update_api_tool_provider(
|
||||
user_id,
|
||||
tenant_id,
|
||||
args['provider'],
|
||||
@ -203,6 +215,7 @@ class ToolApiProviderUpdateApi(Resource):
|
||||
args['schema'],
|
||||
args['privacy_policy'],
|
||||
args['custom_disclaimer'],
|
||||
args.get('labels', []),
|
||||
)
|
||||
|
||||
class ToolApiProviderDeleteApi(Resource):
|
||||
@ -222,7 +235,7 @@ class ToolApiProviderDeleteApi(Resource):
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return ToolManageService.delete_api_tool_provider(
|
||||
return ApiToolManageService.delete_api_tool_provider(
|
||||
user_id,
|
||||
tenant_id,
|
||||
args['provider'],
|
||||
@ -242,7 +255,7 @@ class ToolApiProviderGetApi(Resource):
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return ToolManageService.get_api_tool_provider(
|
||||
return ApiToolManageService.get_api_tool_provider(
|
||||
user_id,
|
||||
tenant_id,
|
||||
args['provider'],
|
||||
@ -253,7 +266,7 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
return ToolManageService.list_builtin_provider_credentials_schema(provider)
|
||||
return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider)
|
||||
|
||||
class ToolApiProviderSchemaApi(Resource):
|
||||
@setup_required
|
||||
@ -266,7 +279,7 @@ class ToolApiProviderSchemaApi(Resource):
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return ToolManageService.parser_api_schema(
|
||||
return ApiToolManageService.parser_api_schema(
|
||||
schema=args['schema'],
|
||||
)
|
||||
|
||||
@ -286,7 +299,7 @@ class ToolApiProviderPreviousTestApi(Resource):
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return ToolManageService.test_api_tool_preview(
|
||||
return ApiToolManageService.test_api_tool_preview(
|
||||
current_user.current_tenant_id,
|
||||
args['provider_name'] if args['provider_name'] else '',
|
||||
args['tool_name'],
|
||||
@ -296,6 +309,153 @@ class ToolApiProviderPreviousTestApi(Resource):
|
||||
args['schema'],
|
||||
)
|
||||
|
||||
class ToolWorkflowProviderCreateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
reqparser = reqparse.RequestParser()
|
||||
reqparser.add_argument('workflow_app_id', type=uuid_value, required=True, nullable=False, location='json')
|
||||
reqparser.add_argument('name', type=alphanumeric, required=True, nullable=False, location='json')
|
||||
reqparser.add_argument('label', type=str, required=True, nullable=False, location='json')
|
||||
reqparser.add_argument('description', type=str, required=True, nullable=False, location='json')
|
||||
reqparser.add_argument('icon', type=dict, required=True, nullable=False, location='json')
|
||||
reqparser.add_argument('parameters', type=list[dict], required=True, nullable=False, location='json')
|
||||
reqparser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json', default='')
|
||||
reqparser.add_argument('labels', type=list[str], required=False, nullable=True, location='json')
|
||||
|
||||
args = reqparser.parse_args()
|
||||
|
||||
return WorkflowToolManageService.create_workflow_tool(
|
||||
user_id,
|
||||
tenant_id,
|
||||
args['workflow_app_id'],
|
||||
args['name'],
|
||||
args['label'],
|
||||
args['icon'],
|
||||
args['description'],
|
||||
args['parameters'],
|
||||
args['privacy_policy'],
|
||||
args.get('labels', []),
|
||||
)
|
||||
|
||||
class ToolWorkflowProviderUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
reqparser = reqparse.RequestParser()
|
||||
reqparser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='json')
|
||||
reqparser.add_argument('name', type=alphanumeric, required=True, nullable=False, location='json')
|
||||
reqparser.add_argument('label', type=str, required=True, nullable=False, location='json')
|
||||
reqparser.add_argument('description', type=str, required=True, nullable=False, location='json')
|
||||
reqparser.add_argument('icon', type=dict, required=True, nullable=False, location='json')
|
||||
reqparser.add_argument('parameters', type=list[dict], required=True, nullable=False, location='json')
|
||||
reqparser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json', default='')
|
||||
reqparser.add_argument('labels', type=list[str], required=False, nullable=True, location='json')
|
||||
|
||||
args = reqparser.parse_args()
|
||||
|
||||
if not args['workflow_tool_id']:
|
||||
raise ValueError('incorrect workflow_tool_id')
|
||||
|
||||
return WorkflowToolManageService.update_workflow_tool(
|
||||
user_id,
|
||||
tenant_id,
|
||||
args['workflow_tool_id'],
|
||||
args['name'],
|
||||
args['label'],
|
||||
args['icon'],
|
||||
args['description'],
|
||||
args['parameters'],
|
||||
args['privacy_policy'],
|
||||
args.get('labels', []),
|
||||
)
|
||||
|
||||
class ToolWorkflowProviderDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
reqparser = reqparse.RequestParser()
|
||||
reqparser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='json')
|
||||
|
||||
args = reqparser.parse_args()
|
||||
|
||||
return WorkflowToolManageService.delete_workflow_tool(
|
||||
user_id,
|
||||
tenant_id,
|
||||
args['workflow_tool_id'],
|
||||
)
|
||||
|
||||
class ToolWorkflowProviderGetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('workflow_tool_id', type=uuid_value, required=False, nullable=True, location='args')
|
||||
parser.add_argument('workflow_app_id', type=uuid_value, required=False, nullable=True, location='args')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.get('workflow_tool_id'):
|
||||
tool = WorkflowToolManageService.get_workflow_tool_by_tool_id(
|
||||
user_id,
|
||||
tenant_id,
|
||||
args['workflow_tool_id'],
|
||||
)
|
||||
elif args.get('workflow_app_id'):
|
||||
tool = WorkflowToolManageService.get_workflow_tool_by_app_id(
|
||||
user_id,
|
||||
tenant_id,
|
||||
args['workflow_app_id'],
|
||||
)
|
||||
else:
|
||||
raise ValueError('incorrect workflow_tool_id or workflow_app_id')
|
||||
|
||||
return jsonable_encoder(tool)
|
||||
|
||||
class ToolWorkflowProviderListToolApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='args')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return jsonable_encoder(WorkflowToolManageService.list_single_workflow_tools(
|
||||
user_id,
|
||||
tenant_id,
|
||||
args['workflow_tool_id'],
|
||||
))
|
||||
|
||||
class ToolBuiltinListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -304,7 +464,7 @@ class ToolBuiltinListApi(Resource):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
return jsonable_encoder([provider.to_dict() for provider in ToolManageService.list_builtin_tools(
|
||||
return jsonable_encoder([provider.to_dict() for provider in BuiltinToolManageService.list_builtin_tools(
|
||||
user_id,
|
||||
tenant_id,
|
||||
)])
|
||||
@ -317,18 +477,43 @@ class ToolApiListApi(Resource):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
return jsonable_encoder([provider.to_dict() for provider in ToolManageService.list_api_tools(
|
||||
return jsonable_encoder([provider.to_dict() for provider in ApiToolManageService.list_api_tools(
|
||||
user_id,
|
||||
tenant_id,
|
||||
)])
|
||||
|
||||
class ToolWorkflowListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
return jsonable_encoder([provider.to_dict() for provider in WorkflowToolManageService.list_tenant_workflow_tools(
|
||||
user_id,
|
||||
tenant_id,
|
||||
)])
|
||||
|
||||
class ToolLabelsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
return jsonable_encoder(ToolLabelsService.list_tool_labels())
|
||||
|
||||
# tool provider
|
||||
api.add_resource(ToolProviderListApi, '/workspaces/current/tool-providers')
|
||||
|
||||
# builtin tool provider
|
||||
api.add_resource(ToolBuiltinProviderListToolsApi, '/workspaces/current/tool-provider/builtin/<provider>/tools')
|
||||
api.add_resource(ToolBuiltinProviderDeleteApi, '/workspaces/current/tool-provider/builtin/<provider>/delete')
|
||||
api.add_resource(ToolBuiltinProviderUpdateApi, '/workspaces/current/tool-provider/builtin/<provider>/update')
|
||||
api.add_resource(ToolBuiltinProviderGetCredentialsApi, '/workspaces/current/tool-provider/builtin/<provider>/credentials')
|
||||
api.add_resource(ToolBuiltinProviderCredentialsSchemaApi, '/workspaces/current/tool-provider/builtin/<provider>/credentials_schema')
|
||||
api.add_resource(ToolBuiltinProviderIconApi, '/workspaces/current/tool-provider/builtin/<provider>/icon')
|
||||
|
||||
# api tool provider
|
||||
api.add_resource(ToolApiProviderAddApi, '/workspaces/current/tool-provider/api/add')
|
||||
api.add_resource(ToolApiProviderGetRemoteSchemaApi, '/workspaces/current/tool-provider/api/remote')
|
||||
api.add_resource(ToolApiProviderListToolsApi, '/workspaces/current/tool-provider/api/tools')
|
||||
@ -338,5 +523,15 @@ api.add_resource(ToolApiProviderGetApi, '/workspaces/current/tool-provider/api/g
|
||||
api.add_resource(ToolApiProviderSchemaApi, '/workspaces/current/tool-provider/api/schema')
|
||||
api.add_resource(ToolApiProviderPreviousTestApi, '/workspaces/current/tool-provider/api/test/pre')
|
||||
|
||||
# workflow tool provider
|
||||
api.add_resource(ToolWorkflowProviderCreateApi, '/workspaces/current/tool-provider/workflow/create')
|
||||
api.add_resource(ToolWorkflowProviderUpdateApi, '/workspaces/current/tool-provider/workflow/update')
|
||||
api.add_resource(ToolWorkflowProviderDeleteApi, '/workspaces/current/tool-provider/workflow/delete')
|
||||
api.add_resource(ToolWorkflowProviderGetApi, '/workspaces/current/tool-provider/workflow/get')
|
||||
api.add_resource(ToolWorkflowProviderListToolApi, '/workspaces/current/tool-provider/workflow/tools')
|
||||
|
||||
api.add_resource(ToolBuiltinListApi, '/workspaces/current/tools/builtin')
|
||||
api.add_resource(ToolApiListApi, '/workspaces/current/tools/api')
|
||||
api.add_resource(ToolApiListApi, '/workspaces/current/tools/api')
|
||||
api.add_resource(ToolWorkflowListApi, '/workspaces/current/tools/workflow')
|
||||
|
||||
api.add_resource(ToolLabelsApi, '/workspaces/current/tool-labels')
|
@ -165,6 +165,7 @@ class BaseAgentRunner(AppRunner):
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_config.app_id,
|
||||
agent_tool=tool,
|
||||
invoke_from=self.application_generate_entity.invoke_from
|
||||
)
|
||||
tool_entity.load_variables(self.variables_pool)
|
||||
|
||||
|
@ -8,7 +8,7 @@ class AgentToolEntity(BaseModel):
|
||||
"""
|
||||
Agent Tool Entity.
|
||||
"""
|
||||
provider_type: Literal["builtin", "api"]
|
||||
provider_type: Literal["builtin", "api", "workflow"]
|
||||
provider_id: str
|
||||
tool_name: str
|
||||
tool_parameters: dict[str, Any] = {}
|
||||
|
@ -1,7 +1,7 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentPromptEntity, AgentToolEntity
|
||||
from core.tools.prompt.template import REACT_PROMPT_TEMPLATES
|
||||
from core.agent.prompt.template import REACT_PROMPT_TEMPLATES
|
||||
|
||||
|
||||
class AgentConfigManager:
|
||||
|
@ -239,4 +239,4 @@ class WorkflowUIBasedAppConfig(AppConfig):
|
||||
"""
|
||||
Workflow UI Based App Config Entity.
|
||||
"""
|
||||
workflow_id: str
|
||||
workflow_id: str
|
@ -98,6 +98,90 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
extras=extras
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
invoke_from=invoke_from,
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=conversation,
|
||||
stream=stream
|
||||
)
|
||||
|
||||
def single_iteration_generate(self, app_model: App,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user: Account,
|
||||
args: dict,
|
||||
stream: bool = True) \
|
||||
-> Union[dict, Generator[dict, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
:param app_model: App
|
||||
:param workflow: Workflow
|
||||
:param user: account or end user
|
||||
:param args: request args
|
||||
:param invoke_from: invoke from source
|
||||
:param stream: is stream
|
||||
"""
|
||||
if not node_id:
|
||||
raise ValueError('node_id is required')
|
||||
|
||||
if args.get('inputs') is None:
|
||||
raise ValueError('inputs is required')
|
||||
|
||||
extras = {
|
||||
"auto_generate_conversation_name": False
|
||||
}
|
||||
|
||||
# get conversation
|
||||
conversation = None
|
||||
if args.get('conversation_id'):
|
||||
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user)
|
||||
|
||||
# convert to app config
|
||||
app_config = AdvancedChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
workflow=workflow
|
||||
)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = AdvancedChatAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
inputs={},
|
||||
query='',
|
||||
files=[],
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
extras=extras,
|
||||
single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity(
|
||||
node_id=node_id,
|
||||
inputs=args['inputs']
|
||||
)
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=conversation,
|
||||
stream=stream
|
||||
)
|
||||
|
||||
def _generate(self, app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
invoke_from: InvokeFrom,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
conversation: Conversation = None,
|
||||
stream: bool = True) \
|
||||
-> Union[dict, Generator[dict, None, None]]:
|
||||
is_first_conversation = False
|
||||
if not conversation:
|
||||
is_first_conversation = True
|
||||
@ -167,18 +251,30 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
"""
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
# get conversation and message
|
||||
conversation = self._get_conversation(conversation_id)
|
||||
message = self._get_message(message_id)
|
||||
|
||||
# chatbot app
|
||||
runner = AdvancedChatAppRunner()
|
||||
runner.run(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message
|
||||
)
|
||||
if application_generate_entity.single_iteration_run:
|
||||
single_iteration_run = application_generate_entity.single_iteration_run
|
||||
runner.single_iteration_run(
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
workflow_id=application_generate_entity.app_config.workflow_id,
|
||||
queue_manager=queue_manager,
|
||||
inputs=single_iteration_run.inputs,
|
||||
node_id=single_iteration_run.node_id,
|
||||
user_id=application_generate_entity.user_id
|
||||
)
|
||||
else:
|
||||
# get conversation and message
|
||||
conversation = self._get_conversation(conversation_id)
|
||||
message = self._get_message(message_id)
|
||||
|
||||
# chatbot app
|
||||
runner = AdvancedChatAppRunner()
|
||||
runner.run(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message
|
||||
)
|
||||
except GenerateTaskStoppedException:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
|
@ -102,6 +102,7 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
user_from=UserFrom.ACCOUNT
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
|
||||
else UserFrom.END_USER,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
user_inputs=inputs,
|
||||
system_inputs={
|
||||
SystemVariable.QUERY: query,
|
||||
@ -109,6 +110,35 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
SystemVariable.CONVERSATION_ID: conversation.id,
|
||||
SystemVariable.USER_ID: user_id
|
||||
},
|
||||
callbacks=workflow_callbacks,
|
||||
call_depth=application_generate_entity.call_depth
|
||||
)
|
||||
|
||||
def single_iteration_run(self, app_id: str, workflow_id: str,
|
||||
queue_manager: AppQueueManager,
|
||||
inputs: dict, node_id: str, user_id: str) -> None:
|
||||
"""
|
||||
Single iteration run
|
||||
"""
|
||||
app_record: App = db.session.query(App).filter(App.id == app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
workflow_callbacks = [WorkflowEventTriggerCallback(
|
||||
queue_manager=queue_manager,
|
||||
workflow=workflow
|
||||
)]
|
||||
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
workflow_engine_manager.single_step_run_iteration_workflow_node(
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
user_id=user_id,
|
||||
user_inputs=inputs,
|
||||
callbacks=workflow_callbacks
|
||||
)
|
||||
|
||||
|
@ -12,6 +12,9 @@ from core.app.entities.queue_entities import (
|
||||
QueueAdvancedChatMessageEndEvent,
|
||||
QueueAnnotationReplyEvent,
|
||||
QueueErrorEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueMessageReplaceEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
@ -64,6 +67,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
_workflow: Workflow
|
||||
_user: Union[Account, EndUser]
|
||||
_workflow_system_variables: dict[SystemVariable, Any]
|
||||
_iteration_nested_relations: dict[str, list[str]]
|
||||
|
||||
def __init__(self, application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
@ -104,6 +108,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
)
|
||||
|
||||
self._stream_generate_routes = self._get_stream_generate_routes()
|
||||
self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict)
|
||||
self._conversation_name_generate_thread = None
|
||||
|
||||
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||
@ -204,6 +209,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
# search stream_generate_routes if node id is answer start at node
|
||||
if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_routes:
|
||||
self._task_state.current_stream_generate_state = self._stream_generate_routes[event.node_id]
|
||||
# reset current route position to 0
|
||||
self._task_state.current_stream_generate_state.current_route_position = 0
|
||||
|
||||
# generate stream outputs when node started
|
||||
yield from self._generate_stream_outputs_when_node_started()
|
||||
@ -225,6 +232,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
)
|
||||
|
||||
if isinstance(event, QueueNodeFailedEvent):
|
||||
yield from self._handle_iteration_exception(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
error=f'Child node failed: {event.error}'
|
||||
)
|
||||
elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent):
|
||||
if isinstance(event, QueueIterationNextEvent):
|
||||
# clear ran node execution infos of current iteration
|
||||
iteration_relations = self._iteration_nested_relations.get(event.node_id)
|
||||
if iteration_relations:
|
||||
for node_id in iteration_relations:
|
||||
self._task_state.ran_node_execution_infos.pop(node_id, None)
|
||||
|
||||
yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event)
|
||||
self._handle_iteration_operation(event)
|
||||
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
|
||||
workflow_run = self._handle_workflow_finished(event)
|
||||
if workflow_run:
|
||||
@ -263,10 +286,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
self._handle_retriever_resources(event)
|
||||
elif isinstance(event, QueueAnnotationReplyEvent):
|
||||
self._handle_annotation_reply(event)
|
||||
# elif isinstance(event, QueueMessageFileEvent):
|
||||
# response = self._message_file_to_stream_response(event)
|
||||
# if response:
|
||||
# yield response
|
||||
elif isinstance(event, QueueTextChunkEvent):
|
||||
delta_text = event.text
|
||||
if delta_text is None:
|
||||
@ -342,7 +361,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
id=self._message.id,
|
||||
**extras
|
||||
)
|
||||
|
||||
|
||||
def _get_stream_generate_routes(self) -> dict[str, ChatflowStreamGenerateRoute]:
|
||||
"""
|
||||
Get stream generate routes.
|
||||
@ -372,7 +391,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
)
|
||||
|
||||
return stream_generate_routes
|
||||
|
||||
|
||||
def _get_answer_start_at_node_ids(self, graph: dict, target_node_id: str) \
|
||||
-> list[str]:
|
||||
"""
|
||||
@ -401,14 +420,23 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
continue
|
||||
|
||||
node_type = source_node.get('data', {}).get('type')
|
||||
node_iteration_id = source_node.get('data', {}).get('iteration_id')
|
||||
iteration_start_node_id = None
|
||||
if node_iteration_id:
|
||||
iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None)
|
||||
iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id')
|
||||
|
||||
if node_type in [
|
||||
NodeType.ANSWER.value,
|
||||
NodeType.IF_ELSE.value,
|
||||
NodeType.QUESTION_CLASSIFIER.value
|
||||
NodeType.QUESTION_CLASSIFIER.value,
|
||||
NodeType.ITERATION.value,
|
||||
NodeType.LOOP.value
|
||||
]:
|
||||
start_node_id = target_node_id
|
||||
start_node_ids.append(start_node_id)
|
||||
elif node_type == NodeType.START.value:
|
||||
elif node_type == NodeType.START.value or \
|
||||
node_iteration_id is not None and iteration_start_node_id == source_node.get('id'):
|
||||
start_node_id = source_node_id
|
||||
start_node_ids.append(start_node_id)
|
||||
else:
|
||||
@ -417,7 +445,27 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
start_node_ids.extend(sub_start_node_ids)
|
||||
|
||||
return start_node_ids
|
||||
|
||||
def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]:
|
||||
"""
|
||||
Get iteration nested relations.
|
||||
:param graph: graph
|
||||
:return:
|
||||
"""
|
||||
nodes = graph.get('nodes')
|
||||
|
||||
iteration_ids = [node.get('id') for node in nodes
|
||||
if node.get('data', {}).get('type') in [
|
||||
NodeType.ITERATION.value,
|
||||
NodeType.LOOP.value,
|
||||
]]
|
||||
|
||||
return {
|
||||
iteration_id: [
|
||||
node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id
|
||||
] for iteration_id in iteration_ids
|
||||
}
|
||||
|
||||
def _generate_stream_outputs_when_node_started(self) -> Generator:
|
||||
"""
|
||||
Generate stream outputs.
|
||||
@ -425,7 +473,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
"""
|
||||
if self._task_state.current_stream_generate_state:
|
||||
route_chunks = self._task_state.current_stream_generate_state.generate_route[
|
||||
self._task_state.current_stream_generate_state.current_route_position:]
|
||||
self._task_state.current_stream_generate_state.current_route_position:
|
||||
]
|
||||
|
||||
for route_chunk in route_chunks:
|
||||
if route_chunk.type == 'text':
|
||||
@ -458,7 +507,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
|
||||
route_chunks = self._task_state.current_stream_generate_state.generate_route[
|
||||
self._task_state.current_stream_generate_state.current_route_position:]
|
||||
|
||||
|
||||
for route_chunk in route_chunks:
|
||||
if route_chunk.type == 'text':
|
||||
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
|
||||
|
@ -1,8 +1,11 @@
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
@ -130,6 +133,66 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_iteration_started(self,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_run_index: int = 1,
|
||||
node_data: Optional[BaseNodeData] = None,
|
||||
inputs: dict = None,
|
||||
predecessor_node_id: Optional[str] = None,
|
||||
metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Publish iteration started
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueIterationStartEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_run_index=node_run_index,
|
||||
node_data=node_data,
|
||||
inputs=inputs,
|
||||
predecessor_node_id=predecessor_node_id,
|
||||
metadata=metadata
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_iteration_next(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
index: int,
|
||||
node_run_index: int,
|
||||
output: Optional[Any]) -> None:
|
||||
"""
|
||||
Publish iteration next
|
||||
"""
|
||||
self._queue_manager._publish(
|
||||
QueueIterationNextEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
index=index,
|
||||
node_run_index=node_run_index,
|
||||
output=output
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_iteration_completed(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_run_index: int,
|
||||
outputs: dict) -> None:
|
||||
"""
|
||||
Publish iteration completed
|
||||
"""
|
||||
self._queue_manager._publish(
|
||||
QueueIterationCompletedEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_run_index=node_run_index,
|
||||
outputs=outputs
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_event(self, event: AppQueueEvent) -> None:
|
||||
"""
|
||||
Publish event
|
||||
|
@ -115,7 +115,8 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
extras=extras
|
||||
extras=extras,
|
||||
call_depth=0
|
||||
)
|
||||
|
||||
# init generate records
|
||||
|
@ -34,7 +34,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True) \
|
||||
stream: bool = True,
|
||||
call_depth: int = 0) \
|
||||
-> Union[dict, Generator[dict, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
@ -75,9 +76,38 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
files=file_objs,
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=invoke_from
|
||||
invoke_from=invoke_from,
|
||||
call_depth=call_depth
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
application_generate_entity=application_generate_entity,
|
||||
invoke_from=invoke_from,
|
||||
stream=stream,
|
||||
call_depth=call_depth
|
||||
)
|
||||
|
||||
def _generate(self, app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
call_depth: int = 0) \
|
||||
-> Union[dict, Generator[dict, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
:param app_model: App
|
||||
:param workflow: Workflow
|
||||
:param user: account or end user
|
||||
:param application_generate_entity: application generate entity
|
||||
:param invoke_from: invoke from source
|
||||
:param stream: is stream
|
||||
"""
|
||||
# init queue manager
|
||||
queue_manager = WorkflowAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
@ -109,6 +139,64 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
|
||||
def single_iteration_generate(self, app_model: App,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user: Account,
|
||||
args: dict,
|
||||
stream: bool = True) \
|
||||
-> Union[dict, Generator[dict, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
:param app_model: App
|
||||
:param workflow: Workflow
|
||||
:param user: account or end user
|
||||
:param args: request args
|
||||
:param invoke_from: invoke from source
|
||||
:param stream: is stream
|
||||
"""
|
||||
if not node_id:
|
||||
raise ValueError('node_id is required')
|
||||
|
||||
if args.get('inputs') is None:
|
||||
raise ValueError('inputs is required')
|
||||
|
||||
extras = {
|
||||
"auto_generate_conversation_name": False
|
||||
}
|
||||
|
||||
# convert to app config
|
||||
app_config = WorkflowAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
workflow=workflow
|
||||
)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = WorkflowAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
extras=extras,
|
||||
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
|
||||
node_id=node_id,
|
||||
inputs=args['inputs']
|
||||
)
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
application_generate_entity=application_generate_entity,
|
||||
stream=stream
|
||||
)
|
||||
|
||||
def _generate_worker(self, flask_app: Flask,
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
queue_manager: AppQueueManager) -> None:
|
||||
@ -123,10 +211,21 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
try:
|
||||
# workflow app
|
||||
runner = WorkflowAppRunner()
|
||||
runner.run(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager
|
||||
)
|
||||
if application_generate_entity.single_iteration_run:
|
||||
single_iteration_run = application_generate_entity.single_iteration_run
|
||||
runner.single_iteration_run(
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
workflow_id=application_generate_entity.app_config.workflow_id,
|
||||
queue_manager=queue_manager,
|
||||
inputs=single_iteration_run.inputs,
|
||||
node_id=single_iteration_run.node_id,
|
||||
user_id=application_generate_entity.user_id
|
||||
)
|
||||
else:
|
||||
runner.run(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager
|
||||
)
|
||||
except GenerateTaskStoppedException:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
|
@ -73,11 +73,44 @@ class WorkflowAppRunner:
|
||||
user_from=UserFrom.ACCOUNT
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
|
||||
else UserFrom.END_USER,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
user_inputs=inputs,
|
||||
system_inputs={
|
||||
SystemVariable.FILES: files,
|
||||
SystemVariable.USER_ID: user_id
|
||||
},
|
||||
callbacks=workflow_callbacks,
|
||||
call_depth=application_generate_entity.call_depth
|
||||
)
|
||||
|
||||
def single_iteration_run(self, app_id: str, workflow_id: str,
|
||||
queue_manager: AppQueueManager,
|
||||
inputs: dict, node_id: str, user_id: str) -> None:
|
||||
"""
|
||||
Single iteration run
|
||||
"""
|
||||
app_record: App = db.session.query(App).filter(App.id == app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
if not app_record.workflow_id:
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
workflow_callbacks = [WorkflowEventTriggerCallback(
|
||||
queue_manager=queue_manager,
|
||||
workflow=workflow
|
||||
)]
|
||||
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
workflow_engine_manager.single_step_run_iteration_workflow_node(
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
user_id=user_id,
|
||||
user_inputs=inputs,
|
||||
callbacks=workflow_callbacks
|
||||
)
|
||||
|
||||
|
@ -9,6 +9,9 @@ from core.app.entities.app_invoke_entities import (
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueErrorEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueMessageReplaceEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
@ -58,6 +61,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
_task_state: WorkflowTaskState
|
||||
_application_generate_entity: WorkflowAppGenerateEntity
|
||||
_workflow_system_variables: dict[SystemVariable, Any]
|
||||
_iteration_nested_relations: dict[str, list[str]]
|
||||
|
||||
def __init__(self, application_generate_entity: WorkflowAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
@ -85,8 +89,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
SystemVariable.USER_ID: user_id
|
||||
}
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
self._task_state = WorkflowTaskState(
|
||||
iteration_nested_node_ids=[]
|
||||
)
|
||||
self._stream_generate_nodes = self._get_stream_generate_nodes()
|
||||
self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict)
|
||||
|
||||
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
"""
|
||||
@ -191,6 +198,22 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
)
|
||||
|
||||
if isinstance(event, QueueNodeFailedEvent):
|
||||
yield from self._handle_iteration_exception(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
error=f'Child node failed: {event.error}'
|
||||
)
|
||||
elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent):
|
||||
if isinstance(event, QueueIterationNextEvent):
|
||||
# clear ran node execution infos of current iteration
|
||||
iteration_relations = self._iteration_nested_relations.get(event.node_id)
|
||||
if iteration_relations:
|
||||
for node_id in iteration_relations:
|
||||
self._task_state.ran_node_execution_infos.pop(node_id, None)
|
||||
|
||||
yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event)
|
||||
self._handle_iteration_operation(event)
|
||||
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
|
||||
workflow_run = self._handle_workflow_finished(event)
|
||||
|
||||
@ -331,13 +354,20 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
continue
|
||||
|
||||
node_type = source_node.get('data', {}).get('type')
|
||||
node_iteration_id = source_node.get('data', {}).get('iteration_id')
|
||||
iteration_start_node_id = None
|
||||
if node_iteration_id:
|
||||
iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None)
|
||||
iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id')
|
||||
|
||||
if node_type in [
|
||||
NodeType.IF_ELSE.value,
|
||||
NodeType.QUESTION_CLASSIFIER.value
|
||||
]:
|
||||
start_node_id = target_node_id
|
||||
start_node_ids.append(start_node_id)
|
||||
elif node_type == NodeType.START.value:
|
||||
elif node_type == NodeType.START.value or \
|
||||
node_iteration_id is not None and iteration_start_node_id == source_node.get('id'):
|
||||
start_node_id = source_node_id
|
||||
start_node_ids.append(start_node_id)
|
||||
else:
|
||||
@ -411,3 +441,24 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]:
|
||||
"""
|
||||
Get iteration nested relations.
|
||||
:param graph: graph
|
||||
:return:
|
||||
"""
|
||||
nodes = graph.get('nodes')
|
||||
|
||||
iteration_ids = [node.get('id') for node in nodes
|
||||
if node.get('data', {}).get('type') in [
|
||||
NodeType.ITERATION.value,
|
||||
NodeType.LOOP.value,
|
||||
]]
|
||||
|
||||
return {
|
||||
iteration_id: [
|
||||
node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id
|
||||
] for iteration_id in iteration_ids
|
||||
}
|
||||
|
@ -1,8 +1,11 @@
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
@ -130,6 +133,66 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_iteration_started(self,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_run_index: int = 1,
|
||||
node_data: Optional[BaseNodeData] = None,
|
||||
inputs: dict = None,
|
||||
predecessor_node_id: Optional[str] = None,
|
||||
metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Publish iteration started
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueIterationStartEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_run_index=node_run_index,
|
||||
node_data=node_data,
|
||||
inputs=inputs,
|
||||
predecessor_node_id=predecessor_node_id,
|
||||
metadata=metadata
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_iteration_next(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
index: int,
|
||||
node_run_index: int,
|
||||
output: Optional[Any]) -> None:
|
||||
"""
|
||||
Publish iteration next
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueIterationNextEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
index=index,
|
||||
node_run_index=node_run_index,
|
||||
output=output
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_iteration_completed(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_run_index: int,
|
||||
outputs: dict) -> None:
|
||||
"""
|
||||
Publish iteration completed
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueIterationCompletedEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_run_index=node_run_index,
|
||||
outputs=outputs
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_event(self, event: AppQueueEvent) -> None:
|
||||
"""
|
||||
Publish event
|
||||
|
@ -102,6 +102,39 @@ class WorkflowLoggingCallback(BaseWorkflowCallback):
|
||||
|
||||
self.print_text(text, color="pink", end="")
|
||||
|
||||
def on_workflow_iteration_started(self,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_run_index: int = 1,
|
||||
node_data: Optional[BaseNodeData] = None,
|
||||
inputs: dict = None,
|
||||
predecessor_node_id: Optional[str] = None,
|
||||
metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Publish iteration started
|
||||
"""
|
||||
self.print_text("\n[on_workflow_iteration_started]", color='blue')
|
||||
self.print_text(f"Node ID: {node_id}", color='blue')
|
||||
|
||||
def on_workflow_iteration_next(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
index: int,
|
||||
node_run_index: int,
|
||||
output: Optional[dict]) -> None:
|
||||
"""
|
||||
Publish iteration next
|
||||
"""
|
||||
self.print_text("\n[on_workflow_iteration_next]", color='blue')
|
||||
|
||||
def on_workflow_iteration_completed(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_run_index: int,
|
||||
outputs: dict) -> None:
|
||||
"""
|
||||
Publish iteration completed
|
||||
"""
|
||||
self.print_text("\n[on_workflow_iteration_completed]", color='blue')
|
||||
|
||||
def on_event(self, event: AppQueueEvent) -> None:
|
||||
"""
|
||||
Publish event
|
||||
|
@ -80,6 +80,9 @@ class AppGenerateEntity(BaseModel):
|
||||
stream: bool
|
||||
invoke_from: InvokeFrom
|
||||
|
||||
# invoke call depth
|
||||
call_depth: int = 0
|
||||
|
||||
# extra parameters, like: auto_generate_conversation_name
|
||||
extras: dict[str, Any] = {}
|
||||
|
||||
@ -126,6 +129,14 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity):
|
||||
conversation_id: Optional[str] = None
|
||||
query: Optional[str] = None
|
||||
|
||||
class SingleIterationRunEntity(BaseModel):
|
||||
"""
|
||||
Single Iteration Run Entity.
|
||||
"""
|
||||
node_id: str
|
||||
inputs: dict
|
||||
|
||||
single_iteration_run: Optional[SingleIterationRunEntity] = None
|
||||
|
||||
class WorkflowAppGenerateEntity(AppGenerateEntity):
|
||||
"""
|
||||
@ -133,3 +144,12 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
|
||||
"""
|
||||
# app config
|
||||
app_config: WorkflowUIBasedAppConfig
|
||||
|
||||
class SingleIterationRunEntity(BaseModel):
|
||||
"""
|
||||
Single Iteration Run Entity.
|
||||
"""
|
||||
node_id: str
|
||||
inputs: dict
|
||||
|
||||
single_iteration_run: Optional[SingleIterationRunEntity] = None
|
@ -1,7 +1,7 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
@ -21,6 +21,9 @@ class QueueEvent(Enum):
|
||||
WORKFLOW_STARTED = "workflow_started"
|
||||
WORKFLOW_SUCCEEDED = "workflow_succeeded"
|
||||
WORKFLOW_FAILED = "workflow_failed"
|
||||
ITERATION_START = "iteration_start"
|
||||
ITERATION_NEXT = "iteration_next"
|
||||
ITERATION_COMPLETED = "iteration_completed"
|
||||
NODE_STARTED = "node_started"
|
||||
NODE_SUCCEEDED = "node_succeeded"
|
||||
NODE_FAILED = "node_failed"
|
||||
@ -47,6 +50,55 @@ class QueueLLMChunkEvent(AppQueueEvent):
|
||||
event = QueueEvent.LLM_CHUNK
|
||||
chunk: LLMResultChunk
|
||||
|
||||
class QueueIterationStartEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueIterationStartEvent entity
|
||||
"""
|
||||
event = QueueEvent.ITERATION_START
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
|
||||
node_run_index: int
|
||||
inputs: dict = None
|
||||
predecessor_node_id: Optional[str] = None
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
class QueueIterationNextEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueIterationNextEvent entity
|
||||
"""
|
||||
event = QueueEvent.ITERATION_NEXT
|
||||
|
||||
index: int
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
|
||||
node_run_index: int
|
||||
output: Optional[Any] # output for the current iteration
|
||||
|
||||
@validator('output', pre=True, always=True)
|
||||
def set_output(cls, v):
|
||||
"""
|
||||
Set output
|
||||
"""
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, int | float | str | bool | dict | list):
|
||||
return v
|
||||
raise ValueError('output must be a valid type')
|
||||
|
||||
class QueueIterationCompletedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueIterationCompletedEvent entity
|
||||
"""
|
||||
event = QueueEvent.ITERATION_COMPLETED
|
||||
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
|
||||
node_run_index: int
|
||||
outputs: dict
|
||||
|
||||
class QueueTextChunkEvent(AppQueueEvent):
|
||||
"""
|
||||
|
@ -1,12 +1,14 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.nodes.answer.entities import GenerateRouteChunk
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class WorkflowStreamGenerateNodes(BaseModel):
|
||||
@ -65,6 +67,7 @@ class WorkflowTaskState(TaskState):
|
||||
|
||||
current_stream_generate_state: Optional[WorkflowStreamGenerateNodes] = None
|
||||
|
||||
iteration_nested_node_ids: list[str] = None
|
||||
|
||||
class AdvancedChatTaskState(WorkflowTaskState):
|
||||
"""
|
||||
@ -91,6 +94,9 @@ class StreamEvent(Enum):
|
||||
WORKFLOW_FINISHED = "workflow_finished"
|
||||
NODE_STARTED = "node_started"
|
||||
NODE_FINISHED = "node_finished"
|
||||
ITERATION_STARTED = "iteration_started"
|
||||
ITERATION_NEXT = "iteration_next"
|
||||
ITERATION_COMPLETED = "iteration_completed"
|
||||
TEXT_CHUNK = "text_chunk"
|
||||
TEXT_REPLACE = "text_replace"
|
||||
|
||||
@ -319,6 +325,74 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
}
|
||||
}
|
||||
|
||||
class IterationNodeStartStreamResponse(StreamResponse):
|
||||
"""
|
||||
NodeStartStreamResponse entity
|
||||
"""
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
id: str
|
||||
node_id: str
|
||||
node_type: str
|
||||
title: str
|
||||
created_at: int
|
||||
extras: dict = {}
|
||||
metadata: dict = {}
|
||||
inputs: dict = {}
|
||||
|
||||
event: StreamEvent = StreamEvent.ITERATION_STARTED
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
class IterationNodeNextStreamResponse(StreamResponse):
|
||||
"""
|
||||
NodeStartStreamResponse entity
|
||||
"""
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
id: str
|
||||
node_id: str
|
||||
node_type: str
|
||||
title: str
|
||||
index: int
|
||||
created_at: int
|
||||
pre_iteration_output: Optional[Any]
|
||||
extras: dict = {}
|
||||
|
||||
event: StreamEvent = StreamEvent.ITERATION_NEXT
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
class IterationNodeCompletedStreamResponse(StreamResponse):
|
||||
"""
|
||||
NodeStartStreamResponse entity
|
||||
"""
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
id: str
|
||||
node_id: str
|
||||
node_type: str
|
||||
title: str
|
||||
outputs: Optional[dict]
|
||||
created_at: int
|
||||
extras: dict = None
|
||||
inputs: dict = None
|
||||
status: WorkflowNodeExecutionStatus
|
||||
error: Optional[str]
|
||||
elapsed_time: float
|
||||
total_tokens: int
|
||||
finished_at: int
|
||||
steps: int
|
||||
|
||||
event: StreamEvent = StreamEvent.ITERATION_COMPLETED
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
class TextChunkStreamResponse(StreamResponse):
|
||||
"""
|
||||
@ -454,3 +528,23 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
|
||||
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
class WorkflowIterationState(BaseModel):
|
||||
"""
|
||||
WorkflowIterationState entity
|
||||
"""
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
parent_iteration_id: Optional[str] = None
|
||||
iteration_id: str
|
||||
current_index: int
|
||||
iteration_steps_boundary: list[int] = None
|
||||
node_execution_id: str
|
||||
started_at: float
|
||||
inputs: dict = None
|
||||
total_tokens: int = 0
|
||||
node_data: BaseNodeData
|
||||
|
||||
current_iterations: dict[str, Data] = None
|
@ -1,9 +1,9 @@
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional, Union, cast
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
@ -13,18 +13,17 @@ from core.app.entities.queue_entities import (
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AdvancedChatTaskState,
|
||||
NodeExecutionInfo,
|
||||
NodeFinishStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
WorkflowFinishStreamResponse,
|
||||
WorkflowStartStreamResponse,
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.app.task_pipeline.workflow_iteration_cycle_manage import WorkflowIterationCycleManage
|
||||
from core.file.file_obj import FileVar
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType, SystemVariable
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from extensions.ext_database import db
|
||||
@ -42,13 +41,7 @@ from models.workflow import (
|
||||
)
|
||||
|
||||
|
||||
class WorkflowCycleManage:
|
||||
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
|
||||
_workflow: Workflow
|
||||
_user: Union[Account, EndUser]
|
||||
_task_state: Union[AdvancedChatTaskState, WorkflowTaskState]
|
||||
_workflow_system_variables: dict[SystemVariable, Any]
|
||||
|
||||
class WorkflowCycleManage(WorkflowIterationCycleManage):
|
||||
def _init_workflow_run(self, workflow: Workflow,
|
||||
triggered_from: WorkflowRunTriggeredFrom,
|
||||
user: Union[Account, EndUser],
|
||||
@ -237,6 +230,7 @@ class WorkflowCycleManage:
|
||||
inputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
execution_metadata: Optional[dict] = None
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Workflow node execution failed
|
||||
@ -255,6 +249,8 @@ class WorkflowCycleManage:
|
||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
|
||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||
workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \
|
||||
if execution_metadata else None
|
||||
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_node_execution)
|
||||
@ -444,6 +440,23 @@ class WorkflowCycleManage:
|
||||
current_node_execution = self._task_state.ran_node_execution_infos[event.node_id]
|
||||
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first()
|
||||
|
||||
execution_metadata = event.execution_metadata if isinstance(event, QueueNodeSucceededEvent) else None
|
||||
|
||||
if self._iteration_state and self._iteration_state.current_iterations:
|
||||
if not execution_metadata:
|
||||
execution_metadata = {}
|
||||
current_iteration_data = None
|
||||
for iteration_node_id in self._iteration_state.current_iterations:
|
||||
data = self._iteration_state.current_iterations[iteration_node_id]
|
||||
if data.parent_iteration_id == None:
|
||||
current_iteration_data = data
|
||||
break
|
||||
|
||||
if current_iteration_data:
|
||||
execution_metadata[NodeRunMetadataKey.ITERATION_ID] = current_iteration_data.iteration_id
|
||||
execution_metadata[NodeRunMetadataKey.ITERATION_INDEX] = current_iteration_data.current_index
|
||||
|
||||
if isinstance(event, QueueNodeSucceededEvent):
|
||||
workflow_node_execution = self._workflow_node_execution_success(
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
@ -451,12 +464,18 @@ class WorkflowCycleManage:
|
||||
inputs=event.inputs,
|
||||
process_data=event.process_data,
|
||||
outputs=event.outputs,
|
||||
execution_metadata=event.execution_metadata
|
||||
execution_metadata=execution_metadata
|
||||
)
|
||||
|
||||
if event.execution_metadata and event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
||||
if execution_metadata and execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
||||
self._task_state.total_tokens += (
|
||||
int(event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)))
|
||||
int(execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)))
|
||||
|
||||
if self._iteration_state:
|
||||
for iteration_node_id in self._iteration_state.current_iterations:
|
||||
data = self._iteration_state.current_iterations[iteration_node_id]
|
||||
if execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
||||
data.total_tokens += int(execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))
|
||||
|
||||
if workflow_node_execution.node_type == NodeType.LLM.value:
|
||||
outputs = workflow_node_execution.outputs_dict
|
||||
@ -469,7 +488,8 @@ class WorkflowCycleManage:
|
||||
error=event.error,
|
||||
inputs=event.inputs,
|
||||
process_data=event.process_data,
|
||||
outputs=event.outputs
|
||||
outputs=event.outputs,
|
||||
execution_metadata=execution_metadata
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
16
api/core/app/task_pipeline/workflow_cycle_state_manager.py
Normal file
16
api/core/app/task_pipeline/workflow_cycle_state_manager.py
Normal file
@ -0,0 +1,16 @@
|
||||
from typing import Any, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||
from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
class WorkflowCycleStateManager:
|
||||
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
|
||||
_workflow: Workflow
|
||||
_user: Union[Account, EndUser]
|
||||
_task_state: Union[AdvancedChatTaskState, WorkflowTaskState]
|
||||
_workflow_system_variables: dict[SystemVariable, Any]
|
281
api/core/app/task_pipeline/workflow_iteration_cycle_manage.py
Normal file
281
api/core/app/task_pipeline/workflow_iteration_cycle_manage.py
Normal file
@ -0,0 +1,281 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
IterationNodeCompletedStreamResponse,
|
||||
IterationNodeNextStreamResponse,
|
||||
IterationNodeStartStreamResponse,
|
||||
NodeExecutionInfo,
|
||||
WorkflowIterationState,
|
||||
)
|
||||
from core.app.task_pipeline.workflow_cycle_state_manager import WorkflowCycleStateManager
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from extensions.ext_database import db
|
||||
from models.workflow import (
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionStatus,
|
||||
WorkflowNodeExecutionTriggeredFrom,
|
||||
WorkflowRun,
|
||||
)
|
||||
|
||||
|
||||
class WorkflowIterationCycleManage(WorkflowCycleStateManager):
|
||||
_iteration_state: WorkflowIterationState = None
|
||||
|
||||
def _init_iteration_state(self) -> WorkflowIterationState:
|
||||
if not self._iteration_state:
|
||||
self._iteration_state = WorkflowIterationState(
|
||||
current_iterations={}
|
||||
)
|
||||
|
||||
def _handle_iteration_to_stream_response(self, task_id: str, event: QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent) \
|
||||
-> Union[IterationNodeStartStreamResponse, IterationNodeNextStreamResponse, IterationNodeCompletedStreamResponse]:
|
||||
"""
|
||||
Handle iteration to stream response
|
||||
:param task_id: task id
|
||||
:param event: iteration event
|
||||
:return:
|
||||
"""
|
||||
if isinstance(event, QueueIterationStartEvent):
|
||||
return IterationNodeStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=self._task_state.workflow_run_id,
|
||||
data=IterationNodeStartStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=event.inputs,
|
||||
metadata=event.metadata
|
||||
)
|
||||
)
|
||||
elif isinstance(event, QueueIterationNextEvent):
|
||||
current_iteration = self._iteration_state.current_iterations[event.node_id]
|
||||
|
||||
return IterationNodeNextStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=self._task_state.workflow_run_id,
|
||||
data=IterationNodeNextStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=current_iteration.node_data.title,
|
||||
index=event.index,
|
||||
pre_iteration_output=event.output,
|
||||
created_at=int(time.time()),
|
||||
extras={}
|
||||
)
|
||||
)
|
||||
elif isinstance(event, QueueIterationCompletedEvent):
|
||||
current_iteration = self._iteration_state.current_iterations[event.node_id]
|
||||
|
||||
return IterationNodeCompletedStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=self._task_state.workflow_run_id,
|
||||
data=IterationNodeCompletedStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=current_iteration.node_data.title,
|
||||
outputs=event.outputs,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=current_iteration.inputs,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
error=None,
|
||||
elapsed_time=time.perf_counter() - current_iteration.started_at,
|
||||
total_tokens=current_iteration.total_tokens,
|
||||
finished_at=int(time.time()),
|
||||
steps=current_iteration.current_index
|
||||
)
|
||||
)
|
||||
|
||||
def _init_iteration_execution_from_workflow_run(self,
|
||||
workflow_run: WorkflowRun,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_title: str,
|
||||
node_run_index: int = 1,
|
||||
inputs: Optional[dict] = None,
|
||||
predecessor_node_id: Optional[str] = None
|
||||
) -> WorkflowNodeExecution:
|
||||
workflow_node_execution = WorkflowNodeExecution(
|
||||
tenant_id=workflow_run.tenant_id,
|
||||
app_id=workflow_run.app_id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
workflow_run_id=workflow_run.id,
|
||||
predecessor_node_id=predecessor_node_id,
|
||||
index=node_run_index,
|
||||
node_id=node_id,
|
||||
node_type=node_type.value,
|
||||
inputs=json.dumps(inputs) if inputs else None,
|
||||
title=node_title,
|
||||
status=WorkflowNodeExecutionStatus.RUNNING.value,
|
||||
created_by_role=workflow_run.created_by_role,
|
||||
created_by=workflow_run.created_by,
|
||||
execution_metadata=json.dumps({
|
||||
'started_run_index': node_run_index + 1,
|
||||
'current_index': 0,
|
||||
'steps_boundary': [],
|
||||
})
|
||||
)
|
||||
|
||||
db.session.add(workflow_node_execution)
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_node_execution)
|
||||
db.session.close()
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_iteration_operation(self, event: QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent) -> WorkflowNodeExecution:
|
||||
if isinstance(event, QueueIterationStartEvent):
|
||||
return self._handle_iteration_started(event)
|
||||
elif isinstance(event, QueueIterationNextEvent):
|
||||
return self._handle_iteration_next(event)
|
||||
elif isinstance(event, QueueIterationCompletedEvent):
|
||||
return self._handle_iteration_completed(event)
|
||||
|
||||
def _handle_iteration_started(self, event: QueueIterationStartEvent) -> WorkflowNodeExecution:
|
||||
self._init_iteration_state()
|
||||
|
||||
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first()
|
||||
workflow_node_execution = self._init_iteration_execution_from_workflow_run(
|
||||
workflow_run=workflow_run,
|
||||
node_id=event.node_id,
|
||||
node_type=NodeType.ITERATION,
|
||||
node_title=event.node_data.title,
|
||||
node_run_index=event.node_run_index,
|
||||
inputs=event.inputs,
|
||||
predecessor_node_id=event.predecessor_node_id
|
||||
)
|
||||
|
||||
latest_node_execution_info = NodeExecutionInfo(
|
||||
workflow_node_execution_id=workflow_node_execution.id,
|
||||
node_type=NodeType.ITERATION,
|
||||
start_at=time.perf_counter()
|
||||
)
|
||||
|
||||
self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info
|
||||
self._task_state.latest_node_execution_info = latest_node_execution_info
|
||||
|
||||
self._iteration_state.current_iterations[event.node_id] = WorkflowIterationState.Data(
|
||||
parent_iteration_id=None,
|
||||
iteration_id=event.node_id,
|
||||
current_index=0,
|
||||
iteration_steps_boundary=[],
|
||||
node_execution_id=workflow_node_execution.id,
|
||||
started_at=time.perf_counter(),
|
||||
inputs=event.inputs,
|
||||
total_tokens=0,
|
||||
node_data=event.node_data
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_iteration_next(self, event: QueueIterationNextEvent) -> WorkflowNodeExecution:
|
||||
if event.node_id not in self._iteration_state.current_iterations:
|
||||
return
|
||||
current_iteration = self._iteration_state.current_iterations[event.node_id]
|
||||
current_iteration.current_index = event.index
|
||||
current_iteration.iteration_steps_boundary.append(event.node_run_index)
|
||||
workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == current_iteration.node_execution_id
|
||||
).first()
|
||||
|
||||
original_node_execution_metadata = workflow_node_execution.execution_metadata_dict
|
||||
if original_node_execution_metadata:
|
||||
original_node_execution_metadata['current_index'] = event.index
|
||||
original_node_execution_metadata['steps_boundary'] = current_iteration.iteration_steps_boundary
|
||||
original_node_execution_metadata['total_tokens'] = current_iteration.total_tokens
|
||||
workflow_node_execution.execution_metadata = json.dumps(original_node_execution_metadata)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
db.session.close()
|
||||
|
||||
def _handle_iteration_completed(self, event: QueueIterationCompletedEvent) -> WorkflowNodeExecution:
|
||||
if event.node_id not in self._iteration_state.current_iterations:
|
||||
return
|
||||
|
||||
current_iteration = self._iteration_state.current_iterations[event.node_id]
|
||||
workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == current_iteration.node_execution_id
|
||||
).first()
|
||||
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||
workflow_node_execution.outputs = json.dumps(event.outputs) if event.outputs else None
|
||||
workflow_node_execution.elapsed_time = time.perf_counter() - current_iteration.started_at
|
||||
|
||||
original_node_execution_metadata = workflow_node_execution.execution_metadata_dict
|
||||
if original_node_execution_metadata:
|
||||
original_node_execution_metadata['steps_boundary'] = current_iteration.iteration_steps_boundary
|
||||
original_node_execution_metadata['total_tokens'] = current_iteration.total_tokens
|
||||
workflow_node_execution.execution_metadata = json.dumps(original_node_execution_metadata)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# remove current iteration
|
||||
self._iteration_state.current_iterations.pop(event.node_id, None)
|
||||
|
||||
# set latest node execution info
|
||||
latest_node_execution_info = NodeExecutionInfo(
|
||||
workflow_node_execution_id=workflow_node_execution.id,
|
||||
node_type=NodeType.ITERATION,
|
||||
start_at=time.perf_counter()
|
||||
)
|
||||
|
||||
self._task_state.latest_node_execution_info = latest_node_execution_info
|
||||
|
||||
db.session.close()
|
||||
|
||||
def _handle_iteration_exception(self, task_id: str, error: str) -> Generator[IterationNodeCompletedStreamResponse, None, None]:
|
||||
"""
|
||||
Handle iteration exception
|
||||
"""
|
||||
if not self._iteration_state or not self._iteration_state.current_iterations:
|
||||
return
|
||||
|
||||
for node_id, current_iteration in self._iteration_state.current_iterations.items():
|
||||
workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == current_iteration.node_execution_id
|
||||
).first()
|
||||
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
workflow_node_execution.error = error
|
||||
workflow_node_execution.elapsed_time = time.perf_counter() - current_iteration.started_at
|
||||
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
yield IterationNodeCompletedStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=self._task_state.workflow_run_id,
|
||||
data=IterationNodeCompletedStreamResponse.Data(
|
||||
id=node_id,
|
||||
node_id=node_id,
|
||||
node_type=NodeType.ITERATION.value,
|
||||
title=current_iteration.node_data.title,
|
||||
outputs={},
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=current_iteration.inputs,
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=error,
|
||||
elapsed_time=time.perf_counter() - current_iteration.started_at,
|
||||
total_tokens=current_iteration.total_tokens,
|
||||
finished_at=int(time.time()),
|
||||
steps=current_iteration.current_index
|
||||
)
|
||||
)
|
@ -42,6 +42,8 @@ class MessageFileParser:
|
||||
raise ValueError('Invalid file url')
|
||||
if file.get('transfer_method') == FileTransferMethod.LOCAL_FILE.value and not file.get('upload_file_id'):
|
||||
raise ValueError('Missing file upload_file_id')
|
||||
if file.get('transform_method') == FileTransferMethod.TOOL_FILE.value and not file.get('tool_file_id'):
|
||||
raise ValueError('Missing file tool_file_id')
|
||||
|
||||
# transform files to file objs
|
||||
type_file_objs = self._to_file_objs(files, file_extra_config)
|
||||
@ -149,12 +151,21 @@ class MessageFileParser:
|
||||
"""
|
||||
if isinstance(file, dict):
|
||||
transfer_method = FileTransferMethod.value_of(file.get('transfer_method'))
|
||||
if transfer_method != FileTransferMethod.TOOL_FILE:
|
||||
return FileVar(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.value_of(file.get('type')),
|
||||
transfer_method=transfer_method,
|
||||
url=file.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None,
|
||||
related_id=file.get('upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None,
|
||||
extra_config=file_extra_config
|
||||
)
|
||||
return FileVar(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.value_of(file.get('type')),
|
||||
transfer_method=transfer_method,
|
||||
url=file.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None,
|
||||
related_id=file.get('upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None,
|
||||
url=None,
|
||||
related_id=file.get('tool_file_id'),
|
||||
extra_config=file_extra_config
|
||||
)
|
||||
else:
|
||||
|
@ -1,6 +1,7 @@
|
||||
from typing import cast
|
||||
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
@ -21,13 +22,25 @@ class PromptMessageUtil:
|
||||
"""
|
||||
prompts = []
|
||||
if model_mode == ModelMode.CHAT.value:
|
||||
tool_calls = []
|
||||
for prompt_message in prompt_messages:
|
||||
if prompt_message.role == PromptMessageRole.USER:
|
||||
role = 'user'
|
||||
elif prompt_message.role == PromptMessageRole.ASSISTANT:
|
||||
role = 'assistant'
|
||||
if isinstance(prompt_message, AssistantPromptMessage):
|
||||
tool_calls = [{
|
||||
'id': tool_call.id,
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': tool_call.function.name,
|
||||
'arguments': tool_call.function.arguments,
|
||||
}
|
||||
} for tool_call in prompt_message.tool_calls]
|
||||
elif prompt_message.role == PromptMessageRole.SYSTEM:
|
||||
role = 'system'
|
||||
elif prompt_message.role == PromptMessageRole.TOOL:
|
||||
role = 'tool'
|
||||
else:
|
||||
continue
|
||||
|
||||
@ -48,11 +61,16 @@ class PromptMessageUtil:
|
||||
else:
|
||||
text = prompt_message.content
|
||||
|
||||
prompts.append({
|
||||
prompt = {
|
||||
"role": role,
|
||||
"text": text,
|
||||
"files": files
|
||||
})
|
||||
}
|
||||
|
||||
if tool_calls:
|
||||
prompt['tool_calls'] = tool_calls
|
||||
|
||||
prompts.append(prompt)
|
||||
else:
|
||||
prompt_message = prompt_messages[0]
|
||||
text = ''
|
||||
|
@ -1,10 +1,10 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderCredentials
|
||||
from core.tools.entities.tool_entities import ToolProviderCredentials, ToolProviderType
|
||||
from core.tools.tool.tool import ToolParameter
|
||||
|
||||
|
||||
@ -14,27 +14,38 @@ class UserTool(BaseModel):
|
||||
label: I18nObject # label
|
||||
description: I18nObject
|
||||
parameters: Optional[list[ToolParameter]]
|
||||
labels: list[str] = None
|
||||
|
||||
UserToolProviderTypeLiteral = Optional[Literal[
|
||||
'builtin', 'api', 'workflow'
|
||||
]]
|
||||
|
||||
class UserToolProvider(BaseModel):
|
||||
class ProviderType(Enum):
|
||||
BUILTIN = "builtin"
|
||||
APP = "app"
|
||||
API = "api"
|
||||
|
||||
id: str
|
||||
author: str
|
||||
name: str # identifier
|
||||
description: I18nObject
|
||||
icon: str
|
||||
label: I18nObject # label
|
||||
type: ProviderType
|
||||
type: ToolProviderType
|
||||
masked_credentials: dict = None
|
||||
original_credentials: dict = None
|
||||
is_team_authorization: bool = False
|
||||
allow_delete: bool = True
|
||||
tools: list[UserTool] = None
|
||||
labels: list[str] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
# -------------
|
||||
# overwrite tool parameter types for temp fix
|
||||
tools = jsonable_encoder(self.tools)
|
||||
for tool in tools:
|
||||
if tool.get('parameters'):
|
||||
for parameter in tool.get('parameters'):
|
||||
if parameter.get('type') == ToolParameter.ToolParameterType.FILE.value:
|
||||
parameter['type'] = 'files'
|
||||
# -------------
|
||||
|
||||
return {
|
||||
'id': self.id,
|
||||
'author': self.author,
|
||||
@ -46,7 +57,8 @@ class UserToolProvider(BaseModel):
|
||||
'team_credentials': self.masked_credentials,
|
||||
'is_team_authorization': self.is_team_authorization,
|
||||
'allow_delete': self.allow_delete,
|
||||
'tools': self.tools
|
||||
'tools': tools,
|
||||
'labels': self.labels,
|
||||
}
|
||||
|
||||
class UserToolProviderCredentials(BaseModel):
|
@ -1,3 +0,0 @@
|
||||
class DEFAULT_PROVIDERS:
|
||||
API_BASED = '__api_based'
|
||||
APP_BASED = '__app_based'
|
@ -1,11 +1,11 @@
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.tools.entities.tool_entities import ToolParameter, ToolProviderType
|
||||
from core.tools.entities.tool_entities import ToolParameter
|
||||
|
||||
|
||||
class ApiBasedToolBundle(BaseModel):
|
||||
class ApiToolBundle(BaseModel):
|
||||
"""
|
||||
This class is used to store the schema information of an api based tool. such as the url, the method, the parameters, etc.
|
||||
"""
|
||||
@ -25,12 +25,3 @@ class ApiBasedToolBundle(BaseModel):
|
||||
icon: Optional[str] = None
|
||||
# openapi operation
|
||||
openapi: dict
|
||||
|
||||
class AppToolBundle(BaseModel):
|
||||
"""
|
||||
This class is used to store the schema information of an tool for an app.
|
||||
"""
|
||||
type: ToolProviderType
|
||||
credential: Optional[dict[str, Any]] = None
|
||||
provider_id: str
|
||||
tool_name: str
|
@ -10,10 +10,11 @@ class ToolProviderType(Enum):
|
||||
"""
|
||||
Enum class for tool provider
|
||||
"""
|
||||
BUILT_IN = "built-in"
|
||||
BUILT_IN = "builtin"
|
||||
WORKFLOW = "workflow"
|
||||
API = "api"
|
||||
APP = "app"
|
||||
DATASET_RETRIEVAL = "dataset-retrieval"
|
||||
APP_BASED = "app-based"
|
||||
API_BASED = "api-based"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'ToolProviderType':
|
||||
@ -77,6 +78,7 @@ class ToolInvokeMessage(BaseModel):
|
||||
LINK = "link"
|
||||
BLOB = "blob"
|
||||
IMAGE_LINK = "image_link"
|
||||
FILE_VAR = "file_var"
|
||||
|
||||
type: MessageType = MessageType.TEXT
|
||||
"""
|
||||
@ -90,6 +92,7 @@ class ToolInvokeMessageBinary(BaseModel):
|
||||
mimetype: str = Field(..., description="The mimetype of the binary")
|
||||
url: str = Field(..., description="The url of the binary")
|
||||
save_as: str = ''
|
||||
file_var: Optional[dict[str, Any]] = None
|
||||
|
||||
class ToolParameterOption(BaseModel):
|
||||
value: str = Field(..., description="The value of the option")
|
||||
@ -102,6 +105,7 @@ class ToolParameter(BaseModel):
|
||||
BOOLEAN = "boolean"
|
||||
SELECT = "select"
|
||||
SECRET_INPUT = "secret-input"
|
||||
FILE = "file"
|
||||
|
||||
class ToolParameterForm(Enum):
|
||||
SCHEMA = "schema" # should be set while adding tool
|
||||
@ -331,6 +335,15 @@ class ModelToolProviderConfiguration(BaseModel):
|
||||
models: list[ModelToolConfiguration] = Field(..., description="The models of the model tool")
|
||||
label: I18nObject = Field(..., description="The label of the model tool")
|
||||
|
||||
|
||||
class WorkflowToolParameterConfiguration(BaseModel):
|
||||
"""
|
||||
Workflow tool configuration
|
||||
"""
|
||||
name: str = Field(..., description="The name of the parameter")
|
||||
description: str = Field(..., description="The description of the parameter")
|
||||
form: ToolParameter.ToolParameterForm = Field(..., description="The form of the parameter")
|
||||
|
||||
class ToolInvokeMeta(BaseModel):
|
||||
"""
|
||||
Tool invoke meta
|
||||
@ -358,4 +371,19 @@ class ToolInvokeMeta(BaseModel):
|
||||
'time_cost': self.time_cost,
|
||||
'error': self.error,
|
||||
'tool_config': self.tool_config,
|
||||
}
|
||||
}
|
||||
|
||||
class ToolLabel(BaseModel):
|
||||
"""
|
||||
Tool label
|
||||
"""
|
||||
name: str = Field(..., description="The name of the tool")
|
||||
label: I18nObject = Field(..., description="The label of the tool")
|
||||
icon: str = Field(..., description="The icon of the tool")
|
||||
|
||||
class ToolInvokeFrom(Enum):
|
||||
"""
|
||||
Enum class for tool invoke
|
||||
"""
|
||||
WORKFLOW = "workflow"
|
||||
AGENT = "agent"
|
96
api/core/tools/entities/values.py
Normal file
96
api/core/tools/entities/values.py
Normal file
@ -0,0 +1,96 @@
|
||||
from enum import Enum
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolLabel
|
||||
|
||||
|
||||
class ToolLabelEnum(Enum):
|
||||
SEARCH = 'search'
|
||||
IMAGE = 'image'
|
||||
VIDEOS = 'videos'
|
||||
WEATHER = 'weather'
|
||||
FINANCE = 'finance'
|
||||
DESIGN = 'design'
|
||||
TRAVEL = 'travel'
|
||||
SOCIAL = 'social'
|
||||
NEWS = 'news'
|
||||
MEDICAL = 'medical'
|
||||
PRODUCTIVITY = 'productivity'
|
||||
EDUCATION = 'education'
|
||||
BUSINESS = 'business'
|
||||
ENTERTAINMENT = 'entertainment'
|
||||
UTILITIES = 'utilities'
|
||||
OTHER = 'other'
|
||||
|
||||
ICONS = {
|
||||
ToolLabelEnum.SEARCH: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M7.33398 1.3335C10.646 1.3335 13.334 4.0215 13.334 7.3335C13.334 10.6455 10.646 13.3335 7.33398 13.3335C4.02198 13.3335 1.33398 10.6455 1.33398 7.3335C1.33398 4.0215 4.02198 1.3335 7.33398 1.3335ZM7.33398 12.0002C9.91232 12.0002 12.0007 9.91183 12.0007 7.3335C12.0007 4.75516 9.91232 2.66683 7.33398 2.66683C4.75565 2.66683 2.66732 4.75516 2.66732 7.3335C2.66732 9.91183 4.75565 12.0002 7.33398 12.0002ZM12.9909 12.0476L14.8764 13.9332L13.9337 14.876L12.0481 12.9904L12.9909 12.0476Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.IMAGE: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M13.0514 9.71752L10.4718 7.13792C10.2115 6.87752 9.78932 6.87752 9.52898 7.13792L4.57721 12.0897C3.4097 11.1113 2.66732 9.64232 2.66732 7.99992C2.66732 5.0544 5.05513 2.66659 8.00065 2.66659C10.9462 2.66659 13.334 5.0544 13.334 7.99992C13.334 8.60085 13.2346 9.17852 13.0514 9.71752ZM5.72683 12.8257L10.0004 8.55212L12.4259 10.9777C11.4668 12.4001 9.84152 13.3331 8.00038 13.3331C7.18632 13.3331 6.41628 13.1511 5.72683 12.8257ZM8.00065 14.6666C11.6825 14.6666 14.6673 11.6818 14.6673 7.99992C14.6673 4.31802 11.6825 1.33325 8.00065 1.33325C4.31875 1.33325 1.33398 4.31802 1.33398 7.99992C1.33398 11.6818 4.31875 14.6666 8.00065 14.6666ZM7.33398 6.66658C7.33398 7.40299 6.73705 7.99992 6.00065 7.99992C5.26427 7.99992 4.66732 7.40299 4.66732 6.66658C4.66732 5.9302 5.26427 5.33325 6.00065 5.33325C6.73705 5.33325 7.33398 5.9302 7.33398 6.66658Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.VIDEOS: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M8.00065 13.3333H13.334V14.6666H8.00065C4.31875 14.6666 1.33398 11.6818 1.33398 7.99992C1.33398 4.31802 4.31875 1.33325 8.00065 1.33325C11.6825 1.33325 14.6673 4.31802 14.6673 7.99992C14.6673 9.50072 14.1714 10.8857 13.3345 11.9999H11.5284C12.6356 11.0227 13.334 9.59285 13.334 7.99992C13.334 5.0544 10.9462 2.66659 8.00065 2.66659C5.05513 2.66659 2.66732 5.0544 2.66732 7.99992C2.66732 10.9455 5.05513 13.3333 8.00065 13.3333ZM8.00065 6.66658C7.26425 6.66658 6.66732 6.06963 6.66732 5.33325C6.66732 4.59687 7.26425 3.99992 8.00065 3.99992C8.73705 3.99992 9.33398 4.59687 9.33398 5.33325C9.33398 6.06963 8.73705 6.66658 8.00065 6.66658ZM5.33398 9.33325C4.5976 9.33325 4.00065 8.73632 4.00065 7.99992C4.00065 7.26352 4.5976 6.66658 5.33398 6.66658C6.07036 6.66658 6.66732 7.26352 6.66732 7.99992C6.66732 8.73632 6.07036 9.33325 5.33398 9.33325ZM10.6673 9.33325C9.93092 9.33325 9.33398 8.73632 9.33398 7.99992C9.33398 7.26352 9.93092 6.66658 10.6673 6.66658C11.4037 6.66658 12.0007 7.26352 12.0007 7.99992C12.0007 8.73632 11.4037 9.33325 10.6673 9.33325ZM8.00065 11.9999C7.26425 11.9999 6.66732 11.403 6.66732 10.6666C6.66732 9.93018 7.26425 9.33325 8.00065 9.33325C8.73705 9.33325 9.33398 9.93018 9.33398 10.6666C9.33398 11.403 8.73705 11.9999 8.00065 11.9999Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.WEATHER: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M6.6553 3.37344C7.42088 2.1484 8.78162 1.3335 10.3327 1.3335C12.7259 1.3335 14.666 3.2736 14.666 5.66683C14.666 6.38704 14.4903 7.06623 14.1794 7.66383C14.8894 8.3325 15.3327 9.28123 15.3327 10.3335C15.3327 12.3586 13.6911 14.0002 11.666 14.0002H5.99935C3.05383 14.0002 0.666016 11.6124 0.666016 8.66683C0.666016 5.72131 3.05383 3.3335 5.99935 3.3335C6.22143 3.3335 6.44034 3.34707 6.6553 3.37344ZM8.03628 3.73629C9.37768 4.29108 10.4435 5.37735 10.9711 6.73256C11.1961 6.68943 11.4284 6.66683 11.666 6.66683C12.1561 6.66683 12.6237 6.76296 13.0511 6.93743C13.2317 6.55162 13.3327 6.12102 13.3327 5.66683C13.3327 4.00998 11.9895 2.66683 10.3327 2.66683C9.41115 2.66683 8.58662 3.08236 8.03628 3.73629ZM11.666 12.6668C12.9547 12.6668 13.9993 11.6222 13.9993 10.3335C13.9993 9.04483 12.9547 8.00016 11.666 8.00016C11.013 8.00016 10.4227 8.26836 9.99922 8.70063C9.99928 8.68936 9.99935 8.6781 9.99935 8.66683C9.99935 6.45769 8.20848 4.66683 5.99935 4.66683C3.79021 4.66683 1.99935 6.45769 1.99935 8.66683C1.99935 10.876 3.79021 12.6668 5.99935 12.6668H11.666Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.FINANCE: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M8.00262 14.6685C4.32071 14.6685 1.33594 11.6838 1.33594 8.00184C1.33594 4.31997 4.32071 1.33521 8.00262 1.33521C11.6845 1.33521 14.6693 4.31997 14.6693 8.00184C14.6693 11.6838 11.6845 14.6685 8.00262 14.6685ZM8.00262 13.3352C10.9482 13.3352 13.336 10.9474 13.336 8.00184C13.336 5.05635 10.9482 2.66854 8.00262 2.66854C5.05708 2.66854 2.66927 5.05635 2.66927 8.00184C2.66927 10.9474 5.05708 13.3352 8.00262 13.3352ZM5.66927 9.33517H9.33595C9.52002 9.33517 9.66928 9.18597 9.66928 9.00184C9.66928 8.81777 9.52002 8.66851 9.33595 8.66851H6.66928C5.7488 8.66851 5.0026 7.92237 5.0026 7.00184C5.0026 6.08139 5.7488 5.33521 6.66928 5.33521H7.33595V4.00187H8.66928V5.33521H10.336V6.66851H6.66928C6.48518 6.66851 6.33594 6.81777 6.33594 7.00184C6.33594 7.18597 6.48518 7.33517 6.66928 7.33517H9.33595C10.2564 7.33517 11.0026 8.08137 11.0026 9.00184C11.0026 9.92237 10.2564 10.6685 9.33595 10.6685H8.66928V12.0018H7.33595V10.6685H5.66927V9.33517Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.DESIGN: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M4.70152 9.41416L3.2873 10.8284L5.17292 12.714L12.7154 5.17154L10.8298 3.28592L9.41557 4.70013L10.3584 5.64295L9.41557 6.58575L8.47277 5.64295L7.52997 6.58575L8.47277 7.52856L7.52997 8.47136L6.58713 7.52856L5.64433 8.47136L6.58713 9.41416L5.64433 10.357L4.70152 9.41416ZM11.3012 1.87171L14.1296 4.70013C14.39 4.96049 14.39 5.38259 14.1296 5.64295L5.64433 14.1282C5.38397 14.3886 4.96187 14.3886 4.70152 14.1282L1.87309 11.2998C1.61274 11.0394 1.61274 10.6174 1.87309 10.357L10.3584 1.87171C10.6187 1.61136 11.0408 1.61136 11.3012 1.87171ZM9.41557 12.2423L10.3584 11.2995L11.8534 12.7945H12.7962V11.8517L11.3012 10.3567L12.244 9.41383L14.0011 11.171V13.9999H11.1732L9.41557 12.2423ZM3.75861 6.58533L1.87299 4.69971C1.61265 4.43937 1.61265 4.01725 1.87299 3.75691L3.75861 1.87129C4.01896 1.61094 4.44107 1.61094 4.70142 1.87129L6.58704 3.75691L5.64423 4.69971L4.23002 3.2855L3.28721 4.22831L4.70142 5.64253L3.75861 6.58533Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.TRAVEL: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M9.44839 2C9.80198 2 10.1411 2.14047 10.3912 2.39053L13.6101 5.60947C13.8602 5.85953 14.0007 6.19866 14.0007 6.55229V11.3333H15.334V12.6667L9.91652 12.6672C9.62032 13.8171 8.57638 14.6667 7.33398 14.6667C6.0916 14.6667 5.04766 13.8171 4.75146 12.6672L2.00065 12.6667C1.63246 12.6667 1.33398 12.3682 1.33398 12V3.33333C1.33398 2.59695 1.93094 2 2.66732 2H9.44839ZM7.33398 10.6667C6.5976 10.6667 6.00065 11.2636 6.00065 12C6.00065 12.7364 6.5976 13.3333 7.33398 13.3333C8.07038 13.3333 8.66732 12.7364 8.66732 12C8.66732 11.2636 8.07038 10.6667 7.33398 10.6667ZM9.44839 3.33333H2.66732V11.3333L4.75128 11.3335C5.04726 10.1833 6.09136 9.33333 7.33398 9.33333C8.57658 9.33333 9.62072 10.1833 9.91665 11.3335L12.6673 11.3333V6.55229L9.44839 3.33333ZM9.33398 4.66667V8.66667H4.00065V4.66667H9.33398ZM8.00065 6H5.33398V7.33333H8.00065V6Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.SOCIAL: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M13.334 7.99992C13.334 5.0544 10.9462 2.66659 8.00065 2.66659C5.05513 2.66659 2.66732 5.0544 2.66732 7.99992C2.66732 10.9455 5.05513 13.3333 8.00065 13.3333C9.09518 13.3333 10.1127 13.0035 10.9594 12.438L11.699 13.5475C10.6408 14.2545 9.36885 14.6666 8.00065 14.6666C4.31875 14.6666 1.33398 11.6818 1.33398 7.99992C1.33398 4.31802 4.31875 1.33325 8.00065 1.33325C11.6825 1.33325 14.6673 4.31802 14.6673 7.99992V8.99992C14.6673 10.2886 13.6227 11.3333 12.334 11.3333C11.5312 11.3333 10.8231 10.9278 10.4032 10.3105C9.79678 10.9409 8.94452 11.3333 8.00065 11.3333C6.1597 11.3333 4.66732 9.84085 4.66732 7.99992C4.66732 6.15897 6.1597 4.66658 8.00065 4.66658C8.75118 4.66658 9.44378 4.91464 10.001 5.33325H11.334V8.99992C11.334 9.55219 11.7817 9.99992 12.334 9.99992C12.8863 9.99992 13.334 9.55219 13.334 8.99992V7.99992ZM8.00065 5.99992C6.89605 5.99992 6.00065 6.89532 6.00065 7.99992C6.00065 9.10452 6.89605 9.99992 8.00065 9.99992C9.10525 9.99992 10.0007 9.10452 10.0007 7.99992C10.0007 6.89532 9.10525 5.99992 8.00065 5.99992Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.NEWS: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M10.6673 13.3335V2.66683H2.66732V12.6668C2.66732 13.035 2.9658 13.3335 3.33398 13.3335H10.6673ZM12.6673 14.6668H3.33398C2.22942 14.6668 1.33398 13.7714 1.33398 12.6668V2.00016C1.33398 1.63198 1.63246 1.3335 2.00065 1.3335H11.334C11.7022 1.3335 12.0007 1.63198 12.0007 2.00016V6.66683H14.6673V12.6668C14.6673 13.7714 13.7719 14.6668 12.6673 14.6668ZM12.0007 8.00016V12.6668C12.0007 13.035 12.2991 13.3335 12.6673 13.3335C13.0355 13.3335 13.334 13.035 13.334 12.6668V8.00016H12.0007ZM4.00065 4.00016H8.00065V8.00016H4.00065V4.00016ZM5.33398 5.3335V6.66683H6.66732V5.3335H5.33398ZM4.00065 8.66683H9.33398V10.0002H4.00065V8.66683ZM4.00065 10.6668H9.33398V12.0002H4.00065V10.6668Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.MEDICAL: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M8.79747 1.51186L10.9641 5.26464C11.1482 5.5835 11.0389 5.99122 10.7201 6.17532L9.85373 6.67474L10.5207 7.83001L9.366 8.49668L8.699 7.34141L7.83333 7.84201C7.51447 8.02608 7.10673 7.91681 6.92267 7.59794L5.69747 5.47632C4.32922 5.89145 3.33333 7.16268 3.33333 8.66654C3.33333 9.08348 3.40987 9.48248 3.54965 9.85034C4.06613 9.52254 4.67762 9.33321 5.33333 9.33321C6.45605 9.33321 7.44913 9.88828 8.05313 10.7389L13.1787 7.78014L13.8454 8.93488L8.5932 11.9672C8.64133 12.1927 8.66667 12.4267 8.66667 12.6665C8.66667 12.895 8.64367 13.1181 8.59993 13.3337L14 13.3332V14.6665L2.66703 14.6673C2.2482 14.1101 2 13.4173 2 12.6665C2 11.9951 2.19855 11.3699 2.54014 10.8467C2.19517 10.1964 2 9.45428 2 8.66654C2 6.66968 3.25421 4.96575 5.01785 4.29953L4.75598 3.84519C4.38779 3.20747 4.60629 2.39202 5.24402 2.02382L6.97607 1.02382C7.6138 0.655637 8.42927 0.874138 8.79747 1.51186ZM5.33333 10.6665C4.22877 10.6665 3.33333 11.562 3.33333 12.6665C3.33333 12.9003 3.37343 13.1247 3.44711 13.3331H7.21953C7.29327 13.1247 7.33333 12.9003 7.33333 12.6665C7.33333 11.562 6.4379 10.6665 5.33333 10.6665ZM7.64273 2.17852L5.91068 3.17852L7.744 6.35395L9.47607 5.35395L7.64273 2.17852Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.PRODUCTIVITY: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M6.64807 11.9999H9.35062C9.43862 11.1989 9.84742 10.5376 10.5111 9.81499C10.5858 9.73365 11.0652 9.23752 11.1221 9.16665C11.6872 8.46199 11.9993 7.58992 11.9993 6.66659C11.9993 4.45745 10.2085 2.66659 7.99935 2.66659C5.79021 2.66659 3.99935 4.45745 3.99935 6.66659C3.99935 7.58945 4.31118 8.46105 4.87576 9.16552C4.93271 9.23659 5.41322 9.73405 5.48704 9.81445C6.15112 10.5375 6.56004 11.1989 6.64807 11.9999ZM9.33268 13.3333H6.66602V13.9999H9.33268V13.3333ZM3.83532 9.99939C3.10365 9.08639 2.66602 7.92759 2.66602 6.66659C2.66602 3.72107 5.05383 1.33325 7.99935 1.33325C10.9449 1.33325 13.3327 3.72107 13.3327 6.66659C13.3327 7.92825 12.8945 9.08759 12.1622 10.0009C11.7487 10.5165 10.666 11.3333 10.666 12.3333V13.9999C10.666 14.7363 10.0691 15.3333 9.33268 15.3333H6.66602C5.92964 15.3333 5.33268 14.7363 5.33268 13.9999V12.3333C5.33268 11.3333 4.24907 10.5157 3.83532 9.99939ZM8.66602 6.66979H10.3327L7.33268 10.6698V8.00312H5.66602L8.66602 3.99992V6.66979Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.EDUCATION: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M14 2.66683H4.66667C3.93029 2.66683 3.33333 3.26378 3.33333 4.00016C3.33333 4.73654 3.93029 5.3335 4.66667 5.3335H14V14.0002C14 14.3684 13.7015 14.6668 13.3333 14.6668H4.66667C3.19391 14.6668 2 13.4729 2 12.0002V4.00016C2 2.5274 3.19391 1.3335 4.66667 1.3335H13.3333C13.7015 1.3335 14 1.63198 14 2.00016V2.66683ZM3.33333 12.0002C3.33333 12.7366 3.93029 13.3335 4.66667 13.3335H12.6667V6.66683H4.66667C4.18095 6.66683 3.72557 6.53697 3.33333 6.31008V12.0002ZM13.3333 4.66683H4.66667C4.29848 4.66683 4 4.36835 4 4.00016C4 3.63198 4.29848 3.3335 4.66667 3.3335H13.3333V4.66683Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.BUSINESS: '''<svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" viewBox="0 0 14 14" fill="none">
|
||||
<path d="M3.66732 3.33341V1.33341C3.66732 0.965228 3.9658 0.666748 4.33398 0.666748H9.66732C10.0355 0.666748 10.334 0.965228 10.334 1.33341V3.33341H13.0007C13.3689 3.33341 13.6673 3.63189 13.6673 4.00008V13.3334C13.6673 13.7016 13.3689 14.0001 13.0007 14.0001H1.00065C0.632464 14.0001 0.333984 13.7016 0.333984 13.3334V4.00008C0.333984 3.63189 0.632464 3.33341 1.00065 3.33341H3.66732ZM12.334 8.66675H1.66732V12.6667H12.334V8.66675ZM12.334 4.66675H1.66732V7.33341H3.66732V6.00008H5.00065V7.33341H9.00065V6.00008H10.334V7.33341H12.334V4.66675ZM5.00065 2.00008V3.33341H9.00065V2.00008H5.00065Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.ENTERTAINMENT: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M11.3327 2.66675C13.5418 2.66675 15.3327 4.45761 15.3327 6.66675V9.33342C15.3327 11.5425 13.5418 13.3334 11.3327 13.3334H4.66602C2.45688 13.3334 0.666016 11.5425 0.666016 9.33342V6.66675C0.666016 4.45761 2.45688 2.66675 4.66602 2.66675H11.3327ZM11.3327 4.00008H4.66602C3.23788 4.00008 2.07196 5.12273 2.00262 6.53365L1.99935 6.66675V9.33342C1.99935 10.7615 3.122 11.9275 4.53292 11.9968L4.66602 12.0001H11.3327C12.7608 12.0001 13.9267 10.8774 13.9961 9.46648L13.9993 9.33342V6.66675C13.9993 5.23861 12.8767 4.07269 11.4657 4.00335L11.3327 4.00008ZM6.66602 6.00008V7.33342H7.99935V8.66675H6.66535L6.66602 10.0001H5.33268L5.33202 8.66675H3.99935V7.33342H5.33268V6.00008H6.66602ZM11.9993 8.66675V10.0001H10.666V8.66675H11.9993ZM10.666 6.00008V7.33342H9.33268V6.00008H10.666Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.UTILITIES: '''<svg xmlns="http://www.w3.org/2000/svg" width="13" height="15" viewBox="0 0 13 15" fill="none">
|
||||
<path d="M12.3346 0.333252C12.7028 0.333252 13.0013 0.631732 13.0013 0.999919V4.33325C13.0013 4.70144 12.7028 4.99992 12.3346 4.99992H9.0013V13.6666C9.0013 14.0348 8.70284 14.3333 8.33463 14.3333H5.66797C5.29978 14.3333 5.0013 14.0348 5.0013 13.6666V4.99992H1.33464C0.966449 4.99992 0.667969 4.70144 0.667969 4.33325V2.74527C0.667969 2.49276 0.810635 2.26192 1.0365 2.14899L4.66797 0.333252H12.3346ZM9.0013 1.66659H4.98273L2.0013 3.1573V3.66659H6.33464V12.9999H7.66797V3.66659H9.0013V1.66659ZM11.668 1.66659H10.3346V3.66659H11.668V1.66659Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.OTHER: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M8.00052 0.666748L4.00065 7.33342H12.0007L8.00052 0.666748ZM8.00052 3.25828L9.64572 6.00008H6.35553L8.00052 3.25828ZM4.50065 13.3334C3.48813 13.3334 2.66732 12.5126 2.66732 11.5001C2.66732 10.4875 3.48813 9.66675 4.50065 9.66675C5.51317 9.66675 6.33398 10.4875 6.33398 11.5001C6.33398 12.5126 5.51317 13.3334 4.50065 13.3334ZM4.50065 14.6667C6.24955 14.6667 7.66732 13.249 7.66732 11.5001C7.66732 9.75115 6.24955 8.33342 4.50065 8.33342C2.75175 8.33342 1.33398 9.75115 1.33398 11.5001C1.33398 13.249 2.75175 14.6667 4.50065 14.6667ZM10.0007 10.3334V13.0001H12.6673V10.3334H10.0007ZM8.66732 14.3334V9.00008H14.0007V14.3334H8.66732Z" fill="#344054"/>
|
||||
</svg>'''
|
||||
}
|
||||
|
||||
default_tool_label_dict = {
|
||||
ToolLabelEnum.SEARCH: ToolLabel(name='search', label=I18nObject(en_US='Search', zh_Hans='搜索'), icon=ICONS[ToolLabelEnum.SEARCH]),
|
||||
ToolLabelEnum.IMAGE: ToolLabel(name='image', label=I18nObject(en_US='Image', zh_Hans='图片'), icon=ICONS[ToolLabelEnum.IMAGE]),
|
||||
ToolLabelEnum.VIDEOS: ToolLabel(name='videos', label=I18nObject(en_US='Videos', zh_Hans='视频'), icon=ICONS[ToolLabelEnum.VIDEOS]),
|
||||
ToolLabelEnum.WEATHER: ToolLabel(name='weather', label=I18nObject(en_US='Weather', zh_Hans='天气'), icon=ICONS[ToolLabelEnum.WEATHER]),
|
||||
ToolLabelEnum.FINANCE: ToolLabel(name='finance', label=I18nObject(en_US='Finance', zh_Hans='金融'), icon=ICONS[ToolLabelEnum.FINANCE]),
|
||||
ToolLabelEnum.DESIGN: ToolLabel(name='design', label=I18nObject(en_US='Design', zh_Hans='设计'), icon=ICONS[ToolLabelEnum.DESIGN]),
|
||||
ToolLabelEnum.TRAVEL: ToolLabel(name='travel', label=I18nObject(en_US='Travel', zh_Hans='旅行'), icon=ICONS[ToolLabelEnum.TRAVEL]),
|
||||
ToolLabelEnum.SOCIAL: ToolLabel(name='social', label=I18nObject(en_US='Social', zh_Hans='社交'), icon=ICONS[ToolLabelEnum.SOCIAL]),
|
||||
ToolLabelEnum.NEWS: ToolLabel(name='news', label=I18nObject(en_US='News', zh_Hans='新闻'), icon=ICONS[ToolLabelEnum.NEWS]),
|
||||
ToolLabelEnum.MEDICAL: ToolLabel(name='medical', label=I18nObject(en_US='Medical', zh_Hans='医疗'), icon=ICONS[ToolLabelEnum.MEDICAL]),
|
||||
ToolLabelEnum.PRODUCTIVITY: ToolLabel(name='productivity', label=I18nObject(en_US='Productivity', zh_Hans='生产力'), icon=ICONS[ToolLabelEnum.PRODUCTIVITY]),
|
||||
ToolLabelEnum.EDUCATION: ToolLabel(name='education', label=I18nObject(en_US='Education', zh_Hans='教育'), icon=ICONS[ToolLabelEnum.EDUCATION]),
|
||||
ToolLabelEnum.BUSINESS: ToolLabel(name='business', label=I18nObject(en_US='Business', zh_Hans='商业'), icon=ICONS[ToolLabelEnum.BUSINESS]),
|
||||
ToolLabelEnum.ENTERTAINMENT: ToolLabel(name='entertainment', label=I18nObject(en_US='Entertainment', zh_Hans='娱乐'), icon=ICONS[ToolLabelEnum.ENTERTAINMENT]),
|
||||
ToolLabelEnum.UTILITIES: ToolLabel(name='utilities', label=I18nObject(en_US='Utilities', zh_Hans='工具'), icon=ICONS[ToolLabelEnum.UTILITIES]),
|
||||
ToolLabelEnum.OTHER: ToolLabel(name='other', label=I18nObject(en_US='Other', zh_Hans='其他'), icon=ICONS[ToolLabelEnum.OTHER]),
|
||||
}
|
||||
|
||||
default_tool_labels = [v for k, v in default_tool_label_dict.items()]
|
||||
default_tool_label_name_list = [label.name for label in default_tool_labels]
|
@ -1,2 +0,0 @@
|
||||
class InvokeModelError(Exception):
|
||||
pass
|
@ -1,7 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiBasedToolBundle
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ToolCredentialsOption,
|
||||
@ -15,11 +14,11 @@ from extensions.ext_database import db
|
||||
from models.tools import ApiToolProvider
|
||||
|
||||
|
||||
class ApiBasedToolProviderController(ToolProviderController):
|
||||
class ApiToolProviderController(ToolProviderController):
|
||||
provider_id: str
|
||||
|
||||
@staticmethod
|
||||
def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiBasedToolProviderController':
|
||||
def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiToolProviderController':
|
||||
credentials_schema = {
|
||||
'auth_type': ToolProviderCredentials(
|
||||
name='auth_type',
|
||||
@ -79,9 +78,11 @@ class ApiBasedToolProviderController(ToolProviderController):
|
||||
else:
|
||||
raise ValueError(f'invalid auth type {auth_type}')
|
||||
|
||||
return ApiBasedToolProviderController(**{
|
||||
user_name = db_provider.user.name if db_provider.user_id else ''
|
||||
|
||||
return ApiToolProviderController(**{
|
||||
'identity': {
|
||||
'author': db_provider.user.name if db_provider.user_id and db_provider.user else '',
|
||||
'author': user_name,
|
||||
'name': db_provider.name,
|
||||
'label': {
|
||||
'en_US': db_provider.name,
|
||||
@ -98,16 +99,10 @@ class ApiBasedToolProviderController(ToolProviderController):
|
||||
})
|
||||
|
||||
@property
|
||||
def app_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.API_BASED
|
||||
|
||||
def _validate_credentials(self, tool_name: str, credentials: dict[str, Any]) -> None:
|
||||
pass
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.API
|
||||
|
||||
def validate_parameters(self, tool_name: str, tool_parameters: dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
def _parse_tool_bundle(self, tool_bundle: ApiBasedToolBundle) -> ApiTool:
|
||||
def _parse_tool_bundle(self, tool_bundle: ApiToolBundle) -> ApiTool:
|
||||
"""
|
||||
parse tool bundle to tool
|
||||
|
||||
@ -136,7 +131,7 @@ class ApiBasedToolProviderController(ToolProviderController):
|
||||
'parameters' : tool_bundle.parameters if tool_bundle.parameters else [],
|
||||
})
|
||||
|
||||
def load_bundled_tools(self, tools: list[ApiBasedToolBundle]) -> list[ApiTool]:
|
||||
def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[ApiTool]:
|
||||
"""
|
||||
load bundled tools
|
||||
|
||||
|
@ -11,10 +11,10 @@ from models.tools import PublishedAppTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class AppBasedToolProviderEntity(ToolProviderController):
|
||||
class AppToolProviderEntity(ToolProviderController):
|
||||
@property
|
||||
def app_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.APP_BASED
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.APP
|
||||
|
||||
def _validate_credentials(self, tool_name: str, credentials: dict[str, Any]) -> None:
|
||||
pass
|
||||
|
@ -1,6 +1,6 @@
|
||||
import os.path
|
||||
|
||||
from core.tools.entities.user_entities import UserToolProvider
|
||||
from core.tools.entities.api_entities import UserToolProvider
|
||||
from core.utils.position_helper import get_position_map, sort_by_position_map
|
||||
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.aippt.tools.aippt import AIPPTGenerateTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@ -9,3 +10,9 @@ class AIPPTProvider(BuiltinToolProviderController):
|
||||
AIPPTGenerateTool._get_api_token(credentials, user_id='__dify_system__')
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.PRODUCTIVITY,
|
||||
ToolLabelEnum.DESIGN,
|
||||
]
|
@ -1,3 +1,4 @@
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.arxiv.tools.arxiv_search import ArxivSearchTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@ -7,7 +8,7 @@ class ArxivProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
ArxivSearchTool().fork_tool_runtime(
|
||||
meta={
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
@ -17,4 +18,9 @@ class ArxivProvider(BuiltinToolProviderController):
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.SEARCH,
|
||||
]
|
@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.azuredalle.tools.dalle3 import DallE3Tool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@ -9,7 +10,7 @@ class AzureDALLEProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
DallE3Tool().fork_tool_runtime(
|
||||
meta={
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
@ -22,3 +23,8 @@ class AzureDALLEProvider(BuiltinToolProviderController):
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.IMAGE
|
||||
]
|
@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.bing.tools.bing_web_search import BingSearchTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@ -9,7 +10,7 @@ class BingProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
BingSearchTool().fork_tool_runtime(
|
||||
meta={
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).validate_credentials(
|
||||
@ -21,3 +22,8 @@ class BingProvider(BuiltinToolProviderController):
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.SEARCH
|
||||
]
|
@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.brave.tools.brave_search import BraveSearchTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@ -9,7 +10,7 @@ class BraveProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
BraveSearchTool().fork_tool_runtime(
|
||||
meta={
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
@ -19,4 +20,9 @@ class BraveProvider(BuiltinToolProviderController):
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.SEARCH,
|
||||
]
|
@ -2,6 +2,7 @@ import matplotlib.pyplot as plt
|
||||
from fontTools.ttLib import TTFont
|
||||
from matplotlib.font_manager import findSystemFonts
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.chart.tools.line import LinearChartTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@ -44,7 +45,7 @@ class ChartProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
LinearChartTool().fork_tool_runtime(
|
||||
meta={
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
@ -54,4 +55,9 @@ class ChartProvider(BuiltinToolProviderController):
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.DESIGN, ToolLabelEnum.PRODUCTIVITY, ToolLabelEnum.UTILITIES
|
||||
]
|
@ -1,8 +1,14 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class CodeToolProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
pass
|
||||
pass
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.PRODUCTIVITY
|
||||
]
|
@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.dalle.tools.dalle2 import DallE2Tool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@ -9,7 +10,7 @@ class DALLEProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
DallE2Tool().fork_tool_runtime(
|
||||
meta={
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
@ -21,4 +22,9 @@ class DALLEProvider(BuiltinToolProviderController):
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.IMAGE, ToolLabelEnum.PRODUCTIVITY
|
||||
]
|
@ -1,3 +1,4 @@
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.devdocs.tools.searchDevDocs import SearchDevDocsTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@ -7,7 +8,7 @@ class DevDocsProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
SearchDevDocsTool().fork_tool_runtime(
|
||||
meta={
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
@ -18,4 +19,9 @@ class DevDocsProvider(BuiltinToolProviderController):
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.SEARCH, ToolLabelEnum.PRODUCTIVITY
|
||||
]
|
@ -1,3 +1,4 @@
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.provider.builtin.dingtalk.tools.dingtalk_group_bot import DingTalkGroupBotTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
@ -6,3 +7,8 @@ class DingTalkProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
DingTalkGroupBotTool()
|
||||
pass
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.SOCIAL
|
||||
]
|
@ -1,3 +1,4 @@
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.duckduckgo.tools.duckduckgo_search import DuckDuckGoSearchTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@ -7,7 +8,7 @@ class DuckDuckGoProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
DuckDuckGoSearchTool().fork_tool_runtime(
|
||||
meta={
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
@ -17,4 +18,9 @@ class DuckDuckGoProvider(BuiltinToolProviderController):
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.SEARCH
|
||||
]
|
@ -1,3 +1,4 @@
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.provider.builtin.feishu.tools.feishu_group_bot import FeishuGroupBotTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
@ -6,3 +7,8 @@ class FeishuProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
FeishuGroupBotTool()
|
||||
pass
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.SOCIAL
|
||||
]
|
@ -1,3 +1,4 @@
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.firecrawl.tools.crawl import CrawlTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@ -8,7 +9,7 @@ class FirecrawlProvider(BuiltinToolProviderController):
|
||||
try:
|
||||
# Example validation using the Crawl tool
|
||||
CrawlTool().fork_tool_runtime(
|
||||
meta={"credentials": credentials}
|
||||
runtime={"credentials": credentials}
|
||||
).invoke(
|
||||
user_id='',
|
||||
tool_parameters={
|
||||
@ -20,4 +21,9 @@ class FirecrawlProvider(BuiltinToolProviderController):
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.SEARCH, ToolLabelEnum.UTILITIES
|
||||
]
|
@ -2,6 +2,7 @@ import urllib.parse
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
@ -24,3 +25,9 @@ class GaodeProvider(BuiltinToolProviderController):
|
||||
raise ToolProviderCredentialValidationError("Gaode API Key is invalid. {}".format(e))
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.UTILITIES, ToolLabelEnum.PRODUCTIVITY,
|
||||
ToolLabelEnum.WEATHER, ToolLabelEnum.TRAVEL
|
||||
]
|
@ -1,5 +1,6 @@
|
||||
import requests
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
@ -30,3 +31,8 @@ class GihubProvider(BuiltinToolProviderController):
|
||||
raise ToolProviderCredentialValidationError("Github API Key and Api Version is invalid. {}".format(e))
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.UTILITIES
|
||||
]
|
@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.google.tools.google_search import GoogleSearchTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@ -9,7 +10,7 @@ class GoogleProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
GoogleSearchTool().fork_tool_runtime(
|
||||
meta={
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
@ -20,4 +21,9 @@ class GoogleProvider(BuiltinToolProviderController):
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.SEARCH
|
||||
]
|
@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
@ -9,4 +10,9 @@ class GoogleProvider(BuiltinToolProviderController):
|
||||
try:
|
||||
pass
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.SEARCH, ToolLabelEnum.PRODUCTIVITY
|
||||
]
|
@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.judge0ce.tools.executeCode import ExecuteCodeTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@ -9,7 +10,7 @@ class Judge0CEProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
ExecuteCodeTool().fork_tool_runtime(
|
||||
meta={
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
@ -20,4 +21,9 @@ class Judge0CEProvider(BuiltinToolProviderController):
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.OTHER, ToolLabelEnum.UTILITIES
|
||||
]
|
@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.maths.tools.eval_expression import EvaluateExpressionTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@ -16,3 +17,8 @@ class MathsProvider(BuiltinToolProviderController):
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.UTILITIES, ToolLabelEnum.PRODUCTIVITY
|
||||
]
|
@ -1,5 +1,6 @@
|
||||
import requests
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
@ -34,3 +35,8 @@ class OpenweatherProvider(BuiltinToolProviderController):
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.WEATHER
|
||||
]
|
@ -1,3 +1,4 @@
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.pubmed.tools.pubmed_search import PubMedSearchTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@ -7,7 +8,7 @@ class PubMedProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
PubMedSearchTool().fork_tool_runtime(
|
||||
meta={
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
@ -17,4 +18,9 @@ class PubMedProvider(BuiltinToolProviderController):
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.MEDICAL, ToolLabelEnum.SEARCH
|
||||
]
|
@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.qrcode.tools.qrcode_generator import QRCodeGeneratorTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@ -14,3 +15,8 @@ class QRCodeProvider(BuiltinToolProviderController):
|
||||
})
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.UTILITIES
|
||||
]
|
@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.searxng.tools.searxng_search import SearXNGSearchTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@ -9,7 +10,7 @@ class SearXNGProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
SearXNGSearchTool().fork_tool_runtime(
|
||||
meta={
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
@ -22,4 +23,9 @@ class SearXNGProvider(BuiltinToolProviderController):
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.SEARCH, ToolLabelEnum.PRODUCTIVITY
|
||||
]
|
@ -1,3 +1,4 @@
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.provider.builtin.slack.tools.slack_webhook import SlackWebhookTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
@ -6,3 +7,8 @@ class SlackProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
SlackWebhookTool()
|
||||
pass
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.SOCIAL
|
||||
]
|
@ -1,5 +1,6 @@
|
||||
import json
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.spark.tools.spark_img_generation import spark_response
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@ -38,3 +39,8 @@ class SparkProvider(BuiltinToolProviderController):
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.IMAGE
|
||||
]
|
@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.provider.builtin.stability.tools.base import BaseStabilityAuthorization
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
@ -12,4 +13,9 @@ class StabilityToolProvider(BuiltinToolProviderController, BaseStabilityAuthoriz
|
||||
"""
|
||||
This method is responsible for validating the credentials.
|
||||
"""
|
||||
self.sd_validate_credentials(credentials)
|
||||
self.sd_validate_credentials(credentials)
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.IMAGE
|
||||
]
|
@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.stablediffusion.tools.stable_diffusion import StableDiffusionTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@ -9,9 +10,14 @@ class StableDiffusionProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
StableDiffusionTool().fork_tool_runtime(
|
||||
meta={
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).validate_models()
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.IMAGE
|
||||
]
|
@ -1,3 +1,4 @@
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.stackexchange.tools.searchStackExQuestions import SearchStackExQuestionsTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@ -7,7 +8,7 @@ class StackExchangeProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
SearchStackExQuestionsTool().fork_tool_runtime(
|
||||
meta={
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
@ -22,4 +23,9 @@ class StackExchangeProvider(BuiltinToolProviderController):
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.SEARCH, ToolLabelEnum.UTILITIES
|
||||
]
|
@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.tavily.tools.tavily_search import TavilySearchTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@ -9,7 +10,7 @@ class TavilyProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
TavilySearchTool().fork_tool_runtime(
|
||||
meta={
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
@ -26,4 +27,9 @@ class TavilyProvider(BuiltinToolProviderController):
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.SEARCH
|
||||
]
|
@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.time.tools.current_time import CurrentTimeTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@ -13,4 +14,9 @@ class WikiPediaProvider(BuiltinToolProviderController):
|
||||
tool_parameters={},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.UTILITIES
|
||||
]
|
@ -2,6 +2,7 @@ from typing import Any
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
@ -31,4 +32,9 @@ class TrelloProvider(BuiltinToolProviderController):
|
||||
raise ToolProviderCredentialValidationError("Error validating Trello credentials")
|
||||
except requests.exceptions.RequestException as e:
|
||||
# Handle other exceptions, such as connection errors
|
||||
raise ToolProviderCredentialValidationError("Error validating Trello credentials")
|
||||
raise ToolProviderCredentialValidationError("Error validating Trello credentials")
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.PRODUCTIVITY
|
||||
]
|
@ -3,6 +3,7 @@ from typing import Any
|
||||
from twilio.base.exceptions import TwilioRestException
|
||||
from twilio.rest import Client
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
@ -26,4 +27,9 @@ class TwilioProvider(BuiltinToolProviderController):
|
||||
except KeyError as e:
|
||||
raise ToolProviderCredentialValidationError(f"Missing required credential: {e}") from e
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.SOCIAL
|
||||
]
|
@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.vectorizer.tools.vectorizer import VectorizerTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@ -9,7 +10,7 @@ class VectorizerProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
VectorizerTool().fork_tool_runtime(
|
||||
meta={
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
@ -20,4 +21,9 @@ class VectorizerProvider(BuiltinToolProviderController):
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.PRODUCTIVITY, ToolLabelEnum.IMAGE
|
||||
]
|
@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.webscraper.tools.webscraper import WebscraperTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@ -9,7 +10,7 @@ class WebscraperProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
WebscraperTool().fork_tool_runtime(
|
||||
meta={
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
@ -20,4 +21,9 @@ class WebscraperProvider(BuiltinToolProviderController):
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.PRODUCTIVITY
|
||||
]
|
@ -1,3 +1,4 @@
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.provider.builtin.wecom.tools.wecom_group_bot import WecomGroupBotTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
@ -6,3 +7,8 @@ class WecomProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
WecomGroupBotTool()
|
||||
pass
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.SOCIAL
|
||||
]
|
@ -1,3 +1,4 @@
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.wikipedia.tools.wikipedia_search import WikiPediaSearchTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@ -7,7 +8,7 @@ class WikiPediaProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
WikiPediaSearchTool().fork_tool_runtime(
|
||||
meta={
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
@ -17,4 +18,9 @@ class WikiPediaProvider(BuiltinToolProviderController):
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.SEARCH
|
||||
]
|
@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.wolframalpha.tools.wolframalpha import WolframAlphaTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@ -9,7 +10,7 @@ class GoogleProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
WolframAlphaTool().fork_tool_runtime(
|
||||
meta={
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
@ -19,4 +20,9 @@ class GoogleProvider(BuiltinToolProviderController):
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.PRODUCTIVITY, ToolLabelEnum.UTILITIES
|
||||
]
|
@ -1,3 +1,4 @@
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.yahoo.tools.ticker import YahooFinanceSearchTickerTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@ -7,7 +8,7 @@ class YahooFinanceProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
YahooFinanceSearchTickerTool().fork_tool_runtime(
|
||||
meta={
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
@ -17,4 +18,9 @@ class YahooFinanceProvider(BuiltinToolProviderController):
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.BUSINESS, ToolLabelEnum.FINANCE
|
||||
]
|
@ -1,3 +1,4 @@
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.youtube.tools.videos import YoutubeVideosAnalyticsTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@ -7,7 +8,7 @@ class YahooFinanceProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
YoutubeVideosAnalyticsTool().fork_tool_runtime(
|
||||
meta={
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
@ -19,4 +20,9 @@ class YahooFinanceProvider(BuiltinToolProviderController):
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.VIDEOS
|
||||
]
|
@ -2,8 +2,9 @@ from abc import abstractmethod
|
||||
from os import listdir, path
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.api_entities import UserToolProviderCredentials
|
||||
from core.tools.entities.tool_entities import ToolParameter, ToolProviderCredentials, ToolProviderType
|
||||
from core.tools.entities.user_entities import UserToolProviderCredentials
|
||||
from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict
|
||||
from core.tools.errors import (
|
||||
ToolNotFoundError,
|
||||
ToolParameterValidationError,
|
||||
@ -19,7 +20,7 @@ from core.utils.module_import_helper import load_single_subclass_from_source
|
||||
|
||||
class BuiltinToolProviderController(ToolProviderController):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
if self.app_type == ToolProviderType.API_BASED or self.app_type == ToolProviderType.APP_BASED:
|
||||
if self.provider_type == ToolProviderType.API or self.provider_type == ToolProviderType.APP:
|
||||
super().__init__(**data)
|
||||
return
|
||||
|
||||
@ -129,13 +130,29 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
len(self.credentials_schema) != 0
|
||||
|
||||
@property
|
||||
def app_type(self) -> ToolProviderType:
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
|
||||
:return: type of the provider
|
||||
"""
|
||||
return ToolProviderType.BUILT_IN
|
||||
|
||||
@property
|
||||
def tool_labels(self) -> list[str]:
|
||||
"""
|
||||
returns the labels of the provider
|
||||
|
||||
:return: labels of the provider
|
||||
"""
|
||||
label_enums = self._get_tool_labels()
|
||||
return [default_tool_label_dict[label].name for label in label_enums]
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
"""
|
||||
returns the labels of the provider
|
||||
"""
|
||||
return []
|
||||
|
||||
def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None:
|
||||
"""
|
||||
|
@ -3,13 +3,13 @@ from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.tools.entities.api_entities import UserToolProviderCredentials
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolParameter,
|
||||
ToolProviderCredentials,
|
||||
ToolProviderIdentity,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.entities.user_entities import UserToolProviderCredentials
|
||||
from core.tools.errors import ToolNotFoundError, ToolParameterValidationError, ToolProviderCredentialValidationError
|
||||
from core.tools.tool.tool import Tool
|
||||
|
||||
@ -67,7 +67,7 @@ class ToolProviderController(BaseModel, ABC):
|
||||
return tool.parameters
|
||||
|
||||
@property
|
||||
def app_type(self) -> ToolProviderType:
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
|
||||
@ -197,26 +197,4 @@ class ToolProviderController(BaseModel, ABC):
|
||||
default_value = str(default_value)
|
||||
|
||||
credentials[credential_name] = default_value
|
||||
|
||||
def validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
|
||||
:param tool_name: the name of the tool, defined in `get_tools`
|
||||
:param credentials: the credentials of the tool
|
||||
"""
|
||||
# validate credentials format
|
||||
self.validate_credentials_format(credentials)
|
||||
|
||||
# validate credentials
|
||||
self._validate_credentials(credentials)
|
||||
|
||||
@abstractmethod
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
|
||||
:param tool_name: the name of the tool, defined in `get_tools`
|
||||
:param credentials: the credentials of the tool
|
||||
"""
|
||||
pass
|
||||
|
230
api/core/tools/provider/workflow_tool_provider.py
Normal file
230
api/core/tools/provider/workflow_tool_provider.py
Normal file
@ -0,0 +1,230 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolDescription,
|
||||
ToolIdentity,
|
||||
ToolParameter,
|
||||
ToolParameterOption,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.tools.tool.workflow_tool import WorkflowTool
|
||||
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, AppMode
|
||||
from models.tools import WorkflowToolProvider
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
class WorkflowToolProviderController(ToolProviderController):
|
||||
provider_id: str
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_provider: WorkflowToolProvider) -> 'WorkflowToolProviderController':
|
||||
app = db_provider.app
|
||||
|
||||
if not app:
|
||||
raise ValueError('app not found')
|
||||
|
||||
controller = WorkflowToolProviderController(**{
|
||||
'identity': {
|
||||
'author': db_provider.user.name if db_provider.user_id and db_provider.user else '',
|
||||
'name': db_provider.label,
|
||||
'label': {
|
||||
'en_US': db_provider.label,
|
||||
'zh_Hans': db_provider.label
|
||||
},
|
||||
'description': {
|
||||
'en_US': db_provider.description,
|
||||
'zh_Hans': db_provider.description
|
||||
},
|
||||
'icon': db_provider.icon,
|
||||
},
|
||||
'credentials_schema': {},
|
||||
'provider_id': db_provider.id or '',
|
||||
})
|
||||
|
||||
# init tools
|
||||
|
||||
controller.tools = [controller._get_db_provider_tool(db_provider, app)]
|
||||
|
||||
return controller
|
||||
|
||||
@property
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.WORKFLOW
|
||||
|
||||
def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool:
|
||||
"""
|
||||
get db provider tool
|
||||
:param db_provider: the db provider
|
||||
:param app: the app
|
||||
:return: the tool
|
||||
"""
|
||||
workflow: Workflow = db.session.query(Workflow).filter(
|
||||
Workflow.app_id == db_provider.app_id,
|
||||
Workflow.version == db_provider.version
|
||||
).first()
|
||||
if not workflow:
|
||||
raise ValueError('workflow not found')
|
||||
|
||||
# fetch start node
|
||||
graph: dict = workflow.graph_dict
|
||||
features_dict: dict = workflow.features_dict
|
||||
features = WorkflowAppConfigManager.convert_features(
|
||||
config_dict=features_dict,
|
||||
app_mode=AppMode.WORKFLOW
|
||||
)
|
||||
|
||||
parameters = db_provider.parameter_configurations
|
||||
variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
|
||||
|
||||
def fetch_workflow_variable(variable_name: str) -> VariableEntity:
|
||||
return next(filter(lambda x: x.variable == variable_name, variables), None)
|
||||
|
||||
user = db_provider.user
|
||||
|
||||
workflow_tool_parameters = []
|
||||
for parameter in parameters:
|
||||
variable = fetch_workflow_variable(parameter.name)
|
||||
if variable:
|
||||
parameter_type = None
|
||||
options = None
|
||||
if variable.type in [
|
||||
VariableEntity.Type.TEXT_INPUT,
|
||||
VariableEntity.Type.PARAGRAPH,
|
||||
]:
|
||||
parameter_type = ToolParameter.ToolParameterType.STRING
|
||||
elif variable.type in [
|
||||
VariableEntity.Type.SELECT
|
||||
]:
|
||||
parameter_type = ToolParameter.ToolParameterType.SELECT
|
||||
elif variable.type in [
|
||||
VariableEntity.Type.NUMBER
|
||||
]:
|
||||
parameter_type = ToolParameter.ToolParameterType.NUMBER
|
||||
else:
|
||||
raise ValueError(f'unsupported variable type {variable.type}')
|
||||
|
||||
if variable.type == VariableEntity.Type.SELECT and variable.options:
|
||||
options = [
|
||||
ToolParameterOption(
|
||||
value=option,
|
||||
label=I18nObject(
|
||||
en_US=option,
|
||||
zh_Hans=option
|
||||
)
|
||||
) for option in variable.options
|
||||
]
|
||||
|
||||
workflow_tool_parameters.append(
|
||||
ToolParameter(
|
||||
name=parameter.name,
|
||||
label=I18nObject(
|
||||
en_US=variable.label,
|
||||
zh_Hans=variable.label
|
||||
),
|
||||
human_description=I18nObject(
|
||||
en_US=parameter.description,
|
||||
zh_Hans=parameter.description
|
||||
),
|
||||
type=parameter_type,
|
||||
form=parameter.form,
|
||||
llm_description=parameter.description,
|
||||
required=variable.required,
|
||||
options=options,
|
||||
default=variable.default
|
||||
)
|
||||
)
|
||||
elif features.file_upload:
|
||||
workflow_tool_parameters.append(
|
||||
ToolParameter(
|
||||
name=parameter.name,
|
||||
label=I18nObject(
|
||||
en_US=parameter.name,
|
||||
zh_Hans=parameter.name
|
||||
),
|
||||
human_description=I18nObject(
|
||||
en_US=parameter.description,
|
||||
zh_Hans=parameter.description
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.FILE,
|
||||
llm_description=parameter.description,
|
||||
required=False,
|
||||
form=parameter.form,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError('variable not found')
|
||||
|
||||
return WorkflowTool(
|
||||
identity=ToolIdentity(
|
||||
author=user.name if user else '',
|
||||
name=db_provider.name,
|
||||
label=I18nObject(
|
||||
en_US=db_provider.label,
|
||||
zh_Hans=db_provider.label
|
||||
),
|
||||
provider=self.provider_id,
|
||||
icon=db_provider.icon,
|
||||
),
|
||||
description=ToolDescription(
|
||||
human=I18nObject(
|
||||
en_US=db_provider.description,
|
||||
zh_Hans=db_provider.description
|
||||
),
|
||||
llm=db_provider.description,
|
||||
),
|
||||
parameters=workflow_tool_parameters,
|
||||
is_team_authorization=True,
|
||||
workflow_app_id=app.id,
|
||||
workflow_entities={
|
||||
'app': app,
|
||||
'workflow': workflow,
|
||||
},
|
||||
version=db_provider.version,
|
||||
workflow_call_depth=0,
|
||||
label=db_provider.label
|
||||
)
|
||||
|
||||
def get_tools(self, user_id: str, tenant_id: str) -> list[WorkflowTool]:
|
||||
"""
|
||||
fetch tools from database
|
||||
|
||||
:param user_id: the user id
|
||||
:param tenant_id: the tenant id
|
||||
:return: the tools
|
||||
"""
|
||||
if self.tools is not None:
|
||||
return self.tools
|
||||
|
||||
db_providers: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.app_id == self.provider_id,
|
||||
).first()
|
||||
|
||||
if not db_providers:
|
||||
return []
|
||||
|
||||
self.tools = [self._get_db_provider_tool(db_providers, db_providers.app)]
|
||||
|
||||
return self.tools
|
||||
|
||||
def get_tool(self, tool_name: str) -> Optional[WorkflowTool]:
|
||||
"""
|
||||
get tool by name
|
||||
|
||||
:param tool_name: the name of the tool
|
||||
:return: the tool
|
||||
"""
|
||||
if self.tools is None:
|
||||
return None
|
||||
|
||||
for tool in self.tools:
|
||||
if tool.identity.name == tool_name:
|
||||
return tool
|
||||
|
||||
return None
|
@ -8,9 +8,8 @@ import httpx
|
||||
import requests
|
||||
|
||||
import core.helper.ssrf_proxy as ssrf_proxy
|
||||
from core.tools.entities.tool_bundle import ApiBasedToolBundle
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
|
||||
from core.tools.entities.user_entities import UserToolProvider
|
||||
from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError
|
||||
from core.tools.tool.tool import Tool
|
||||
|
||||
@ -20,12 +19,12 @@ API_TOOL_DEFAULT_TIMEOUT = (
|
||||
)
|
||||
|
||||
class ApiTool(Tool):
|
||||
api_bundle: ApiBasedToolBundle
|
||||
api_bundle: ApiToolBundle
|
||||
|
||||
"""
|
||||
Api tool
|
||||
"""
|
||||
def fork_tool_runtime(self, meta: dict[str, Any]) -> 'Tool':
|
||||
def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'Tool':
|
||||
"""
|
||||
fork a new tool with meta data
|
||||
|
||||
@ -37,7 +36,7 @@ class ApiTool(Tool):
|
||||
parameters=self.parameters.copy() if self.parameters else None,
|
||||
description=self.description.copy() if self.description else None,
|
||||
api_bundle=self.api_bundle.copy() if self.api_bundle else None,
|
||||
runtime=Tool.Runtime(**meta)
|
||||
runtime=Tool.Runtime(**runtime)
|
||||
)
|
||||
|
||||
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False) -> str:
|
||||
@ -55,7 +54,7 @@ class ApiTool(Tool):
|
||||
return self.validate_and_parse_response(response)
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return UserToolProvider.ProviderType.API
|
||||
return ToolProviderType.API
|
||||
|
||||
def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
headers = {}
|
||||
|
@ -2,9 +2,8 @@
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.entities.user_entities import UserToolProvider
|
||||
from core.tools.model.tool_model_manager import ToolModelManager
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.utils.model_invocation_utils import ModelInvocationUtils
|
||||
from core.tools.utils.web_reader_tool import get_url
|
||||
|
||||
_SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language
|
||||
@ -34,7 +33,7 @@ class BuiltinTool(Tool):
|
||||
:return: the model result
|
||||
"""
|
||||
# invoke model
|
||||
return ToolModelManager.invoke(
|
||||
return ModelInvocationUtils.invoke(
|
||||
user_id=user_id,
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
tool_type='builtin',
|
||||
@ -43,7 +42,7 @@ class BuiltinTool(Tool):
|
||||
)
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return UserToolProvider.ProviderType.BUILTIN
|
||||
return ToolProviderType.BUILT_IN
|
||||
|
||||
def get_max_tokens(self) -> int:
|
||||
"""
|
||||
@ -52,7 +51,7 @@ class BuiltinTool(Tool):
|
||||
:param model_config: the model config
|
||||
:return: the max tokens
|
||||
"""
|
||||
return ToolModelManager.get_max_llm_context_tokens(
|
||||
return ModelInvocationUtils.get_max_llm_context_tokens(
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
)
|
||||
|
||||
@ -63,7 +62,7 @@ class BuiltinTool(Tool):
|
||||
:param prompt_messages: the prompt messages
|
||||
:return: the tokens
|
||||
"""
|
||||
return ToolModelManager.calculate_tokens(
|
||||
return ModelInvocationUtils.calculate_tokens(
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
|
@ -4,9 +4,12 @@ from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.file_obj import FileVar
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolDescription,
|
||||
ToolIdentity,
|
||||
ToolInvokeFrom,
|
||||
ToolInvokeMessage,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
@ -25,10 +28,7 @@ class Tool(BaseModel, ABC):
|
||||
|
||||
@validator('parameters', pre=True, always=True)
|
||||
def set_parameters(cls, v, values):
|
||||
if not v:
|
||||
return []
|
||||
|
||||
return v
|
||||
return v or []
|
||||
|
||||
class Runtime(BaseModel):
|
||||
"""
|
||||
@ -41,6 +41,8 @@ class Tool(BaseModel, ABC):
|
||||
|
||||
tenant_id: str = None
|
||||
tool_id: str = None
|
||||
invoke_from: InvokeFrom = None
|
||||
tool_invoke_from: ToolInvokeFrom = None
|
||||
credentials: dict[str, Any] = None
|
||||
runtime_parameters: dict[str, Any] = None
|
||||
|
||||
@ -53,7 +55,7 @@ class Tool(BaseModel, ABC):
|
||||
class VARIABLE_KEY(Enum):
|
||||
IMAGE = 'image'
|
||||
|
||||
def fork_tool_runtime(self, meta: dict[str, Any]) -> 'Tool':
|
||||
def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'Tool':
|
||||
"""
|
||||
fork a new tool with meta data
|
||||
|
||||
@ -64,7 +66,7 @@ class Tool(BaseModel, ABC):
|
||||
identity=self.identity.copy() if self.identity else None,
|
||||
parameters=self.parameters.copy() if self.parameters else None,
|
||||
description=self.description.copy() if self.description else None,
|
||||
runtime=Tool.Runtime(**meta),
|
||||
runtime=Tool.Runtime(**runtime),
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
@ -208,17 +210,17 @@ class Tool(BaseModel, ABC):
|
||||
if response.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
result += response.message
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
result += f"result link: {response.message}. please tell user to check it."
|
||||
result += f"result link: {response.message}. please tell user to check it. \n"
|
||||
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
|
||||
response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now."
|
||||
result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now. \n"
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
if len(response.message) > 114:
|
||||
result += str(response.message[:114]) + '...'
|
||||
else:
|
||||
result += str(response.message)
|
||||
else:
|
||||
result += f"tool response: {response.message}."
|
||||
result += f"tool response: {response.message}. \n"
|
||||
|
||||
return result
|
||||
|
||||
@ -343,6 +345,14 @@ class Tool(BaseModel, ABC):
|
||||
message=image,
|
||||
save_as=save_as)
|
||||
|
||||
def create_file_var_message(self, file_var: FileVar) -> ToolInvokeMessage:
|
||||
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE_VAR,
|
||||
message='',
|
||||
meta={
|
||||
'file_var': file_var
|
||||
},
|
||||
save_as='')
|
||||
|
||||
def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage:
|
||||
"""
|
||||
create a link message
|
||||
|
200
api/core/tools/tool/workflow_tool.py
Normal file
200
api/core/tools/tool/workflow_tool.py
Normal file
@ -0,0 +1,200 @@
|
||||
import json
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from typing import Any, Union
|
||||
|
||||
from core.file.file_obj import FileTransferMethod, FileVar
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType
|
||||
from core.tools.tool.tool import Tool
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, EndUser
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class WorkflowTool(Tool):
|
||||
workflow_app_id: str
|
||||
version: str
|
||||
workflow_entities: dict[str, Any]
|
||||
workflow_call_depth: int
|
||||
|
||||
label: str
|
||||
|
||||
"""
|
||||
Workflow tool.
|
||||
"""
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
"""
|
||||
get the tool provider type
|
||||
|
||||
:return: the tool provider type
|
||||
"""
|
||||
return ToolProviderType.WORKFLOW
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \
|
||||
-> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke the tool
|
||||
"""
|
||||
app = self._get_app(app_id=self.workflow_app_id)
|
||||
workflow = self._get_workflow(app_id=self.workflow_app_id, version=self.version)
|
||||
|
||||
# transform the tool parameters
|
||||
tool_parameters, files = self._transform_args(tool_parameters)
|
||||
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
generator = WorkflowAppGenerator()
|
||||
result = generator.generate(
|
||||
app_model=app,
|
||||
workflow=workflow,
|
||||
user=self._get_user(user_id),
|
||||
args={
|
||||
'inputs': tool_parameters,
|
||||
'files': files
|
||||
},
|
||||
invoke_from=self.runtime.invoke_from,
|
||||
stream=False,
|
||||
call_depth=self.workflow_call_depth + 1,
|
||||
)
|
||||
|
||||
data = result.get('data', {})
|
||||
|
||||
if data.get('error'):
|
||||
raise Exception(data.get('error'))
|
||||
|
||||
result = []
|
||||
|
||||
outputs = data.get('outputs', {})
|
||||
outputs, files = self._extract_files(outputs)
|
||||
for file in files:
|
||||
result.append(self.create_file_var_message(file))
|
||||
|
||||
result.append(self.create_text_message(json.dumps(outputs)))
|
||||
|
||||
return result
|
||||
|
||||
def _get_user(self, user_id: str) -> Union[EndUser, Account]:
|
||||
"""
|
||||
get the user by user id
|
||||
"""
|
||||
|
||||
user = db.session.query(EndUser).filter(EndUser.id == user_id).first()
|
||||
if not user:
|
||||
user = db.session.query(Account).filter(Account.id == user_id).first()
|
||||
|
||||
if not user:
|
||||
raise ValueError('user not found')
|
||||
|
||||
return user
|
||||
|
||||
def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'WorkflowTool':
|
||||
"""
|
||||
fork a new tool with meta data
|
||||
|
||||
:param meta: the meta data of a tool call processing, tenant_id is required
|
||||
:return: the new tool
|
||||
"""
|
||||
return self.__class__(
|
||||
identity=deepcopy(self.identity),
|
||||
parameters=deepcopy(self.parameters),
|
||||
description=deepcopy(self.description),
|
||||
runtime=Tool.Runtime(**runtime),
|
||||
workflow_app_id=self.workflow_app_id,
|
||||
workflow_entities=self.workflow_entities,
|
||||
workflow_call_depth=self.workflow_call_depth,
|
||||
version=self.version,
|
||||
label=self.label
|
||||
)
|
||||
|
||||
def _get_workflow(self, app_id: str, version: str) -> Workflow:
|
||||
"""
|
||||
get the workflow by app id and version
|
||||
"""
|
||||
if not version:
|
||||
workflow = db.session.query(Workflow).filter(
|
||||
Workflow.app_id == app_id,
|
||||
Workflow.version != 'draft'
|
||||
).order_by(Workflow.created_at.desc()).first()
|
||||
else:
|
||||
workflow = db.session.query(Workflow).filter(
|
||||
Workflow.app_id == app_id,
|
||||
Workflow.version == version
|
||||
).first()
|
||||
|
||||
if not workflow:
|
||||
raise ValueError('workflow not found or not published')
|
||||
|
||||
return workflow
|
||||
|
||||
def _get_app(self, app_id: str) -> App:
|
||||
"""
|
||||
get the app by app id
|
||||
"""
|
||||
app = db.session.query(App).filter(App.id == app_id).first()
|
||||
if not app:
|
||||
raise ValueError('app not found')
|
||||
|
||||
return app
|
||||
|
||||
def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]:
|
||||
"""
|
||||
transform the tool parameters
|
||||
|
||||
:param tool_parameters: the tool parameters
|
||||
:return: tool_parameters, files
|
||||
"""
|
||||
parameter_rules = self.get_all_runtime_parameters()
|
||||
parameters_result = {}
|
||||
files = []
|
||||
for parameter in parameter_rules:
|
||||
if parameter.type == ToolParameter.ToolParameterType.FILE:
|
||||
file = tool_parameters.get(parameter.name)
|
||||
if file:
|
||||
try:
|
||||
file_var_list = [FileVar(**f) for f in file]
|
||||
for file_var in file_var_list:
|
||||
file_dict = {
|
||||
'transfer_method': file_var.transfer_method.value,
|
||||
'type': file_var.type.value,
|
||||
}
|
||||
if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
file_dict['tool_file_id'] = file_var.related_id
|
||||
elif file_var.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
file_dict['upload_file_id'] = file_var.related_id
|
||||
elif file_var.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
file_dict['url'] = file_var.preview_url
|
||||
|
||||
files.append(file_dict)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
else:
|
||||
parameters_result[parameter.name] = tool_parameters.get(parameter.name)
|
||||
|
||||
return parameters_result, files
|
||||
|
||||
def _extract_files(self, outputs: dict) -> tuple[dict, list[FileVar]]:
|
||||
"""
|
||||
extract files from the result
|
||||
|
||||
:param result: the result
|
||||
:return: the result, files
|
||||
"""
|
||||
files = []
|
||||
result = {}
|
||||
for key, value in outputs.items():
|
||||
if isinstance(value, list):
|
||||
has_file = False
|
||||
for item in value:
|
||||
if isinstance(item, dict) and item.get('__variant') == 'FileVar':
|
||||
try:
|
||||
files.append(FileVar(**item))
|
||||
has_file = True
|
||||
except Exception as e:
|
||||
pass
|
||||
if has_file:
|
||||
continue
|
||||
|
||||
result[key] = value
|
||||
|
||||
return result, files
|
@ -1,7 +1,10 @@
|
||||
from copy import deepcopy
|
||||
from datetime import datetime, timezone
|
||||
from mimetypes import guess_type
|
||||
from typing import Union
|
||||
|
||||
from yarl import URL
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
@ -17,6 +20,7 @@ from core.tools.errors import (
|
||||
ToolProviderNotFoundError,
|
||||
)
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool.workflow_tool import WorkflowTool
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from extensions.ext_database import db
|
||||
from models.model import Message, MessageFile
|
||||
@ -115,7 +119,8 @@ class ToolEngine:
|
||||
@staticmethod
|
||||
def workflow_invoke(tool: Tool, tool_parameters: dict,
|
||||
user_id: str, workflow_id: str,
|
||||
workflow_tool_callback: DifyWorkflowCallbackHandler) \
|
||||
workflow_tool_callback: DifyWorkflowCallbackHandler,
|
||||
workflow_call_depth: int) \
|
||||
-> list[ToolInvokeMessage]:
|
||||
"""
|
||||
Workflow invokes the tool with the given arguments.
|
||||
@ -127,6 +132,9 @@ class ToolEngine:
|
||||
tool_inputs=tool_parameters
|
||||
)
|
||||
|
||||
if isinstance(tool, WorkflowTool):
|
||||
tool.workflow_call_depth = workflow_call_depth + 1
|
||||
|
||||
response = tool.invoke(user_id, tool_parameters)
|
||||
|
||||
# hit the callback handler
|
||||
@ -195,8 +203,24 @@ class ToolEngine:
|
||||
for response in tool_response:
|
||||
if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
|
||||
response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
mimetype = None
|
||||
if response.meta.get('mime_type'):
|
||||
mimetype = response.meta.get('mime_type')
|
||||
else:
|
||||
try:
|
||||
url = URL(response.message)
|
||||
extension = url.suffix
|
||||
guess_type_result, _ = guess_type(f'a{extension}')
|
||||
if guess_type_result:
|
||||
mimetype = guess_type_result
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not mimetype:
|
||||
mimetype = 'image/jpeg'
|
||||
|
||||
result.append(ToolInvokeMessageBinary(
|
||||
mimetype=response.meta.get('mime_type', 'octet/stream'),
|
||||
mimetype=response.meta.get('mime_type', 'image/jpeg'),
|
||||
url=response.message,
|
||||
save_as=response.save_as,
|
||||
))
|
||||
|
96
api/core/tools/tool_label_manager.py
Normal file
96
api/core/tools/tool_label_manager.py
Normal file
@ -0,0 +1,96 @@
|
||||
from core.tools.entities.values import default_tool_label_name_list
|
||||
from core.tools.provider.api_tool_provider import ApiToolProviderController
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ToolLabelBinding
|
||||
|
||||
|
||||
class ToolLabelManager:
|
||||
@classmethod
|
||||
def filter_tool_labels(cls, tool_labels: list[str]) -> list[str]:
|
||||
"""
|
||||
Filter tool labels
|
||||
"""
|
||||
tool_labels = [label for label in tool_labels if label in default_tool_label_name_list]
|
||||
return list(set(tool_labels))
|
||||
|
||||
@classmethod
|
||||
def update_tool_labels(cls, controller: ToolProviderController, labels: list[str]):
|
||||
"""
|
||||
Update tool labels
|
||||
"""
|
||||
labels = cls.filter_tool_labels(labels)
|
||||
|
||||
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
|
||||
provider_id = controller.provider_id
|
||||
else:
|
||||
raise ValueError('Unsupported tool type')
|
||||
|
||||
# delete old labels
|
||||
db.session.query(ToolLabelBinding).filter(
|
||||
ToolLabelBinding.tool_id == provider_id
|
||||
).delete()
|
||||
|
||||
# insert new labels
|
||||
for label in labels:
|
||||
db.session.add(ToolLabelBinding(
|
||||
tool_id=provider_id,
|
||||
tool_type=controller.provider_type.value,
|
||||
label_name=label,
|
||||
))
|
||||
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def get_tool_labels(cls, controller: ToolProviderController) -> list[str]:
|
||||
"""
|
||||
Get tool labels
|
||||
"""
|
||||
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
|
||||
provider_id = controller.provider_id
|
||||
elif isinstance(controller, BuiltinToolProviderController):
|
||||
return controller.tool_labels
|
||||
else:
|
||||
raise ValueError('Unsupported tool type')
|
||||
|
||||
labels: list[ToolLabelBinding] = db.session.query(ToolLabelBinding.label_name).filter(
|
||||
ToolLabelBinding.tool_id == provider_id,
|
||||
ToolLabelBinding.tool_type == controller.provider_type.value,
|
||||
).all()
|
||||
|
||||
return [label.label_name for label in labels]
|
||||
|
||||
@classmethod
|
||||
def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]:
|
||||
"""
|
||||
Get tools labels
|
||||
|
||||
:param tool_providers: list of tool providers
|
||||
|
||||
:return: dict of tool labels
|
||||
:key: tool id
|
||||
:value: list of tool labels
|
||||
"""
|
||||
if not tool_providers:
|
||||
return {}
|
||||
|
||||
for controller in tool_providers:
|
||||
if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
|
||||
raise ValueError('Unsupported tool type')
|
||||
|
||||
provider_ids = [controller.provider_id for controller in tool_providers]
|
||||
|
||||
labels: list[ToolLabelBinding] = db.session.query(ToolLabelBinding).filter(
|
||||
ToolLabelBinding.tool_id.in_(provider_ids)
|
||||
).all()
|
||||
|
||||
tool_labels = {
|
||||
label.tool_id: [] for label in labels
|
||||
}
|
||||
|
||||
for label in labels:
|
||||
tool_labels[label.tool_id].append(label.label_name)
|
||||
|
||||
return tool_labels
|
@ -9,21 +9,24 @@ from typing import Any, Union
|
||||
from flask import current_app
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools import *
|
||||
from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ToolInvokeFrom,
|
||||
ToolParameter,
|
||||
)
|
||||
from core.tools.entities.user_entities import UserToolProvider
|
||||
from core.tools.errors import ToolProviderNotFoundError
|
||||
from core.tools.provider.api_tool_provider import ApiBasedToolProviderController
|
||||
from core.tools.provider.api_tool_provider import ApiToolProviderController
|
||||
from core.tools.provider.builtin._positions import BuiltinToolProviderSort
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
from core.tools.tool.api_tool import ApiTool
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.utils.configuration import (
|
||||
ToolConfigurationManager,
|
||||
ToolParameterConfigurationManager,
|
||||
@ -31,8 +34,8 @@ from core.tools.utils.configuration import (
|
||||
from core.utils.module_import_helper import load_single_subclass_from_source
|
||||
from core.workflow.nodes.tool.entities import ToolEntity
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider
|
||||
from services.tools_transform_service import ToolTransformService
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -99,7 +102,12 @@ class ToolManager:
|
||||
raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
|
||||
|
||||
@classmethod
|
||||
def get_tool_runtime(cls, provider_type: str, provider_name: str, tool_name: str, tenant_id: str) \
|
||||
def get_tool_runtime(cls, provider_type: str,
|
||||
provider_id: str,
|
||||
tool_name: str,
|
||||
tenant_id: str,
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \
|
||||
-> Union[BuiltinTool, ApiTool]:
|
||||
"""
|
||||
get the tool runtime
|
||||
@ -111,51 +119,76 @@ class ToolManager:
|
||||
:return: the tool
|
||||
"""
|
||||
if provider_type == 'builtin':
|
||||
builtin_tool = cls.get_builtin_tool(provider_name, tool_name)
|
||||
builtin_tool = cls.get_builtin_tool(provider_id, tool_name)
|
||||
|
||||
# check if the builtin tool need credentials
|
||||
provider_controller = cls.get_builtin_provider(provider_name)
|
||||
provider_controller = cls.get_builtin_provider(provider_id)
|
||||
if not provider_controller.need_credentials:
|
||||
return builtin_tool.fork_tool_runtime(meta={
|
||||
return builtin_tool.fork_tool_runtime(runtime={
|
||||
'tenant_id': tenant_id,
|
||||
'credentials': {},
|
||||
'invoke_from': invoke_from,
|
||||
'tool_invoke_from': tool_invoke_from,
|
||||
})
|
||||
|
||||
# get credentials
|
||||
builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider_name,
|
||||
BuiltinToolProvider.provider == provider_id,
|
||||
).first()
|
||||
|
||||
if builtin_provider is None:
|
||||
raise ToolProviderNotFoundError(f'builtin provider {provider_name} not found')
|
||||
raise ToolProviderNotFoundError(f'builtin provider {provider_id} not found')
|
||||
|
||||
# decrypt the credentials
|
||||
credentials = builtin_provider.credentials
|
||||
controller = cls.get_builtin_provider(provider_name)
|
||||
controller = cls.get_builtin_provider(provider_id)
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
|
||||
|
||||
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
|
||||
|
||||
return builtin_tool.fork_tool_runtime(meta={
|
||||
return builtin_tool.fork_tool_runtime(runtime={
|
||||
'tenant_id': tenant_id,
|
||||
'credentials': decrypted_credentials,
|
||||
'runtime_parameters': {}
|
||||
'runtime_parameters': {},
|
||||
'invoke_from': invoke_from,
|
||||
'tool_invoke_from': tool_invoke_from,
|
||||
})
|
||||
|
||||
elif provider_type == 'api':
|
||||
if tenant_id is None:
|
||||
raise ValueError('tenant id is required for api provider')
|
||||
|
||||
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_name)
|
||||
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
|
||||
|
||||
# decrypt the credentials
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=api_provider)
|
||||
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
|
||||
|
||||
return api_provider.get_tool(tool_name).fork_tool_runtime(meta={
|
||||
return api_provider.get_tool(tool_name).fork_tool_runtime(runtime={
|
||||
'tenant_id': tenant_id,
|
||||
'credentials': decrypted_credentials,
|
||||
'invoke_from': invoke_from,
|
||||
'tool_invoke_from': tool_invoke_from,
|
||||
})
|
||||
elif provider_type == 'workflow':
|
||||
workflow_provider = db.session.query(WorkflowToolProvider).filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.id == provider_id
|
||||
).first()
|
||||
|
||||
if workflow_provider is None:
|
||||
raise ToolProviderNotFoundError(f'workflow provider {provider_id} not found')
|
||||
|
||||
controller = ToolTransformService.workflow_provider_to_controller(
|
||||
db_provider=workflow_provider
|
||||
)
|
||||
|
||||
return controller.get_tools(user_id=None, tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(runtime={
|
||||
'tenant_id': tenant_id,
|
||||
'credentials': {},
|
||||
'invoke_from': invoke_from,
|
||||
'tool_invoke_from': tool_invoke_from,
|
||||
})
|
||||
elif provider_type == 'app':
|
||||
raise NotImplementedError('app provider not implemented')
|
||||
@ -207,18 +240,25 @@ class ToolManager:
|
||||
return parameter_value
|
||||
|
||||
@classmethod
|
||||
def get_agent_tool_runtime(cls, tenant_id: str, app_id: str, agent_tool: AgentToolEntity) -> Tool:
|
||||
def get_agent_tool_runtime(cls, tenant_id: str, app_id: str, agent_tool: AgentToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER) -> Tool:
|
||||
"""
|
||||
get the agent tool runtime
|
||||
"""
|
||||
tool_entity = cls.get_tool_runtime(
|
||||
provider_type=agent_tool.provider_type, provider_name=agent_tool.provider_id,
|
||||
provider_type=agent_tool.provider_type,
|
||||
provider_id=agent_tool.provider_id,
|
||||
tool_name=agent_tool.tool_name,
|
||||
tenant_id=tenant_id,
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=ToolInvokeFrom.AGENT
|
||||
)
|
||||
runtime_parameters = {}
|
||||
parameters = tool_entity.get_all_runtime_parameters()
|
||||
for parameter in parameters:
|
||||
# check file types
|
||||
if parameter.type == ToolParameter.ToolParameterType.FILE:
|
||||
raise ValueError(f"file type parameter {parameter.name} not supported in agent")
|
||||
|
||||
if parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
# save tool parameter to tool entity memory
|
||||
value = cls._init_runtime_parameter(parameter, agent_tool.tool_parameters)
|
||||
@ -238,15 +278,17 @@ class ToolManager:
|
||||
return tool_entity
|
||||
|
||||
@classmethod
|
||||
def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, workflow_tool: ToolEntity):
|
||||
def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, workflow_tool: ToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER) -> Tool:
|
||||
"""
|
||||
get the workflow tool runtime
|
||||
"""
|
||||
tool_entity = cls.get_tool_runtime(
|
||||
provider_type=workflow_tool.provider_type,
|
||||
provider_name=workflow_tool.provider_id,
|
||||
provider_id=workflow_tool.provider_id,
|
||||
tool_name=workflow_tool.tool_name,
|
||||
tenant_id=tenant_id,
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=ToolInvokeFrom.WORKFLOW
|
||||
)
|
||||
runtime_parameters = {}
|
||||
parameters = tool_entity.get_all_runtime_parameters()
|
||||
@ -371,51 +413,91 @@ class ToolManager:
|
||||
return cls._builtin_tools_labels[tool_name]
|
||||
|
||||
@classmethod
|
||||
def user_list_providers(cls, user_id: str, tenant_id: str) -> list[UserToolProvider]:
|
||||
def user_list_providers(cls, user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral) -> list[UserToolProvider]:
|
||||
result_providers: dict[str, UserToolProvider] = {}
|
||||
|
||||
# get builtin providers
|
||||
builtin_providers = cls.list_builtin_providers()
|
||||
|
||||
# get db builtin providers
|
||||
db_builtin_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider). \
|
||||
filter(BuiltinToolProvider.tenant_id == tenant_id).all()
|
||||
filters = []
|
||||
if not typ:
|
||||
filters.extend(['builtin', 'api', 'workflow'])
|
||||
else:
|
||||
filters.append(typ)
|
||||
|
||||
find_db_builtin_provider = lambda provider: next(
|
||||
(x for x in db_builtin_providers if x.provider == provider),
|
||||
None
|
||||
)
|
||||
if 'builtin' in filters:
|
||||
|
||||
# append builtin providers
|
||||
for provider in builtin_providers:
|
||||
user_provider = ToolTransformService.builtin_provider_to_user_provider(
|
||||
provider_controller=provider,
|
||||
db_provider=find_db_builtin_provider(provider.identity.name),
|
||||
decrypt_credentials=False
|
||||
# get builtin providers
|
||||
builtin_providers = cls.list_builtin_providers()
|
||||
|
||||
# get db builtin providers
|
||||
db_builtin_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider). \
|
||||
filter(BuiltinToolProvider.tenant_id == tenant_id).all()
|
||||
|
||||
find_db_builtin_provider = lambda provider: next(
|
||||
(x for x in db_builtin_providers if x.provider == provider),
|
||||
None
|
||||
)
|
||||
|
||||
result_providers[provider.identity.name] = user_provider
|
||||
# append builtin providers
|
||||
for provider in builtin_providers:
|
||||
user_provider = ToolTransformService.builtin_provider_to_user_provider(
|
||||
provider_controller=provider,
|
||||
db_provider=find_db_builtin_provider(provider.identity.name),
|
||||
decrypt_credentials=False
|
||||
)
|
||||
|
||||
result_providers[provider.identity.name] = user_provider
|
||||
|
||||
# get db api providers
|
||||
db_api_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider). \
|
||||
filter(ApiToolProvider.tenant_id == tenant_id).all()
|
||||
|
||||
for db_api_provider in db_api_providers:
|
||||
provider_controller = ToolTransformService.api_provider_to_controller(
|
||||
db_provider=db_api_provider,
|
||||
)
|
||||
user_provider = ToolTransformService.api_provider_to_user_provider(
|
||||
provider_controller=provider_controller,
|
||||
db_provider=db_api_provider,
|
||||
decrypt_credentials=False
|
||||
)
|
||||
result_providers[db_api_provider.name] = user_provider
|
||||
if 'api' in filters:
|
||||
db_api_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider). \
|
||||
filter(ApiToolProvider.tenant_id == tenant_id).all()
|
||||
|
||||
api_provider_controllers = [{
|
||||
'provider': provider,
|
||||
'controller': ToolTransformService.api_provider_to_controller(provider)
|
||||
} for provider in db_api_providers]
|
||||
|
||||
# get labels
|
||||
labels = ToolLabelManager.get_tools_labels([x['controller'] for x in api_provider_controllers])
|
||||
|
||||
for api_provider_controller in api_provider_controllers:
|
||||
user_provider = ToolTransformService.api_provider_to_user_provider(
|
||||
provider_controller=api_provider_controller['controller'],
|
||||
db_provider=api_provider_controller['provider'],
|
||||
decrypt_credentials=False,
|
||||
labels=labels.get(api_provider_controller['controller'].provider_id, [])
|
||||
)
|
||||
result_providers[f'api_provider.{user_provider.name}'] = user_provider
|
||||
|
||||
if 'workflow' in filters:
|
||||
# get workflow providers
|
||||
workflow_providers: list[WorkflowToolProvider] = db.session.query(WorkflowToolProvider). \
|
||||
filter(WorkflowToolProvider.tenant_id == tenant_id).all()
|
||||
|
||||
workflow_provider_controllers = []
|
||||
for provider in workflow_providers:
|
||||
try:
|
||||
workflow_provider_controllers.append(
|
||||
ToolTransformService.workflow_provider_to_controller(db_provider=provider)
|
||||
)
|
||||
except Exception as e:
|
||||
# app has been deleted
|
||||
pass
|
||||
|
||||
labels = ToolLabelManager.get_tools_labels(workflow_provider_controllers)
|
||||
|
||||
for provider_controller in workflow_provider_controllers:
|
||||
user_provider = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=provider_controller,
|
||||
labels=labels.get(provider_controller.provider_id, []),
|
||||
)
|
||||
result_providers[f'workflow_provider.{user_provider.name}'] = user_provider
|
||||
|
||||
return BuiltinToolProviderSort.sort(list(result_providers.values()))
|
||||
|
||||
@classmethod
|
||||
def get_api_provider_controller(cls, tenant_id: str, provider_id: str) -> tuple[
|
||||
ApiBasedToolProviderController, dict[str, Any]]:
|
||||
ApiToolProviderController, dict[str, Any]]:
|
||||
"""
|
||||
get the api provider
|
||||
|
||||
@ -431,7 +513,7 @@ class ToolManager:
|
||||
if provider is None:
|
||||
raise ToolProviderNotFoundError(f'api provider {provider_id} not found')
|
||||
|
||||
controller = ApiBasedToolProviderController.from_db(
|
||||
controller = ApiToolProviderController.from_db(
|
||||
provider,
|
||||
ApiProviderAuthType.API_KEY if provider.credentials['auth_type'] == 'api_key' else
|
||||
ApiProviderAuthType.NONE
|
||||
@ -462,7 +544,7 @@ class ToolManager:
|
||||
credentials = {}
|
||||
|
||||
# package tool provider controller
|
||||
controller = ApiBasedToolProviderController.from_db(
|
||||
controller = ApiToolProviderController.from_db(
|
||||
provider, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
|
||||
)
|
||||
# init tool configuration
|
||||
@ -479,6 +561,9 @@ class ToolManager:
|
||||
"content": "\ud83d\ude01"
|
||||
}
|
||||
|
||||
# add tool labels
|
||||
labels = ToolLabelManager.get_tool_labels(controller)
|
||||
|
||||
return jsonable_encoder({
|
||||
'schema_type': provider.schema_type,
|
||||
'schema': provider.schema,
|
||||
@ -487,7 +572,8 @@ class ToolManager:
|
||||
'description': provider.description,
|
||||
'credentials': masked_credentials,
|
||||
'privacy_policy': provider.privacy_policy,
|
||||
'custom_disclaimer': provider.custom_disclaimer
|
||||
'custom_disclaimer': provider.custom_disclaimer,
|
||||
'labels': labels,
|
||||
})
|
||||
|
||||
@classmethod
|
||||
@ -519,6 +605,15 @@ class ToolManager:
|
||||
"background": "#252525",
|
||||
"content": "\ud83d\ude01"
|
||||
}
|
||||
elif provider_type == 'workflow':
|
||||
provider: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.id == provider_id
|
||||
).first()
|
||||
if provider is None:
|
||||
raise ToolProviderNotFoundError(f'workflow provider {provider_id} not found')
|
||||
|
||||
return json.loads(provider.icon)
|
||||
else:
|
||||
raise ValueError(f"provider type {provider_type} not found")
|
||||
|
||||
|
@ -72,8 +72,8 @@ class ToolConfigurationManager(BaseModel):
|
||||
return a deep copy of credentials with decrypted values
|
||||
"""
|
||||
cache = ToolProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}',
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f'{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}',
|
||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER
|
||||
)
|
||||
cached_credentials = cache.get()
|
||||
@ -95,8 +95,8 @@ class ToolConfigurationManager(BaseModel):
|
||||
|
||||
def delete_tool_credentials_cache(self):
|
||||
cache = ToolProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}',
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f'{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}',
|
||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER
|
||||
)
|
||||
cache.delete()
|
||||
|
@ -1,14 +1,15 @@
|
||||
import logging
|
||||
from mimetypes import guess_extension
|
||||
|
||||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ToolFileMessageTransformer:
|
||||
@staticmethod
|
||||
def transform_tool_invoke_messages(messages: list[ToolInvokeMessage],
|
||||
@classmethod
|
||||
def transform_tool_invoke_messages(cls, messages: list[ToolInvokeMessage],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
conversation_id: str) -> list[ToolInvokeMessage]:
|
||||
@ -62,7 +63,7 @@ class ToolFileMessageTransformer:
|
||||
mimetype=mimetype
|
||||
)
|
||||
|
||||
url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".bin"}'
|
||||
url = cls.get_tool_file_url(file.id, guess_extension(file.mimetype))
|
||||
|
||||
# check if file is image
|
||||
if 'image' in mimetype:
|
||||
@ -79,7 +80,30 @@ class ToolFileMessageTransformer:
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
))
|
||||
elif message.type == ToolInvokeMessage.MessageType.FILE_VAR:
|
||||
file_var: FileVar = message.meta.get('file_var')
|
||||
if file_var:
|
||||
if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
url = cls.get_tool_file_url(file_var.related_id, file_var.extension)
|
||||
if file_var.type == FileType.IMAGE:
|
||||
result.append(ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=url,
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
))
|
||||
else:
|
||||
result.append(ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.LINK,
|
||||
message=url,
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
))
|
||||
else:
|
||||
result.append(message)
|
||||
|
||||
return result
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_tool_file_url(cls, tool_file_id: str, extension: str) -> str:
|
||||
return f'/files/tools/{tool_file_id}{extension or ".bin"}'
|
@ -20,12 +20,14 @@ from core.model_runtime.errors.invoke import (
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel, ModelPropertyKey
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.model.errors import InvokeModelError
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ToolModelInvoke
|
||||
|
||||
|
||||
class ToolModelManager:
|
||||
class InvokeModelError(Exception):
|
||||
pass
|
||||
|
||||
class ModelInvocationUtils:
|
||||
@staticmethod
|
||||
def get_max_llm_context_tokens(
|
||||
tenant_id: str,
|
@ -9,14 +9,14 @@ from requests import get
|
||||
from yaml import YAMLError, safe_load
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiBasedToolBundle
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter
|
||||
from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError
|
||||
|
||||
|
||||
class ApiBasedToolSchemaParser:
|
||||
@staticmethod
|
||||
def parse_openapi_to_tool_bundle(openapi: dict, extra_info: dict = None, warning: dict = None) -> list[ApiBasedToolBundle]:
|
||||
def parse_openapi_to_tool_bundle(openapi: dict, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]:
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
@ -145,7 +145,7 @@ class ApiBasedToolSchemaParser:
|
||||
|
||||
interface['operation']['operationId'] = f'{path}_{interface["method"]}'
|
||||
|
||||
bundles.append(ApiBasedToolBundle(
|
||||
bundles.append(ApiToolBundle(
|
||||
server_url=server_url + interface['path'],
|
||||
method=interface['method'],
|
||||
summary=interface['operation']['description'] if 'description' in interface['operation'] else
|
||||
@ -176,7 +176,7 @@ class ApiBasedToolSchemaParser:
|
||||
return ToolParameter.ToolParameterType.STRING
|
||||
|
||||
@staticmethod
|
||||
def parse_openapi_yaml_to_tool_bundle(yaml: str, extra_info: dict = None, warning: dict = None) -> list[ApiBasedToolBundle]:
|
||||
def parse_openapi_yaml_to_tool_bundle(yaml: str, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]:
|
||||
"""
|
||||
parse openapi yaml to tool bundle
|
||||
|
||||
@ -258,7 +258,7 @@ class ApiBasedToolSchemaParser:
|
||||
return openapi
|
||||
|
||||
@staticmethod
|
||||
def parse_openai_plugin_json_to_tool_bundle(json: str, extra_info: dict = None, warning: dict = None) -> list[ApiBasedToolBundle]:
|
||||
def parse_openai_plugin_json_to_tool_bundle(json: str, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]:
|
||||
"""
|
||||
parse openapi plugin yaml to tool bundle
|
||||
|
||||
@ -290,7 +290,7 @@ class ApiBasedToolSchemaParser:
|
||||
return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(response.text, extra_info=extra_info, warning=warning)
|
||||
|
||||
@staticmethod
|
||||
def auto_parse_to_tool_bundle(content: str, extra_info: dict = None, warning: dict = None) -> tuple[list[ApiBasedToolBundle], str]:
|
||||
def auto_parse_to_tool_bundle(content: str, extra_info: dict = None, warning: dict = None) -> tuple[list[ApiToolBundle], str]:
|
||||
"""
|
||||
auto parse to tool bundle
|
||||
|
||||
|
48
api/core/tools/utils/workflow_configuration_sync.py
Normal file
48
api/core/tools/utils/workflow_configuration_sync.py
Normal file
@ -0,0 +1,48 @@
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||
|
||||
|
||||
class WorkflowToolConfigurationUtils:
|
||||
@classmethod
|
||||
def check_parameter_configurations(cls, configurations: list[dict]):
|
||||
"""
|
||||
check parameter configurations
|
||||
"""
|
||||
for configuration in configurations:
|
||||
if not WorkflowToolParameterConfiguration(**configuration):
|
||||
raise ValueError('invalid parameter configuration')
|
||||
|
||||
@classmethod
|
||||
def get_workflow_graph_variables(cls, graph: dict) -> list[VariableEntity]:
|
||||
"""
|
||||
get workflow graph variables
|
||||
"""
|
||||
nodes = graph.get('nodes', [])
|
||||
start_node = next(filter(lambda x: x.get('data', {}).get('type') == 'start', nodes), None)
|
||||
|
||||
if not start_node:
|
||||
return []
|
||||
|
||||
return [
|
||||
VariableEntity(**variable) for variable in start_node.get('data', {}).get('variables', [])
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def check_is_synced(cls,
|
||||
variables: list[VariableEntity],
|
||||
tool_configurations: list[WorkflowToolParameterConfiguration]) -> None:
|
||||
"""
|
||||
check is synced
|
||||
|
||||
raise ValueError if not synced
|
||||
"""
|
||||
variable_names = [variable.variable for variable in variables]
|
||||
|
||||
if len(tool_configurations) != len(variables):
|
||||
raise ValueError('parameter configuration mismatch, please republish the tool to update')
|
||||
|
||||
for parameter in tool_configurations:
|
||||
if parameter.name not in variable_names:
|
||||
raise ValueError('parameter configuration mismatch, please republish the tool to update')
|
||||
|
||||
return True
|
@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.app.entities.queue_entities import AppQueueEvent
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
@ -71,6 +71,42 @@ class BaseWorkflowCallback(ABC):
|
||||
Publish text chunk
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_workflow_iteration_started(self,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_run_index: int = 1,
|
||||
node_data: Optional[BaseNodeData] = None,
|
||||
inputs: dict = None,
|
||||
predecessor_node_id: Optional[str] = None,
|
||||
metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Publish iteration started
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_workflow_iteration_next(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
index: int,
|
||||
node_run_index: int,
|
||||
output: Optional[Any],
|
||||
) -> None:
|
||||
"""
|
||||
Publish iteration next
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_workflow_iteration_completed(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_run_index: int,
|
||||
outputs: dict) -> None:
|
||||
"""
|
||||
Publish iteration completed
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_event(self, event: AppQueueEvent) -> None:
|
||||
|
@ -7,3 +7,16 @@ from pydantic import BaseModel
|
||||
class BaseNodeData(ABC, BaseModel):
|
||||
title: str
|
||||
desc: Optional[str] = None
|
||||
|
||||
class BaseIterationNodeData(BaseNodeData):
|
||||
start_node_id: str
|
||||
|
||||
class BaseIterationState(BaseModel):
|
||||
iteration_node_id: str
|
||||
index: int
|
||||
inputs: dict
|
||||
|
||||
class MetaData(BaseModel):
|
||||
pass
|
||||
|
||||
metadata: MetaData
|
@ -21,7 +21,11 @@ class NodeType(Enum):
|
||||
QUESTION_CLASSIFIER = 'question-classifier'
|
||||
HTTP_REQUEST = 'http-request'
|
||||
TOOL = 'tool'
|
||||
VARIABLE_AGGREGATOR = 'variable-aggregator'
|
||||
VARIABLE_ASSIGNER = 'variable-assigner'
|
||||
LOOP = 'loop'
|
||||
ITERATION = 'iteration'
|
||||
PARAMETER_EXTRACTOR = 'parameter-extractor'
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'NodeType':
|
||||
@ -68,6 +72,8 @@ class NodeRunMetadataKey(Enum):
|
||||
TOTAL_PRICE = 'total_price'
|
||||
CURRENCY = 'currency'
|
||||
TOOL_INFO = 'tool_info'
|
||||
ITERATION_ID = 'iteration_id'
|
||||
ITERATION_INDEX = 'iteration_index'
|
||||
|
||||
|
||||
class NodeRunResult(BaseModel):
|
||||
|
@ -90,3 +90,12 @@ class VariablePool:
|
||||
raise ValueError(f'Invalid value type: {target_value_type.value}')
|
||||
|
||||
return value
|
||||
|
||||
def clear_node_variables(self, node_id: str) -> None:
|
||||
"""
|
||||
Clear node variables
|
||||
:param node_id: node id
|
||||
:return:
|
||||
"""
|
||||
if node_id in self.variables_mapping:
|
||||
self.variables_mapping.pop(node_id)
|
@ -1,5 +1,9 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.base_node_data_entities import BaseIterationState
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode, UserFrom
|
||||
@ -22,6 +26,9 @@ class WorkflowRunState:
|
||||
workflow_type: WorkflowType
|
||||
user_id: str
|
||||
user_from: UserFrom
|
||||
invoke_from: InvokeFrom
|
||||
|
||||
workflow_call_depth: int
|
||||
|
||||
start_at: float
|
||||
variable_pool: VariablePool
|
||||
@ -30,20 +37,37 @@ class WorkflowRunState:
|
||||
|
||||
workflow_nodes_and_results: list[WorkflowNodeAndResult]
|
||||
|
||||
class NodeRun(BaseModel):
|
||||
node_id: str
|
||||
iteration_node_id: str
|
||||
|
||||
workflow_node_runs: list[NodeRun]
|
||||
workflow_node_steps: int
|
||||
|
||||
current_iteration_state: Optional[BaseIterationState]
|
||||
|
||||
def __init__(self, workflow: Workflow,
|
||||
start_at: float,
|
||||
variable_pool: VariablePool,
|
||||
user_id: str,
|
||||
user_from: UserFrom):
|
||||
user_from: UserFrom,
|
||||
invoke_from: InvokeFrom,
|
||||
workflow_call_depth: int):
|
||||
self.workflow_id = workflow.id
|
||||
self.tenant_id = workflow.tenant_id
|
||||
self.app_id = workflow.app_id
|
||||
self.workflow_type = WorkflowType.value_of(workflow.type)
|
||||
self.user_id = user_id
|
||||
self.user_from = user_from
|
||||
self.invoke_from = invoke_from
|
||||
self.workflow_call_depth = workflow_call_depth
|
||||
|
||||
self.start_at = start_at
|
||||
self.variable_pool = variable_pool
|
||||
|
||||
self.total_tokens = 0
|
||||
self.workflow_nodes_and_results = []
|
||||
|
||||
self.current_iteration_state = None
|
||||
self.workflow_node_steps = 1
|
||||
self.workflow_node_runs = []
|
@ -2,8 +2,9 @@ from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
@ -37,6 +38,9 @@ class BaseNode(ABC):
|
||||
workflow_id: str
|
||||
user_id: str
|
||||
user_from: UserFrom
|
||||
invoke_from: InvokeFrom
|
||||
|
||||
workflow_call_depth: int
|
||||
|
||||
node_id: str
|
||||
node_data: BaseNodeData
|
||||
@ -49,13 +53,17 @@ class BaseNode(ABC):
|
||||
workflow_id: str,
|
||||
user_id: str,
|
||||
user_from: UserFrom,
|
||||
invoke_from: InvokeFrom,
|
||||
config: dict,
|
||||
callbacks: list[BaseWorkflowCallback] = None) -> None:
|
||||
callbacks: list[BaseWorkflowCallback] = None,
|
||||
workflow_call_depth: int = 0) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.app_id = app_id
|
||||
self.workflow_id = workflow_id
|
||||
self.user_id = user_id
|
||||
self.user_from = user_from
|
||||
self.invoke_from = invoke_from
|
||||
self.workflow_call_depth = workflow_call_depth
|
||||
|
||||
self.node_id = config.get("id")
|
||||
if not self.node_id:
|
||||
@ -140,3 +148,38 @@ class BaseNode(ABC):
|
||||
:return:
|
||||
"""
|
||||
return self._node_type
|
||||
|
||||
class BaseIterationNode(BaseNode):
|
||||
@abstractmethod
|
||||
def _run(self, variable_pool: VariablePool) -> BaseIterationState:
|
||||
"""
|
||||
Run node
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def run(self, variable_pool: VariablePool) -> BaseIterationState:
|
||||
"""
|
||||
Run node entry
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
return self._run(variable_pool=variable_pool)
|
||||
|
||||
def get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str:
|
||||
"""
|
||||
Get next iteration start node id based on the graph.
|
||||
:param graph: graph
|
||||
:return: next node id
|
||||
"""
|
||||
return self._get_next_iteration(variable_pool, state)
|
||||
|
||||
@abstractmethod
|
||||
def _get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str:
|
||||
"""
|
||||
Get next iteration start node id based on the graph.
|
||||
:param graph: graph
|
||||
:return: next node id
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
39
api/core/workflow/nodes/iteration/entities.py
Normal file
39
api/core/workflow/nodes/iteration/entities.py
Normal file
@ -0,0 +1,39 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState
|
||||
|
||||
|
||||
class IterationNodeData(BaseIterationNodeData):
|
||||
"""
|
||||
Iteration Node Data.
|
||||
"""
|
||||
parent_loop_id: Optional[str] # redundant field, not used currently
|
||||
iterator_selector: list[str] # variable selector
|
||||
output_selector: list[str] # output selector
|
||||
|
||||
class IterationState(BaseIterationState):
|
||||
"""
|
||||
Iteration State.
|
||||
"""
|
||||
outputs: list[Any] = None
|
||||
current_output: Optional[Any] = None
|
||||
|
||||
class MetaData(BaseIterationState.MetaData):
|
||||
"""
|
||||
Data.
|
||||
"""
|
||||
iterator_length: int
|
||||
|
||||
def get_last_output(self) -> Optional[Any]:
|
||||
"""
|
||||
Get last output.
|
||||
"""
|
||||
if self.outputs:
|
||||
return self.outputs[-1]
|
||||
return None
|
||||
|
||||
def get_current_output(self) -> Optional[Any]:
|
||||
"""
|
||||
Get current output.
|
||||
"""
|
||||
return self.current_output
|
119
api/core/workflow/nodes/iteration/iteration_node.py
Normal file
119
api/core/workflow/nodes/iteration/iteration_node.py
Normal file
@ -0,0 +1,119 @@
|
||||
from typing import cast
|
||||
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.entities.base_node_data_entities import BaseIterationState
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseIterationNode
|
||||
from core.workflow.nodes.iteration.entities import IterationNodeData, IterationState
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class IterationNode(BaseIterationNode):
|
||||
"""
|
||||
Iteration Node.
|
||||
"""
|
||||
_node_data_cls = IterationNodeData
|
||||
_node_type = NodeType.ITERATION
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> BaseIterationState:
|
||||
"""
|
||||
Run the node.
|
||||
"""
|
||||
iterator = variable_pool.get_variable_value(cast(IterationNodeData, self.node_data).iterator_selector)
|
||||
|
||||
if not isinstance(iterator, list):
|
||||
raise ValueError(f"Invalid iterator value: {iterator}, please provide a list.")
|
||||
|
||||
state = IterationState(iteration_node_id=self.node_id, index=-1, inputs={
|
||||
'iterator_selector': iterator
|
||||
}, outputs=[], metadata=IterationState.MetaData(
|
||||
iterator_length=len(iterator) if iterator is not None else 0
|
||||
))
|
||||
|
||||
self._set_current_iteration_variable(variable_pool, state)
|
||||
return state
|
||||
|
||||
def _get_next_iteration(self, variable_pool: VariablePool, state: IterationState) -> NodeRunResult | str:
|
||||
"""
|
||||
Get next iteration start node id based on the graph.
|
||||
:param graph: graph
|
||||
:return: next node id
|
||||
"""
|
||||
# resolve current output
|
||||
self._resolve_current_output(variable_pool, state)
|
||||
# move to next iteration
|
||||
self._next_iteration(variable_pool, state)
|
||||
|
||||
node_data = cast(IterationNodeData, self.node_data)
|
||||
if self._reached_iteration_limit(variable_pool, state):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={
|
||||
'output': jsonable_encoder(state.outputs)
|
||||
}
|
||||
)
|
||||
|
||||
return node_data.start_node_id
|
||||
|
||||
def _set_current_iteration_variable(self, variable_pool: VariablePool, state: IterationState):
|
||||
"""
|
||||
Set current iteration variable.
|
||||
:variable_pool: variable pool
|
||||
"""
|
||||
node_data = cast(IterationNodeData, self.node_data)
|
||||
|
||||
variable_pool.append_variable(self.node_id, ['index'], state.index)
|
||||
# get the iterator value
|
||||
iterator = variable_pool.get_variable_value(node_data.iterator_selector)
|
||||
|
||||
if iterator is None or not isinstance(iterator, list):
|
||||
return
|
||||
|
||||
if state.index < len(iterator):
|
||||
variable_pool.append_variable(self.node_id, ['item'], iterator[state.index])
|
||||
|
||||
def _next_iteration(self, variable_pool: VariablePool, state: IterationState):
|
||||
"""
|
||||
Move to next iteration.
|
||||
:param variable_pool: variable pool
|
||||
"""
|
||||
state.index += 1
|
||||
self._set_current_iteration_variable(variable_pool, state)
|
||||
|
||||
def _reached_iteration_limit(self, variable_pool: VariablePool, state: IterationState):
|
||||
"""
|
||||
Check if iteration limit is reached.
|
||||
:return: True if iteration limit is reached, False otherwise
|
||||
"""
|
||||
node_data = cast(IterationNodeData, self.node_data)
|
||||
iterator = variable_pool.get_variable_value(node_data.iterator_selector)
|
||||
|
||||
if iterator is None or not isinstance(iterator, list):
|
||||
return True
|
||||
|
||||
return state.index >= len(iterator)
|
||||
|
||||
def _resolve_current_output(self, variable_pool: VariablePool, state: IterationState):
|
||||
"""
|
||||
Resolve current output.
|
||||
:param variable_pool: variable pool
|
||||
"""
|
||||
output_selector = cast(IterationNodeData, self.node_data).output_selector
|
||||
output = variable_pool.get_variable_value(output_selector)
|
||||
# clear the output for this iteration
|
||||
variable_pool.append_variable(self.node_id, output_selector[1:], None)
|
||||
state.current_output = output
|
||||
if output is not None:
|
||||
state.outputs.append(output)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: IterationNodeData) -> dict[str, list[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
return {
|
||||
'input_selector': node_data.iterator_selector,
|
||||
}
|
0
api/core/workflow/nodes/loop/__init__.py
Normal file
0
api/core/workflow/nodes/loop/__init__.py
Normal file
13
api/core/workflow/nodes/loop/entities.py
Normal file
13
api/core/workflow/nodes/loop/entities.py
Normal file
@ -0,0 +1,13 @@
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState
|
||||
|
||||
|
||||
class LoopNodeData(BaseIterationNodeData):
|
||||
"""
|
||||
Loop Node Data.
|
||||
"""
|
||||
|
||||
class LoopState(BaseIterationState):
|
||||
"""
|
||||
Loop State.
|
||||
"""
|
20
api/core/workflow/nodes/loop/loop_node.py
Normal file
20
api/core/workflow/nodes/loop/loop_node.py
Normal file
@ -0,0 +1,20 @@
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseIterationNode
|
||||
from core.workflow.nodes.loop.entities import LoopNodeData, LoopState
|
||||
|
||||
|
||||
class LoopNode(BaseIterationNode):
|
||||
"""
|
||||
Loop Node.
|
||||
"""
|
||||
_node_data_cls = LoopNodeData
|
||||
_node_type = NodeType.LOOP
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> LoopState:
|
||||
return super()._run(variable_pool)
|
||||
|
||||
def _get_next_iteration(self, variable_loop: VariablePool) -> NodeRunResult | str:
|
||||
"""
|
||||
Get next iteration start node id based on the graph.
|
||||
"""
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user