Feat/workflow phase2 (#4687)

This commit is contained in:
Yeuoly 2024-05-27 22:01:11 +08:00 committed by GitHub
parent 45deaee762
commit e852a21634
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
139 changed files with 5997 additions and 779 deletions

View File

@ -137,6 +137,71 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
logging.exception("internal server error.")
raise InternalServerError()
class AdvancedChatDraftRunIterationNodeApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
def post(self, app_model: App, node_id: str):
"""
Run draft workflow iteration node
"""
parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, location='json')
args = parser.parse_args()
try:
response = AppGenerateService.generate_single_iteration(
app_model=app_model,
user=current_user,
node_id=node_id,
args=args,
streaming=True
)
return helper.compact_generate_response(response)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationCompletedError:
raise ConversationCompletedError()
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")
raise InternalServerError()
class WorkflowDraftRunIterationNodeApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
def post(self, app_model: App, node_id: str):
"""
Run draft workflow iteration node
"""
parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, location='json')
args = parser.parse_args()
try:
response = AppGenerateService.generate_single_iteration(
app_model=app_model,
user=current_user,
node_id=node_id,
args=args,
streaming=True
)
return helper.compact_generate_response(response)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationCompletedError:
raise ConversationCompletedError()
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")
raise InternalServerError()
class DraftWorkflowRunApi(Resource):
@setup_required
@ -326,6 +391,8 @@ api.add_resource(AdvancedChatDraftWorkflowRunApi, '/apps/<uuid:app_id>/advanced-
api.add_resource(DraftWorkflowRunApi, '/apps/<uuid:app_id>/workflows/draft/run')
api.add_resource(WorkflowTaskStopApi, '/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop')
api.add_resource(DraftWorkflowNodeRunApi, '/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run')
api.add_resource(AdvancedChatDraftRunIterationNodeApi, '/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run')
api.add_resource(WorkflowDraftRunIterationNodeApi, '/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run')
api.add_resource(PublishedWorkflowApi, '/apps/<uuid:app_id>/workflows/publish')
api.add_resource(DefaultBlockConfigsApi, '/apps/<uuid:app_id>/workflows/default-workflow-block-configs')
api.add_resource(DefaultBlockConfigApi, '/apps/<uuid:app_id>/workflows/default-workflow-block-configs'

View File

@ -9,8 +9,13 @@ from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import alphanumeric, uuid_value
from libs.login import login_required
from services.tools_manage_service import ToolManageService
from services.tools.api_tools_manage_service import ApiToolManageService
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from services.tools.tool_labels_service import ToolLabelsService
from services.tools.tools_manage_service import ToolCommonService
from services.tools.workflow_tools_manage_service import WorkflowToolManageService
class ToolProviderListApi(Resource):
@ -21,7 +26,11 @@ class ToolProviderListApi(Resource):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
return ToolManageService.list_tool_providers(user_id, tenant_id)
req = reqparse.RequestParser()
req.add_argument('type', type=str, choices=['builtin', 'model', 'api', 'workflow'], required=False, nullable=True, location='args')
args = req.parse_args()
return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get('type', None))
class ToolBuiltinProviderListToolsApi(Resource):
@setup_required
@ -31,7 +40,7 @@ class ToolBuiltinProviderListToolsApi(Resource):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
return jsonable_encoder(ToolManageService.list_builtin_tool_provider_tools(
return jsonable_encoder(BuiltinToolManageService.list_builtin_tool_provider_tools(
user_id,
tenant_id,
provider,
@ -48,7 +57,7 @@ class ToolBuiltinProviderDeleteApi(Resource):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
return ToolManageService.delete_builtin_tool_provider(
return BuiltinToolManageService.delete_builtin_tool_provider(
user_id,
tenant_id,
provider,
@ -70,7 +79,7 @@ class ToolBuiltinProviderUpdateApi(Resource):
args = parser.parse_args()
return ToolManageService.update_builtin_tool_provider(
return BuiltinToolManageService.update_builtin_tool_provider(
user_id,
tenant_id,
provider,
@ -85,7 +94,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
return ToolManageService.get_builtin_tool_provider_credentials(
return BuiltinToolManageService.get_builtin_tool_provider_credentials(
user_id,
tenant_id,
provider,
@ -94,7 +103,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
class ToolBuiltinProviderIconApi(Resource):
@setup_required
def get(self, provider):
icon_bytes, mimetype = ToolManageService.get_builtin_tool_provider_icon(provider)
icon_bytes, mimetype = BuiltinToolManageService.get_builtin_tool_provider_icon(provider)
icon_cache_max_age = int(current_app.config.get('TOOL_ICON_CACHE_MAX_AGE'))
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
@ -116,11 +125,12 @@ class ToolApiProviderAddApi(Resource):
parser.add_argument('provider', type=str, required=True, nullable=False, location='json')
parser.add_argument('icon', type=dict, required=True, nullable=False, location='json')
parser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json')
parser.add_argument('labels', type=list[str], required=False, nullable=True, location='json', default=[])
parser.add_argument('custom_disclaimer', type=str, required=False, nullable=True, location='json')
args = parser.parse_args()
return ToolManageService.create_api_tool_provider(
return ApiToolManageService.create_api_tool_provider(
user_id,
tenant_id,
args['provider'],
@ -130,6 +140,7 @@ class ToolApiProviderAddApi(Resource):
args['schema'],
args.get('privacy_policy', ''),
args.get('custom_disclaimer', ''),
args.get('labels', []),
)
class ToolApiProviderGetRemoteSchemaApi(Resource):
@ -143,7 +154,7 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
args = parser.parse_args()
return ToolManageService.get_api_tool_provider_remote_schema(
return ApiToolManageService.get_api_tool_provider_remote_schema(
current_user.id,
current_user.current_tenant_id,
args['url'],
@ -163,7 +174,7 @@ class ToolApiProviderListToolsApi(Resource):
args = parser.parse_args()
return jsonable_encoder(ToolManageService.list_api_tool_provider_tools(
return jsonable_encoder(ApiToolManageService.list_api_tool_provider_tools(
user_id,
tenant_id,
args['provider'],
@ -188,11 +199,12 @@ class ToolApiProviderUpdateApi(Resource):
parser.add_argument('original_provider', type=str, required=True, nullable=False, location='json')
parser.add_argument('icon', type=dict, required=True, nullable=False, location='json')
parser.add_argument('privacy_policy', type=str, required=True, nullable=True, location='json')
parser.add_argument('labels', type=list[str], required=False, nullable=True, location='json')
parser.add_argument('custom_disclaimer', type=str, required=True, nullable=True, location='json')
args = parser.parse_args()
return ToolManageService.update_api_tool_provider(
return ApiToolManageService.update_api_tool_provider(
user_id,
tenant_id,
args['provider'],
@ -203,6 +215,7 @@ class ToolApiProviderUpdateApi(Resource):
args['schema'],
args['privacy_policy'],
args['custom_disclaimer'],
args.get('labels', []),
)
class ToolApiProviderDeleteApi(Resource):
@ -222,7 +235,7 @@ class ToolApiProviderDeleteApi(Resource):
args = parser.parse_args()
return ToolManageService.delete_api_tool_provider(
return ApiToolManageService.delete_api_tool_provider(
user_id,
tenant_id,
args['provider'],
@ -242,7 +255,7 @@ class ToolApiProviderGetApi(Resource):
args = parser.parse_args()
return ToolManageService.get_api_tool_provider(
return ApiToolManageService.get_api_tool_provider(
user_id,
tenant_id,
args['provider'],
@ -253,7 +266,7 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
@login_required
@account_initialization_required
def get(self, provider):
return ToolManageService.list_builtin_provider_credentials_schema(provider)
return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider)
class ToolApiProviderSchemaApi(Resource):
@setup_required
@ -266,7 +279,7 @@ class ToolApiProviderSchemaApi(Resource):
args = parser.parse_args()
return ToolManageService.parser_api_schema(
return ApiToolManageService.parser_api_schema(
schema=args['schema'],
)
@ -286,7 +299,7 @@ class ToolApiProviderPreviousTestApi(Resource):
args = parser.parse_args()
return ToolManageService.test_api_tool_preview(
return ApiToolManageService.test_api_tool_preview(
current_user.current_tenant_id,
args['provider_name'] if args['provider_name'] else '',
args['tool_name'],
@ -296,6 +309,153 @@ class ToolApiProviderPreviousTestApi(Resource):
args['schema'],
)
class ToolWorkflowProviderCreateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
if not current_user.is_admin_or_owner:
raise Forbidden()
user_id = current_user.id
tenant_id = current_user.current_tenant_id
reqparser = reqparse.RequestParser()
reqparser.add_argument('workflow_app_id', type=uuid_value, required=True, nullable=False, location='json')
reqparser.add_argument('name', type=alphanumeric, required=True, nullable=False, location='json')
reqparser.add_argument('label', type=str, required=True, nullable=False, location='json')
reqparser.add_argument('description', type=str, required=True, nullable=False, location='json')
reqparser.add_argument('icon', type=dict, required=True, nullable=False, location='json')
reqparser.add_argument('parameters', type=list[dict], required=True, nullable=False, location='json')
reqparser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json', default='')
reqparser.add_argument('labels', type=list[str], required=False, nullable=True, location='json')
args = reqparser.parse_args()
return WorkflowToolManageService.create_workflow_tool(
user_id,
tenant_id,
args['workflow_app_id'],
args['name'],
args['label'],
args['icon'],
args['description'],
args['parameters'],
args['privacy_policy'],
args.get('labels', []),
)
class ToolWorkflowProviderUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
if not current_user.is_admin_or_owner:
raise Forbidden()
user_id = current_user.id
tenant_id = current_user.current_tenant_id
reqparser = reqparse.RequestParser()
reqparser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='json')
reqparser.add_argument('name', type=alphanumeric, required=True, nullable=False, location='json')
reqparser.add_argument('label', type=str, required=True, nullable=False, location='json')
reqparser.add_argument('description', type=str, required=True, nullable=False, location='json')
reqparser.add_argument('icon', type=dict, required=True, nullable=False, location='json')
reqparser.add_argument('parameters', type=list[dict], required=True, nullable=False, location='json')
reqparser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json', default='')
reqparser.add_argument('labels', type=list[str], required=False, nullable=True, location='json')
args = reqparser.parse_args()
if not args['workflow_tool_id']:
raise ValueError('incorrect workflow_tool_id')
return WorkflowToolManageService.update_workflow_tool(
user_id,
tenant_id,
args['workflow_tool_id'],
args['name'],
args['label'],
args['icon'],
args['description'],
args['parameters'],
args['privacy_policy'],
args.get('labels', []),
)
class ToolWorkflowProviderDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
if not current_user.is_admin_or_owner:
raise Forbidden()
user_id = current_user.id
tenant_id = current_user.current_tenant_id
reqparser = reqparse.RequestParser()
reqparser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='json')
args = reqparser.parse_args()
return WorkflowToolManageService.delete_workflow_tool(
user_id,
tenant_id,
args['workflow_tool_id'],
)
class ToolWorkflowProviderGetApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument('workflow_tool_id', type=uuid_value, required=False, nullable=True, location='args')
parser.add_argument('workflow_app_id', type=uuid_value, required=False, nullable=True, location='args')
args = parser.parse_args()
if args.get('workflow_tool_id'):
tool = WorkflowToolManageService.get_workflow_tool_by_tool_id(
user_id,
tenant_id,
args['workflow_tool_id'],
)
elif args.get('workflow_app_id'):
tool = WorkflowToolManageService.get_workflow_tool_by_app_id(
user_id,
tenant_id,
args['workflow_app_id'],
)
else:
raise ValueError('incorrect workflow_tool_id or workflow_app_id')
return jsonable_encoder(tool)
class ToolWorkflowProviderListToolApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='args')
args = parser.parse_args()
return jsonable_encoder(WorkflowToolManageService.list_single_workflow_tools(
user_id,
tenant_id,
args['workflow_tool_id'],
))
class ToolBuiltinListApi(Resource):
@setup_required
@login_required
@ -304,7 +464,7 @@ class ToolBuiltinListApi(Resource):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
return jsonable_encoder([provider.to_dict() for provider in ToolManageService.list_builtin_tools(
return jsonable_encoder([provider.to_dict() for provider in BuiltinToolManageService.list_builtin_tools(
user_id,
tenant_id,
)])
@ -317,18 +477,43 @@ class ToolApiListApi(Resource):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
return jsonable_encoder([provider.to_dict() for provider in ToolManageService.list_api_tools(
return jsonable_encoder([provider.to_dict() for provider in ApiToolManageService.list_api_tools(
user_id,
tenant_id,
)])
class ToolWorkflowListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
return jsonable_encoder([provider.to_dict() for provider in WorkflowToolManageService.list_tenant_workflow_tools(
user_id,
tenant_id,
)])
class ToolLabelsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
return jsonable_encoder(ToolLabelsService.list_tool_labels())
# tool provider
api.add_resource(ToolProviderListApi, '/workspaces/current/tool-providers')
# builtin tool provider
api.add_resource(ToolBuiltinProviderListToolsApi, '/workspaces/current/tool-provider/builtin/<provider>/tools')
api.add_resource(ToolBuiltinProviderDeleteApi, '/workspaces/current/tool-provider/builtin/<provider>/delete')
api.add_resource(ToolBuiltinProviderUpdateApi, '/workspaces/current/tool-provider/builtin/<provider>/update')
api.add_resource(ToolBuiltinProviderGetCredentialsApi, '/workspaces/current/tool-provider/builtin/<provider>/credentials')
api.add_resource(ToolBuiltinProviderCredentialsSchemaApi, '/workspaces/current/tool-provider/builtin/<provider>/credentials_schema')
api.add_resource(ToolBuiltinProviderIconApi, '/workspaces/current/tool-provider/builtin/<provider>/icon')
# api tool provider
api.add_resource(ToolApiProviderAddApi, '/workspaces/current/tool-provider/api/add')
api.add_resource(ToolApiProviderGetRemoteSchemaApi, '/workspaces/current/tool-provider/api/remote')
api.add_resource(ToolApiProviderListToolsApi, '/workspaces/current/tool-provider/api/tools')
@ -338,5 +523,15 @@ api.add_resource(ToolApiProviderGetApi, '/workspaces/current/tool-provider/api/g
api.add_resource(ToolApiProviderSchemaApi, '/workspaces/current/tool-provider/api/schema')
api.add_resource(ToolApiProviderPreviousTestApi, '/workspaces/current/tool-provider/api/test/pre')
# workflow tool provider
api.add_resource(ToolWorkflowProviderCreateApi, '/workspaces/current/tool-provider/workflow/create')
api.add_resource(ToolWorkflowProviderUpdateApi, '/workspaces/current/tool-provider/workflow/update')
api.add_resource(ToolWorkflowProviderDeleteApi, '/workspaces/current/tool-provider/workflow/delete')
api.add_resource(ToolWorkflowProviderGetApi, '/workspaces/current/tool-provider/workflow/get')
api.add_resource(ToolWorkflowProviderListToolApi, '/workspaces/current/tool-provider/workflow/tools')
api.add_resource(ToolBuiltinListApi, '/workspaces/current/tools/builtin')
api.add_resource(ToolApiListApi, '/workspaces/current/tools/api')
api.add_resource(ToolWorkflowListApi, '/workspaces/current/tools/workflow')
api.add_resource(ToolLabelsApi, '/workspaces/current/tool-labels')

View File

@ -165,6 +165,7 @@ class BaseAgentRunner(AppRunner):
tenant_id=self.tenant_id,
app_id=self.app_config.app_id,
agent_tool=tool,
invoke_from=self.application_generate_entity.invoke_from
)
tool_entity.load_variables(self.variables_pool)

View File

@ -8,7 +8,7 @@ class AgentToolEntity(BaseModel):
"""
Agent Tool Entity.
"""
provider_type: Literal["builtin", "api"]
provider_type: Literal["builtin", "api", "workflow"]
provider_id: str
tool_name: str
tool_parameters: dict[str, Any] = {}

View File

@ -1,7 +1,7 @@
from typing import Optional
from core.agent.entities import AgentEntity, AgentPromptEntity, AgentToolEntity
from core.tools.prompt.template import REACT_PROMPT_TEMPLATES
from core.agent.prompt.template import REACT_PROMPT_TEMPLATES
class AgentConfigManager:

View File

@ -98,6 +98,90 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
extras=extras
)
return self._generate(
app_model=app_model,
workflow=workflow,
user=user,
invoke_from=invoke_from,
application_generate_entity=application_generate_entity,
conversation=conversation,
stream=stream
)
def single_iteration_generate(self, app_model: App,
workflow: Workflow,
node_id: str,
user: Account,
args: dict,
stream: bool = True) \
-> Union[dict, Generator[dict, None, None]]:
"""
Generate App response.
:param app_model: App
:param workflow: Workflow
:param user: account or end user
:param args: request args
:param invoke_from: invoke from source
:param stream: is stream
"""
if not node_id:
raise ValueError('node_id is required')
if args.get('inputs') is None:
raise ValueError('inputs is required')
extras = {
"auto_generate_conversation_name": False
}
# get conversation
conversation = None
if args.get('conversation_id'):
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user)
# convert to app config
app_config = AdvancedChatAppConfigManager.get_app_config(
app_model=app_model,
workflow=workflow
)
# init application generate entity
application_generate_entity = AdvancedChatAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
conversation_id=conversation.id if conversation else None,
inputs={},
query='',
files=[],
user_id=user.id,
stream=stream,
invoke_from=InvokeFrom.DEBUGGER,
extras=extras,
single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity(
node_id=node_id,
inputs=args['inputs']
)
)
return self._generate(
app_model=app_model,
workflow=workflow,
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
conversation=conversation,
stream=stream
)
def _generate(self, app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
application_generate_entity: AdvancedChatAppGenerateEntity,
conversation: Conversation = None,
stream: bool = True) \
-> Union[dict, Generator[dict, None, None]]:
is_first_conversation = False
if not conversation:
is_first_conversation = True
@ -167,6 +251,18 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
"""
with flask_app.app_context():
try:
runner = AdvancedChatAppRunner()
if application_generate_entity.single_iteration_run:
single_iteration_run = application_generate_entity.single_iteration_run
runner.single_iteration_run(
app_id=application_generate_entity.app_config.app_id,
workflow_id=application_generate_entity.app_config.workflow_id,
queue_manager=queue_manager,
inputs=single_iteration_run.inputs,
node_id=single_iteration_run.node_id,
user_id=application_generate_entity.user_id
)
else:
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)

View File

@ -102,6 +102,7 @@ class AdvancedChatAppRunner(AppRunner):
user_from=UserFrom.ACCOUNT
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
else UserFrom.END_USER,
invoke_from=application_generate_entity.invoke_from,
user_inputs=inputs,
system_inputs={
SystemVariable.QUERY: query,
@ -109,6 +110,35 @@ class AdvancedChatAppRunner(AppRunner):
SystemVariable.CONVERSATION_ID: conversation.id,
SystemVariable.USER_ID: user_id
},
callbacks=workflow_callbacks,
call_depth=application_generate_entity.call_depth
)
def single_iteration_run(self, app_id: str, workflow_id: str,
queue_manager: AppQueueManager,
inputs: dict, node_id: str, user_id: str) -> None:
"""
Single iteration run
"""
app_record: App = db.session.query(App).filter(App.id == app_id).first()
if not app_record:
raise ValueError("App not found")
workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id)
if not workflow:
raise ValueError("Workflow not initialized")
workflow_callbacks = [WorkflowEventTriggerCallback(
queue_manager=queue_manager,
workflow=workflow
)]
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.single_step_run_iteration_workflow_node(
workflow=workflow,
node_id=node_id,
user_id=user_id,
user_inputs=inputs,
callbacks=workflow_callbacks
)

View File

@ -12,6 +12,9 @@ from core.app.entities.queue_entities import (
QueueAdvancedChatMessageEndEvent,
QueueAnnotationReplyEvent,
QueueErrorEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueMessageReplaceEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
@ -64,6 +67,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
_workflow: Workflow
_user: Union[Account, EndUser]
_workflow_system_variables: dict[SystemVariable, Any]
_iteration_nested_relations: dict[str, list[str]]
def __init__(self, application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
@ -104,6 +108,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
)
self._stream_generate_routes = self._get_stream_generate_routes()
self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict)
self._conversation_name_generate_thread = None
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
@ -204,6 +209,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# search stream_generate_routes if node id is answer start at node
if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_routes:
self._task_state.current_stream_generate_state = self._stream_generate_routes[event.node_id]
# reset current route position to 0
self._task_state.current_stream_generate_state.current_route_position = 0
# generate stream outputs when node started
yield from self._generate_stream_outputs_when_node_started()
@ -225,6 +232,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
)
if isinstance(event, QueueNodeFailedEvent):
yield from self._handle_iteration_exception(
task_id=self._application_generate_entity.task_id,
error=f'Child node failed: {event.error}'
)
elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent):
if isinstance(event, QueueIterationNextEvent):
# clear ran node execution infos of current iteration
iteration_relations = self._iteration_nested_relations.get(event.node_id)
if iteration_relations:
for node_id in iteration_relations:
self._task_state.ran_node_execution_infos.pop(node_id, None)
yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event)
self._handle_iteration_operation(event)
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
workflow_run = self._handle_workflow_finished(event)
if workflow_run:
@ -263,10 +286,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._handle_retriever_resources(event)
elif isinstance(event, QueueAnnotationReplyEvent):
self._handle_annotation_reply(event)
# elif isinstance(event, QueueMessageFileEvent):
# response = self._message_file_to_stream_response(event)
# if response:
# yield response
elif isinstance(event, QueueTextChunkEvent):
delta_text = event.text
if delta_text is None:
@ -401,14 +420,23 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
continue
node_type = source_node.get('data', {}).get('type')
node_iteration_id = source_node.get('data', {}).get('iteration_id')
iteration_start_node_id = None
if node_iteration_id:
iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None)
iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id')
if node_type in [
NodeType.ANSWER.value,
NodeType.IF_ELSE.value,
NodeType.QUESTION_CLASSIFIER.value
NodeType.QUESTION_CLASSIFIER.value,
NodeType.ITERATION.value,
NodeType.LOOP.value
]:
start_node_id = target_node_id
start_node_ids.append(start_node_id)
elif node_type == NodeType.START.value:
elif node_type == NodeType.START.value or \
node_iteration_id is not None and iteration_start_node_id == source_node.get('id'):
start_node_id = source_node_id
start_node_ids.append(start_node_id)
else:
@ -418,6 +446,26 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
return start_node_ids
def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]:
"""
Get iteration nested relations.
:param graph: graph
:return:
"""
nodes = graph.get('nodes')
iteration_ids = [node.get('id') for node in nodes
if node.get('data', {}).get('type') in [
NodeType.ITERATION.value,
NodeType.LOOP.value,
]]
return {
iteration_id: [
node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id
] for iteration_id in iteration_ids
}
def _generate_stream_outputs_when_node_started(self) -> Generator:
"""
Generate stream outputs.
@ -425,7 +473,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
"""
if self._task_state.current_stream_generate_state:
route_chunks = self._task_state.current_stream_generate_state.generate_route[
self._task_state.current_stream_generate_state.current_route_position:]
self._task_state.current_stream_generate_state.current_route_position:
]
for route_chunk in route_chunks:
if route_chunk.type == 'text':

View File

@ -1,8 +1,11 @@
from typing import Optional
from typing import Any, Optional
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
@ -130,6 +133,66 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
), PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_started(self,
node_id: str,
node_type: NodeType,
node_run_index: int = 1,
node_data: Optional[BaseNodeData] = None,
inputs: dict = None,
predecessor_node_id: Optional[str] = None,
metadata: Optional[dict] = None) -> None:
"""
Publish iteration started
"""
self._queue_manager.publish(
QueueIterationStartEvent(
node_id=node_id,
node_type=node_type,
node_run_index=node_run_index,
node_data=node_data,
inputs=inputs,
predecessor_node_id=predecessor_node_id,
metadata=metadata
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_next(self, node_id: str,
node_type: NodeType,
index: int,
node_run_index: int,
output: Optional[Any]) -> None:
"""
Publish iteration next
"""
self._queue_manager._publish(
QueueIterationNextEvent(
node_id=node_id,
node_type=node_type,
index=index,
node_run_index=node_run_index,
output=output
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_completed(self, node_id: str,
node_type: NodeType,
node_run_index: int,
outputs: dict) -> None:
"""
Publish iteration completed
"""
self._queue_manager._publish(
QueueIterationCompletedEvent(
node_id=node_id,
node_type=node_type,
node_run_index=node_run_index,
outputs=outputs
),
PublishFrom.APPLICATION_MANAGER
)
def on_event(self, event: AppQueueEvent) -> None:
"""
Publish event

View File

@ -115,7 +115,8 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
user_id=user.id,
stream=stream,
invoke_from=invoke_from,
extras=extras
extras=extras,
call_depth=0
)
# init generate records

View File

@ -34,7 +34,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: bool = True) \
stream: bool = True,
call_depth: int = 0) \
-> Union[dict, Generator[dict, None, None]]:
"""
Generate App response.
@ -75,9 +76,38 @@ class WorkflowAppGenerator(BaseAppGenerator):
files=file_objs,
user_id=user.id,
stream=stream,
invoke_from=invoke_from
invoke_from=invoke_from,
call_depth=call_depth
)
return self._generate(
app_model=app_model,
workflow=workflow,
user=user,
application_generate_entity=application_generate_entity,
invoke_from=invoke_from,
stream=stream,
call_depth=call_depth
)
def _generate(self, app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
application_generate_entity: WorkflowAppGenerateEntity,
invoke_from: InvokeFrom,
stream: bool = True,
call_depth: int = 0) \
-> Union[dict, Generator[dict, None, None]]:
"""
Generate App response.
:param app_model: App
:param workflow: Workflow
:param user: account or end user
:param application_generate_entity: application generate entity
:param invoke_from: invoke from source
:param stream: is stream
"""
# init queue manager
queue_manager = WorkflowAppQueueManager(
task_id=application_generate_entity.task_id,
@ -109,6 +139,64 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from=invoke_from
)
def single_iteration_generate(self, app_model: App,
workflow: Workflow,
node_id: str,
user: Account,
args: dict,
stream: bool = True) \
-> Union[dict, Generator[dict, None, None]]:
"""
Generate App response.
:param app_model: App
:param workflow: Workflow
:param user: account or end user
:param args: request args
:param invoke_from: invoke from source
:param stream: is stream
"""
if not node_id:
raise ValueError('node_id is required')
if args.get('inputs') is None:
raise ValueError('inputs is required')
extras = {
"auto_generate_conversation_name": False
}
# convert to app config
app_config = WorkflowAppConfigManager.get_app_config(
app_model=app_model,
workflow=workflow
)
# init application generate entity
application_generate_entity = WorkflowAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
inputs={},
files=[],
user_id=user.id,
stream=stream,
invoke_from=InvokeFrom.DEBUGGER,
extras=extras,
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
node_id=node_id,
inputs=args['inputs']
)
)
return self._generate(
app_model=app_model,
workflow=workflow,
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
stream=stream
)
def _generate_worker(self, flask_app: Flask,
application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager) -> None:
@ -123,6 +211,17 @@ class WorkflowAppGenerator(BaseAppGenerator):
try:
# workflow app
runner = WorkflowAppRunner()
if application_generate_entity.single_iteration_run:
single_iteration_run = application_generate_entity.single_iteration_run
runner.single_iteration_run(
app_id=application_generate_entity.app_config.app_id,
workflow_id=application_generate_entity.app_config.workflow_id,
queue_manager=queue_manager,
inputs=single_iteration_run.inputs,
node_id=single_iteration_run.node_id,
user_id=application_generate_entity.user_id
)
else:
runner.run(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager

View File

@ -73,11 +73,44 @@ class WorkflowAppRunner:
user_from=UserFrom.ACCOUNT
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
else UserFrom.END_USER,
invoke_from=application_generate_entity.invoke_from,
user_inputs=inputs,
system_inputs={
SystemVariable.FILES: files,
SystemVariable.USER_ID: user_id
},
callbacks=workflow_callbacks,
call_depth=application_generate_entity.call_depth
)
def single_iteration_run(self, app_id: str, workflow_id: str,
queue_manager: AppQueueManager,
inputs: dict, node_id: str, user_id: str) -> None:
"""
Single iteration run
"""
app_record: App = db.session.query(App).filter(App.id == app_id).first()
if not app_record:
raise ValueError("App not found")
if not app_record.workflow_id:
raise ValueError("Workflow not initialized")
workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id)
if not workflow:
raise ValueError("Workflow not initialized")
workflow_callbacks = [WorkflowEventTriggerCallback(
queue_manager=queue_manager,
workflow=workflow
)]
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.single_step_run_iteration_workflow_node(
workflow=workflow,
node_id=node_id,
user_id=user_id,
user_inputs=inputs,
callbacks=workflow_callbacks
)

View File

@ -9,6 +9,9 @@ from core.app.entities.app_invoke_entities import (
)
from core.app.entities.queue_entities import (
QueueErrorEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueMessageReplaceEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
@ -58,6 +61,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
_task_state: WorkflowTaskState
_application_generate_entity: WorkflowAppGenerateEntity
_workflow_system_variables: dict[SystemVariable, Any]
_iteration_nested_relations: dict[str, list[str]]
def __init__(self, application_generate_entity: WorkflowAppGenerateEntity,
workflow: Workflow,
@ -85,8 +89,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
SystemVariable.USER_ID: user_id
}
self._task_state = WorkflowTaskState()
self._task_state = WorkflowTaskState(
iteration_nested_node_ids=[]
)
self._stream_generate_nodes = self._get_stream_generate_nodes()
self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict)
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""
@ -191,6 +198,22 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
)
if isinstance(event, QueueNodeFailedEvent):
yield from self._handle_iteration_exception(
task_id=self._application_generate_entity.task_id,
error=f'Child node failed: {event.error}'
)
elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent):
if isinstance(event, QueueIterationNextEvent):
# clear ran node execution infos of current iteration
iteration_relations = self._iteration_nested_relations.get(event.node_id)
if iteration_relations:
for node_id in iteration_relations:
self._task_state.ran_node_execution_infos.pop(node_id, None)
yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event)
self._handle_iteration_operation(event)
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
workflow_run = self._handle_workflow_finished(event)
@ -331,13 +354,20 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
continue
node_type = source_node.get('data', {}).get('type')
node_iteration_id = source_node.get('data', {}).get('iteration_id')
iteration_start_node_id = None
if node_iteration_id:
iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None)
iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id')
if node_type in [
NodeType.IF_ELSE.value,
NodeType.QUESTION_CLASSIFIER.value
]:
start_node_id = target_node_id
start_node_ids.append(start_node_id)
elif node_type == NodeType.START.value:
elif node_type == NodeType.START.value or \
node_iteration_id is not None and iteration_start_node_id == source_node.get('id'):
start_node_id = source_node_id
start_node_ids.append(start_node_id)
else:
@ -411,3 +441,24 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
return False
return True
def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]:
"""
Get iteration nested relations.
:param graph: graph
:return:
"""
nodes = graph.get('nodes')
iteration_ids = [node.get('id') for node in nodes
if node.get('data', {}).get('type') in [
NodeType.ITERATION.value,
NodeType.LOOP.value,
]]
return {
iteration_id: [
node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id
] for iteration_id in iteration_ids
}

View File

@ -1,8 +1,11 @@
from typing import Optional
from typing import Any, Optional
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
@ -130,6 +133,66 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
), PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_started(self,
node_id: str,
node_type: NodeType,
node_run_index: int = 1,
node_data: Optional[BaseNodeData] = None,
inputs: dict = None,
predecessor_node_id: Optional[str] = None,
metadata: Optional[dict] = None) -> None:
"""
Publish iteration started
"""
self._queue_manager.publish(
QueueIterationStartEvent(
node_id=node_id,
node_type=node_type,
node_run_index=node_run_index,
node_data=node_data,
inputs=inputs,
predecessor_node_id=predecessor_node_id,
metadata=metadata
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_next(self, node_id: str,
node_type: NodeType,
index: int,
node_run_index: int,
output: Optional[Any]) -> None:
"""
Publish iteration next
"""
self._queue_manager.publish(
QueueIterationNextEvent(
node_id=node_id,
node_type=node_type,
index=index,
node_run_index=node_run_index,
output=output
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_completed(self, node_id: str,
node_type: NodeType,
node_run_index: int,
outputs: dict) -> None:
"""
Publish iteration completed
"""
self._queue_manager.publish(
QueueIterationCompletedEvent(
node_id=node_id,
node_type=node_type,
node_run_index=node_run_index,
outputs=outputs
),
PublishFrom.APPLICATION_MANAGER
)
def on_event(self, event: AppQueueEvent) -> None:
"""
Publish event

View File

@ -102,6 +102,39 @@ class WorkflowLoggingCallback(BaseWorkflowCallback):
self.print_text(text, color="pink", end="")
def on_workflow_iteration_started(self,
node_id: str,
node_type: NodeType,
node_run_index: int = 1,
node_data: Optional[BaseNodeData] = None,
inputs: dict = None,
predecessor_node_id: Optional[str] = None,
metadata: Optional[dict] = None) -> None:
"""
Publish iteration started
"""
self.print_text("\n[on_workflow_iteration_started]", color='blue')
self.print_text(f"Node ID: {node_id}", color='blue')
def on_workflow_iteration_next(self, node_id: str,
node_type: NodeType,
index: int,
node_run_index: int,
output: Optional[dict]) -> None:
"""
Publish iteration next
"""
self.print_text("\n[on_workflow_iteration_next]", color='blue')
def on_workflow_iteration_completed(self, node_id: str,
node_type: NodeType,
node_run_index: int,
outputs: dict) -> None:
"""
Publish iteration completed
"""
self.print_text("\n[on_workflow_iteration_completed]", color='blue')
def on_event(self, event: AppQueueEvent) -> None:
"""
Publish event

View File

@ -80,6 +80,9 @@ class AppGenerateEntity(BaseModel):
stream: bool
invoke_from: InvokeFrom
# invoke call depth
call_depth: int = 0
# extra parameters, like: auto_generate_conversation_name
extras: dict[str, Any] = {}
@ -126,6 +129,14 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity):
conversation_id: Optional[str] = None
query: Optional[str] = None
class SingleIterationRunEntity(BaseModel):
"""
Single Iteration Run Entity.
"""
node_id: str
inputs: dict
single_iteration_run: Optional[SingleIterationRunEntity] = None
class WorkflowAppGenerateEntity(AppGenerateEntity):
"""
@ -133,3 +144,12 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
"""
# app config
app_config: WorkflowUIBasedAppConfig
class SingleIterationRunEntity(BaseModel):
"""
Single Iteration Run Entity.
"""
node_id: str
inputs: dict
single_iteration_run: Optional[SingleIterationRunEntity] = None

View File

@ -1,7 +1,7 @@
from enum import Enum
from typing import Any, Optional
from pydantic import BaseModel
from pydantic import BaseModel, validator
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.workflow.entities.base_node_data_entities import BaseNodeData
@ -21,6 +21,9 @@ class QueueEvent(Enum):
WORKFLOW_STARTED = "workflow_started"
WORKFLOW_SUCCEEDED = "workflow_succeeded"
WORKFLOW_FAILED = "workflow_failed"
ITERATION_START = "iteration_start"
ITERATION_NEXT = "iteration_next"
ITERATION_COMPLETED = "iteration_completed"
NODE_STARTED = "node_started"
NODE_SUCCEEDED = "node_succeeded"
NODE_FAILED = "node_failed"
@ -47,6 +50,55 @@ class QueueLLMChunkEvent(AppQueueEvent):
event = QueueEvent.LLM_CHUNK
chunk: LLMResultChunk
class QueueIterationStartEvent(AppQueueEvent):
"""
QueueIterationStartEvent entity
"""
event = QueueEvent.ITERATION_START
node_id: str
node_type: NodeType
node_data: BaseNodeData
node_run_index: int
inputs: dict = None
predecessor_node_id: Optional[str] = None
metadata: Optional[dict] = None
class QueueIterationNextEvent(AppQueueEvent):
"""
QueueIterationNextEvent entity
"""
event = QueueEvent.ITERATION_NEXT
index: int
node_id: str
node_type: NodeType
node_run_index: int
output: Optional[Any] # output for the current iteration
@validator('output', pre=True, always=True)
def set_output(cls, v):
"""
Set output
"""
if v is None:
return None
if isinstance(v, int | float | str | bool | dict | list):
return v
raise ValueError('output must be a valid type')
class QueueIterationCompletedEvent(AppQueueEvent):
"""
QueueIterationCompletedEvent entity
"""
event = QueueEvent.ITERATION_COMPLETED
node_id: str
node_type: NodeType
node_run_index: int
outputs: dict
class QueueTextChunkEvent(AppQueueEvent):
"""

View File

@ -1,12 +1,14 @@
from enum import Enum
from typing import Optional
from typing import Any, Optional
from pydantic import BaseModel
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
from core.workflow.nodes.answer.entities import GenerateRouteChunk
from models.workflow import WorkflowNodeExecutionStatus
class WorkflowStreamGenerateNodes(BaseModel):
@ -65,6 +67,7 @@ class WorkflowTaskState(TaskState):
current_stream_generate_state: Optional[WorkflowStreamGenerateNodes] = None
iteration_nested_node_ids: list[str] = None
class AdvancedChatTaskState(WorkflowTaskState):
"""
@ -91,6 +94,9 @@ class StreamEvent(Enum):
WORKFLOW_FINISHED = "workflow_finished"
NODE_STARTED = "node_started"
NODE_FINISHED = "node_finished"
ITERATION_STARTED = "iteration_started"
ITERATION_NEXT = "iteration_next"
ITERATION_COMPLETED = "iteration_completed"
TEXT_CHUNK = "text_chunk"
TEXT_REPLACE = "text_replace"
@ -319,6 +325,74 @@ class NodeFinishStreamResponse(StreamResponse):
}
}
class IterationNodeStartStreamResponse(StreamResponse):
"""
NodeStartStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
id: str
node_id: str
node_type: str
title: str
created_at: int
extras: dict = {}
metadata: dict = {}
inputs: dict = {}
event: StreamEvent = StreamEvent.ITERATION_STARTED
workflow_run_id: str
data: Data
class IterationNodeNextStreamResponse(StreamResponse):
"""
NodeStartStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
id: str
node_id: str
node_type: str
title: str
index: int
created_at: int
pre_iteration_output: Optional[Any]
extras: dict = {}
event: StreamEvent = StreamEvent.ITERATION_NEXT
workflow_run_id: str
data: Data
class IterationNodeCompletedStreamResponse(StreamResponse):
"""
NodeStartStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
id: str
node_id: str
node_type: str
title: str
outputs: Optional[dict]
created_at: int
extras: dict = None
inputs: dict = None
status: WorkflowNodeExecutionStatus
error: Optional[str]
elapsed_time: float
total_tokens: int
finished_at: int
steps: int
event: StreamEvent = StreamEvent.ITERATION_COMPLETED
workflow_run_id: str
data: Data
class TextChunkStreamResponse(StreamResponse):
"""
@ -454,3 +528,23 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
workflow_run_id: str
data: Data
class WorkflowIterationState(BaseModel):
"""
WorkflowIterationState entity
"""
class Data(BaseModel):
"""
Data entity
"""
parent_iteration_id: Optional[str] = None
iteration_id: str
current_index: int
iteration_steps_boundary: list[int] = None
node_execution_id: str
started_at: float
inputs: dict = None
total_tokens: int = 0
node_data: BaseNodeData
current_iterations: dict[str, Data] = None

View File

@ -1,9 +1,9 @@
import json
import time
from datetime import datetime, timezone
from typing import Any, Optional, Union, cast
from typing import Optional, Union, cast
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
QueueNodeFailedEvent,
QueueNodeStartedEvent,
@ -13,18 +13,17 @@ from core.app.entities.queue_entities import (
QueueWorkflowSucceededEvent,
)
from core.app.entities.task_entities import (
AdvancedChatTaskState,
NodeExecutionInfo,
NodeFinishStreamResponse,
NodeStartStreamResponse,
WorkflowFinishStreamResponse,
WorkflowStartStreamResponse,
WorkflowTaskState,
)
from core.app.task_pipeline.workflow_iteration_cycle_manage import WorkflowIterationCycleManage
from core.file.file_obj import FileVar
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.tool_manager import ToolManager
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType, SystemVariable
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from extensions.ext_database import db
@ -42,13 +41,7 @@ from models.workflow import (
)
class WorkflowCycleManage:
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
_workflow: Workflow
_user: Union[Account, EndUser]
_task_state: Union[AdvancedChatTaskState, WorkflowTaskState]
_workflow_system_variables: dict[SystemVariable, Any]
class WorkflowCycleManage(WorkflowIterationCycleManage):
def _init_workflow_run(self, workflow: Workflow,
triggered_from: WorkflowRunTriggeredFrom,
user: Union[Account, EndUser],
@ -237,6 +230,7 @@ class WorkflowCycleManage:
inputs: Optional[dict] = None,
process_data: Optional[dict] = None,
outputs: Optional[dict] = None,
execution_metadata: Optional[dict] = None
) -> WorkflowNodeExecution:
"""
Workflow node execution failed
@ -255,6 +249,8 @@ class WorkflowCycleManage:
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \
if execution_metadata else None
db.session.commit()
db.session.refresh(workflow_node_execution)
@ -444,6 +440,23 @@ class WorkflowCycleManage:
current_node_execution = self._task_state.ran_node_execution_infos[event.node_id]
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first()
execution_metadata = event.execution_metadata if isinstance(event, QueueNodeSucceededEvent) else None
if self._iteration_state and self._iteration_state.current_iterations:
if not execution_metadata:
execution_metadata = {}
current_iteration_data = None
for iteration_node_id in self._iteration_state.current_iterations:
data = self._iteration_state.current_iterations[iteration_node_id]
if data.parent_iteration_id == None:
current_iteration_data = data
break
if current_iteration_data:
execution_metadata[NodeRunMetadataKey.ITERATION_ID] = current_iteration_data.iteration_id
execution_metadata[NodeRunMetadataKey.ITERATION_INDEX] = current_iteration_data.current_index
if isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._workflow_node_execution_success(
workflow_node_execution=workflow_node_execution,
@ -451,12 +464,18 @@ class WorkflowCycleManage:
inputs=event.inputs,
process_data=event.process_data,
outputs=event.outputs,
execution_metadata=event.execution_metadata
execution_metadata=execution_metadata
)
if event.execution_metadata and event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
if execution_metadata and execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
self._task_state.total_tokens += (
int(event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)))
int(execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)))
if self._iteration_state:
for iteration_node_id in self._iteration_state.current_iterations:
data = self._iteration_state.current_iterations[iteration_node_id]
if execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
data.total_tokens += int(execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))
if workflow_node_execution.node_type == NodeType.LLM.value:
outputs = workflow_node_execution.outputs_dict
@ -469,7 +488,8 @@ class WorkflowCycleManage:
error=event.error,
inputs=event.inputs,
process_data=event.process_data,
outputs=event.outputs
outputs=event.outputs,
execution_metadata=execution_metadata
)
db.session.close()

View File

@ -0,0 +1,16 @@
from typing import Any, Union
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState
from core.workflow.entities.node_entities import SystemVariable
from models.account import Account
from models.model import EndUser
from models.workflow import Workflow
class WorkflowCycleStateManager:
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
_workflow: Workflow
_user: Union[Account, EndUser]
_task_state: Union[AdvancedChatTaskState, WorkflowTaskState]
_workflow_system_variables: dict[SystemVariable, Any]

View File

@ -0,0 +1,281 @@
import json
import time
from collections.abc import Generator
from typing import Optional, Union
from core.app.entities.queue_entities import (
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
)
from core.app.entities.task_entities import (
IterationNodeCompletedStreamResponse,
IterationNodeNextStreamResponse,
IterationNodeStartStreamResponse,
NodeExecutionInfo,
WorkflowIterationState,
)
from core.app.task_pipeline.workflow_cycle_state_manager import WorkflowCycleStateManager
from core.workflow.entities.node_entities import NodeType
from extensions.ext_database import db
from models.workflow import (
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
WorkflowNodeExecutionTriggeredFrom,
WorkflowRun,
)
class WorkflowIterationCycleManage(WorkflowCycleStateManager):
_iteration_state: WorkflowIterationState = None
def _init_iteration_state(self) -> WorkflowIterationState:
if not self._iteration_state:
self._iteration_state = WorkflowIterationState(
current_iterations={}
)
def _handle_iteration_to_stream_response(self, task_id: str, event: QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent) \
-> Union[IterationNodeStartStreamResponse, IterationNodeNextStreamResponse, IterationNodeCompletedStreamResponse]:
"""
Handle iteration to stream response
:param task_id: task id
:param event: iteration event
:return:
"""
if isinstance(event, QueueIterationStartEvent):
return IterationNodeStartStreamResponse(
task_id=task_id,
workflow_run_id=self._task_state.workflow_run_id,
data=IterationNodeStartStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
created_at=int(time.time()),
extras={},
inputs=event.inputs,
metadata=event.metadata
)
)
elif isinstance(event, QueueIterationNextEvent):
current_iteration = self._iteration_state.current_iterations[event.node_id]
return IterationNodeNextStreamResponse(
task_id=task_id,
workflow_run_id=self._task_state.workflow_run_id,
data=IterationNodeNextStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=current_iteration.node_data.title,
index=event.index,
pre_iteration_output=event.output,
created_at=int(time.time()),
extras={}
)
)
elif isinstance(event, QueueIterationCompletedEvent):
current_iteration = self._iteration_state.current_iterations[event.node_id]
return IterationNodeCompletedStreamResponse(
task_id=task_id,
workflow_run_id=self._task_state.workflow_run_id,
data=IterationNodeCompletedStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=current_iteration.node_data.title,
outputs=event.outputs,
created_at=int(time.time()),
extras={},
inputs=current_iteration.inputs,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
error=None,
elapsed_time=time.perf_counter() - current_iteration.started_at,
total_tokens=current_iteration.total_tokens,
finished_at=int(time.time()),
steps=current_iteration.current_index
)
)
def _init_iteration_execution_from_workflow_run(self,
workflow_run: WorkflowRun,
node_id: str,
node_type: NodeType,
node_title: str,
node_run_index: int = 1,
inputs: Optional[dict] = None,
predecessor_node_id: Optional[str] = None
) -> WorkflowNodeExecution:
workflow_node_execution = WorkflowNodeExecution(
tenant_id=workflow_run.tenant_id,
app_id=workflow_run.app_id,
workflow_id=workflow_run.workflow_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
workflow_run_id=workflow_run.id,
predecessor_node_id=predecessor_node_id,
index=node_run_index,
node_id=node_id,
node_type=node_type.value,
inputs=json.dumps(inputs) if inputs else None,
title=node_title,
status=WorkflowNodeExecutionStatus.RUNNING.value,
created_by_role=workflow_run.created_by_role,
created_by=workflow_run.created_by,
execution_metadata=json.dumps({
'started_run_index': node_run_index + 1,
'current_index': 0,
'steps_boundary': [],
})
)
db.session.add(workflow_node_execution)
db.session.commit()
db.session.refresh(workflow_node_execution)
db.session.close()
return workflow_node_execution
def _handle_iteration_operation(self, event: QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent) -> WorkflowNodeExecution:
if isinstance(event, QueueIterationStartEvent):
return self._handle_iteration_started(event)
elif isinstance(event, QueueIterationNextEvent):
return self._handle_iteration_next(event)
elif isinstance(event, QueueIterationCompletedEvent):
return self._handle_iteration_completed(event)
def _handle_iteration_started(self, event: QueueIterationStartEvent) -> WorkflowNodeExecution:
self._init_iteration_state()
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first()
workflow_node_execution = self._init_iteration_execution_from_workflow_run(
workflow_run=workflow_run,
node_id=event.node_id,
node_type=NodeType.ITERATION,
node_title=event.node_data.title,
node_run_index=event.node_run_index,
inputs=event.inputs,
predecessor_node_id=event.predecessor_node_id
)
latest_node_execution_info = NodeExecutionInfo(
workflow_node_execution_id=workflow_node_execution.id,
node_type=NodeType.ITERATION,
start_at=time.perf_counter()
)
self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info
self._task_state.latest_node_execution_info = latest_node_execution_info
self._iteration_state.current_iterations[event.node_id] = WorkflowIterationState.Data(
parent_iteration_id=None,
iteration_id=event.node_id,
current_index=0,
iteration_steps_boundary=[],
node_execution_id=workflow_node_execution.id,
started_at=time.perf_counter(),
inputs=event.inputs,
total_tokens=0,
node_data=event.node_data
)
db.session.close()
return workflow_node_execution
def _handle_iteration_next(self, event: QueueIterationNextEvent) -> WorkflowNodeExecution:
if event.node_id not in self._iteration_state.current_iterations:
return
current_iteration = self._iteration_state.current_iterations[event.node_id]
current_iteration.current_index = event.index
current_iteration.iteration_steps_boundary.append(event.node_run_index)
workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == current_iteration.node_execution_id
).first()
original_node_execution_metadata = workflow_node_execution.execution_metadata_dict
if original_node_execution_metadata:
original_node_execution_metadata['current_index'] = event.index
original_node_execution_metadata['steps_boundary'] = current_iteration.iteration_steps_boundary
original_node_execution_metadata['total_tokens'] = current_iteration.total_tokens
workflow_node_execution.execution_metadata = json.dumps(original_node_execution_metadata)
db.session.commit()
db.session.close()
def _handle_iteration_completed(self, event: QueueIterationCompletedEvent) -> WorkflowNodeExecution:
if event.node_id not in self._iteration_state.current_iterations:
return
current_iteration = self._iteration_state.current_iterations[event.node_id]
workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == current_iteration.node_execution_id
).first()
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
workflow_node_execution.outputs = json.dumps(event.outputs) if event.outputs else None
workflow_node_execution.elapsed_time = time.perf_counter() - current_iteration.started_at
original_node_execution_metadata = workflow_node_execution.execution_metadata_dict
if original_node_execution_metadata:
original_node_execution_metadata['steps_boundary'] = current_iteration.iteration_steps_boundary
original_node_execution_metadata['total_tokens'] = current_iteration.total_tokens
workflow_node_execution.execution_metadata = json.dumps(original_node_execution_metadata)
db.session.commit()
# remove current iteration
self._iteration_state.current_iterations.pop(event.node_id, None)
# set latest node execution info
latest_node_execution_info = NodeExecutionInfo(
workflow_node_execution_id=workflow_node_execution.id,
node_type=NodeType.ITERATION,
start_at=time.perf_counter()
)
self._task_state.latest_node_execution_info = latest_node_execution_info
db.session.close()
def _handle_iteration_exception(self, task_id: str, error: str) -> Generator[IterationNodeCompletedStreamResponse, None, None]:
"""
Handle iteration exception
"""
if not self._iteration_state or not self._iteration_state.current_iterations:
return
for node_id, current_iteration in self._iteration_state.current_iterations.items():
workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == current_iteration.node_execution_id
).first()
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error
workflow_node_execution.elapsed_time = time.perf_counter() - current_iteration.started_at
db.session.commit()
db.session.close()
yield IterationNodeCompletedStreamResponse(
task_id=task_id,
workflow_run_id=self._task_state.workflow_run_id,
data=IterationNodeCompletedStreamResponse.Data(
id=node_id,
node_id=node_id,
node_type=NodeType.ITERATION.value,
title=current_iteration.node_data.title,
outputs={},
created_at=int(time.time()),
extras={},
inputs=current_iteration.inputs,
status=WorkflowNodeExecutionStatus.FAILED,
error=error,
elapsed_time=time.perf_counter() - current_iteration.started_at,
total_tokens=current_iteration.total_tokens,
finished_at=int(time.time()),
steps=current_iteration.current_index
)
)

View File

@ -42,6 +42,8 @@ class MessageFileParser:
raise ValueError('Invalid file url')
if file.get('transfer_method') == FileTransferMethod.LOCAL_FILE.value and not file.get('upload_file_id'):
raise ValueError('Missing file upload_file_id')
if file.get('transform_method') == FileTransferMethod.TOOL_FILE.value and not file.get('tool_file_id'):
raise ValueError('Missing file tool_file_id')
# transform files to file objs
type_file_objs = self._to_file_objs(files, file_extra_config)
@ -149,6 +151,7 @@ class MessageFileParser:
"""
if isinstance(file, dict):
transfer_method = FileTransferMethod.value_of(file.get('transfer_method'))
if transfer_method != FileTransferMethod.TOOL_FILE:
return FileVar(
tenant_id=self.tenant_id,
type=FileType.value_of(file.get('type')),
@ -157,6 +160,14 @@ class MessageFileParser:
related_id=file.get('upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None,
extra_config=file_extra_config
)
return FileVar(
tenant_id=self.tenant_id,
type=FileType.value_of(file.get('type')),
transfer_method=transfer_method,
url=None,
related_id=file.get('tool_file_id'),
extra_config=file_extra_config
)
else:
return FileVar(
id=file.id,

View File

@ -1,6 +1,7 @@
from typing import cast
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentType,
@ -21,13 +22,25 @@ class PromptMessageUtil:
"""
prompts = []
if model_mode == ModelMode.CHAT.value:
tool_calls = []
for prompt_message in prompt_messages:
if prompt_message.role == PromptMessageRole.USER:
role = 'user'
elif prompt_message.role == PromptMessageRole.ASSISTANT:
role = 'assistant'
if isinstance(prompt_message, AssistantPromptMessage):
tool_calls = [{
'id': tool_call.id,
'type': 'function',
'function': {
'name': tool_call.function.name,
'arguments': tool_call.function.arguments,
}
} for tool_call in prompt_message.tool_calls]
elif prompt_message.role == PromptMessageRole.SYSTEM:
role = 'system'
elif prompt_message.role == PromptMessageRole.TOOL:
role = 'tool'
else:
continue
@ -48,11 +61,16 @@ class PromptMessageUtil:
else:
text = prompt_message.content
prompts.append({
prompt = {
"role": role,
"text": text,
"files": files
})
}
if tool_calls:
prompt['tool_calls'] = tool_calls
prompts.append(prompt)
else:
prompt_message = prompt_messages[0]
text = ''

View File

@ -1,10 +1,10 @@
from enum import Enum
from typing import Optional
from typing import Literal, Optional
from pydantic import BaseModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderCredentials
from core.tools.entities.tool_entities import ToolProviderCredentials, ToolProviderType
from core.tools.tool.tool import ToolParameter
@ -14,27 +14,38 @@ class UserTool(BaseModel):
label: I18nObject # label
description: I18nObject
parameters: Optional[list[ToolParameter]]
labels: list[str] = None
UserToolProviderTypeLiteral = Optional[Literal[
'builtin', 'api', 'workflow'
]]
class UserToolProvider(BaseModel):
class ProviderType(Enum):
BUILTIN = "builtin"
APP = "app"
API = "api"
id: str
author: str
name: str # identifier
description: I18nObject
icon: str
label: I18nObject # label
type: ProviderType
type: ToolProviderType
masked_credentials: dict = None
original_credentials: dict = None
is_team_authorization: bool = False
allow_delete: bool = True
tools: list[UserTool] = None
labels: list[str] = None
def to_dict(self) -> dict:
# -------------
# overwrite tool parameter types for temp fix
tools = jsonable_encoder(self.tools)
for tool in tools:
if tool.get('parameters'):
for parameter in tool.get('parameters'):
if parameter.get('type') == ToolParameter.ToolParameterType.FILE.value:
parameter['type'] = 'files'
# -------------
return {
'id': self.id,
'author': self.author,
@ -46,7 +57,8 @@ class UserToolProvider(BaseModel):
'team_credentials': self.masked_credentials,
'is_team_authorization': self.is_team_authorization,
'allow_delete': self.allow_delete,
'tools': self.tools
'tools': tools,
'labels': self.labels,
}
class UserToolProviderCredentials(BaseModel):

View File

@ -1,3 +0,0 @@
class DEFAULT_PROVIDERS:
API_BASED = '__api_based'
APP_BASED = '__app_based'

View File

@ -1,11 +1,11 @@
from typing import Any, Optional
from typing import Optional
from pydantic import BaseModel
from core.tools.entities.tool_entities import ToolParameter, ToolProviderType
from core.tools.entities.tool_entities import ToolParameter
class ApiBasedToolBundle(BaseModel):
class ApiToolBundle(BaseModel):
"""
This class is used to store the schema information of an api based tool. such as the url, the method, the parameters, etc.
"""
@ -25,12 +25,3 @@ class ApiBasedToolBundle(BaseModel):
icon: Optional[str] = None
# openapi operation
openapi: dict
class AppToolBundle(BaseModel):
"""
This class is used to store the schema information of an tool for an app.
"""
type: ToolProviderType
credential: Optional[dict[str, Any]] = None
provider_id: str
tool_name: str

View File

@ -10,10 +10,11 @@ class ToolProviderType(Enum):
"""
Enum class for tool provider
"""
BUILT_IN = "built-in"
BUILT_IN = "builtin"
WORKFLOW = "workflow"
API = "api"
APP = "app"
DATASET_RETRIEVAL = "dataset-retrieval"
APP_BASED = "app-based"
API_BASED = "api-based"
@classmethod
def value_of(cls, value: str) -> 'ToolProviderType':
@ -77,6 +78,7 @@ class ToolInvokeMessage(BaseModel):
LINK = "link"
BLOB = "blob"
IMAGE_LINK = "image_link"
FILE_VAR = "file_var"
type: MessageType = MessageType.TEXT
"""
@ -90,6 +92,7 @@ class ToolInvokeMessageBinary(BaseModel):
mimetype: str = Field(..., description="The mimetype of the binary")
url: str = Field(..., description="The url of the binary")
save_as: str = ''
file_var: Optional[dict[str, Any]] = None
class ToolParameterOption(BaseModel):
value: str = Field(..., description="The value of the option")
@ -102,6 +105,7 @@ class ToolParameter(BaseModel):
BOOLEAN = "boolean"
SELECT = "select"
SECRET_INPUT = "secret-input"
FILE = "file"
class ToolParameterForm(Enum):
SCHEMA = "schema" # should be set while adding tool
@ -331,6 +335,15 @@ class ModelToolProviderConfiguration(BaseModel):
models: list[ModelToolConfiguration] = Field(..., description="The models of the model tool")
label: I18nObject = Field(..., description="The label of the model tool")
class WorkflowToolParameterConfiguration(BaseModel):
"""
Workflow tool configuration
"""
name: str = Field(..., description="The name of the parameter")
description: str = Field(..., description="The description of the parameter")
form: ToolParameter.ToolParameterForm = Field(..., description="The form of the parameter")
class ToolInvokeMeta(BaseModel):
"""
Tool invoke meta
@ -359,3 +372,18 @@ class ToolInvokeMeta(BaseModel):
'error': self.error,
'tool_config': self.tool_config,
}
class ToolLabel(BaseModel):
"""
Tool label
"""
name: str = Field(..., description="The name of the tool")
label: I18nObject = Field(..., description="The label of the tool")
icon: str = Field(..., description="The icon of the tool")
class ToolInvokeFrom(Enum):
"""
Enum class for tool invoke
"""
WORKFLOW = "workflow"
AGENT = "agent"

View File

@ -0,0 +1,96 @@
from enum import Enum
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolLabel
class ToolLabelEnum(Enum):
SEARCH = 'search'
IMAGE = 'image'
VIDEOS = 'videos'
WEATHER = 'weather'
FINANCE = 'finance'
DESIGN = 'design'
TRAVEL = 'travel'
SOCIAL = 'social'
NEWS = 'news'
MEDICAL = 'medical'
PRODUCTIVITY = 'productivity'
EDUCATION = 'education'
BUSINESS = 'business'
ENTERTAINMENT = 'entertainment'
UTILITIES = 'utilities'
OTHER = 'other'
ICONS = {
ToolLabelEnum.SEARCH: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M7.33398 1.3335C10.646 1.3335 13.334 4.0215 13.334 7.3335C13.334 10.6455 10.646 13.3335 7.33398 13.3335C4.02198 13.3335 1.33398 10.6455 1.33398 7.3335C1.33398 4.0215 4.02198 1.3335 7.33398 1.3335ZM7.33398 12.0002C9.91232 12.0002 12.0007 9.91183 12.0007 7.3335C12.0007 4.75516 9.91232 2.66683 7.33398 2.66683C4.75565 2.66683 2.66732 4.75516 2.66732 7.3335C2.66732 9.91183 4.75565 12.0002 7.33398 12.0002ZM12.9909 12.0476L14.8764 13.9332L13.9337 14.876L12.0481 12.9904L12.9909 12.0476Z" fill="#344054"/>
</svg>''',
ToolLabelEnum.IMAGE: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M13.0514 9.71752L10.4718 7.13792C10.2115 6.87752 9.78932 6.87752 9.52898 7.13792L4.57721 12.0897C3.4097 11.1113 2.66732 9.64232 2.66732 7.99992C2.66732 5.0544 5.05513 2.66659 8.00065 2.66659C10.9462 2.66659 13.334 5.0544 13.334 7.99992C13.334 8.60085 13.2346 9.17852 13.0514 9.71752ZM5.72683 12.8257L10.0004 8.55212L12.4259 10.9777C11.4668 12.4001 9.84152 13.3331 8.00038 13.3331C7.18632 13.3331 6.41628 13.1511 5.72683 12.8257ZM8.00065 14.6666C11.6825 14.6666 14.6673 11.6818 14.6673 7.99992C14.6673 4.31802 11.6825 1.33325 8.00065 1.33325C4.31875 1.33325 1.33398 4.31802 1.33398 7.99992C1.33398 11.6818 4.31875 14.6666 8.00065 14.6666ZM7.33398 6.66658C7.33398 7.40299 6.73705 7.99992 6.00065 7.99992C5.26427 7.99992 4.66732 7.40299 4.66732 6.66658C4.66732 5.9302 5.26427 5.33325 6.00065 5.33325C6.73705 5.33325 7.33398 5.9302 7.33398 6.66658Z" fill="#344054"/>
</svg>''',
ToolLabelEnum.VIDEOS: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M8.00065 13.3333H13.334V14.6666H8.00065C4.31875 14.6666 1.33398 11.6818 1.33398 7.99992C1.33398 4.31802 4.31875 1.33325 8.00065 1.33325C11.6825 1.33325 14.6673 4.31802 14.6673 7.99992C14.6673 9.50072 14.1714 10.8857 13.3345 11.9999H11.5284C12.6356 11.0227 13.334 9.59285 13.334 7.99992C13.334 5.0544 10.9462 2.66659 8.00065 2.66659C5.05513 2.66659 2.66732 5.0544 2.66732 7.99992C2.66732 10.9455 5.05513 13.3333 8.00065 13.3333ZM8.00065 6.66658C7.26425 6.66658 6.66732 6.06963 6.66732 5.33325C6.66732 4.59687 7.26425 3.99992 8.00065 3.99992C8.73705 3.99992 9.33398 4.59687 9.33398 5.33325C9.33398 6.06963 8.73705 6.66658 8.00065 6.66658ZM5.33398 9.33325C4.5976 9.33325 4.00065 8.73632 4.00065 7.99992C4.00065 7.26352 4.5976 6.66658 5.33398 6.66658C6.07036 6.66658 6.66732 7.26352 6.66732 7.99992C6.66732 8.73632 6.07036 9.33325 5.33398 9.33325ZM10.6673 9.33325C9.93092 9.33325 9.33398 8.73632 9.33398 7.99992C9.33398 7.26352 9.93092 6.66658 10.6673 6.66658C11.4037 6.66658 12.0007 7.26352 12.0007 7.99992C12.0007 8.73632 11.4037 9.33325 10.6673 9.33325ZM8.00065 11.9999C7.26425 11.9999 6.66732 11.403 6.66732 10.6666C6.66732 9.93018 7.26425 9.33325 8.00065 9.33325C8.73705 9.33325 9.33398 9.93018 9.33398 10.6666C9.33398 11.403 8.73705 11.9999 8.00065 11.9999Z" fill="#344054"/>
</svg>''',
ToolLabelEnum.WEATHER: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M6.6553 3.37344C7.42088 2.1484 8.78162 1.3335 10.3327 1.3335C12.7259 1.3335 14.666 3.2736 14.666 5.66683C14.666 6.38704 14.4903 7.06623 14.1794 7.66383C14.8894 8.3325 15.3327 9.28123 15.3327 10.3335C15.3327 12.3586 13.6911 14.0002 11.666 14.0002H5.99935C3.05383 14.0002 0.666016 11.6124 0.666016 8.66683C0.666016 5.72131 3.05383 3.3335 5.99935 3.3335C6.22143 3.3335 6.44034 3.34707 6.6553 3.37344ZM8.03628 3.73629C9.37768 4.29108 10.4435 5.37735 10.9711 6.73256C11.1961 6.68943 11.4284 6.66683 11.666 6.66683C12.1561 6.66683 12.6237 6.76296 13.0511 6.93743C13.2317 6.55162 13.3327 6.12102 13.3327 5.66683C13.3327 4.00998 11.9895 2.66683 10.3327 2.66683C9.41115 2.66683 8.58662 3.08236 8.03628 3.73629ZM11.666 12.6668C12.9547 12.6668 13.9993 11.6222 13.9993 10.3335C13.9993 9.04483 12.9547 8.00016 11.666 8.00016C11.013 8.00016 10.4227 8.26836 9.99922 8.70063C9.99928 8.68936 9.99935 8.6781 9.99935 8.66683C9.99935 6.45769 8.20848 4.66683 5.99935 4.66683C3.79021 4.66683 1.99935 6.45769 1.99935 8.66683C1.99935 10.876 3.79021 12.6668 5.99935 12.6668H11.666Z" fill="#344054"/>
</svg>''',
ToolLabelEnum.FINANCE: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M8.00262 14.6685C4.32071 14.6685 1.33594 11.6838 1.33594 8.00184C1.33594 4.31997 4.32071 1.33521 8.00262 1.33521C11.6845 1.33521 14.6693 4.31997 14.6693 8.00184C14.6693 11.6838 11.6845 14.6685 8.00262 14.6685ZM8.00262 13.3352C10.9482 13.3352 13.336 10.9474 13.336 8.00184C13.336 5.05635 10.9482 2.66854 8.00262 2.66854C5.05708 2.66854 2.66927 5.05635 2.66927 8.00184C2.66927 10.9474 5.05708 13.3352 8.00262 13.3352ZM5.66927 9.33517H9.33595C9.52002 9.33517 9.66928 9.18597 9.66928 9.00184C9.66928 8.81777 9.52002 8.66851 9.33595 8.66851H6.66928C5.7488 8.66851 5.0026 7.92237 5.0026 7.00184C5.0026 6.08139 5.7488 5.33521 6.66928 5.33521H7.33595V4.00187H8.66928V5.33521H10.336V6.66851H6.66928C6.48518 6.66851 6.33594 6.81777 6.33594 7.00184C6.33594 7.18597 6.48518 7.33517 6.66928 7.33517H9.33595C10.2564 7.33517 11.0026 8.08137 11.0026 9.00184C11.0026 9.92237 10.2564 10.6685 9.33595 10.6685H8.66928V12.0018H7.33595V10.6685H5.66927V9.33517Z" fill="#344054"/>
</svg>''',
ToolLabelEnum.DESIGN: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M4.70152 9.41416L3.2873 10.8284L5.17292 12.714L12.7154 5.17154L10.8298 3.28592L9.41557 4.70013L10.3584 5.64295L9.41557 6.58575L8.47277 5.64295L7.52997 6.58575L8.47277 7.52856L7.52997 8.47136L6.58713 7.52856L5.64433 8.47136L6.58713 9.41416L5.64433 10.357L4.70152 9.41416ZM11.3012 1.87171L14.1296 4.70013C14.39 4.96049 14.39 5.38259 14.1296 5.64295L5.64433 14.1282C5.38397 14.3886 4.96187 14.3886 4.70152 14.1282L1.87309 11.2998C1.61274 11.0394 1.61274 10.6174 1.87309 10.357L10.3584 1.87171C10.6187 1.61136 11.0408 1.61136 11.3012 1.87171ZM9.41557 12.2423L10.3584 11.2995L11.8534 12.7945H12.7962V11.8517L11.3012 10.3567L12.244 9.41383L14.0011 11.171V13.9999H11.1732L9.41557 12.2423ZM3.75861 6.58533L1.87299 4.69971C1.61265 4.43937 1.61265 4.01725 1.87299 3.75691L3.75861 1.87129C4.01896 1.61094 4.44107 1.61094 4.70142 1.87129L6.58704 3.75691L5.64423 4.69971L4.23002 3.2855L3.28721 4.22831L4.70142 5.64253L3.75861 6.58533Z" fill="#344054"/>
</svg>''',
ToolLabelEnum.TRAVEL: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M9.44839 2C9.80198 2 10.1411 2.14047 10.3912 2.39053L13.6101 5.60947C13.8602 5.85953 14.0007 6.19866 14.0007 6.55229V11.3333H15.334V12.6667L9.91652 12.6672C9.62032 13.8171 8.57638 14.6667 7.33398 14.6667C6.0916 14.6667 5.04766 13.8171 4.75146 12.6672L2.00065 12.6667C1.63246 12.6667 1.33398 12.3682 1.33398 12V3.33333C1.33398 2.59695 1.93094 2 2.66732 2H9.44839ZM7.33398 10.6667C6.5976 10.6667 6.00065 11.2636 6.00065 12C6.00065 12.7364 6.5976 13.3333 7.33398 13.3333C8.07038 13.3333 8.66732 12.7364 8.66732 12C8.66732 11.2636 8.07038 10.6667 7.33398 10.6667ZM9.44839 3.33333H2.66732V11.3333L4.75128 11.3335C5.04726 10.1833 6.09136 9.33333 7.33398 9.33333C8.57658 9.33333 9.62072 10.1833 9.91665 11.3335L12.6673 11.3333V6.55229L9.44839 3.33333ZM9.33398 4.66667V8.66667H4.00065V4.66667H9.33398ZM8.00065 6H5.33398V7.33333H8.00065V6Z" fill="#344054"/>
</svg>''',
ToolLabelEnum.SOCIAL: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M13.334 7.99992C13.334 5.0544 10.9462 2.66659 8.00065 2.66659C5.05513 2.66659 2.66732 5.0544 2.66732 7.99992C2.66732 10.9455 5.05513 13.3333 8.00065 13.3333C9.09518 13.3333 10.1127 13.0035 10.9594 12.438L11.699 13.5475C10.6408 14.2545 9.36885 14.6666 8.00065 14.6666C4.31875 14.6666 1.33398 11.6818 1.33398 7.99992C1.33398 4.31802 4.31875 1.33325 8.00065 1.33325C11.6825 1.33325 14.6673 4.31802 14.6673 7.99992V8.99992C14.6673 10.2886 13.6227 11.3333 12.334 11.3333C11.5312 11.3333 10.8231 10.9278 10.4032 10.3105C9.79678 10.9409 8.94452 11.3333 8.00065 11.3333C6.1597 11.3333 4.66732 9.84085 4.66732 7.99992C4.66732 6.15897 6.1597 4.66658 8.00065 4.66658C8.75118 4.66658 9.44378 4.91464 10.001 5.33325H11.334V8.99992C11.334 9.55219 11.7817 9.99992 12.334 9.99992C12.8863 9.99992 13.334 9.55219 13.334 8.99992V7.99992ZM8.00065 5.99992C6.89605 5.99992 6.00065 6.89532 6.00065 7.99992C6.00065 9.10452 6.89605 9.99992 8.00065 9.99992C9.10525 9.99992 10.0007 9.10452 10.0007 7.99992C10.0007 6.89532 9.10525 5.99992 8.00065 5.99992Z" fill="#344054"/>
</svg>''',
ToolLabelEnum.NEWS: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M10.6673 13.3335V2.66683H2.66732V12.6668C2.66732 13.035 2.9658 13.3335 3.33398 13.3335H10.6673ZM12.6673 14.6668H3.33398C2.22942 14.6668 1.33398 13.7714 1.33398 12.6668V2.00016C1.33398 1.63198 1.63246 1.3335 2.00065 1.3335H11.334C11.7022 1.3335 12.0007 1.63198 12.0007 2.00016V6.66683H14.6673V12.6668C14.6673 13.7714 13.7719 14.6668 12.6673 14.6668ZM12.0007 8.00016V12.6668C12.0007 13.035 12.2991 13.3335 12.6673 13.3335C13.0355 13.3335 13.334 13.035 13.334 12.6668V8.00016H12.0007ZM4.00065 4.00016H8.00065V8.00016H4.00065V4.00016ZM5.33398 5.3335V6.66683H6.66732V5.3335H5.33398ZM4.00065 8.66683H9.33398V10.0002H4.00065V8.66683ZM4.00065 10.6668H9.33398V12.0002H4.00065V10.6668Z" fill="#344054"/>
</svg>''',
ToolLabelEnum.MEDICAL: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M8.79747 1.51186L10.9641 5.26464C11.1482 5.5835 11.0389 5.99122 10.7201 6.17532L9.85373 6.67474L10.5207 7.83001L9.366 8.49668L8.699 7.34141L7.83333 7.84201C7.51447 8.02608 7.10673 7.91681 6.92267 7.59794L5.69747 5.47632C4.32922 5.89145 3.33333 7.16268 3.33333 8.66654C3.33333 9.08348 3.40987 9.48248 3.54965 9.85034C4.06613 9.52254 4.67762 9.33321 5.33333 9.33321C6.45605 9.33321 7.44913 9.88828 8.05313 10.7389L13.1787 7.78014L13.8454 8.93488L8.5932 11.9672C8.64133 12.1927 8.66667 12.4267 8.66667 12.6665C8.66667 12.895 8.64367 13.1181 8.59993 13.3337L14 13.3332V14.6665L2.66703 14.6673C2.2482 14.1101 2 13.4173 2 12.6665C2 11.9951 2.19855 11.3699 2.54014 10.8467C2.19517 10.1964 2 9.45428 2 8.66654C2 6.66968 3.25421 4.96575 5.01785 4.29953L4.75598 3.84519C4.38779 3.20747 4.60629 2.39202 5.24402 2.02382L6.97607 1.02382C7.6138 0.655637 8.42927 0.874138 8.79747 1.51186ZM5.33333 10.6665C4.22877 10.6665 3.33333 11.562 3.33333 12.6665C3.33333 12.9003 3.37343 13.1247 3.44711 13.3331H7.21953C7.29327 13.1247 7.33333 12.9003 7.33333 12.6665C7.33333 11.562 6.4379 10.6665 5.33333 10.6665ZM7.64273 2.17852L5.91068 3.17852L7.744 6.35395L9.47607 5.35395L7.64273 2.17852Z" fill="#344054"/>
</svg>''',
ToolLabelEnum.PRODUCTIVITY: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M6.64807 11.9999H9.35062C9.43862 11.1989 9.84742 10.5376 10.5111 9.81499C10.5858 9.73365 11.0652 9.23752 11.1221 9.16665C11.6872 8.46199 11.9993 7.58992 11.9993 6.66659C11.9993 4.45745 10.2085 2.66659 7.99935 2.66659C5.79021 2.66659 3.99935 4.45745 3.99935 6.66659C3.99935 7.58945 4.31118 8.46105 4.87576 9.16552C4.93271 9.23659 5.41322 9.73405 5.48704 9.81445C6.15112 10.5375 6.56004 11.1989 6.64807 11.9999ZM9.33268 13.3333H6.66602V13.9999H9.33268V13.3333ZM3.83532 9.99939C3.10365 9.08639 2.66602 7.92759 2.66602 6.66659C2.66602 3.72107 5.05383 1.33325 7.99935 1.33325C10.9449 1.33325 13.3327 3.72107 13.3327 6.66659C13.3327 7.92825 12.8945 9.08759 12.1622 10.0009C11.7487 10.5165 10.666 11.3333 10.666 12.3333V13.9999C10.666 14.7363 10.0691 15.3333 9.33268 15.3333H6.66602C5.92964 15.3333 5.33268 14.7363 5.33268 13.9999V12.3333C5.33268 11.3333 4.24907 10.5157 3.83532 9.99939ZM8.66602 6.66979H10.3327L7.33268 10.6698V8.00312H5.66602L8.66602 3.99992V6.66979Z" fill="#344054"/>
</svg>''',
ToolLabelEnum.EDUCATION: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M14 2.66683H4.66667C3.93029 2.66683 3.33333 3.26378 3.33333 4.00016C3.33333 4.73654 3.93029 5.3335 4.66667 5.3335H14V14.0002C14 14.3684 13.7015 14.6668 13.3333 14.6668H4.66667C3.19391 14.6668 2 13.4729 2 12.0002V4.00016C2 2.5274 3.19391 1.3335 4.66667 1.3335H13.3333C13.7015 1.3335 14 1.63198 14 2.00016V2.66683ZM3.33333 12.0002C3.33333 12.7366 3.93029 13.3335 4.66667 13.3335H12.6667V6.66683H4.66667C4.18095 6.66683 3.72557 6.53697 3.33333 6.31008V12.0002ZM13.3333 4.66683H4.66667C4.29848 4.66683 4 4.36835 4 4.00016C4 3.63198 4.29848 3.3335 4.66667 3.3335H13.3333V4.66683Z" fill="#344054"/>
</svg>''',
ToolLabelEnum.BUSINESS: '''<svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" viewBox="0 0 14 14" fill="none">
<path d="M3.66732 3.33341V1.33341C3.66732 0.965228 3.9658 0.666748 4.33398 0.666748H9.66732C10.0355 0.666748 10.334 0.965228 10.334 1.33341V3.33341H13.0007C13.3689 3.33341 13.6673 3.63189 13.6673 4.00008V13.3334C13.6673 13.7016 13.3689 14.0001 13.0007 14.0001H1.00065C0.632464 14.0001 0.333984 13.7016 0.333984 13.3334V4.00008C0.333984 3.63189 0.632464 3.33341 1.00065 3.33341H3.66732ZM12.334 8.66675H1.66732V12.6667H12.334V8.66675ZM12.334 4.66675H1.66732V7.33341H3.66732V6.00008H5.00065V7.33341H9.00065V6.00008H10.334V7.33341H12.334V4.66675ZM5.00065 2.00008V3.33341H9.00065V2.00008H5.00065Z" fill="#344054"/>
</svg>''',
ToolLabelEnum.ENTERTAINMENT: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M11.3327 2.66675C13.5418 2.66675 15.3327 4.45761 15.3327 6.66675V9.33342C15.3327 11.5425 13.5418 13.3334 11.3327 13.3334H4.66602C2.45688 13.3334 0.666016 11.5425 0.666016 9.33342V6.66675C0.666016 4.45761 2.45688 2.66675 4.66602 2.66675H11.3327ZM11.3327 4.00008H4.66602C3.23788 4.00008 2.07196 5.12273 2.00262 6.53365L1.99935 6.66675V9.33342C1.99935 10.7615 3.122 11.9275 4.53292 11.9968L4.66602 12.0001H11.3327C12.7608 12.0001 13.9267 10.8774 13.9961 9.46648L13.9993 9.33342V6.66675C13.9993 5.23861 12.8767 4.07269 11.4657 4.00335L11.3327 4.00008ZM6.66602 6.00008V7.33342H7.99935V8.66675H6.66535L6.66602 10.0001H5.33268L5.33202 8.66675H3.99935V7.33342H5.33268V6.00008H6.66602ZM11.9993 8.66675V10.0001H10.666V8.66675H11.9993ZM10.666 6.00008V7.33342H9.33268V6.00008H10.666Z" fill="#344054"/>
</svg>''',
ToolLabelEnum.UTILITIES: '''<svg xmlns="http://www.w3.org/2000/svg" width="13" height="15" viewBox="0 0 13 15" fill="none">
<path d="M12.3346 0.333252C12.7028 0.333252 13.0013 0.631732 13.0013 0.999919V4.33325C13.0013 4.70144 12.7028 4.99992 12.3346 4.99992H9.0013V13.6666C9.0013 14.0348 8.70284 14.3333 8.33463 14.3333H5.66797C5.29978 14.3333 5.0013 14.0348 5.0013 13.6666V4.99992H1.33464C0.966449 4.99992 0.667969 4.70144 0.667969 4.33325V2.74527C0.667969 2.49276 0.810635 2.26192 1.0365 2.14899L4.66797 0.333252H12.3346ZM9.0013 1.66659H4.98273L2.0013 3.1573V3.66659H6.33464V12.9999H7.66797V3.66659H9.0013V1.66659ZM11.668 1.66659H10.3346V3.66659H11.668V1.66659Z" fill="#344054"/>
</svg>''',
ToolLabelEnum.OTHER: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M8.00052 0.666748L4.00065 7.33342H12.0007L8.00052 0.666748ZM8.00052 3.25828L9.64572 6.00008H6.35553L8.00052 3.25828ZM4.50065 13.3334C3.48813 13.3334 2.66732 12.5126 2.66732 11.5001C2.66732 10.4875 3.48813 9.66675 4.50065 9.66675C5.51317 9.66675 6.33398 10.4875 6.33398 11.5001C6.33398 12.5126 5.51317 13.3334 4.50065 13.3334ZM4.50065 14.6667C6.24955 14.6667 7.66732 13.249 7.66732 11.5001C7.66732 9.75115 6.24955 8.33342 4.50065 8.33342C2.75175 8.33342 1.33398 9.75115 1.33398 11.5001C1.33398 13.249 2.75175 14.6667 4.50065 14.6667ZM10.0007 10.3334V13.0001H12.6673V10.3334H10.0007ZM8.66732 14.3334V9.00008H14.0007V14.3334H8.66732Z" fill="#344054"/>
</svg>'''
}
default_tool_label_dict = {
ToolLabelEnum.SEARCH: ToolLabel(name='search', label=I18nObject(en_US='Search', zh_Hans='搜索'), icon=ICONS[ToolLabelEnum.SEARCH]),
ToolLabelEnum.IMAGE: ToolLabel(name='image', label=I18nObject(en_US='Image', zh_Hans='图片'), icon=ICONS[ToolLabelEnum.IMAGE]),
ToolLabelEnum.VIDEOS: ToolLabel(name='videos', label=I18nObject(en_US='Videos', zh_Hans='视频'), icon=ICONS[ToolLabelEnum.VIDEOS]),
ToolLabelEnum.WEATHER: ToolLabel(name='weather', label=I18nObject(en_US='Weather', zh_Hans='天气'), icon=ICONS[ToolLabelEnum.WEATHER]),
ToolLabelEnum.FINANCE: ToolLabel(name='finance', label=I18nObject(en_US='Finance', zh_Hans='金融'), icon=ICONS[ToolLabelEnum.FINANCE]),
ToolLabelEnum.DESIGN: ToolLabel(name='design', label=I18nObject(en_US='Design', zh_Hans='设计'), icon=ICONS[ToolLabelEnum.DESIGN]),
ToolLabelEnum.TRAVEL: ToolLabel(name='travel', label=I18nObject(en_US='Travel', zh_Hans='旅行'), icon=ICONS[ToolLabelEnum.TRAVEL]),
ToolLabelEnum.SOCIAL: ToolLabel(name='social', label=I18nObject(en_US='Social', zh_Hans='社交'), icon=ICONS[ToolLabelEnum.SOCIAL]),
ToolLabelEnum.NEWS: ToolLabel(name='news', label=I18nObject(en_US='News', zh_Hans='新闻'), icon=ICONS[ToolLabelEnum.NEWS]),
ToolLabelEnum.MEDICAL: ToolLabel(name='medical', label=I18nObject(en_US='Medical', zh_Hans='医疗'), icon=ICONS[ToolLabelEnum.MEDICAL]),
ToolLabelEnum.PRODUCTIVITY: ToolLabel(name='productivity', label=I18nObject(en_US='Productivity', zh_Hans='生产力'), icon=ICONS[ToolLabelEnum.PRODUCTIVITY]),
ToolLabelEnum.EDUCATION: ToolLabel(name='education', label=I18nObject(en_US='Education', zh_Hans='教育'), icon=ICONS[ToolLabelEnum.EDUCATION]),
ToolLabelEnum.BUSINESS: ToolLabel(name='business', label=I18nObject(en_US='Business', zh_Hans='商业'), icon=ICONS[ToolLabelEnum.BUSINESS]),
ToolLabelEnum.ENTERTAINMENT: ToolLabel(name='entertainment', label=I18nObject(en_US='Entertainment', zh_Hans='娱乐'), icon=ICONS[ToolLabelEnum.ENTERTAINMENT]),
ToolLabelEnum.UTILITIES: ToolLabel(name='utilities', label=I18nObject(en_US='Utilities', zh_Hans='工具'), icon=ICONS[ToolLabelEnum.UTILITIES]),
ToolLabelEnum.OTHER: ToolLabel(name='other', label=I18nObject(en_US='Other', zh_Hans='其他'), icon=ICONS[ToolLabelEnum.OTHER]),
}
default_tool_labels = [v for k, v in default_tool_label_dict.items()]
default_tool_label_name_list = [label.name for label in default_tool_labels]

View File

@ -1,2 +0,0 @@
class InvokeModelError(Exception):
pass

View File

@ -1,7 +1,6 @@
from typing import Any
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiBasedToolBundle
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
ToolCredentialsOption,
@ -15,11 +14,11 @@ from extensions.ext_database import db
from models.tools import ApiToolProvider
class ApiBasedToolProviderController(ToolProviderController):
class ApiToolProviderController(ToolProviderController):
provider_id: str
@staticmethod
def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiBasedToolProviderController':
def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiToolProviderController':
credentials_schema = {
'auth_type': ToolProviderCredentials(
name='auth_type',
@ -79,9 +78,11 @@ class ApiBasedToolProviderController(ToolProviderController):
else:
raise ValueError(f'invalid auth type {auth_type}')
return ApiBasedToolProviderController(**{
user_name = db_provider.user.name if db_provider.user_id else ''
return ApiToolProviderController(**{
'identity': {
'author': db_provider.user.name if db_provider.user_id and db_provider.user else '',
'author': user_name,
'name': db_provider.name,
'label': {
'en_US': db_provider.name,
@ -98,16 +99,10 @@ class ApiBasedToolProviderController(ToolProviderController):
})
@property
def app_type(self) -> ToolProviderType:
return ToolProviderType.API_BASED
def provider_type(self) -> ToolProviderType:
return ToolProviderType.API
def _validate_credentials(self, tool_name: str, credentials: dict[str, Any]) -> None:
pass
def validate_parameters(self, tool_name: str, tool_parameters: dict[str, Any]) -> None:
pass
def _parse_tool_bundle(self, tool_bundle: ApiBasedToolBundle) -> ApiTool:
def _parse_tool_bundle(self, tool_bundle: ApiToolBundle) -> ApiTool:
"""
parse tool bundle to tool
@ -136,7 +131,7 @@ class ApiBasedToolProviderController(ToolProviderController):
'parameters' : tool_bundle.parameters if tool_bundle.parameters else [],
})
def load_bundled_tools(self, tools: list[ApiBasedToolBundle]) -> list[ApiTool]:
def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[ApiTool]:
"""
load bundled tools

View File

@ -11,10 +11,10 @@ from models.tools import PublishedAppTool
logger = logging.getLogger(__name__)
class AppBasedToolProviderEntity(ToolProviderController):
class AppToolProviderEntity(ToolProviderController):
@property
def app_type(self) -> ToolProviderType:
return ToolProviderType.APP_BASED
def provider_type(self) -> ToolProviderType:
return ToolProviderType.APP
def _validate_credentials(self, tool_name: str, credentials: dict[str, Any]) -> None:
pass

View File

@ -1,6 +1,6 @@
import os.path
from core.tools.entities.user_entities import UserToolProvider
from core.tools.entities.api_entities import UserToolProvider
from core.utils.position_helper import get_position_map, sort_by_position_map

View File

@ -1,3 +1,4 @@
from core.tools.entities.values import ToolLabelEnum
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.aippt.tools.aippt import AIPPTGenerateTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
@ -9,3 +10,9 @@ class AIPPTProvider(BuiltinToolProviderController):
AIPPTGenerateTool._get_api_token(credentials, user_id='__dify_system__')
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
def _get_tool_labels(self) -> list[ToolLabelEnum]:
return [
ToolLabelEnum.PRODUCTIVITY,
ToolLabelEnum.DESIGN,
]

View File

@ -1,3 +1,4 @@
from core.tools.entities.values import ToolLabelEnum
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.arxiv.tools.arxiv_search import ArxivSearchTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
@ -7,7 +8,7 @@ class ArxivProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
try:
ArxivSearchTool().fork_tool_runtime(
meta={
runtime={
"credentials": credentials,
}
).invoke(
@ -18,3 +19,8 @@ class ArxivProvider(BuiltinToolProviderController):
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
def _get_tool_labels(self) -> list[ToolLabelEnum]:
return [
ToolLabelEnum.SEARCH,
]

View File

@ -1,5 +1,6 @@
from typing import Any
from core.tools.entities.values import ToolLabelEnum
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.azuredalle.tools.dalle3 import DallE3Tool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
@ -9,7 +10,7 @@ class AzureDALLEProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
DallE3Tool().fork_tool_runtime(
meta={
runtime={
"credentials": credentials,
}
).invoke(
@ -22,3 +23,8 @@ class AzureDALLEProvider(BuiltinToolProviderController):
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
def _get_tool_labels(self) -> list[ToolLabelEnum]:
return [
ToolLabelEnum.IMAGE
]

View File

@ -1,5 +1,6 @@
from typing import Any
from core.tools.entities.values import ToolLabelEnum
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.bing.tools.bing_web_search import BingSearchTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
@ -9,7 +10,7 @@ class BingProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
BingSearchTool().fork_tool_runtime(
meta={
runtime={
"credentials": credentials,
}
).validate_credentials(
@ -21,3 +22,8 @@ class BingProvider(BuiltinToolProviderController):
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
def _get_tool_labels(self) -> list[ToolLabelEnum]:
return [
ToolLabelEnum.SEARCH
]

View File

@ -1,5 +1,6 @@
from typing import Any
from core.tools.entities.values import ToolLabelEnum
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.brave.tools.brave_search import BraveSearchTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
@ -9,7 +10,7 @@ class BraveProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
BraveSearchTool().fork_tool_runtime(
meta={
runtime={
"credentials": credentials,
}
).invoke(
@ -20,3 +21,8 @@ class BraveProvider(BuiltinToolProviderController):
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
def _get_tool_labels(self) -> list[ToolLabelEnum]:
return [
ToolLabelEnum.SEARCH,
]

View File

@ -2,6 +2,7 @@ import matplotlib.pyplot as plt
from fontTools.ttLib import TTFont
from matplotlib.font_manager import findSystemFonts
from core.tools.entities.values import ToolLabelEnum
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.chart.tools.line import LinearChartTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
@ -44,7 +45,7 @@ class ChartProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
try:
LinearChartTool().fork_tool_runtime(
meta={
runtime={
"credentials": credentials,
}
).invoke(
@ -55,3 +56,8 @@ class ChartProvider(BuiltinToolProviderController):
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
def _get_tool_labels(self) -> list[ToolLabelEnum]:
return [
ToolLabelEnum.DESIGN, ToolLabelEnum.PRODUCTIVITY, ToolLabelEnum.UTILITIES
]

View File

@ -1,8 +1,14 @@
from typing import Any
from core.tools.entities.values import ToolLabelEnum
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
class CodeToolProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
pass
def _get_tool_labels(self) -> list[ToolLabelEnum]:
return [
ToolLabelEnum.PRODUCTIVITY
]

View File

@ -1,5 +1,6 @@
from typing import Any
from core.tools.entities.values import ToolLabelEnum
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.dalle.tools.dalle2 import DallE2Tool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
@ -9,7 +10,7 @@ class DALLEProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
DallE2Tool().fork_tool_runtime(
meta={
runtime={
"credentials": credentials,
}
).invoke(
@ -22,3 +23,8 @@ class DALLEProvider(BuiltinToolProviderController):
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
def _get_tool_labels(self) -> list[ToolLabelEnum]:
return [
ToolLabelEnum.IMAGE, ToolLabelEnum.PRODUCTIVITY
]

View File

@ -1,3 +1,4 @@
from core.tools.entities.values import ToolLabelEnum
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.devdocs.tools.searchDevDocs import SearchDevDocsTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
@ -7,7 +8,7 @@ class DevDocsProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
try:
SearchDevDocsTool().fork_tool_runtime(
meta={
runtime={
"credentials": credentials,
}
).invoke(
@ -19,3 +20,8 @@ class DevDocsProvider(BuiltinToolProviderController):
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
def _get_tool_labels(self) -> list[ToolLabelEnum]:
return [
ToolLabelEnum.SEARCH, ToolLabelEnum.PRODUCTIVITY
]

View File

@ -1,3 +1,4 @@
from core.tools.entities.values import ToolLabelEnum
from core.tools.provider.builtin.dingtalk.tools.dingtalk_group_bot import DingTalkGroupBotTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
@ -6,3 +7,8 @@ class DingTalkProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
DingTalkGroupBotTool()
pass
def _get_tool_labels(self) -> list[ToolLabelEnum]:
return [
ToolLabelEnum.SOCIAL
]

View File

@ -1,3 +1,4 @@
from core.tools.entities.values import ToolLabelEnum
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.duckduckgo.tools.duckduckgo_search import DuckDuckGoSearchTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
@ -7,7 +8,7 @@ class DuckDuckGoProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
try:
DuckDuckGoSearchTool().fork_tool_runtime(
meta={
runtime={
"credentials": credentials,
}
).invoke(
@ -18,3 +19,8 @@ class DuckDuckGoProvider(BuiltinToolProviderController):
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
def _get_tool_labels(self) -> list[ToolLabelEnum]:
return [
ToolLabelEnum.SEARCH
]

View File

@ -1,3 +1,4 @@
from core.tools.entities.values import ToolLabelEnum
from core.tools.provider.builtin.feishu.tools.feishu_group_bot import FeishuGroupBotTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
@ -6,3 +7,8 @@ class FeishuProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
FeishuGroupBotTool()
pass
def _get_tool_labels(self) -> list[ToolLabelEnum]:
return [
ToolLabelEnum.SOCIAL
]

View File

@ -1,3 +1,4 @@
from core.tools.entities.values import ToolLabelEnum
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.firecrawl.tools.crawl import CrawlTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
@ -8,7 +9,7 @@ class FirecrawlProvider(BuiltinToolProviderController):
try:
# Example validation using the Crawl tool
CrawlTool().fork_tool_runtime(
meta={"credentials": credentials}
runtime={"credentials": credentials}
).invoke(
user_id='',
tool_parameters={
@ -21,3 +22,8 @@ class FirecrawlProvider(BuiltinToolProviderController):
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
def _get_tool_labels(self) -> list[ToolLabelEnum]:
return [
ToolLabelEnum.SEARCH, ToolLabelEnum.UTILITIES
]

View File

@ -2,6 +2,7 @@ import urllib.parse
import requests
from core.tools.entities.values import ToolLabelEnum
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
@ -24,3 +25,9 @@ class GaodeProvider(BuiltinToolProviderController):
raise ToolProviderCredentialValidationError("Gaode API Key is invalid. {}".format(e))
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
def _get_tool_labels(self) -> list[ToolLabelEnum]:
return [
ToolLabelEnum.UTILITIES, ToolLabelEnum.PRODUCTIVITY,
ToolLabelEnum.WEATHER, ToolLabelEnum.TRAVEL
]

View File

@ -1,5 +1,6 @@
import requests
from core.tools.entities.values import ToolLabelEnum
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
@ -30,3 +31,8 @@ class GihubProvider(BuiltinToolProviderController):
raise ToolProviderCredentialValidationError("Github API Key and Api Version is invalid. {}".format(e))
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
def _get_tool_labels(self) -> list[ToolLabelEnum]:
return [
ToolLabelEnum.UTILITIES
]

View File

@ -1,5 +1,6 @@
from typing import Any
from core.tools.entities.values import ToolLabelEnum
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.google.tools.google_search import GoogleSearchTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
@ -9,7 +10,7 @@ class GoogleProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
GoogleSearchTool().fork_tool_runtime(
meta={
runtime={
"credentials": credentials,
}
).invoke(
@ -21,3 +22,8 @@ class GoogleProvider(BuiltinToolProviderController):
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
def _get_tool_labels(self) -> list[ToolLabelEnum]:
return [
ToolLabelEnum.SEARCH
]

View File

@ -1,5 +1,6 @@
from typing import Any
from core.tools.entities.values import ToolLabelEnum
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
@ -10,3 +11,8 @@ class GoogleProvider(BuiltinToolProviderController):
pass
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
def _get_tool_labels(self) -> list[ToolLabelEnum]:
return [
ToolLabelEnum.SEARCH, ToolLabelEnum.PRODUCTIVITY
]

View File

@ -1,5 +1,6 @@
from typing import Any
from core.tools.entities.values import ToolLabelEnum
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.judge0ce.tools.executeCode import ExecuteCodeTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
@ -9,7 +10,7 @@ class Judge0CEProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
ExecuteCodeTool().fork_tool_runtime(
meta={
runtime={
"credentials": credentials,
}
).invoke(
@ -21,3 +22,8 @@ class Judge0CEProvider(BuiltinToolProviderController):
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
def _get_tool_labels(self) -> list[ToolLabelEnum]:
return [
ToolLabelEnum.OTHER, ToolLabelEnum.UTILITIES
]

View File

@ -1,5 +1,6 @@
from typing import Any
from core.tools.entities.values import ToolLabelEnum
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.maths.tools.eval_expression import EvaluateExpressionTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
@ -16,3 +17,8 @@ class MathsProvider(BuiltinToolProviderController):
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
def _get_tool_labels(self) -> list[ToolLabelEnum]:
return [
ToolLabelEnum.UTILITIES, ToolLabelEnum.PRODUCTIVITY
]

View File

@ -1,5 +1,6 @@
import requests
from core.tools.entities.values import ToolLabelEnum
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
@ -34,3 +35,8 @@ class OpenweatherProvider(BuiltinToolProviderController):
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
def _get_tool_labels(self) -> list[ToolLabelEnum]:
return [
ToolLabelEnum.WEATHER
]

View File

@ -1,3 +1,4 @@
from core.tools.entities.values import ToolLabelEnum
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.pubmed.tools.pubmed_search import PubMedSearchTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
@ -7,7 +8,7 @@ class PubMedProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
try:
PubMedSearchTool().fork_tool_runtime(
meta={
runtime={
"credentials": credentials,
}
).invoke(
@ -18,3 +19,8 @@ class PubMedProvider(BuiltinToolProviderController):
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
def _get_tool_labels(self) -> list[ToolLabelEnum]:
return [
ToolLabelEnum.MEDICAL, ToolLabelEnum.SEARCH
]

View File

@ -1,5 +1,6 @@
from typing import Any
from core.tools.entities.values import ToolLabelEnum
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.qrcode.tools.qrcode_generator import QRCodeGeneratorTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
@ -14,3 +15,8 @@ class QRCodeProvider(BuiltinToolProviderController):
})
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
def _get_tool_labels(self) -> list[ToolLabelEnum]:
return [
ToolLabelEnum.UTILITIES
]

View File

@ -1,5 +1,6 @@
from typing import Any
from core.tools.entities.values import ToolLabelEnum
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.searxng.tools.searxng_search import SearXNGSearchTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
@ -9,7 +10,7 @@ class SearXNGProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
SearXNGSearchTool().fork_tool_runtime(
meta={
runtime={
"credentials": credentials,
}
).invoke(
@ -23,3 +24,8 @@ class SearXNGProvider(BuiltinToolProviderController):
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
def _get_tool_labels(self) -> list[ToolLabelEnum]:
return [
ToolLabelEnum.SEARCH, ToolLabelEnum.PRODUCTIVITY
]

View File

@ -1,3 +1,4 @@
from core.tools.entities.values import ToolLabelEnum
from core.tools.provider.builtin.slack.tools.slack_webhook import SlackWebhookTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
@ -6,3 +7,8 @@ class SlackProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
SlackWebhookTool()
pass
def _get_tool_labels(self) -> list[ToolLabelEnum]:
return [
ToolLabelEnum.SOCIAL
]

View File

@ -1,5 +1,6 @@
import json
from core.tools.entities.values import ToolLabelEnum
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.spark.tools.spark_img_generation import spark_response
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
@ -38,3 +39,8 @@ class SparkProvider(BuiltinToolProviderController):
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
def _get_tool_labels(self) -> list[ToolLabelEnum]:
return [
ToolLabelEnum.IMAGE
]

View File

@ -1,5 +1,6 @@
from typing import Any
from core.tools.entities.values import ToolLabelEnum
from core.tools.provider.builtin.stability.tools.base import BaseStabilityAuthorization
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
@ -13,3 +14,8 @@ class StabilityToolProvider(BuiltinToolProviderController, BaseStabilityAuthoriz
This method is responsible for validating the credentials.
"""
self.sd_validate_credentials(credentials)
def _get_tool_labels(self) -> list[ToolLabelEnum]:
return [
ToolLabelEnum.IMAGE
]

View File

@ -1,5 +1,6 @@
from typing import Any
from core.tools.entities.values import ToolLabelEnum
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.stablediffusion.tools.stable_diffusion import StableDiffusionTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
@ -9,9 +10,14 @@ class StableDiffusionProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
StableDiffusionTool().fork_tool_runtime(
meta={
runtime={
"credentials": credentials,
}
).validate_models()
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
def _get_tool_labels(self) -> list[ToolLabelEnum]:
return [
ToolLabelEnum.IMAGE
]

View File

@ -1,3 +1,4 @@
from core.tools.entities.values import ToolLabelEnum
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.stackexchange.tools.searchStackExQuestions import SearchStackExQuestionsTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
@ -7,7 +8,7 @@ class StackExchangeProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
try:
SearchStackExQuestionsTool().fork_tool_runtime(
meta={
runtime={
"credentials": credentials,
}
).invoke(
@ -23,3 +24,8 @@ class StackExchangeProvider(BuiltinToolProviderController):
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
def _get_tool_labels(self) -> list[ToolLabelEnum]:
return [
ToolLabelEnum.SEARCH, ToolLabelEnum.UTILITIES
]

View File

@ -1,5 +1,6 @@
from typing import Any
from core.tools.entities.values import ToolLabelEnum
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.tavily.tools.tavily_search import TavilySearchTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
@ -9,7 +10,7 @@ class TavilyProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
TavilySearchTool().fork_tool_runtime(
meta={
runtime={
"credentials": credentials,
}
).invoke(
@ -27,3 +28,8 @@ class TavilyProvider(BuiltinToolProviderController):
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
def _get_tool_labels(self) -> list[ToolLabelEnum]:
return [
ToolLabelEnum.SEARCH
]

View File

@ -1,5 +1,6 @@
from typing import Any
from core.tools.entities.values import ToolLabelEnum
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.time.tools.current_time import CurrentTimeTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
@ -14,3 +15,8 @@ class WikiPediaProvider(BuiltinToolProviderController):
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
def _get_tool_labels(self) -> list[ToolLabelEnum]:
return [
ToolLabelEnum.UTILITIES
]

View File

@ -2,6 +2,7 @@ from typing import Any
import requests
from core.tools.entities.values import ToolLabelEnum
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
@ -32,3 +33,8 @@ class TrelloProvider(BuiltinToolProviderController):
except requests.exceptions.RequestException as e:
# Handle other exceptions, such as connection errors
raise ToolProviderCredentialValidationError("Error validating Trello credentials")
def _get_tool_labels(self) -> list[ToolLabelEnum]:
return [
ToolLabelEnum.PRODUCTIVITY
]

View File

@ -3,6 +3,7 @@ from typing import Any
from twilio.base.exceptions import TwilioRestException
from twilio.rest import Client
from core.tools.entities.values import ToolLabelEnum
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
@ -27,3 +28,8 @@ class TwilioProvider(BuiltinToolProviderController):
raise ToolProviderCredentialValidationError(f"Missing required credential: {e}") from e
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
def _get_tool_labels(self) -> list[ToolLabelEnum]:
return [
ToolLabelEnum.SOCIAL
]

View File

@ -1,5 +1,6 @@
from typing import Any
from core.tools.entities.values import ToolLabelEnum
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.vectorizer.tools.vectorizer import VectorizerTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
@ -9,7 +10,7 @@ class VectorizerProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
VectorizerTool().fork_tool_runtime(
meta={
runtime={
"credentials": credentials,
}
).invoke(
@ -21,3 +22,8 @@ class VectorizerProvider(BuiltinToolProviderController):
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
def _get_tool_labels(self) -> list[ToolLabelEnum]:
return [
ToolLabelEnum.PRODUCTIVITY, ToolLabelEnum.IMAGE
]

View File

@ -1,5 +1,6 @@
from typing import Any
from core.tools.entities.values import ToolLabelEnum
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.webscraper.tools.webscraper import WebscraperTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
@ -9,7 +10,7 @@ class WebscraperProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
WebscraperTool().fork_tool_runtime(
meta={
runtime={
"credentials": credentials,
}
).invoke(
@ -21,3 +22,8 @@ class WebscraperProvider(BuiltinToolProviderController):
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
def _get_tool_labels(self) -> list[ToolLabelEnum]:
return [
ToolLabelEnum.PRODUCTIVITY
]

View File

@ -1,3 +1,4 @@
from core.tools.entities.values import ToolLabelEnum
from core.tools.provider.builtin.wecom.tools.wecom_group_bot import WecomGroupBotTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
@ -6,3 +7,8 @@ class WecomProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
WecomGroupBotTool()
pass
def _get_tool_labels(self) -> list[ToolLabelEnum]:
return [
ToolLabelEnum.SOCIAL
]

View File

@ -1,3 +1,4 @@
from core.tools.entities.values import ToolLabelEnum
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.wikipedia.tools.wikipedia_search import WikiPediaSearchTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
@ -7,7 +8,7 @@ class WikiPediaProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
try:
WikiPediaSearchTool().fork_tool_runtime(
meta={
runtime={
"credentials": credentials,
}
).invoke(
@ -18,3 +19,8 @@ class WikiPediaProvider(BuiltinToolProviderController):
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
def _get_tool_labels(self) -> list[ToolLabelEnum]:
return [
ToolLabelEnum.SEARCH
]

View File

@ -1,5 +1,6 @@
from typing import Any
from core.tools.entities.values import ToolLabelEnum
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.wolframalpha.tools.wolframalpha import WolframAlphaTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
@ -9,7 +10,7 @@ class GoogleProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
WolframAlphaTool().fork_tool_runtime(
meta={
runtime={
"credentials": credentials,
}
).invoke(
@ -20,3 +21,8 @@ class GoogleProvider(BuiltinToolProviderController):
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
def _get_tool_labels(self) -> list[ToolLabelEnum]:
return [
ToolLabelEnum.PRODUCTIVITY, ToolLabelEnum.UTILITIES
]

View File

@ -1,3 +1,4 @@
from core.tools.entities.values import ToolLabelEnum
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.yahoo.tools.ticker import YahooFinanceSearchTickerTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
@ -7,7 +8,7 @@ class YahooFinanceProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
try:
YahooFinanceSearchTickerTool().fork_tool_runtime(
meta={
runtime={
"credentials": credentials,
}
).invoke(
@ -18,3 +19,8 @@ class YahooFinanceProvider(BuiltinToolProviderController):
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
def _get_tool_labels(self) -> list[ToolLabelEnum]:
return [
ToolLabelEnum.BUSINESS, ToolLabelEnum.FINANCE
]

View File

@ -1,3 +1,4 @@
from core.tools.entities.values import ToolLabelEnum
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.youtube.tools.videos import YoutubeVideosAnalyticsTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
@ -7,7 +8,7 @@ class YahooFinanceProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
try:
YoutubeVideosAnalyticsTool().fork_tool_runtime(
meta={
runtime={
"credentials": credentials,
}
).invoke(
@ -20,3 +21,8 @@ class YahooFinanceProvider(BuiltinToolProviderController):
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
def _get_tool_labels(self) -> list[ToolLabelEnum]:
return [
ToolLabelEnum.VIDEOS
]

View File

@ -2,8 +2,9 @@ from abc import abstractmethod
from os import listdir, path
from typing import Any
from core.tools.entities.api_entities import UserToolProviderCredentials
from core.tools.entities.tool_entities import ToolParameter, ToolProviderCredentials, ToolProviderType
from core.tools.entities.user_entities import UserToolProviderCredentials
from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict
from core.tools.errors import (
ToolNotFoundError,
ToolParameterValidationError,
@ -19,7 +20,7 @@ from core.utils.module_import_helper import load_single_subclass_from_source
class BuiltinToolProviderController(ToolProviderController):
def __init__(self, **data: Any) -> None:
if self.app_type == ToolProviderType.API_BASED or self.app_type == ToolProviderType.APP_BASED:
if self.provider_type == ToolProviderType.API or self.provider_type == ToolProviderType.APP:
super().__init__(**data)
return
@ -129,7 +130,7 @@ class BuiltinToolProviderController(ToolProviderController):
len(self.credentials_schema) != 0
@property
def app_type(self) -> ToolProviderType:
def provider_type(self) -> ToolProviderType:
"""
returns the type of the provider
@ -137,6 +138,22 @@ class BuiltinToolProviderController(ToolProviderController):
"""
return ToolProviderType.BUILT_IN
@property
def tool_labels(self) -> list[str]:
"""
returns the labels of the provider
:return: labels of the provider
"""
label_enums = self._get_tool_labels()
return [default_tool_label_dict[label].name for label in label_enums]
def _get_tool_labels(self) -> list[ToolLabelEnum]:
"""
returns the labels of the provider
"""
return []
def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None:
"""
validate the parameters of the tool and set the default value if needed

View File

@ -3,13 +3,13 @@ from typing import Any, Optional
from pydantic import BaseModel
from core.tools.entities.api_entities import UserToolProviderCredentials
from core.tools.entities.tool_entities import (
ToolParameter,
ToolProviderCredentials,
ToolProviderIdentity,
ToolProviderType,
)
from core.tools.entities.user_entities import UserToolProviderCredentials
from core.tools.errors import ToolNotFoundError, ToolParameterValidationError, ToolProviderCredentialValidationError
from core.tools.tool.tool import Tool
@ -67,7 +67,7 @@ class ToolProviderController(BaseModel, ABC):
return tool.parameters
@property
def app_type(self) -> ToolProviderType:
def provider_type(self) -> ToolProviderType:
"""
returns the type of the provider
@ -198,25 +198,3 @@ class ToolProviderController(BaseModel, ABC):
credentials[credential_name] = default_value
def validate_credentials(self, credentials: dict[str, Any]) -> None:
"""
validate the credentials of the provider
:param tool_name: the name of the tool, defined in `get_tools`
:param credentials: the credentials of the tool
"""
# validate credentials format
self.validate_credentials_format(credentials)
# validate credentials
self._validate_credentials(credentials)
@abstractmethod
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
"""
validate the credentials of the provider
:param tool_name: the name of the tool, defined in `get_tools`
:param credentials: the credentials of the tool
"""
pass

View File

@ -0,0 +1,230 @@
from typing import Optional
from core.app.app_config.entities import VariableEntity
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.model_runtime.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
ToolDescription,
ToolIdentity,
ToolParameter,
ToolParameterOption,
ToolProviderType,
)
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool.workflow_tool import WorkflowTool
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
from extensions.ext_database import db
from models.model import App, AppMode
from models.tools import WorkflowToolProvider
from models.workflow import Workflow
class WorkflowToolProviderController(ToolProviderController):
provider_id: str
@classmethod
def from_db(cls, db_provider: WorkflowToolProvider) -> 'WorkflowToolProviderController':
app = db_provider.app
if not app:
raise ValueError('app not found')
controller = WorkflowToolProviderController(**{
'identity': {
'author': db_provider.user.name if db_provider.user_id and db_provider.user else '',
'name': db_provider.label,
'label': {
'en_US': db_provider.label,
'zh_Hans': db_provider.label
},
'description': {
'en_US': db_provider.description,
'zh_Hans': db_provider.description
},
'icon': db_provider.icon,
},
'credentials_schema': {},
'provider_id': db_provider.id or '',
})
# init tools
controller.tools = [controller._get_db_provider_tool(db_provider, app)]
return controller
@property
def provider_type(self) -> ToolProviderType:
return ToolProviderType.WORKFLOW
def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool:
"""
get db provider tool
:param db_provider: the db provider
:param app: the app
:return: the tool
"""
workflow: Workflow = db.session.query(Workflow).filter(
Workflow.app_id == db_provider.app_id,
Workflow.version == db_provider.version
).first()
if not workflow:
raise ValueError('workflow not found')
# fetch start node
graph: dict = workflow.graph_dict
features_dict: dict = workflow.features_dict
features = WorkflowAppConfigManager.convert_features(
config_dict=features_dict,
app_mode=AppMode.WORKFLOW
)
parameters = db_provider.parameter_configurations
variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
def fetch_workflow_variable(variable_name: str) -> VariableEntity:
return next(filter(lambda x: x.variable == variable_name, variables), None)
user = db_provider.user
workflow_tool_parameters = []
for parameter in parameters:
variable = fetch_workflow_variable(parameter.name)
if variable:
parameter_type = None
options = None
if variable.type in [
VariableEntity.Type.TEXT_INPUT,
VariableEntity.Type.PARAGRAPH,
]:
parameter_type = ToolParameter.ToolParameterType.STRING
elif variable.type in [
VariableEntity.Type.SELECT
]:
parameter_type = ToolParameter.ToolParameterType.SELECT
elif variable.type in [
VariableEntity.Type.NUMBER
]:
parameter_type = ToolParameter.ToolParameterType.NUMBER
else:
raise ValueError(f'unsupported variable type {variable.type}')
if variable.type == VariableEntity.Type.SELECT and variable.options:
options = [
ToolParameterOption(
value=option,
label=I18nObject(
en_US=option,
zh_Hans=option
)
) for option in variable.options
]
workflow_tool_parameters.append(
ToolParameter(
name=parameter.name,
label=I18nObject(
en_US=variable.label,
zh_Hans=variable.label
),
human_description=I18nObject(
en_US=parameter.description,
zh_Hans=parameter.description
),
type=parameter_type,
form=parameter.form,
llm_description=parameter.description,
required=variable.required,
options=options,
default=variable.default
)
)
elif features.file_upload:
workflow_tool_parameters.append(
ToolParameter(
name=parameter.name,
label=I18nObject(
en_US=parameter.name,
zh_Hans=parameter.name
),
human_description=I18nObject(
en_US=parameter.description,
zh_Hans=parameter.description
),
type=ToolParameter.ToolParameterType.FILE,
llm_description=parameter.description,
required=False,
form=parameter.form,
)
)
else:
raise ValueError('variable not found')
return WorkflowTool(
identity=ToolIdentity(
author=user.name if user else '',
name=db_provider.name,
label=I18nObject(
en_US=db_provider.label,
zh_Hans=db_provider.label
),
provider=self.provider_id,
icon=db_provider.icon,
),
description=ToolDescription(
human=I18nObject(
en_US=db_provider.description,
zh_Hans=db_provider.description
),
llm=db_provider.description,
),
parameters=workflow_tool_parameters,
is_team_authorization=True,
workflow_app_id=app.id,
workflow_entities={
'app': app,
'workflow': workflow,
},
version=db_provider.version,
workflow_call_depth=0,
label=db_provider.label
)
def get_tools(self, user_id: str, tenant_id: str) -> list[WorkflowTool]:
"""
fetch tools from database
:param user_id: the user id
:param tenant_id: the tenant id
:return: the tools
"""
if self.tools is not None:
return self.tools
db_providers: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.app_id == self.provider_id,
).first()
if not db_providers:
return []
self.tools = [self._get_db_provider_tool(db_providers, db_providers.app)]
return self.tools
def get_tool(self, tool_name: str) -> Optional[WorkflowTool]:
"""
get tool by name
:param tool_name: the name of the tool
:return: the tool
"""
if self.tools is None:
return None
for tool in self.tools:
if tool.identity.name == tool_name:
return tool
return None

View File

@ -8,9 +8,8 @@ import httpx
import requests
import core.helper.ssrf_proxy as ssrf_proxy
from core.tools.entities.tool_bundle import ApiBasedToolBundle
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
from core.tools.entities.user_entities import UserToolProvider
from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError
from core.tools.tool.tool import Tool
@ -20,12 +19,12 @@ API_TOOL_DEFAULT_TIMEOUT = (
)
class ApiTool(Tool):
api_bundle: ApiBasedToolBundle
api_bundle: ApiToolBundle
"""
Api tool
"""
def fork_tool_runtime(self, meta: dict[str, Any]) -> 'Tool':
def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'Tool':
"""
fork a new tool with meta data
@ -37,7 +36,7 @@ class ApiTool(Tool):
parameters=self.parameters.copy() if self.parameters else None,
description=self.description.copy() if self.description else None,
api_bundle=self.api_bundle.copy() if self.api_bundle else None,
runtime=Tool.Runtime(**meta)
runtime=Tool.Runtime(**runtime)
)
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False) -> str:
@ -55,7 +54,7 @@ class ApiTool(Tool):
return self.validate_and_parse_response(response)
def tool_provider_type(self) -> ToolProviderType:
return UserToolProvider.ProviderType.API
return ToolProviderType.API
def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]:
headers = {}

View File

@ -2,9 +2,8 @@
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.entities.user_entities import UserToolProvider
from core.tools.model.tool_model_manager import ToolModelManager
from core.tools.tool.tool import Tool
from core.tools.utils.model_invocation_utils import ModelInvocationUtils
from core.tools.utils.web_reader_tool import get_url
_SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language
@ -34,7 +33,7 @@ class BuiltinTool(Tool):
:return: the model result
"""
# invoke model
return ToolModelManager.invoke(
return ModelInvocationUtils.invoke(
user_id=user_id,
tenant_id=self.runtime.tenant_id,
tool_type='builtin',
@ -43,7 +42,7 @@ class BuiltinTool(Tool):
)
def tool_provider_type(self) -> ToolProviderType:
return UserToolProvider.ProviderType.BUILTIN
return ToolProviderType.BUILT_IN
def get_max_tokens(self) -> int:
"""
@ -52,7 +51,7 @@ class BuiltinTool(Tool):
:param model_config: the model config
:return: the max tokens
"""
return ToolModelManager.get_max_llm_context_tokens(
return ModelInvocationUtils.get_max_llm_context_tokens(
tenant_id=self.runtime.tenant_id,
)
@ -63,7 +62,7 @@ class BuiltinTool(Tool):
:param prompt_messages: the prompt messages
:return: the tokens
"""
return ToolModelManager.calculate_tokens(
return ModelInvocationUtils.calculate_tokens(
tenant_id=self.runtime.tenant_id,
prompt_messages=prompt_messages
)

View File

@ -4,9 +4,12 @@ from typing import Any, Optional, Union
from pydantic import BaseModel, validator
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.file_obj import FileVar
from core.tools.entities.tool_entities import (
ToolDescription,
ToolIdentity,
ToolInvokeFrom,
ToolInvokeMessage,
ToolParameter,
ToolProviderType,
@ -25,10 +28,7 @@ class Tool(BaseModel, ABC):
@validator('parameters', pre=True, always=True)
def set_parameters(cls, v, values):
if not v:
return []
return v
return v or []
class Runtime(BaseModel):
"""
@ -41,6 +41,8 @@ class Tool(BaseModel, ABC):
tenant_id: str = None
tool_id: str = None
invoke_from: InvokeFrom = None
tool_invoke_from: ToolInvokeFrom = None
credentials: dict[str, Any] = None
runtime_parameters: dict[str, Any] = None
@ -53,7 +55,7 @@ class Tool(BaseModel, ABC):
class VARIABLE_KEY(Enum):
IMAGE = 'image'
def fork_tool_runtime(self, meta: dict[str, Any]) -> 'Tool':
def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'Tool':
"""
fork a new tool with meta data
@ -64,7 +66,7 @@ class Tool(BaseModel, ABC):
identity=self.identity.copy() if self.identity else None,
parameters=self.parameters.copy() if self.parameters else None,
description=self.description.copy() if self.description else None,
runtime=Tool.Runtime(**meta),
runtime=Tool.Runtime(**runtime),
)
@abstractmethod
@ -208,17 +210,17 @@ class Tool(BaseModel, ABC):
if response.type == ToolInvokeMessage.MessageType.TEXT:
result += response.message
elif response.type == ToolInvokeMessage.MessageType.LINK:
result += f"result link: {response.message}. please tell user to check it."
result += f"result link: {response.message}. please tell user to check it. \n"
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
response.type == ToolInvokeMessage.MessageType.IMAGE:
result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now."
result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now. \n"
elif response.type == ToolInvokeMessage.MessageType.BLOB:
if len(response.message) > 114:
result += str(response.message[:114]) + '...'
else:
result += str(response.message)
else:
result += f"tool response: {response.message}."
result += f"tool response: {response.message}. \n"
return result
@ -343,6 +345,14 @@ class Tool(BaseModel, ABC):
message=image,
save_as=save_as)
def create_file_var_message(self, file_var: FileVar) -> ToolInvokeMessage:
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE_VAR,
message='',
meta={
'file_var': file_var
},
save_as='')
def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage:
"""
create a link message

View File

@ -0,0 +1,200 @@
import json
import logging
from copy import deepcopy
from typing import Any, Union
from core.file.file_obj import FileTransferMethod, FileVar
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType
from core.tools.tool.tool import Tool
from extensions.ext_database import db
from models.account import Account
from models.model import App, EndUser
from models.workflow import Workflow
logger = logging.getLogger(__name__)
class WorkflowTool(Tool):
workflow_app_id: str
version: str
workflow_entities: dict[str, Any]
workflow_call_depth: int
label: str
"""
Workflow tool.
"""
def tool_provider_type(self) -> ToolProviderType:
"""
get the tool provider type
:return: the tool provider type
"""
return ToolProviderType.WORKFLOW
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \
-> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke the tool
"""
app = self._get_app(app_id=self.workflow_app_id)
workflow = self._get_workflow(app_id=self.workflow_app_id, version=self.version)
# transform the tool parameters
tool_parameters, files = self._transform_args(tool_parameters)
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
generator = WorkflowAppGenerator()
result = generator.generate(
app_model=app,
workflow=workflow,
user=self._get_user(user_id),
args={
'inputs': tool_parameters,
'files': files
},
invoke_from=self.runtime.invoke_from,
stream=False,
call_depth=self.workflow_call_depth + 1,
)
data = result.get('data', {})
if data.get('error'):
raise Exception(data.get('error'))
result = []
outputs = data.get('outputs', {})
outputs, files = self._extract_files(outputs)
for file in files:
result.append(self.create_file_var_message(file))
result.append(self.create_text_message(json.dumps(outputs)))
return result
def _get_user(self, user_id: str) -> Union[EndUser, Account]:
"""
get the user by user id
"""
user = db.session.query(EndUser).filter(EndUser.id == user_id).first()
if not user:
user = db.session.query(Account).filter(Account.id == user_id).first()
if not user:
raise ValueError('user not found')
return user
def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'WorkflowTool':
"""
fork a new tool with meta data
:param meta: the meta data of a tool call processing, tenant_id is required
:return: the new tool
"""
return self.__class__(
identity=deepcopy(self.identity),
parameters=deepcopy(self.parameters),
description=deepcopy(self.description),
runtime=Tool.Runtime(**runtime),
workflow_app_id=self.workflow_app_id,
workflow_entities=self.workflow_entities,
workflow_call_depth=self.workflow_call_depth,
version=self.version,
label=self.label
)
def _get_workflow(self, app_id: str, version: str) -> Workflow:
"""
get the workflow by app id and version
"""
if not version:
workflow = db.session.query(Workflow).filter(
Workflow.app_id == app_id,
Workflow.version != 'draft'
).order_by(Workflow.created_at.desc()).first()
else:
workflow = db.session.query(Workflow).filter(
Workflow.app_id == app_id,
Workflow.version == version
).first()
if not workflow:
raise ValueError('workflow not found or not published')
return workflow
def _get_app(self, app_id: str) -> App:
"""
get the app by app id
"""
app = db.session.query(App).filter(App.id == app_id).first()
if not app:
raise ValueError('app not found')
return app
def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]:
"""
transform the tool parameters
:param tool_parameters: the tool parameters
:return: tool_parameters, files
"""
parameter_rules = self.get_all_runtime_parameters()
parameters_result = {}
files = []
for parameter in parameter_rules:
if parameter.type == ToolParameter.ToolParameterType.FILE:
file = tool_parameters.get(parameter.name)
if file:
try:
file_var_list = [FileVar(**f) for f in file]
for file_var in file_var_list:
file_dict = {
'transfer_method': file_var.transfer_method.value,
'type': file_var.type.value,
}
if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
file_dict['tool_file_id'] = file_var.related_id
elif file_var.transfer_method == FileTransferMethod.LOCAL_FILE:
file_dict['upload_file_id'] = file_var.related_id
elif file_var.transfer_method == FileTransferMethod.REMOTE_URL:
file_dict['url'] = file_var.preview_url
files.append(file_dict)
except Exception as e:
logger.exception(e)
else:
parameters_result[parameter.name] = tool_parameters.get(parameter.name)
return parameters_result, files
def _extract_files(self, outputs: dict) -> tuple[dict, list[FileVar]]:
"""
extract files from the result
:param result: the result
:return: the result, files
"""
files = []
result = {}
for key, value in outputs.items():
if isinstance(value, list):
has_file = False
for item in value:
if isinstance(item, dict) and item.get('__variant') == 'FileVar':
try:
files.append(FileVar(**item))
has_file = True
except Exception as e:
pass
if has_file:
continue
result[key] = value
return result, files

View File

@ -1,7 +1,10 @@
from copy import deepcopy
from datetime import datetime, timezone
from mimetypes import guess_type
from typing import Union
from yarl import URL
from core.app.entities.app_invoke_entities import InvokeFrom
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
@ -17,6 +20,7 @@ from core.tools.errors import (
ToolProviderNotFoundError,
)
from core.tools.tool.tool import Tool
from core.tools.tool.workflow_tool import WorkflowTool
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from extensions.ext_database import db
from models.model import Message, MessageFile
@ -115,7 +119,8 @@ class ToolEngine:
@staticmethod
def workflow_invoke(tool: Tool, tool_parameters: dict,
user_id: str, workflow_id: str,
workflow_tool_callback: DifyWorkflowCallbackHandler) \
workflow_tool_callback: DifyWorkflowCallbackHandler,
workflow_call_depth: int) \
-> list[ToolInvokeMessage]:
"""
Workflow invokes the tool with the given arguments.
@ -127,6 +132,9 @@ class ToolEngine:
tool_inputs=tool_parameters
)
if isinstance(tool, WorkflowTool):
tool.workflow_call_depth = workflow_call_depth + 1
response = tool.invoke(user_id, tool_parameters)
# hit the callback handler
@ -195,8 +203,24 @@ class ToolEngine:
for response in tool_response:
if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
response.type == ToolInvokeMessage.MessageType.IMAGE:
mimetype = None
if response.meta.get('mime_type'):
mimetype = response.meta.get('mime_type')
else:
try:
url = URL(response.message)
extension = url.suffix
guess_type_result, _ = guess_type(f'a{extension}')
if guess_type_result:
mimetype = guess_type_result
except Exception:
pass
if not mimetype:
mimetype = 'image/jpeg'
result.append(ToolInvokeMessageBinary(
mimetype=response.meta.get('mime_type', 'octet/stream'),
mimetype=response.meta.get('mime_type', 'image/jpeg'),
url=response.message,
save_as=response.save_as,
))

View File

@ -0,0 +1,96 @@
from core.tools.entities.values import default_tool_label_name_list
from core.tools.provider.api_tool_provider import ApiToolProviderController
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController
from extensions.ext_database import db
from models.tools import ToolLabelBinding
class ToolLabelManager:
@classmethod
def filter_tool_labels(cls, tool_labels: list[str]) -> list[str]:
"""
Filter tool labels
"""
tool_labels = [label for label in tool_labels if label in default_tool_label_name_list]
return list(set(tool_labels))
@classmethod
def update_tool_labels(cls, controller: ToolProviderController, labels: list[str]):
"""
Update tool labels
"""
labels = cls.filter_tool_labels(labels)
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
provider_id = controller.provider_id
else:
raise ValueError('Unsupported tool type')
# delete old labels
db.session.query(ToolLabelBinding).filter(
ToolLabelBinding.tool_id == provider_id
).delete()
# insert new labels
for label in labels:
db.session.add(ToolLabelBinding(
tool_id=provider_id,
tool_type=controller.provider_type.value,
label_name=label,
))
db.session.commit()
@classmethod
def get_tool_labels(cls, controller: ToolProviderController) -> list[str]:
"""
Get tool labels
"""
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
provider_id = controller.provider_id
elif isinstance(controller, BuiltinToolProviderController):
return controller.tool_labels
else:
raise ValueError('Unsupported tool type')
labels: list[ToolLabelBinding] = db.session.query(ToolLabelBinding.label_name).filter(
ToolLabelBinding.tool_id == provider_id,
ToolLabelBinding.tool_type == controller.provider_type.value,
).all()
return [label.label_name for label in labels]
@classmethod
def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]:
"""
Get tools labels
:param tool_providers: list of tool providers
:return: dict of tool labels
:key: tool id
:value: list of tool labels
"""
if not tool_providers:
return {}
for controller in tool_providers:
if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
raise ValueError('Unsupported tool type')
provider_ids = [controller.provider_id for controller in tool_providers]
labels: list[ToolLabelBinding] = db.session.query(ToolLabelBinding).filter(
ToolLabelBinding.tool_id.in_(provider_ids)
).all()
tool_labels = {
label.tool_id: [] for label in labels
}
for label in labels:
tool_labels[label.tool_id].append(label.label_name)
return tool_labels

View File

@ -9,21 +9,24 @@ from typing import Any, Union
from flask import current_app
from core.agent.entities import AgentToolEntity
from core.app.entities.app_invoke_entities import InvokeFrom
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools import *
from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
ToolInvokeFrom,
ToolParameter,
)
from core.tools.entities.user_entities import UserToolProvider
from core.tools.errors import ToolProviderNotFoundError
from core.tools.provider.api_tool_provider import ApiBasedToolProviderController
from core.tools.provider.api_tool_provider import ApiToolProviderController
from core.tools.provider.builtin._positions import BuiltinToolProviderSort
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
from core.tools.tool.api_tool import ApiTool
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.tool.tool import Tool
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import (
ToolConfigurationManager,
ToolParameterConfigurationManager,
@ -31,8 +34,8 @@ from core.tools.utils.configuration import (
from core.utils.module_import_helper import load_single_subclass_from_source
from core.workflow.nodes.tool.entities import ToolEntity
from extensions.ext_database import db
from models.tools import ApiToolProvider, BuiltinToolProvider
from services.tools_transform_service import ToolTransformService
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
from services.tools.tools_transform_service import ToolTransformService
logger = logging.getLogger(__name__)
@ -99,7 +102,12 @@ class ToolManager:
raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
@classmethod
def get_tool_runtime(cls, provider_type: str, provider_name: str, tool_name: str, tenant_id: str) \
def get_tool_runtime(cls, provider_type: str,
provider_id: str,
tool_name: str,
tenant_id: str,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \
-> Union[BuiltinTool, ApiTool]:
"""
get the tool runtime
@ -111,51 +119,76 @@ class ToolManager:
:return: the tool
"""
if provider_type == 'builtin':
builtin_tool = cls.get_builtin_tool(provider_name, tool_name)
builtin_tool = cls.get_builtin_tool(provider_id, tool_name)
# check if the builtin tool need credentials
provider_controller = cls.get_builtin_provider(provider_name)
provider_controller = cls.get_builtin_provider(provider_id)
if not provider_controller.need_credentials:
return builtin_tool.fork_tool_runtime(meta={
return builtin_tool.fork_tool_runtime(runtime={
'tenant_id': tenant_id,
'credentials': {},
'invoke_from': invoke_from,
'tool_invoke_from': tool_invoke_from,
})
# get credentials
builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider_name,
BuiltinToolProvider.provider == provider_id,
).first()
if builtin_provider is None:
raise ToolProviderNotFoundError(f'builtin provider {provider_name} not found')
raise ToolProviderNotFoundError(f'builtin provider {provider_id} not found')
# decrypt the credentials
credentials = builtin_provider.credentials
controller = cls.get_builtin_provider(provider_name)
controller = cls.get_builtin_provider(provider_id)
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
return builtin_tool.fork_tool_runtime(meta={
return builtin_tool.fork_tool_runtime(runtime={
'tenant_id': tenant_id,
'credentials': decrypted_credentials,
'runtime_parameters': {}
'runtime_parameters': {},
'invoke_from': invoke_from,
'tool_invoke_from': tool_invoke_from,
})
elif provider_type == 'api':
if tenant_id is None:
raise ValueError('tenant id is required for api provider')
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_name)
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
# decrypt the credentials
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=api_provider)
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
return api_provider.get_tool(tool_name).fork_tool_runtime(meta={
return api_provider.get_tool(tool_name).fork_tool_runtime(runtime={
'tenant_id': tenant_id,
'credentials': decrypted_credentials,
'invoke_from': invoke_from,
'tool_invoke_from': tool_invoke_from,
})
elif provider_type == 'workflow':
workflow_provider = db.session.query(WorkflowToolProvider).filter(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.id == provider_id
).first()
if workflow_provider is None:
raise ToolProviderNotFoundError(f'workflow provider {provider_id} not found')
controller = ToolTransformService.workflow_provider_to_controller(
db_provider=workflow_provider
)
return controller.get_tools(user_id=None, tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(runtime={
'tenant_id': tenant_id,
'credentials': {},
'invoke_from': invoke_from,
'tool_invoke_from': tool_invoke_from,
})
elif provider_type == 'app':
raise NotImplementedError('app provider not implemented')
@ -207,18 +240,25 @@ class ToolManager:
return parameter_value
@classmethod
def get_agent_tool_runtime(cls, tenant_id: str, app_id: str, agent_tool: AgentToolEntity) -> Tool:
def get_agent_tool_runtime(cls, tenant_id: str, app_id: str, agent_tool: AgentToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER) -> Tool:
"""
get the agent tool runtime
"""
tool_entity = cls.get_tool_runtime(
provider_type=agent_tool.provider_type, provider_name=agent_tool.provider_id,
provider_type=agent_tool.provider_type,
provider_id=agent_tool.provider_id,
tool_name=agent_tool.tool_name,
tenant_id=tenant_id,
invoke_from=invoke_from,
tool_invoke_from=ToolInvokeFrom.AGENT
)
runtime_parameters = {}
parameters = tool_entity.get_all_runtime_parameters()
for parameter in parameters:
# check file types
if parameter.type == ToolParameter.ToolParameterType.FILE:
raise ValueError(f"file type parameter {parameter.name} not supported in agent")
if parameter.form == ToolParameter.ToolParameterForm.FORM:
# save tool parameter to tool entity memory
value = cls._init_runtime_parameter(parameter, agent_tool.tool_parameters)
@ -238,15 +278,17 @@ class ToolManager:
return tool_entity
@classmethod
def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, workflow_tool: ToolEntity):
def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, workflow_tool: ToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER) -> Tool:
"""
get the workflow tool runtime
"""
tool_entity = cls.get_tool_runtime(
provider_type=workflow_tool.provider_type,
provider_name=workflow_tool.provider_id,
provider_id=workflow_tool.provider_id,
tool_name=workflow_tool.tool_name,
tenant_id=tenant_id,
invoke_from=invoke_from,
tool_invoke_from=ToolInvokeFrom.WORKFLOW
)
runtime_parameters = {}
parameters = tool_entity.get_all_runtime_parameters()
@ -371,9 +413,17 @@ class ToolManager:
return cls._builtin_tools_labels[tool_name]
@classmethod
def user_list_providers(cls, user_id: str, tenant_id: str) -> list[UserToolProvider]:
def user_list_providers(cls, user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral) -> list[UserToolProvider]:
result_providers: dict[str, UserToolProvider] = {}
filters = []
if not typ:
filters.extend(['builtin', 'api', 'workflow'])
else:
filters.append(typ)
if 'builtin' in filters:
# get builtin providers
builtin_providers = cls.list_builtin_providers()
@ -397,25 +447,57 @@ class ToolManager:
result_providers[provider.identity.name] = user_provider
# get db api providers
if 'api' in filters:
db_api_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider). \
filter(ApiToolProvider.tenant_id == tenant_id).all()
for db_api_provider in db_api_providers:
provider_controller = ToolTransformService.api_provider_to_controller(
db_provider=db_api_provider,
)
api_provider_controllers = [{
'provider': provider,
'controller': ToolTransformService.api_provider_to_controller(provider)
} for provider in db_api_providers]
# get labels
labels = ToolLabelManager.get_tools_labels([x['controller'] for x in api_provider_controllers])
for api_provider_controller in api_provider_controllers:
user_provider = ToolTransformService.api_provider_to_user_provider(
provider_controller=provider_controller,
db_provider=db_api_provider,
decrypt_credentials=False
provider_controller=api_provider_controller['controller'],
db_provider=api_provider_controller['provider'],
decrypt_credentials=False,
labels=labels.get(api_provider_controller['controller'].provider_id, [])
)
result_providers[db_api_provider.name] = user_provider
result_providers[f'api_provider.{user_provider.name}'] = user_provider
if 'workflow' in filters:
# get workflow providers
workflow_providers: list[WorkflowToolProvider] = db.session.query(WorkflowToolProvider). \
filter(WorkflowToolProvider.tenant_id == tenant_id).all()
workflow_provider_controllers = []
for provider in workflow_providers:
try:
workflow_provider_controllers.append(
ToolTransformService.workflow_provider_to_controller(db_provider=provider)
)
except Exception as e:
# app has been deleted
pass
labels = ToolLabelManager.get_tools_labels(workflow_provider_controllers)
for provider_controller in workflow_provider_controllers:
user_provider = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=provider_controller,
labels=labels.get(provider_controller.provider_id, []),
)
result_providers[f'workflow_provider.{user_provider.name}'] = user_provider
return BuiltinToolProviderSort.sort(list(result_providers.values()))
@classmethod
def get_api_provider_controller(cls, tenant_id: str, provider_id: str) -> tuple[
ApiBasedToolProviderController, dict[str, Any]]:
ApiToolProviderController, dict[str, Any]]:
"""
get the api provider
@ -431,7 +513,7 @@ class ToolManager:
if provider is None:
raise ToolProviderNotFoundError(f'api provider {provider_id} not found')
controller = ApiBasedToolProviderController.from_db(
controller = ApiToolProviderController.from_db(
provider,
ApiProviderAuthType.API_KEY if provider.credentials['auth_type'] == 'api_key' else
ApiProviderAuthType.NONE
@ -462,7 +544,7 @@ class ToolManager:
credentials = {}
# package tool provider controller
controller = ApiBasedToolProviderController.from_db(
controller = ApiToolProviderController.from_db(
provider, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
)
# init tool configuration
@ -479,6 +561,9 @@ class ToolManager:
"content": "\ud83d\ude01"
}
# add tool labels
labels = ToolLabelManager.get_tool_labels(controller)
return jsonable_encoder({
'schema_type': provider.schema_type,
'schema': provider.schema,
@ -487,7 +572,8 @@ class ToolManager:
'description': provider.description,
'credentials': masked_credentials,
'privacy_policy': provider.privacy_policy,
'custom_disclaimer': provider.custom_disclaimer
'custom_disclaimer': provider.custom_disclaimer,
'labels': labels,
})
@classmethod
@ -519,6 +605,15 @@ class ToolManager:
"background": "#252525",
"content": "\ud83d\ude01"
}
elif provider_type == 'workflow':
provider: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.id == provider_id
).first()
if provider is None:
raise ToolProviderNotFoundError(f'workflow provider {provider_id} not found')
return json.loads(provider.icon)
else:
raise ValueError(f"provider type {provider_type} not found")

View File

@ -73,7 +73,7 @@ class ToolConfigurationManager(BaseModel):
"""
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}',
identity_id=f'{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}',
cache_type=ToolProviderCredentialsCacheType.PROVIDER
)
cached_credentials = cache.get()
@ -96,7 +96,7 @@ class ToolConfigurationManager(BaseModel):
def delete_tool_credentials_cache(self):
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}',
identity_id=f'{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}',
cache_type=ToolProviderCredentialsCacheType.PROVIDER
)
cache.delete()

View File

@ -1,14 +1,15 @@
import logging
from mimetypes import guess_extension
from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool_file_manager import ToolFileManager
logger = logging.getLogger(__name__)
class ToolFileMessageTransformer:
@staticmethod
def transform_tool_invoke_messages(messages: list[ToolInvokeMessage],
@classmethod
def transform_tool_invoke_messages(cls, messages: list[ToolInvokeMessage],
user_id: str,
tenant_id: str,
conversation_id: str) -> list[ToolInvokeMessage]:
@ -62,7 +63,7 @@ class ToolFileMessageTransformer:
mimetype=mimetype
)
url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".bin"}'
url = cls.get_tool_file_url(file.id, guess_extension(file.mimetype))
# check if file is image
if 'image' in mimetype:
@ -79,7 +80,30 @@ class ToolFileMessageTransformer:
save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {},
))
elif message.type == ToolInvokeMessage.MessageType.FILE_VAR:
file_var: FileVar = message.meta.get('file_var')
if file_var:
if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
url = cls.get_tool_file_url(file_var.related_id, file_var.extension)
if file_var.type == FileType.IMAGE:
result.append(ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
message=url,
save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {},
))
else:
result.append(ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.LINK,
message=url,
save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {},
))
else:
result.append(message)
return result
@classmethod
def get_tool_file_url(cls, tool_file_id: str, extension: str) -> str:
return f'/files/tools/{tool_file_id}{extension or ".bin"}'

View File

@ -20,12 +20,14 @@ from core.model_runtime.errors.invoke import (
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel, ModelPropertyKey
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.model.errors import InvokeModelError
from extensions.ext_database import db
from models.tools import ToolModelInvoke
class ToolModelManager:
class InvokeModelError(Exception):
pass
class ModelInvocationUtils:
@staticmethod
def get_max_llm_context_tokens(
tenant_id: str,

View File

@ -9,14 +9,14 @@ from requests import get
from yaml import YAMLError, safe_load
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiBasedToolBundle
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter
from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError
class ApiBasedToolSchemaParser:
@staticmethod
def parse_openapi_to_tool_bundle(openapi: dict, extra_info: dict = None, warning: dict = None) -> list[ApiBasedToolBundle]:
def parse_openapi_to_tool_bundle(openapi: dict, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]:
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
@ -145,7 +145,7 @@ class ApiBasedToolSchemaParser:
interface['operation']['operationId'] = f'{path}_{interface["method"]}'
bundles.append(ApiBasedToolBundle(
bundles.append(ApiToolBundle(
server_url=server_url + interface['path'],
method=interface['method'],
summary=interface['operation']['description'] if 'description' in interface['operation'] else
@ -176,7 +176,7 @@ class ApiBasedToolSchemaParser:
return ToolParameter.ToolParameterType.STRING
@staticmethod
def parse_openapi_yaml_to_tool_bundle(yaml: str, extra_info: dict = None, warning: dict = None) -> list[ApiBasedToolBundle]:
def parse_openapi_yaml_to_tool_bundle(yaml: str, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]:
"""
parse openapi yaml to tool bundle
@ -258,7 +258,7 @@ class ApiBasedToolSchemaParser:
return openapi
@staticmethod
def parse_openai_plugin_json_to_tool_bundle(json: str, extra_info: dict = None, warning: dict = None) -> list[ApiBasedToolBundle]:
def parse_openai_plugin_json_to_tool_bundle(json: str, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]:
"""
parse openapi plugin yaml to tool bundle
@ -290,7 +290,7 @@ class ApiBasedToolSchemaParser:
return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(response.text, extra_info=extra_info, warning=warning)
@staticmethod
def auto_parse_to_tool_bundle(content: str, extra_info: dict = None, warning: dict = None) -> tuple[list[ApiBasedToolBundle], str]:
def auto_parse_to_tool_bundle(content: str, extra_info: dict = None, warning: dict = None) -> tuple[list[ApiToolBundle], str]:
"""
auto parse to tool bundle

View File

@ -0,0 +1,48 @@
from core.app.app_config.entities import VariableEntity
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
class WorkflowToolConfigurationUtils:
@classmethod
def check_parameter_configurations(cls, configurations: list[dict]):
"""
check parameter configurations
"""
for configuration in configurations:
if not WorkflowToolParameterConfiguration(**configuration):
raise ValueError('invalid parameter configuration')
@classmethod
def get_workflow_graph_variables(cls, graph: dict) -> list[VariableEntity]:
"""
get workflow graph variables
"""
nodes = graph.get('nodes', [])
start_node = next(filter(lambda x: x.get('data', {}).get('type') == 'start', nodes), None)
if not start_node:
return []
return [
VariableEntity(**variable) for variable in start_node.get('data', {}).get('variables', [])
]
@classmethod
def check_is_synced(cls,
variables: list[VariableEntity],
tool_configurations: list[WorkflowToolParameterConfiguration]) -> None:
"""
check is synced
raise ValueError if not synced
"""
variable_names = [variable.variable for variable in variables]
if len(tool_configurations) != len(variables):
raise ValueError('parameter configuration mismatch, please republish the tool to update')
for parameter in tool_configurations:
if parameter.name not in variable_names:
raise ValueError('parameter configuration mismatch, please republish the tool to update')
return True

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Optional
from typing import Any, Optional
from core.app.entities.queue_entities import AppQueueEvent
from core.workflow.entities.base_node_data_entities import BaseNodeData
@ -72,6 +72,42 @@ class BaseWorkflowCallback(ABC):
"""
raise NotImplementedError
@abstractmethod
def on_workflow_iteration_started(self,
node_id: str,
node_type: NodeType,
node_run_index: int = 1,
node_data: Optional[BaseNodeData] = None,
inputs: dict = None,
predecessor_node_id: Optional[str] = None,
metadata: Optional[dict] = None) -> None:
"""
Publish iteration started
"""
raise NotImplementedError
@abstractmethod
def on_workflow_iteration_next(self, node_id: str,
node_type: NodeType,
index: int,
node_run_index: int,
output: Optional[Any],
) -> None:
"""
Publish iteration next
"""
raise NotImplementedError
@abstractmethod
def on_workflow_iteration_completed(self, node_id: str,
node_type: NodeType,
node_run_index: int,
outputs: dict) -> None:
"""
Publish iteration completed
"""
raise NotImplementedError
@abstractmethod
def on_event(self, event: AppQueueEvent) -> None:
"""

View File

@ -7,3 +7,16 @@ from pydantic import BaseModel
class BaseNodeData(ABC, BaseModel):
title: str
desc: Optional[str] = None
class BaseIterationNodeData(BaseNodeData):
start_node_id: str
class BaseIterationState(BaseModel):
iteration_node_id: str
index: int
inputs: dict
class MetaData(BaseModel):
pass
metadata: MetaData

View File

@ -21,7 +21,11 @@ class NodeType(Enum):
QUESTION_CLASSIFIER = 'question-classifier'
HTTP_REQUEST = 'http-request'
TOOL = 'tool'
VARIABLE_AGGREGATOR = 'variable-aggregator'
VARIABLE_ASSIGNER = 'variable-assigner'
LOOP = 'loop'
ITERATION = 'iteration'
PARAMETER_EXTRACTOR = 'parameter-extractor'
@classmethod
def value_of(cls, value: str) -> 'NodeType':
@ -68,6 +72,8 @@ class NodeRunMetadataKey(Enum):
TOTAL_PRICE = 'total_price'
CURRENCY = 'currency'
TOOL_INFO = 'tool_info'
ITERATION_ID = 'iteration_id'
ITERATION_INDEX = 'iteration_index'
class NodeRunResult(BaseModel):

View File

@ -90,3 +90,12 @@ class VariablePool:
raise ValueError(f'Invalid value type: {target_value_type.value}')
return value
def clear_node_variables(self, node_id: str) -> None:
"""
Clear node variables
:param node_id: node id
:return:
"""
if node_id in self.variables_mapping:
self.variables_mapping.pop(node_id)

View File

@ -1,5 +1,9 @@
from typing import Optional
from pydantic import BaseModel
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.base_node_data_entities import BaseIterationState
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode, UserFrom
@ -22,6 +26,9 @@ class WorkflowRunState:
workflow_type: WorkflowType
user_id: str
user_from: UserFrom
invoke_from: InvokeFrom
workflow_call_depth: int
start_at: float
variable_pool: VariablePool
@ -30,20 +37,37 @@ class WorkflowRunState:
workflow_nodes_and_results: list[WorkflowNodeAndResult]
class NodeRun(BaseModel):
node_id: str
iteration_node_id: str
workflow_node_runs: list[NodeRun]
workflow_node_steps: int
current_iteration_state: Optional[BaseIterationState]
def __init__(self, workflow: Workflow,
start_at: float,
variable_pool: VariablePool,
user_id: str,
user_from: UserFrom):
user_from: UserFrom,
invoke_from: InvokeFrom,
workflow_call_depth: int):
self.workflow_id = workflow.id
self.tenant_id = workflow.tenant_id
self.app_id = workflow.app_id
self.workflow_type = WorkflowType.value_of(workflow.type)
self.user_id = user_id
self.user_from = user_from
self.invoke_from = invoke_from
self.workflow_call_depth = workflow_call_depth
self.start_at = start_at
self.variable_pool = variable_pool
self.total_tokens = 0
self.workflow_nodes_and_results = []
self.current_iteration_state = None
self.workflow_node_steps = 1
self.workflow_node_runs = []

View File

@ -2,8 +2,9 @@ from abc import ABC, abstractmethod
from enum import Enum
from typing import Optional
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
@ -37,6 +38,9 @@ class BaseNode(ABC):
workflow_id: str
user_id: str
user_from: UserFrom
invoke_from: InvokeFrom
workflow_call_depth: int
node_id: str
node_data: BaseNodeData
@ -49,13 +53,17 @@ class BaseNode(ABC):
workflow_id: str,
user_id: str,
user_from: UserFrom,
invoke_from: InvokeFrom,
config: dict,
callbacks: list[BaseWorkflowCallback] = None) -> None:
callbacks: list[BaseWorkflowCallback] = None,
workflow_call_depth: int = 0) -> None:
self.tenant_id = tenant_id
self.app_id = app_id
self.workflow_id = workflow_id
self.user_id = user_id
self.user_from = user_from
self.invoke_from = invoke_from
self.workflow_call_depth = workflow_call_depth
self.node_id = config.get("id")
if not self.node_id:
@ -140,3 +148,38 @@ class BaseNode(ABC):
:return:
"""
return self._node_type
class BaseIterationNode(BaseNode):
@abstractmethod
def _run(self, variable_pool: VariablePool) -> BaseIterationState:
"""
Run node
:param variable_pool: variable pool
:return:
"""
raise NotImplementedError
def run(self, variable_pool: VariablePool) -> BaseIterationState:
"""
Run node entry
:param variable_pool: variable pool
:return:
"""
return self._run(variable_pool=variable_pool)
def get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str:
"""
Get next iteration start node id based on the graph.
:param graph: graph
:return: next node id
"""
return self._get_next_iteration(variable_pool, state)
@abstractmethod
def _get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str:
"""
Get next iteration start node id based on the graph.
:param graph: graph
:return: next node id
"""
raise NotImplementedError

View File

@ -0,0 +1,39 @@
from typing import Any, Optional
from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState
class IterationNodeData(BaseIterationNodeData):
"""
Iteration Node Data.
"""
parent_loop_id: Optional[str] # redundant field, not used currently
iterator_selector: list[str] # variable selector
output_selector: list[str] # output selector
class IterationState(BaseIterationState):
"""
Iteration State.
"""
outputs: list[Any] = None
current_output: Optional[Any] = None
class MetaData(BaseIterationState.MetaData):
"""
Data.
"""
iterator_length: int
def get_last_output(self) -> Optional[Any]:
"""
Get last output.
"""
if self.outputs:
return self.outputs[-1]
return None
def get_current_output(self) -> Optional[Any]:
"""
Get current output.
"""
return self.current_output

View File

@ -0,0 +1,119 @@
from typing import cast
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities.base_node_data_entities import BaseIterationState
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseIterationNode
from core.workflow.nodes.iteration.entities import IterationNodeData, IterationState
from models.workflow import WorkflowNodeExecutionStatus
class IterationNode(BaseIterationNode):
"""
Iteration Node.
"""
_node_data_cls = IterationNodeData
_node_type = NodeType.ITERATION
def _run(self, variable_pool: VariablePool) -> BaseIterationState:
"""
Run the node.
"""
iterator = variable_pool.get_variable_value(cast(IterationNodeData, self.node_data).iterator_selector)
if not isinstance(iterator, list):
raise ValueError(f"Invalid iterator value: {iterator}, please provide a list.")
state = IterationState(iteration_node_id=self.node_id, index=-1, inputs={
'iterator_selector': iterator
}, outputs=[], metadata=IterationState.MetaData(
iterator_length=len(iterator) if iterator is not None else 0
))
self._set_current_iteration_variable(variable_pool, state)
return state
def _get_next_iteration(self, variable_pool: VariablePool, state: IterationState) -> NodeRunResult | str:
"""
Get next iteration start node id based on the graph.
:param graph: graph
:return: next node id
"""
# resolve current output
self._resolve_current_output(variable_pool, state)
# move to next iteration
self._next_iteration(variable_pool, state)
node_data = cast(IterationNodeData, self.node_data)
if self._reached_iteration_limit(variable_pool, state):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
'output': jsonable_encoder(state.outputs)
}
)
return node_data.start_node_id
def _set_current_iteration_variable(self, variable_pool: VariablePool, state: IterationState):
"""
Set current iteration variable.
:variable_pool: variable pool
"""
node_data = cast(IterationNodeData, self.node_data)
variable_pool.append_variable(self.node_id, ['index'], state.index)
# get the iterator value
iterator = variable_pool.get_variable_value(node_data.iterator_selector)
if iterator is None or not isinstance(iterator, list):
return
if state.index < len(iterator):
variable_pool.append_variable(self.node_id, ['item'], iterator[state.index])
def _next_iteration(self, variable_pool: VariablePool, state: IterationState):
"""
Move to next iteration.
:param variable_pool: variable pool
"""
state.index += 1
self._set_current_iteration_variable(variable_pool, state)
def _reached_iteration_limit(self, variable_pool: VariablePool, state: IterationState):
"""
Check if iteration limit is reached.
:return: True if iteration limit is reached, False otherwise
"""
node_data = cast(IterationNodeData, self.node_data)
iterator = variable_pool.get_variable_value(node_data.iterator_selector)
if iterator is None or not isinstance(iterator, list):
return True
return state.index >= len(iterator)
def _resolve_current_output(self, variable_pool: VariablePool, state: IterationState):
"""
Resolve current output.
:param variable_pool: variable pool
"""
output_selector = cast(IterationNodeData, self.node_data).output_selector
output = variable_pool.get_variable_value(output_selector)
# clear the output for this iteration
variable_pool.append_variable(self.node_id, output_selector[1:], None)
state.current_output = output
if output is not None:
state.outputs.append(output)
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: IterationNodeData) -> dict[str, list[str]]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
return {
'input_selector': node_data.iterator_selector,
}

View File

View File

@ -0,0 +1,13 @@
from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState
class LoopNodeData(BaseIterationNodeData):
"""
Loop Node Data.
"""
class LoopState(BaseIterationState):
"""
Loop State.
"""

View File

@ -0,0 +1,20 @@
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseIterationNode
from core.workflow.nodes.loop.entities import LoopNodeData, LoopState
class LoopNode(BaseIterationNode):
"""
Loop Node.
"""
_node_data_cls = LoopNodeData
_node_type = NodeType.LOOP
def _run(self, variable_pool: VariablePool) -> LoopState:
return super()._run(variable_pool)
def _get_next_iteration(self, variable_loop: VariablePool) -> NodeRunResult | str:
"""
Get next iteration start node id based on the graph.
"""

Some files were not shown because too many files have changed in this diff Show More