diff --git a/.gitignore b/.gitignore index 22aba65364..29374d13dc 100644 --- a/.gitignore +++ b/.gitignore @@ -153,6 +153,9 @@ docker-legacy/volumes/etcd/* docker-legacy/volumes/minio/* docker-legacy/volumes/milvus/* docker-legacy/volumes/chroma/* +docker-legacy/volumes/opensearch/data/* +docker-legacy/volumes/pgvectors/data/* +docker-legacy/volumes/pgvector/data/* docker/volumes/app/storage/* docker/volumes/certbot/* @@ -164,6 +167,12 @@ docker/volumes/etcd/* docker/volumes/minio/* docker/volumes/milvus/* docker/volumes/chroma/* +docker/volumes/opensearch/data/* +docker/volumes/myscale/data/* +docker/volumes/myscale/log/* +docker/volumes/unstructured/* +docker/volumes/pgvector/data/* +docker/volumes/pgvecto_rs/data/* docker/nginx/conf.d/default.conf docker/middleware.env diff --git a/api/app.py b/api/app.py index ad219ca0d6..91a49337fc 100644 --- a/api/app.py +++ b/api/app.py @@ -164,7 +164,7 @@ def initialize_extensions(app): @login_manager.request_loader def load_user_from_request(request_from_flask_login): """Load user based on the request.""" - if request.blueprint not in ["console", "inner_api"]: + if request.blueprint not in {"console", "inner_api"}: return None # Check if the user_id contains a dot, indicating the old format auth_header = request.headers.get("Authorization", "") diff --git a/api/commands.py b/api/commands.py index 887270b43e..3a6b4963cf 100644 --- a/api/commands.py +++ b/api/commands.py @@ -140,9 +140,9 @@ def reset_encrypt_key_pair(): @click.command("vdb-migrate", help="migrate vector db.") @click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.") def vdb_migrate(scope: str): - if scope in ["knowledge", "all"]: + if scope in {"knowledge", "all"}: migrate_knowledge_vector_database() - if scope in ["annotation", "all"]: + if scope in {"annotation", "all"}: migrate_annotation_vector_database() diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 7332758e83..c1ef05a488 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -94,7 +94,7 @@ class ChatMessageTextApi(Resource): message_id = args.get("message_id", None) text = args.get("text", None) if ( - app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] + app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value} and app_model.workflow and app_model.workflow.features_dict ): diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 1df0f5de9d..ad0c0580ae 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -71,7 +71,7 @@ class OAuthCallback(Resource): account = _generate_account(provider, user_info) # Check account status - if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value: + if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}: return {"error": "Account is banned or closed."}, 403 if account.status == AccountStatus.PENDING.value: diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 076f3cd44d..829ef11e52 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -354,7 +354,7 @@ class DocumentIndexingEstimateApi(DocumentResource): document_id = str(document_id) document = self.get_document(dataset_id, document_id) - if document.indexing_status in ["completed", "error"]: + if document.indexing_status in {"completed", "error"}: raise DocumentAlreadyFinishedError() data_process_rule = document.dataset_process_rule @@ -421,7 +421,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): info_list = [] extract_settings = [] for document in documents: - if document.indexing_status in ["completed", "error"]: + if document.indexing_status in {"completed", "error"}: raise DocumentAlreadyFinishedError() data_source_info = document.data_source_info_dict # format document files info @@ -665,7 +665,7 @@ class DocumentProcessingApi(DocumentResource): db.session.commit() elif action == "resume": - if document.indexing_status not in ["paused", "error"]: + if document.indexing_status not in {"paused", "error"}: raise InvalidActionError("Document not in paused or error state.") document.paused_by = None diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index 2eb7e04490..9690677f61 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -81,7 +81,7 @@ class ChatTextApi(InstalledAppResource): message_id = args.get("message_id", None) text = args.get("text", None) if ( - app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] + app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value} and app_model.workflow and app_model.workflow.features_dict ): diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index c039e8bca5..f464692098 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -92,7 +92,7 @@ class ChatApi(InstalledAppResource): def post(self, installed_app): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() @@ -140,7 +140,7 @@ class ChatStopApi(InstalledAppResource): def post(self, installed_app, task_id): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index 2918024b64..6f9d7769b9 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -20,7 +20,7 @@ class ConversationListApi(InstalledAppResource): def get(self, installed_app): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() @@ -50,7 +50,7 @@ class ConversationApi(InstalledAppResource): def delete(self, installed_app, c_id): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) @@ -68,7 +68,7 @@ class ConversationRenameApi(InstalledAppResource): def post(self, installed_app, c_id): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) @@ -90,7 +90,7 @@ class ConversationPinApi(InstalledAppResource): def patch(self, installed_app, c_id): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) @@ -107,7 +107,7 @@ class ConversationUnPinApi(InstalledAppResource): def patch(self, installed_app, c_id): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 3f1e64a247..408afc33a0 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -31,7 +31,7 @@ class InstalledAppsListApi(Resource): "app_owner_tenant_id": installed_app.app_owner_tenant_id, "is_pinned": installed_app.is_pinned, "last_used_at": installed_app.last_used_at, - "editable": current_user.role in ["owner", "admin"], + "editable": current_user.role in {"owner", "admin"}, "uninstallable": current_tenant_id == installed_app.app_owner_tenant_id, } for installed_app in installed_apps diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index f5eb185172..0e0238556c 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -40,7 +40,7 @@ class MessageListApi(InstalledAppResource): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() @@ -125,7 +125,7 @@ class MessageSuggestedQuestionApi(InstalledAppResource): def get(self, installed_app, message_id): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() message_id = str(message_id) diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index ad55b04043..aab7dd7888 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -43,7 +43,7 @@ class AppParameterApi(InstalledAppResource): """Retrieve app parameters.""" app_model = installed_app.app - if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: workflow = app_model.workflow if workflow is None: raise AppUnavailableError() diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 623f0b8b74..af3ebc099b 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -194,7 +194,7 @@ class WebappLogoWorkspaceApi(Resource): raise TooManyFilesError() extension = file.filename.split(".")[-1] - if extension.lower() not in ["svg", "png"]: + if extension.lower() not in {"svg", "png"}: raise UnsupportedFileTypeError() try: diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index ecc2d73deb..f7c091217b 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -42,7 +42,7 @@ class AppParameterApi(Resource): @marshal_with(parameters_fields) def get(self, app_model: App): """Retrieve app parameters.""" - if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: workflow = app_model.workflow if workflow is None: raise AppUnavailableError() diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 8d8ca8d78c..5db4163647 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -79,7 +79,7 @@ class TextApi(Resource): message_id = args.get("message_id", None) text = args.get("text", None) if ( - app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] + app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value} and app_model.workflow and app_model.workflow.features_dict ): diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index f1771baf31..8d8e356c4c 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -96,7 +96,7 @@ class ChatApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() @@ -144,7 +144,7 @@ class ChatStopApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser, task_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 734027a1c5..527ef4ecd3 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -18,7 +18,7 @@ class ConversationApi(Resource): @marshal_with(conversation_infinite_scroll_pagination_fields) def get(self, app_model: App, end_user: EndUser): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() @@ -52,7 +52,7 @@ class ConversationDetailApi(Resource): @marshal_with(simple_conversation_fields) def delete(self, app_model: App, end_user: EndUser, c_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) @@ -69,7 +69,7 @@ class ConversationRenameApi(Resource): @marshal_with(simple_conversation_fields) def post(self, app_model: App, end_user: EndUser, c_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index b39aaf7dd8..e54e6f4903 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -76,7 +76,7 @@ class MessageListApi(Resource): @marshal_with(message_infinite_scroll_pagination_fields) def get(self, app_model: App, end_user: EndUser): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() @@ -117,7 +117,7 @@ class MessageSuggestedApi(Resource): def get(self, app_model: App, end_user: EndUser, message_id): message_id = str(message_id) app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() try: diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index aabca93338..20b4e4674c 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -41,7 +41,7 @@ class AppParameterApi(WebApiResource): @marshal_with(parameters_fields) def get(self, app_model: App, end_user): """Retrieve app parameters.""" - if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: workflow = app_model.workflow if workflow is None: raise AppUnavailableError() diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index 49c467dbe1..23550efe2e 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -78,7 +78,7 @@ class TextApi(WebApiResource): message_id = args.get("message_id", None) text = args.get("text", None) if ( - app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] + app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value} and app_model.workflow and app_model.workflow.features_dict ): diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 0837eedfb0..115492b796 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -87,7 +87,7 @@ class CompletionStopApi(WebApiResource): class ChatApi(WebApiResource): def post(self, app_model, end_user): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() @@ -136,7 +136,7 @@ class ChatApi(WebApiResource): class ChatStopApi(WebApiResource): def post(self, app_model, end_user, task_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index 6bbfa94c27..c3b0cd4f44 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -18,7 +18,7 @@ class ConversationListApi(WebApiResource): @marshal_with(conversation_infinite_scroll_pagination_fields) def get(self, app_model, end_user): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() @@ -56,7 +56,7 @@ class ConversationListApi(WebApiResource): class ConversationApi(WebApiResource): def delete(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) @@ -73,7 +73,7 @@ class ConversationRenameApi(WebApiResource): @marshal_with(simple_conversation_fields) def post(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) @@ -92,7 +92,7 @@ class ConversationRenameApi(WebApiResource): class ConversationPinApi(WebApiResource): def patch(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) @@ -108,7 +108,7 @@ class ConversationPinApi(WebApiResource): class ConversationUnPinApi(WebApiResource): def patch(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 56aaaa930a..0d4047f4ef 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -78,7 +78,7 @@ class MessageListApi(WebApiResource): @marshal_with(message_infinite_scroll_pagination_fields) def get(self, app_model, end_user): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() @@ -160,7 +160,7 @@ class MessageMoreLikeThisApi(WebApiResource): class MessageSuggestedQuestionApi(WebApiResource): def get(self, app_model, end_user, message_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotCompletionAppError() message_id = str(message_id) diff --git a/api/core/agent/output_parser/cot_output_parser.py b/api/core/agent/output_parser/cot_output_parser.py index 1a161677dd..d04e38777a 100644 --- a/api/core/agent/output_parser/cot_output_parser.py +++ b/api/core/agent/output_parser/cot_output_parser.py @@ -90,7 +90,7 @@ class CotAgentOutputParser: if not in_code_block and not in_json: if delta.lower() == action_str[action_idx] and action_idx == 0: - if last_character not in ["\n", " ", ""]: + if last_character not in {"\n", " ", ""}: index += steps yield delta continue @@ -117,7 +117,7 @@ class CotAgentOutputParser: action_idx = 0 if delta.lower() == thought_str[thought_idx] and thought_idx == 0: - if last_character not in ["\n", " ", ""]: + if last_character not in {"\n", " ", ""}: index += steps yield delta continue diff --git a/api/core/app/app_config/base_app_config_manager.py b/api/core/app/app_config/base_app_config_manager.py index 0fd2a779a4..24d80f9cdd 100644 --- a/api/core/app/app_config/base_app_config_manager.py +++ b/api/core/app/app_config/base_app_config_manager.py @@ -29,7 +29,7 @@ class BaseAppConfigManager: additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(config=config_dict) additional_features.file_upload = FileUploadConfigManager.convert( - config=config_dict, is_vision=app_mode in [AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT] + config=config_dict, is_vision=app_mode in {AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT} ) additional_features.opening_statement, additional_features.suggested_questions = ( diff --git a/api/core/app/app_config/easy_ui_based_app/agent/manager.py b/api/core/app/app_config/easy_ui_based_app/agent/manager.py index 6e89f19508..f503543d7b 100644 --- a/api/core/app/app_config/easy_ui_based_app/agent/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/agent/manager.py @@ -18,7 +18,7 @@ class AgentConfigManager: if agent_strategy == "function_call": strategy = AgentEntity.Strategy.FUNCTION_CALLING - elif agent_strategy == "cot" or agent_strategy == "react": + elif agent_strategy in {"cot", "react"}: strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT else: # old configs, try to detect default strategy @@ -43,10 +43,10 @@ class AgentConfigManager: agent_tools.append(AgentToolEntity(**agent_tool_properties)) - if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in [ + if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in { "react_router", "router", - ]: + }: agent_prompt = agent_dict.get("prompt", None) or {} # check model mode model_mode = config.get("model", {}).get("mode", "completion") diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index ff131b62e2..a22395b8e3 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -167,7 +167,7 @@ class DatasetConfigManager: config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value has_datasets = False - if config["agent_mode"]["strategy"] in [PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value]: + if config["agent_mode"]["strategy"] in {PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value}: for tool in config["agent_mode"]["tools"]: key = list(tool.keys())[0] if key == "dataset": diff --git a/api/core/app/app_config/easy_ui_based_app/variables/manager.py b/api/core/app/app_config/easy_ui_based_app/variables/manager.py index e70522f21d..a1bfde3208 100644 --- a/api/core/app/app_config/easy_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/variables/manager.py @@ -42,12 +42,12 @@ class BasicVariablesConfigManager: variable=variable["variable"], type=variable["type"], config=variable["config"] ) ) - elif variable_type in [ + elif variable_type in { VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH, VariableEntityType.NUMBER, VariableEntityType.SELECT, - ]: + }: variable = variables[variable_type] variable_entities.append( VariableEntity( @@ -97,7 +97,7 @@ class BasicVariablesConfigManager: variables = [] for item in config["user_input_form"]: key = list(item.keys())[0] - if key not in ["text-input", "select", "paragraph", "number", "external_data_tool"]: + if key not in {"text-input", "select", "paragraph", "number", "external_data_tool"}: raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'") form_item = item[key] diff --git a/api/core/app/app_config/features/file_upload/manager.py b/api/core/app/app_config/features/file_upload/manager.py index 5f7fc99151..7a275cb532 100644 --- a/api/core/app/app_config/features/file_upload/manager.py +++ b/api/core/app/app_config/features/file_upload/manager.py @@ -54,14 +54,14 @@ class FileUploadConfigManager: if is_vision: detail = config["file_upload"]["image"]["detail"] - if detail not in ["high", "low"]: + if detail not in {"high", "low"}: raise ValueError("detail must be in ['high', 'low']") transfer_methods = config["file_upload"]["image"]["transfer_methods"] if not isinstance(transfer_methods, list): raise ValueError("transfer_methods must be of list type") for method in transfer_methods: - if method not in ["remote_url", "local_file"]: + if method not in {"remote_url", "local_file"}: raise ValueError("transfer_methods must be in ['remote_url', 'local_file']") return config, ["file_upload"] diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index c4cdba6441..1bca1e1b71 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -73,7 +73,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): raise ValueError("Workflow not initialized") user_id = None - if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: + if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() if end_user: user_id = end_user.session_id @@ -175,7 +175,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): user_id=self.application_generate_entity.user_id, user_from=( UserFrom.ACCOUNT - if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] + if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else UserFrom.END_USER ), invoke_from=self.application_generate_entity.invoke_from, diff --git a/api/core/app/apps/base_app_generate_response_converter.py b/api/core/app/apps/base_app_generate_response_converter.py index 73025d99d0..c6855ac854 100644 --- a/api/core/app/apps/base_app_generate_response_converter.py +++ b/api/core/app/apps/base_app_generate_response_converter.py @@ -16,7 +16,7 @@ class AppGenerateResponseConverter(ABC): def convert( cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom ) -> dict[str, Any] | Generator[str, Any, None]: - if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: + if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}: if isinstance(response, AppBlockingResponse): return cls.convert_blocking_full_response(response) else: diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index ce6f7d4338..15be7000fc 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -22,11 +22,11 @@ class BaseAppGenerator: return var.default or "" if ( var.type - in ( + in { VariableEntityType.TEXT_INPUT, VariableEntityType.SELECT, VariableEntityType.PARAGRAPH, - ) + } and user_input_value and not isinstance(user_input_value, str) ): @@ -44,7 +44,7 @@ class BaseAppGenerator: options = var.options or [] if user_input_value not in options: raise ValueError(f"{var.variable} in input form must be one of the following: {options}") - elif var.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH): + elif var.type in {VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH}: if var.max_length and user_input_value and len(user_input_value) > var.max_length: raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters") diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index f3c3199354..4c4d282e99 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -32,7 +32,7 @@ class AppQueueManager: self._user_id = user_id self._invoke_from = invoke_from - user_prefix = "account" if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end-user" + user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user" redis_client.setex( AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}" ) @@ -118,7 +118,7 @@ class AppQueueManager: if result is None: return - user_prefix = "account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end-user" + user_prefix = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user" if result.decode("utf-8") != f"{user_prefix}-{user_id}": return diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index f629c5c8b7..c4db95cbd0 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -148,7 +148,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): # get from source end_user_id = None account_id = None - if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: + if application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: from_source = "api" end_user_id = application_generate_entity.user_id else: @@ -165,11 +165,11 @@ class MessageBasedAppGenerator(BaseAppGenerator): model_provider = application_generate_entity.model_conf.provider model_id = application_generate_entity.model_conf.model override_model_configs = None - if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS and app_config.app_mode in [ + if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS and app_config.app_mode in { AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION, - ]: + }: override_model_configs = app_config.app_model_config_dict # get conversation introduction diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 81c8463dd5..22ec228fa7 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -53,7 +53,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): app_config = cast(WorkflowAppConfig, app_config) user_id = None - if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: + if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() if end_user: user_id = end_user.session_id @@ -113,7 +113,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): user_id=self.application_generate_entity.user_id, user_from=( UserFrom.ACCOUNT - if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] + if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else UserFrom.END_USER ), invoke_from=self.application_generate_entity.invoke_from, diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py index 2e37a126c3..77b6bb554c 100644 --- a/api/core/app/features/annotation_reply/annotation_reply.py +++ b/api/core/app/features/annotation_reply/annotation_reply.py @@ -63,7 +63,7 @@ class AnnotationReplyFeature: score = documents[0].metadata["score"] annotation = AppAnnotationService.get_annotation_by_id(annotation_id) if annotation: - if invoke_from in [InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP]: + if invoke_from in {InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP}: from_source = "api" else: from_source = "console" diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 659503301e..8f834b6458 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -372,7 +372,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan self._message, application_generate_entity=self._application_generate_entity, conversation=self._conversation, - is_first_message=self._application_generate_entity.app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT] + is_first_message=self._application_generate_entity.app_config.app_mode in {AppMode.AGENT_CHAT, AppMode.CHAT} and self._application_generate_entity.conversation_id is None, extras=self._application_generate_entity.extras, ) diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index a030d5dcbf..f10189798f 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -383,7 +383,7 @@ class WorkflowCycleManage: :param workflow_node_execution: workflow node execution :return: """ - if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]: + if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: return None response = NodeStartStreamResponse( @@ -430,7 +430,7 @@ class WorkflowCycleManage: :param workflow_node_execution: workflow node execution :return: """ - if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]: + if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: return None return NodeFinishStreamResponse( diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 6d5393ce5c..7cf472d984 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -29,7 +29,7 @@ class DatasetIndexToolCallbackHandler: source="app", source_app_id=self._app_id, created_by_role=( - "account" if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user" + "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user" ), created_by=self._user_id, ) diff --git a/api/core/embedding/cached_embedding.py b/api/core/embedding/cached_embedding.py index 4cc793b0d7..8ce12fd59f 100644 --- a/api/core/embedding/cached_embedding.py +++ b/api/core/embedding/cached_embedding.py @@ -65,7 +65,7 @@ class CacheEmbedding(Embeddings): except IntegrityError: db.session.rollback() except Exception as e: - logging.exception("Failed transform embedding: ", e) + logging.exception("Failed transform embedding: %s", e) cache_embeddings = [] try: for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings): @@ -85,7 +85,7 @@ class CacheEmbedding(Embeddings): db.session.rollback() except Exception as ex: db.session.rollback() - logger.error("Failed to embed documents: ", ex) + logger.error("Failed to embed documents: %s", ex) raise ex return text_embeddings @@ -116,10 +116,7 @@ class CacheEmbedding(Embeddings): # Transform to string encoded_str = encoded_vector.decode("utf-8") redis_client.setex(embedding_cache_key, 600, encoded_str) - - except IntegrityError: - db.session.rollback() - except: - logging.exception("Failed to add embedding to redis") + except Exception as ex: + logging.exception("Failed to add embedding to redis %s", ex) return embedding_results diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index eeb1dbfda0..af20df41b1 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -292,7 +292,7 @@ class IndexingRunner: self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict ) -> list[Document]: # load file - if dataset_document.data_source_type not in ["upload_file", "notion_import", "website_crawl"]: + if dataset_document.data_source_type not in {"upload_file", "notion_import", "website_crawl"}: return [] data_source_info = dataset_document.data_source_info_dict diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index a14d237a12..d3185c3b11 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -52,7 +52,7 @@ class TokenBufferMemory: files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() if files: file_extra_config = None - if self.conversation.mode not in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + if self.conversation.mode not in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config) else: if message.workflow_run_id: diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index d898ef1490..52ea787c3a 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -27,17 +27,17 @@ class ModelType(Enum): :return: model type """ - if origin_model_type == "text-generation" or origin_model_type == cls.LLM.value: + if origin_model_type in {"text-generation", cls.LLM.value}: return cls.LLM - elif origin_model_type == "embeddings" or origin_model_type == cls.TEXT_EMBEDDING.value: + elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING.value}: return cls.TEXT_EMBEDDING - elif origin_model_type == "reranking" or origin_model_type == cls.RERANK.value: + elif origin_model_type in {"reranking", cls.RERANK.value}: return cls.RERANK - elif origin_model_type == "speech2text" or origin_model_type == cls.SPEECH2TEXT.value: + elif origin_model_type in {"speech2text", cls.SPEECH2TEXT.value}: return cls.SPEECH2TEXT - elif origin_model_type == "tts" or origin_model_type == cls.TTS.value: + elif origin_model_type in {"tts", cls.TTS.value}: return cls.TTS - elif origin_model_type == "text2img" or origin_model_type == cls.TEXT2IMG.value: + elif origin_model_type in {"text2img", cls.TEXT2IMG.value}: return cls.TEXT2IMG elif origin_model_type == cls.MODERATION.value: return cls.MODERATION diff --git a/api/core/model_runtime/model_providers/anthropic/llm/llm.py b/api/core/model_runtime/model_providers/anthropic/llm/llm.py index ff741e0240..46e1b415b8 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/llm.py +++ b/api/core/model_runtime/model_providers/anthropic/llm/llm.py @@ -494,7 +494,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): mime_type = data_split[0].replace("data:", "") base64_data = data_split[1] - if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]: + if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}: raise ValueError( f"Unsupported image type {mime_type}, " f"only support image/jpeg, image/png, image/gif, and image/webp" diff --git a/api/core/model_runtime/model_providers/azure_openai/tts/tts.py b/api/core/model_runtime/model_providers/azure_openai/tts/tts.py index 8db044b24d..af178703a0 100644 --- a/api/core/model_runtime/model_providers/azure_openai/tts/tts.py +++ b/api/core/model_runtime/model_providers/azure_openai/tts/tts.py @@ -85,14 +85,14 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): for i in range(len(sentences)) ] for future in futures: - yield from future.result().__enter__().iter_bytes(1024) + yield from future.result().__enter__().iter_bytes(1024) # noqa:PLC2801 else: response = client.audio.speech.with_streaming_response.create( model=model, voice=voice, response_format="mp3", input=content_text.strip() ) - yield from response.__enter__().iter_bytes(1024) + yield from response.__enter__().iter_bytes(1024) # noqa:PLC2801 except Exception as ex: raise InvokeBadRequestError(str(ex)) diff --git a/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-haiku-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-haiku-v1.yaml index fe5f54de13..24a65ef1bb 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-haiku-v1.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-haiku-v1.yaml @@ -1,6 +1,6 @@ model: eu.anthropic.claude-3-haiku-20240307-v1:0 label: - en_US: Claude 3 Haiku(Cross Region Inference) + en_US: Claude 3 Haiku(EU.Cross Region Inference) model_type: llm features: - agent-thought diff --git a/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-sonnet-v1.5.yaml b/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-sonnet-v1.5.yaml index 9f8d029a57..e3d25c7d8f 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-sonnet-v1.5.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-sonnet-v1.5.yaml @@ -1,6 +1,6 @@ model: eu.anthropic.claude-3-5-sonnet-20240620-v1:0 label: - en_US: Claude 3.5 Sonnet(Cross Region Inference) + en_US: Claude 3.5 Sonnet(EU.Cross Region Inference) model_type: llm features: - agent-thought diff --git a/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-sonnet-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-sonnet-v1.yaml index bfaf5abb8e..9a06a4ad6d 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-sonnet-v1.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-sonnet-v1.yaml @@ -1,6 +1,6 @@ model: eu.anthropic.claude-3-sonnet-20240229-v1:0 label: - en_US: Claude 3 Sonnet(Cross Region Inference) + en_US: Claude 3 Sonnet(EU.Cross Region Inference) model_type: llm features: - agent-thought diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py index c34c20ced3..77bab0c294 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/llm.py +++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py @@ -1,8 +1,8 @@ # standard import import base64 -import io import json import logging +import mimetypes from collections.abc import Generator from typing import Optional, Union, cast @@ -17,7 +17,6 @@ from botocore.exceptions import ( ServiceNotInRegionError, UnknownServiceError, ) -from PIL.Image import Image # local import from core.model_runtime.callbacks.base_callback import Callback @@ -443,8 +442,9 @@ class BedrockLargeLanguageModel(LargeLanguageModel): try: url = message_content.data image_content = requests.get(url).content - with Image.open(io.BytesIO(image_content)) as img: - mime_type = f"image/{img.format.lower()}" + if "?" in url: + url = url.split("?")[0] + mime_type, _ = mimetypes.guess_type(url) base64_data = base64.b64encode(image_content).decode("utf-8") except Exception as ex: raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") @@ -454,7 +454,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): base64_data = data_split[1] image_content = base64.b64decode(base64_data) - if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]: + if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}: raise ValueError( f"Unsupported image type {mime_type}, " f"only support image/jpeg, image/png, image/gif, and image/webp" @@ -886,16 +886,16 @@ class BedrockLargeLanguageModel(LargeLanguageModel): if error_code == "AccessDeniedException": return InvokeAuthorizationError(error_msg) - elif error_code in ["ResourceNotFoundException", "ValidationException"]: + elif error_code in {"ResourceNotFoundException", "ValidationException"}: return InvokeBadRequestError(error_msg) - elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]: + elif error_code in {"ThrottlingException", "ServiceQuotaExceededException"}: return InvokeRateLimitError(error_msg) - elif error_code in [ + elif error_code in { "ModelTimeoutException", "ModelErrorException", "InternalServerException", "ModelNotReadyException", - ]: + }: return InvokeServerUnavailableError(error_msg) elif error_code == "ModelStreamErrorException": return InvokeConnectionError(error_msg) diff --git a/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-haiku-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-haiku-v1.yaml index 58c1f05779..9247f46974 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-haiku-v1.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-haiku-v1.yaml @@ -1,6 +1,6 @@ model: us.anthropic.claude-3-haiku-20240307-v1:0 label: - en_US: Claude 3 Haiku(Cross Region Inference) + en_US: Claude 3 Haiku(US.Cross Region Inference) model_type: llm features: - agent-thought diff --git a/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-opus-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-opus-v1.yaml index 6b9e1ec067..f9854d51f0 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-opus-v1.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-opus-v1.yaml @@ -1,6 +1,6 @@ model: us.anthropic.claude-3-opus-20240229-v1:0 label: - en_US: Claude 3 Opus(Cross Region Inference) + en_US: Claude 3 Opus(US.Cross Region Inference) model_type: llm features: - agent-thought diff --git a/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-sonnet-v1.5.yaml b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-sonnet-v1.5.yaml index f1e0d6c5a2..fbcab2d5f3 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-sonnet-v1.5.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-sonnet-v1.5.yaml @@ -1,6 +1,6 @@ model: us.anthropic.claude-3-5-sonnet-20240620-v1:0 label: - en_US: Claude 3.5 Sonnet(Cross Region Inference) + en_US: Claude 3.5 Sonnet(US.Cross Region Inference) model_type: llm features: - agent-thought diff --git a/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-sonnet-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-sonnet-v1.yaml index dce50bf4b5..9f5a1501f0 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-sonnet-v1.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-sonnet-v1.yaml @@ -1,6 +1,6 @@ model: us.anthropic.claude-3-sonnet-20240229-v1:0 label: - en_US: Claude 3 Sonnet(Cross Region Inference) + en_US: Claude 3 Sonnet(US.Cross Region Inference) model_type: llm features: - agent-thought diff --git a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py index 2d898e3aaa..251170d1ae 100644 --- a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py @@ -186,16 +186,16 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel): if error_code == "AccessDeniedException": return InvokeAuthorizationError(error_msg) - elif error_code in ["ResourceNotFoundException", "ValidationException"]: + elif error_code in {"ResourceNotFoundException", "ValidationException"}: return InvokeBadRequestError(error_msg) - elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]: + elif error_code in {"ThrottlingException", "ServiceQuotaExceededException"}: return InvokeRateLimitError(error_msg) - elif error_code in [ + elif error_code in { "ModelTimeoutException", "ModelErrorException", "InternalServerException", "ModelNotReadyException", - ]: + }: return InvokeServerUnavailableError(error_msg) elif error_code == "ModelStreamErrorException": return InvokeConnectionError(error_msg) diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index b10d0edba3..3fc6787a44 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -6,10 +6,10 @@ from collections.abc import Generator from typing import Optional, Union, cast import google.ai.generativelanguage as glm -import google.api_core.exceptions as exceptions import google.generativeai as genai -import google.generativeai.client as client import requests +from google.api_core import exceptions +from google.generativeai import client from google.generativeai.types import ContentType, GenerateContentResponse, HarmBlockThreshold, HarmCategory from google.generativeai.types.content_types import to_part from PIL import Image diff --git a/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py b/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py index 48ab477c50..9d29237fdd 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py @@ -77,7 +77,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel if "huggingfacehub_api_type" not in credentials: raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type must be provided.") - if credentials["huggingfacehub_api_type"] not in ("inference_endpoints", "hosted_inference_api"): + if credentials["huggingfacehub_api_type"] not in {"inference_endpoints", "hosted_inference_api"}: raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type is invalid.") if "huggingfacehub_api_token" not in credentials: @@ -94,7 +94,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel credentials["huggingfacehub_api_token"], model ) - if credentials["task_type"] not in ("text2text-generation", "text-generation"): + if credentials["task_type"] not in {"text2text-generation", "text-generation"}: raise CredentialsValidateFailedError( "Huggingface Hub Task Type must be one of text2text-generation, text-generation." ) diff --git a/api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py b/api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py index c128c35f6d..74a1dfc3ff 100644 --- a/api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py @@ -49,8 +49,7 @@ class HuggingfaceTeiRerankModel(RerankModel): return RerankResult(model=model, docs=[]) server_url = credentials["server_url"] - if server_url.endswith("/"): - server_url = server_url[:-1] + server_url = server_url.removesuffix("/") try: results = TeiHelper.invoke_rerank(server_url, query, docs) diff --git a/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py b/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py index 288637495f..81ab249214 100644 --- a/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py +++ b/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py @@ -75,7 +75,7 @@ class TeiHelper: if len(model_type.keys()) < 1: raise RuntimeError("model_type is empty") model_type = list(model_type.keys())[0] - if model_type not in ["embedding", "reranker"]: + if model_type not in {"embedding", "reranker"}: raise RuntimeError(f"invalid model_type: {model_type}") max_input_length = response_json.get("max_input_length", 512) diff --git a/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py index 2d04abb277..55f3c25804 100644 --- a/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py @@ -42,8 +42,7 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): """ server_url = credentials["server_url"] - if server_url.endswith("/"): - server_url = server_url[:-1] + server_url = server_url.removesuffix("/") # get model properties context_size = self._get_context_size(model, credentials) @@ -119,8 +118,7 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): num_tokens = 0 server_url = credentials["server_url"] - if server_url.endswith("/"): - server_url = server_url[:-1] + server_url = server_url.removesuffix("/") batch_tokens = TeiHelper.invoke_tokenize(server_url, texts) num_tokens = sum(len(tokens) for tokens in batch_tokens) diff --git a/api/core/model_runtime/model_providers/hunyuan/llm/_position.yaml b/api/core/model_runtime/model_providers/hunyuan/llm/_position.yaml index 2c1b981f85..ca8600a534 100644 --- a/api/core/model_runtime/model_providers/hunyuan/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/hunyuan/llm/_position.yaml @@ -2,3 +2,4 @@ - hunyuan-standard - hunyuan-standard-256k - hunyuan-pro +- hunyuan-turbo diff --git a/api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-turbo.yaml b/api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-turbo.yaml new file mode 100644 index 0000000000..4837fed4ba --- /dev/null +++ b/api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-turbo.yaml @@ -0,0 +1,38 @@ +model: hunyuan-turbo +label: + zh_Hans: hunyuan-turbo + en_US: hunyuan-turbo +model_type: llm +features: + - agent-thought + - tool-call + - multi-tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 32000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 32000 + - name: enable_enhance + label: + zh_Hans: 功能增强 + en_US: Enable Enhancement + type: boolean + help: + zh_Hans: 功能增强(如搜索)开关,关闭时将直接由主模型生成回复内容,可以降低响应时延(对于流式输出时的首字时延尤为明显)。但在少数场景里,回复效果可能会下降。 + en_US: Allow the model to perform external search to enhance the generation results. + required: false + default: true +pricing: + input: '0.015' + output: '0.05' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/jina/jina.py b/api/core/model_runtime/model_providers/jina/jina.py index 33977b6a33..186a0a0fa7 100644 --- a/api/core/model_runtime/model_providers/jina/jina.py +++ b/api/core/model_runtime/model_providers/jina/jina.py @@ -18,9 +18,9 @@ class JinaProvider(ModelProvider): try: model_instance = self.get_model_instance(ModelType.TEXT_EMBEDDING) - # Use `jina-embeddings-v2-base-en` model for validate, + # Use `jina-embeddings-v3` model for validate, # no matter what model you pass in, text completion model or chat model - model_instance.validate_credentials(model="jina-embeddings-v2-base-en", credentials=credentials) + model_instance.validate_credentials(model="jina-embeddings-v3", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: diff --git a/api/core/model_runtime/model_providers/jina/rerank/rerank.py b/api/core/model_runtime/model_providers/jina/rerank/rerank.py index d8394f7a4c..79ca68914f 100644 --- a/api/core/model_runtime/model_providers/jina/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/jina/rerank/rerank.py @@ -48,8 +48,7 @@ class JinaRerankModel(RerankModel): return RerankResult(model=model, docs=[]) base_url = credentials.get("base_url", "https://api.jina.ai/v1") - if base_url.endswith("/"): - base_url = base_url[:-1] + base_url = base_url.removesuffix("/") try: response = httpx.post( diff --git a/api/core/model_runtime/model_providers/jina/text_embedding/jina-embeddings-v3.yaml b/api/core/model_runtime/model_providers/jina/text_embedding/jina-embeddings-v3.yaml new file mode 100644 index 0000000000..4e5374dc9d --- /dev/null +++ b/api/core/model_runtime/model_providers/jina/text_embedding/jina-embeddings-v3.yaml @@ -0,0 +1,9 @@ +model: jina-embeddings-v3 +model_type: text-embedding +model_properties: + context_size: 8192 + max_chunks: 2048 +pricing: + input: '0.001' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py index 7ed3e4d384..ceb79567d5 100644 --- a/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py @@ -44,8 +44,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): raise CredentialsValidateFailedError("api_key is required") base_url = credentials.get("base_url", self.api_base) - if base_url.endswith("/"): - base_url = base_url[:-1] + base_url = base_url.removesuffix("/") url = base_url + "/embeddings" headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"} @@ -57,6 +56,9 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): data = {"model": model, "input": [transform_jina_input_text(model, text) for text in texts]} + if model == "jina-embeddings-v3": + data["task"] = "text-matching" + try: response = post(url, headers=headers, data=dumps(data)) except Exception as e: diff --git a/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py b/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py index 96f99c8929..88cc0e8e0f 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py +++ b/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py @@ -100,9 +100,9 @@ class MinimaxChatCompletion: return self._handle_chat_generate_response(response) def _handle_error(self, code: int, msg: str): - if code == 1000 or code == 1001 or code == 1013 or code == 1027: + if code in {1000, 1001, 1013, 1027}: raise InternalServerError(msg) - elif code == 1002 or code == 1039: + elif code in {1002, 1039}: raise RateLimitReachedError(msg) elif code == 1004: raise InvalidAuthenticationError(msg) diff --git a/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py b/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py index 0a2a67a56d..8b8fdbb6bd 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py +++ b/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py @@ -105,9 +105,9 @@ class MinimaxChatCompletionPro: return self._handle_chat_generate_response(response) def _handle_error(self, code: int, msg: str): - if code == 1000 or code == 1001 or code == 1013 or code == 1027: + if code in {1000, 1001, 1013, 1027}: raise InternalServerError(msg) - elif code == 1002 or code == 1039: + elif code in {1002, 1039}: raise RateLimitReachedError(msg) elif code == 1004: raise InvalidAuthenticationError(msg) diff --git a/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py index 02a53708be..76fd1342bd 100644 --- a/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py @@ -114,7 +114,7 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel): raise CredentialsValidateFailedError("Invalid api key") def _handle_error(self, code: int, msg: str): - if code == 1000 or code == 1001: + if code in {1000, 1001}: raise InternalServerError(msg) elif code == 1002: raise RateLimitReachedError(msg) diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0613.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0613.yaml index 31dc53e89f..a1ad07d712 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0613.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0613.yaml @@ -31,3 +31,4 @@ pricing: output: '0.002' unit: '0.001' currency: USD +deprecated: true diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k-0613.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k-0613.yaml index 4a0e2ef191..4e30279284 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k-0613.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k-0613.yaml @@ -31,3 +31,4 @@ pricing: output: '0.004' unit: '0.001' currency: USD +deprecated: true diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index 95533ccfaf..d42fce528a 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -125,7 +125,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): model_mode = self.get_model_mode(base_model, credentials) # transform response format - if "response_format" in model_parameters and model_parameters["response_format"] in ["JSON", "XML"]: + if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}: stop = stop or [] if model_mode == LLMMode.CHAT: # chat model @@ -615,10 +615,12 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): block_as_stream = False if model.startswith("o1"): - block_as_stream = True - stream = False - if "stream_options" in extra_model_kwargs: - del extra_model_kwargs["stream_options"] + if stream: + block_as_stream = True + stream = False + + if "stream_options" in extra_model_kwargs: + del extra_model_kwargs["stream_options"] if "stop" in extra_model_kwargs: del extra_model_kwargs["stop"] diff --git a/api/core/model_runtime/model_providers/openai/llm/o1-mini-2024-09-12.yaml b/api/core/model_runtime/model_providers/openai/llm/o1-mini-2024-09-12.yaml index 07a3bc9a7a..0ade7f8ded 100644 --- a/api/core/model_runtime/model_providers/openai/llm/o1-mini-2024-09-12.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/o1-mini-2024-09-12.yaml @@ -11,9 +11,9 @@ model_properties: parameter_rules: - name: max_tokens use_template: max_tokens - default: 65563 + default: 65536 min: 1 - max: 65563 + max: 65536 - name: response_format label: zh_Hans: 回复格式 diff --git a/api/core/model_runtime/model_providers/openai/llm/o1-mini.yaml b/api/core/model_runtime/model_providers/openai/llm/o1-mini.yaml index 3e83529201..60816c5d1e 100644 --- a/api/core/model_runtime/model_providers/openai/llm/o1-mini.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/o1-mini.yaml @@ -11,9 +11,9 @@ model_properties: parameter_rules: - name: max_tokens use_template: max_tokens - default: 65563 + default: 65536 min: 1 - max: 65563 + max: 65536 - name: response_format label: zh_Hans: 回复格式 diff --git a/api/core/model_runtime/model_providers/openai/tts/tts.py b/api/core/model_runtime/model_providers/openai/tts/tts.py index b50b43199f..a14c91639b 100644 --- a/api/core/model_runtime/model_providers/openai/tts/tts.py +++ b/api/core/model_runtime/model_providers/openai/tts/tts.py @@ -89,14 +89,14 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): for i in range(len(sentences)) ] for future in futures: - yield from future.result().__enter__().iter_bytes(1024) + yield from future.result().__enter__().iter_bytes(1024) # noqa:PLC2801 else: response = client.audio.speech.with_streaming_response.create( model=model, voice=voice, response_format="mp3", input=content_text.strip() ) - yield from response.__enter__().iter_bytes(1024) + yield from response.__enter__().iter_bytes(1024) # noqa:PLC2801 except Exception as ex: raise InvokeBadRequestError(str(ex)) diff --git a/api/core/model_runtime/model_providers/openrouter/llm/llm.py b/api/core/model_runtime/model_providers/openrouter/llm/llm.py index 71b5745f7d..b6bb249a04 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/llm.py +++ b/api/core/model_runtime/model_providers/openrouter/llm/llm.py @@ -12,7 +12,6 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel): credentials["endpoint_url"] = "https://openrouter.ai/api/v1" credentials["mode"] = self.get_model_mode(model).value credentials["function_calling_type"] = "tool_call" - return def _invoke( self, diff --git a/api/core/model_runtime/model_providers/replicate/llm/llm.py b/api/core/model_runtime/model_providers/replicate/llm/llm.py index daef8949fb..3641b35dc0 100644 --- a/api/core/model_runtime/model_providers/replicate/llm/llm.py +++ b/api/core/model_runtime/model_providers/replicate/llm/llm.py @@ -154,7 +154,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): ) for key, value in input_properties: - if key not in ["system_prompt", "prompt"] and "stop" not in key: + if key not in {"system_prompt", "prompt"} and "stop" not in key: value_type = value.get("type") if not value_type: diff --git a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py index f6b7754d74..71b6fb99c4 100644 --- a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py @@ -86,7 +86,7 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel): ) for input_property in input_properties: - if input_property[0] in ("text", "texts", "inputs"): + if input_property[0] in {"text", "texts", "inputs"}: text_input_key = input_property[0] return text_input_key @@ -96,7 +96,7 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel): def _generate_embeddings_by_text_input_key( client: ReplicateClient, replicate_model_version: str, text_input_key: str, texts: list[str] ) -> list[list[float]]: - if text_input_key in ("text", "inputs"): + if text_input_key in {"text", "inputs"}: embeddings = [] for text in texts: result = client.run(replicate_model_version, input={text_input_key: text}) diff --git a/api/core/model_runtime/model_providers/siliconflow/rerank/rerank.py b/api/core/model_runtime/model_providers/siliconflow/rerank/rerank.py index 6f652e9d52..58b033d28a 100644 --- a/api/core/model_runtime/model_providers/siliconflow/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/siliconflow/rerank/rerank.py @@ -30,8 +30,7 @@ class SiliconflowRerankModel(RerankModel): return RerankResult(model=model, docs=[]) base_url = credentials.get("base_url", "https://api.siliconflow.cn/v1") - if base_url.endswith("/"): - base_url = base_url[:-1] + base_url = base_url.removesuffix("/") try: response = httpx.post( base_url + "/rerank", diff --git a/api/core/model_runtime/model_providers/tongyi/llm/llm.py b/api/core/model_runtime/model_providers/tongyi/llm/llm.py index cd7718361f..1d4eba6668 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/llm.py +++ b/api/core/model_runtime/model_providers/tongyi/llm/llm.py @@ -89,7 +89,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel): :param tools: tools for tool calling :return: """ - if model in ["qwen-turbo-chat", "qwen-plus-chat"]: + if model in {"qwen-turbo-chat", "qwen-plus-chat"}: model = model.replace("-chat", "") if model == "farui-plus": model = "qwen-farui-plus" @@ -157,7 +157,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel): mode = self.get_model_mode(model, credentials) - if model in ["qwen-turbo-chat", "qwen-plus-chat"]: + if model in {"qwen-turbo-chat", "qwen-plus-chat"}: model = model.replace("-chat", "") extra_model_kwargs = {} @@ -201,7 +201,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel): :param prompt_messages: prompt messages :return: llm response """ - if response.status_code != 200 and response.status_code != HTTPStatus.OK: + if response.status_code not in {200, HTTPStatus.OK}: raise ServiceUnavailableError(response.message) # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( @@ -240,7 +240,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel): full_text = "" tool_calls = [] for index, response in enumerate(responses): - if response.status_code != 200 and response.status_code != HTTPStatus.OK: + if response.status_code not in {200, HTTPStatus.OK}: raise ServiceUnavailableError( f"Failed to invoke model {model}, status code: {response.status_code}, " f"message: {response.message}" diff --git a/api/core/model_runtime/model_providers/upstage/llm/llm.py b/api/core/model_runtime/model_providers/upstage/llm/llm.py index 74524e81e2..a18ee90624 100644 --- a/api/core/model_runtime/model_providers/upstage/llm/llm.py +++ b/api/core/model_runtime/model_providers/upstage/llm/llm.py @@ -93,7 +93,7 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): """ Code block mode wrapper for invoking large language model """ - if "response_format" in model_parameters and model_parameters["response_format"] in ["JSON", "XML"]: + if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}: stop = stop or [] self._transform_chat_json_prompts( model=model, diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py index dad5002f35..da69b7cdf3 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py @@ -5,7 +5,6 @@ import logging from collections.abc import Generator from typing import Optional, Union, cast -import google.api_core.exceptions as exceptions import google.auth.transport.requests import vertexai.generative_models as glm from anthropic import AnthropicVertex, Stream @@ -17,6 +16,7 @@ from anthropic.types import ( MessageStopEvent, MessageStreamEvent, ) +from google.api_core import exceptions from google.cloud import aiplatform from google.oauth2 import service_account from PIL import Image @@ -346,7 +346,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): mime_type = data_split[0].replace("data:", "") base64_data = data_split[1] - if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]: + if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}: raise ValueError( f"Unsupported image type {mime_type}, " f"only support image/jpeg, image/png, image/gif, and image/webp" diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py index 97c77de8d3..c22bf8e76d 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py @@ -96,7 +96,6 @@ class Signer: signing_key = Signer.get_signing_secret_key_v4(credentials.sk, md.date, md.region, md.service) sign = Util.to_hex(Util.hmac_sha256(signing_key, signing_str)) request.headers["Authorization"] = Signer.build_auth_header_v4(sign, md, credentials) - return @staticmethod def hashed_canonical_request_v4(request, meta): @@ -105,7 +104,7 @@ class Signer: signed_headers = {} for key in request.headers: - if key in ["Content-Type", "Content-Md5", "Host"] or key.startswith("X-"): + if key in {"Content-Type", "Content-Md5", "Host"} or key.startswith("X-"): signed_headers[key.lower()] = request.headers[key] if "host" in signed_headers: diff --git a/api/core/model_runtime/model_providers/wenxin/llm/llm.py b/api/core/model_runtime/model_providers/wenxin/llm/llm.py index ec3556f7da..f7c160b6b4 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/llm.py +++ b/api/core/model_runtime/model_providers/wenxin/llm/llm.py @@ -69,7 +69,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel): """ Code block mode wrapper for invoking large language model """ - if "response_format" in model_parameters and model_parameters["response_format"] in ["JSON", "XML"]: + if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}: response_format = model_parameters["response_format"] stop = stop or [] self._transform_json_prompts( diff --git a/api/core/model_runtime/model_providers/xinference/llm/llm.py b/api/core/model_runtime/model_providers/xinference/llm/llm.py index f8f6c6b12d..4fadda5df5 100644 --- a/api/core/model_runtime/model_providers/xinference/llm/llm.py +++ b/api/core/model_runtime/model_providers/xinference/llm/llm.py @@ -459,8 +459,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): if "server_url" not in credentials: raise CredentialsValidateFailedError("server_url is required in credentials") - if credentials["server_url"].endswith("/"): - credentials["server_url"] = credentials["server_url"][:-1] + credentials["server_url"] = credentials["server_url"].removesuffix("/") api_key = credentials.get("api_key") or "abc" diff --git a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py index 1582fe43b9..8f18bc42d2 100644 --- a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py @@ -50,8 +50,7 @@ class XinferenceRerankModel(RerankModel): server_url = credentials["server_url"] model_uid = credentials["model_uid"] api_key = credentials.get("api_key") - if server_url.endswith("/"): - server_url = server_url[:-1] + server_url = server_url.removesuffix("/") auth_headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} params = {"documents": docs, "query": query, "top_n": top_n, "return_documents": True} @@ -98,8 +97,7 @@ class XinferenceRerankModel(RerankModel): if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]: raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") - if credentials["server_url"].endswith("/"): - credentials["server_url"] = credentials["server_url"][:-1] + credentials["server_url"] = credentials["server_url"].removesuffix("/") # initialize client client = Client( diff --git a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py index 18efde758c..a6c5b8a0a5 100644 --- a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py @@ -45,8 +45,7 @@ class XinferenceSpeech2TextModel(Speech2TextModel): if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]: raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") - if credentials["server_url"].endswith("/"): - credentials["server_url"] = credentials["server_url"][:-1] + credentials["server_url"] = credentials["server_url"].removesuffix("/") # initialize client client = Client( @@ -116,8 +115,7 @@ class XinferenceSpeech2TextModel(Speech2TextModel): server_url = credentials["server_url"] model_uid = credentials["model_uid"] api_key = credentials.get("api_key") - if server_url.endswith("/"): - server_url = server_url[:-1] + server_url = server_url.removesuffix("/") auth_headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} try: diff --git a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py index ac704e7de8..8043af1d6c 100644 --- a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py @@ -45,8 +45,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): server_url = credentials["server_url"] model_uid = credentials["model_uid"] api_key = credentials.get("api_key") - if server_url.endswith("/"): - server_url = server_url[:-1] + server_url = server_url.removesuffix("/") auth_headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} try: @@ -118,8 +117,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): if extra_args.max_tokens: credentials["max_tokens"] = extra_args.max_tokens - if server_url.endswith("/"): - server_url = server_url[:-1] + server_url = server_url.removesuffix("/") client = Client( base_url=server_url, diff --git a/api/core/model_runtime/model_providers/xinference/tts/tts.py b/api/core/model_runtime/model_providers/xinference/tts/tts.py index d29e8ed8f9..10538b5788 100644 --- a/api/core/model_runtime/model_providers/xinference/tts/tts.py +++ b/api/core/model_runtime/model_providers/xinference/tts/tts.py @@ -73,8 +73,7 @@ class XinferenceText2SpeechModel(TTSModel): if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]: raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") - if credentials["server_url"].endswith("/"): - credentials["server_url"] = credentials["server_url"][:-1] + credentials["server_url"] = credentials["server_url"].removesuffix("/") extra_param = XinferenceHelper.get_xinference_extra_parameter( server_url=credentials["server_url"], @@ -189,8 +188,7 @@ class XinferenceText2SpeechModel(TTSModel): :param voice: model timbre :return: text translated to audio file """ - if credentials["server_url"].endswith("/"): - credentials["server_url"] = credentials["server_url"][:-1] + credentials["server_url"] = credentials["server_url"].removesuffix("/") try: api_key = credentials.get("api_key") diff --git a/api/core/model_runtime/model_providers/xinference/xinference_helper.py b/api/core/model_runtime/model_providers/xinference/xinference_helper.py index 1e05da9c56..619ee1492a 100644 --- a/api/core/model_runtime/model_providers/xinference/xinference_helper.py +++ b/api/core/model_runtime/model_providers/xinference/xinference_helper.py @@ -103,7 +103,7 @@ class XinferenceHelper: model_handle_type = "embedding" elif response_json.get("model_type") == "audio": model_handle_type = "audio" - if model_family and model_family in ["ChatTTS", "CosyVoice", "FishAudio"]: + if model_family and model_family in {"ChatTTS", "CosyVoice", "FishAudio"}: model_ability.append("text-to-audio") else: model_ability.append("audio-to-text") diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py index f76e51fee9..ea331701ab 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py +++ b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py @@ -186,10 +186,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): new_prompt_messages: list[PromptMessage] = [] for prompt_message in prompt_messages: copy_prompt_message = prompt_message.copy() - if copy_prompt_message.role in [PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL]: + if copy_prompt_message.role in {PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL}: if isinstance(copy_prompt_message.content, list): # check if model is 'glm-4v' - if model not in ("glm-4v", "glm-4v-plus"): + if model not in {"glm-4v", "glm-4v-plus"}: # not support list message continue # get image and @@ -209,10 +209,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): ): new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content else: - if ( - copy_prompt_message.role == PromptMessageRole.USER - or copy_prompt_message.role == PromptMessageRole.TOOL - ): + if copy_prompt_message.role in {PromptMessageRole.USER, PromptMessageRole.TOOL}: new_prompt_messages.append(copy_prompt_message) elif copy_prompt_message.role == PromptMessageRole.SYSTEM: new_prompt_message = SystemPromptMessage(content=copy_prompt_message.content) @@ -226,7 +223,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): else: new_prompt_messages.append(copy_prompt_message) - if model == "glm-4v" or model == "glm-4v-plus": + if model in {"glm-4v", "glm-4v-plus"}: params = self._construct_glm_4v_parameter(model, new_prompt_messages, model_parameters) else: params = {"model": model, "messages": [], **model_parameters} @@ -270,11 +267,11 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): # chatglm model for prompt_message in new_prompt_messages: # merge system message to user message - if ( - prompt_message.role == PromptMessageRole.SYSTEM - or prompt_message.role == PromptMessageRole.TOOL - or prompt_message.role == PromptMessageRole.USER - ): + if prompt_message.role in { + PromptMessageRole.SYSTEM, + PromptMessageRole.TOOL, + PromptMessageRole.USER, + }: if len(params["messages"]) > 0 and params["messages"][-1]["role"] == "user": params["messages"][-1]["content"] += "\n\n" + prompt_message.content else: diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py index 3566c6b332..ec2745d059 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py @@ -127,8 +127,7 @@ class SSELineParser: field, _p, value = line.partition(":") - if value.startswith(" "): - value = value[1:] + value = value.removeprefix(" ") if field == "data": self._data.append(value) elif field == "event": diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/__init__.py index af0991892e..416f516ef7 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/__init__.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/__init__.py @@ -1,5 +1,4 @@ from __future__ import annotations -from .fine_tuning_job import FineTuningJob as FineTuningJob -from .fine_tuning_job import ListOfFineTuningJob as ListOfFineTuningJob -from .fine_tuning_job_event import FineTuningJobEvent as FineTuningJobEvent +from .fine_tuning_job import FineTuningJob, ListOfFineTuningJob +from .fine_tuning_job_event import FineTuningJobEvent diff --git a/api/core/model_runtime/schema_validators/common_validator.py b/api/core/model_runtime/schema_validators/common_validator.py index c05edb72e3..029ec1a581 100644 --- a/api/core/model_runtime/schema_validators/common_validator.py +++ b/api/core/model_runtime/schema_validators/common_validator.py @@ -75,7 +75,7 @@ class CommonValidator: if not isinstance(value, str): raise ValueError(f"Variable {credential_form_schema.variable} should be string") - if credential_form_schema.type in [FormType.SELECT, FormType.RADIO]: + if credential_form_schema.type in {FormType.SELECT, FormType.RADIO}: # If the value is in options, no validation is performed if credential_form_schema.options: if value not in [option.value for option in credential_form_schema.options]: @@ -83,7 +83,7 @@ class CommonValidator: if credential_form_schema.type == FormType.SWITCH: # If the value is not in ['true', 'false'], an exception is thrown - if value.lower() not in ["true", "false"]: + if value.lower() not in {"true", "false"}: raise ValueError(f"Variable {credential_form_schema.variable} should be true or false") value = True if value.lower() == "true" else False diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index f8300cc271..8d57855120 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -51,7 +51,7 @@ class ElasticSearchVector(BaseVector): def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch: try: parsed_url = urlparse(config.host) - if parsed_url.scheme in ["http", "https"]: + if parsed_url.scheme in {"http", "https"}: hosts = f"{config.host}:{config.port}" else: hosts = f"http://{config.host}:{config.port}" @@ -94,7 +94,7 @@ class ElasticSearchVector(BaseVector): return uuids def text_exists(self, id: str) -> bool: - return self._client.exists(index=self._collection_name, id=id).__bool__() + return bool(self._client.exists(index=self._collection_name, id=id)) def delete_by_ids(self, ids: list[str]) -> None: for id in ids: diff --git a/api/core/rag/datasource/vdb/myscale/myscale_vector.py b/api/core/rag/datasource/vdb/myscale/myscale_vector.py index 1bd5bcd3e4..2320a69a30 100644 --- a/api/core/rag/datasource/vdb/myscale/myscale_vector.py +++ b/api/core/rag/datasource/vdb/myscale/myscale_vector.py @@ -35,7 +35,7 @@ class MyScaleVector(BaseVector): super().__init__(collection_name) self._config = config self._metric = metric - self._vec_order = SortOrder.ASC if metric.upper() in ["COSINE", "L2"] else SortOrder.DESC + self._vec_order = SortOrder.ASC if metric.upper() in {"COSINE", "L2"} else SortOrder.DESC self._client = get_client( host=config.host, port=config.port, @@ -92,7 +92,7 @@ class MyScaleVector(BaseVector): @staticmethod def escape_str(value: Any) -> str: - return "".join(" " if c in ("\\", "'") else c for c in str(value)) + return "".join(" " if c in {"\\", "'"} else c for c in str(value)) def text_exists(self, id: str) -> bool: results = self._client.query(f"SELECT id FROM {self._config.database}.{self._collection_name} WHERE id='{id}'") diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py index b974fa80a4..77ec45b4d3 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py @@ -223,15 +223,7 @@ class OracleVector(BaseVector): words = pseg.cut(query) current_entity = "" for word, pos in words: - if ( - pos == "nr" - or pos == "Ng" - or pos == "eng" - or pos == "nz" - or pos == "n" - or pos == "ORG" - or pos == "v" - ): # nr: 人名, ns: 地名, nt: 机构名 + if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名, ns: 地名, nt: 机构名 current_entity += word else: if current_entity: diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index 3181656f59..fe7eaa32e6 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -98,17 +98,17 @@ class ExtractProcessor: unstructured_api_url = dify_config.UNSTRUCTURED_API_URL unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY if etl_type == "Unstructured": - if file_extension == ".xlsx" or file_extension == ".xls": + if file_extension in {".xlsx", ".xls"}: extractor = ExcelExtractor(file_path) elif file_extension == ".pdf": extractor = PdfExtractor(file_path) - elif file_extension in [".md", ".markdown"]: + elif file_extension in {".md", ".markdown"}: extractor = ( UnstructuredMarkdownExtractor(file_path, unstructured_api_url) if is_automatic else MarkdownExtractor(file_path, autodetect_encoding=True) ) - elif file_extension in [".htm", ".html"]: + elif file_extension in {".htm", ".html"}: extractor = HtmlExtractor(file_path) elif file_extension == ".docx": extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by) @@ -134,13 +134,13 @@ class ExtractProcessor: else TextExtractor(file_path, autodetect_encoding=True) ) else: - if file_extension == ".xlsx" or file_extension == ".xls": + if file_extension in {".xlsx", ".xls"}: extractor = ExcelExtractor(file_path) elif file_extension == ".pdf": extractor = PdfExtractor(file_path) - elif file_extension in [".md", ".markdown"]: + elif file_extension in {".md", ".markdown"}: extractor = MarkdownExtractor(file_path, autodetect_encoding=True) - elif file_extension in [".htm", ".html"]: + elif file_extension in {".htm", ".html"}: extractor = HtmlExtractor(file_path) elif file_extension == ".docx": extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by) diff --git a/api/core/rag/extractor/firecrawl/firecrawl_app.py b/api/core/rag/extractor/firecrawl/firecrawl_app.py index 054ce5f4b2..17c2087a0a 100644 --- a/api/core/rag/extractor/firecrawl/firecrawl_app.py +++ b/api/core/rag/extractor/firecrawl/firecrawl_app.py @@ -32,7 +32,7 @@ class FirecrawlApp: else: raise Exception(f'Failed to scrape URL. Error: {response["error"]}') - elif response.status_code in [402, 409, 500]: + elif response.status_code in {402, 409, 500}: error_message = response.json().get("error", "Unknown error occurred") raise Exception(f"Failed to scrape URL. Status code: {response.status_code}. Error: {error_message}") else: diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index 0ee24983a4..87a4ce08bf 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -103,12 +103,12 @@ class NotionExtractor(BaseExtractor): multi_select_list = property_value[type] for multi_select in multi_select_list: value.append(multi_select["name"]) - elif type == "rich_text" or type == "title": + elif type in {"rich_text", "title"}: if len(property_value[type]) > 0: value = property_value[type][0]["plain_text"] else: value = "" - elif type == "select" or type == "status": + elif type in {"select", "status"}: if property_value[type]: value = property_value[type]["name"] else: diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 12868d6ae4..286ecd4c03 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -115,7 +115,7 @@ class DatasetRetrieval: available_datasets.append(dataset) all_documents = [] - user_from = "account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user" + user_from = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user" if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: all_documents = self.single_retrieve( app_id, @@ -426,7 +426,7 @@ class DatasetRetrieval: retrieval_method=retrieval_model["search_method"], dataset_id=dataset.id, query=query, - top_k=top_k, + top_k=retrieval_model.get("top_k") or 2, score_threshold=retrieval_model.get("score_threshold", 0.0) if retrieval_model["score_threshold_enabled"] else 0.0, diff --git a/api/core/rag/splitter/text_splitter.py b/api/core/rag/splitter/text_splitter.py index 161c36607d..7dd62f8de1 100644 --- a/api/core/rag/splitter/text_splitter.py +++ b/api/core/rag/splitter/text_splitter.py @@ -35,7 +35,7 @@ def _split_text_with_regex(text: str, separator: str, keep_separator: bool) -> l splits = re.split(separator, text) else: splits = list(text) - return [s for s in splits if (s != "" and s != "\n")] + return [s for s in splits if (s not in {"", "\n"})] class TextSplitter(BaseDocumentTransformer, ABC): diff --git a/api/core/tools/provider/app_tool_provider.py b/api/core/tools/provider/app_tool_provider.py index 01544d7e56..09f328cd1f 100644 --- a/api/core/tools/provider/app_tool_provider.py +++ b/api/core/tools/provider/app_tool_provider.py @@ -68,7 +68,7 @@ class AppToolProviderEntity(ToolProviderController): label = input_form[form_type]["label"] variable_name = input_form[form_type]["variable_name"] options = input_form[form_type].get("options", []) - if form_type == "paragraph" or form_type == "text-input": + if form_type in {"paragraph", "text-input"}: tool["parameters"].append( ToolParameter( name=variable_name, diff --git a/api/core/tools/provider/builtin/aippt/tools/aippt.py b/api/core/tools/provider/builtin/aippt/tools/aippt.py index a2d69fbcd1..dd9371f70d 100644 --- a/api/core/tools/provider/builtin/aippt/tools/aippt.py +++ b/api/core/tools/provider/builtin/aippt/tools/aippt.py @@ -168,7 +168,7 @@ class AIPPTGenerateTool(BuiltinTool): pass elif event == "close": break - elif event == "error" or event == "filter": + elif event in {"error", "filter"}: raise Exception(f"Failed to generate outline: {data}") return outline @@ -213,7 +213,7 @@ class AIPPTGenerateTool(BuiltinTool): pass elif event == "close": break - elif event == "error" or event == "filter": + elif event in {"error", "filter"}: raise Exception(f"Failed to generate content: {data}") return content diff --git a/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py b/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py index 7462824be1..cfa3cfb092 100644 --- a/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py +++ b/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py @@ -39,11 +39,11 @@ class DallE3Tool(BuiltinTool): n = tool_parameters.get("n", 1) # get quality quality = tool_parameters.get("quality", "standard") - if quality not in ["standard", "hd"]: + if quality not in {"standard", "hd"}: return self.create_text_message("Invalid quality") # get style style = tool_parameters.get("style", "vivid") - if style not in ["natural", "vivid"]: + if style not in {"natural", "vivid"}: return self.create_text_message("Invalid style") # set extra body seed_id = tool_parameters.get("seed_id", self._generate_random_id(8)) diff --git a/api/core/tools/provider/builtin/brave/brave.yaml b/api/core/tools/provider/builtin/brave/brave.yaml index 93d315f839..2b0dcc0188 100644 --- a/api/core/tools/provider/builtin/brave/brave.yaml +++ b/api/core/tools/provider/builtin/brave/brave.yaml @@ -29,3 +29,11 @@ credentials_for_provider: zh_Hans: 从 Brave 获取您的 Brave Search API key pt_BR: Get your Brave Search API key from Brave url: https://brave.com/search/api/ + base_url: + type: text-input + required: false + label: + en_US: Brave server's Base URL + zh_Hans: Brave服务器的API URL + placeholder: + en_US: https://api.search.brave.com/res/v1/web/search diff --git a/api/core/tools/provider/builtin/brave/tools/brave_search.py b/api/core/tools/provider/builtin/brave/tools/brave_search.py index 94a4d92844..c34362ae52 100644 --- a/api/core/tools/provider/builtin/brave/tools/brave_search.py +++ b/api/core/tools/provider/builtin/brave/tools/brave_search.py @@ -7,6 +7,8 @@ from pydantic import BaseModel, Field from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool +BRAVE_BASE_URL = "https://api.search.brave.com/res/v1/web/search" + class BraveSearchWrapper(BaseModel): """Wrapper around the Brave search engine.""" @@ -15,8 +17,10 @@ class BraveSearchWrapper(BaseModel): """The API key to use for the Brave search engine.""" search_kwargs: dict = Field(default_factory=dict) """Additional keyword arguments to pass to the search request.""" - base_url: str = "https://api.search.brave.com/res/v1/web/search" + base_url: str = BRAVE_BASE_URL """The base URL for the Brave search engine.""" + ensure_ascii: bool = True + """Ensure the JSON output is ASCII encoded.""" def run(self, query: str) -> str: """Query the Brave search engine and return the results as a JSON string. @@ -36,7 +40,7 @@ class BraveSearchWrapper(BaseModel): } for item in web_search_results ] - return json.dumps(final_results) + return json.dumps(final_results, ensure_ascii=self.ensure_ascii) def _search_request(self, query: str) -> list[dict]: headers = { @@ -68,7 +72,9 @@ class BraveSearch(BaseModel): search_wrapper: BraveSearchWrapper @classmethod - def from_api_key(cls, api_key: str, search_kwargs: Optional[dict] = None, **kwargs: Any) -> "BraveSearch": + def from_api_key( + cls, api_key: str, base_url: str, search_kwargs: Optional[dict] = None, ensure_ascii: bool = True, **kwargs: Any + ) -> "BraveSearch": """Create a tool from an api key. Args: @@ -79,7 +85,9 @@ class BraveSearch(BaseModel): Returns: A tool. """ - wrapper = BraveSearchWrapper(api_key=api_key, search_kwargs=search_kwargs or {}) + wrapper = BraveSearchWrapper( + api_key=api_key, base_url=base_url, search_kwargs=search_kwargs or {}, ensure_ascii=ensure_ascii + ) return cls(search_wrapper=wrapper, **kwargs) def _run( @@ -109,11 +117,18 @@ class BraveSearchTool(BuiltinTool): query = tool_parameters.get("query", "") count = tool_parameters.get("count", 3) api_key = self.runtime.credentials["brave_search_api_key"] + base_url = self.runtime.credentials.get("base_url", BRAVE_BASE_URL) + ensure_ascii = tool_parameters.get("ensure_ascii", True) + + if len(base_url) == 0: + base_url = BRAVE_BASE_URL if not query: return self.create_text_message("Please input query") - tool = BraveSearch.from_api_key(api_key=api_key, search_kwargs={"count": count}) + tool = BraveSearch.from_api_key( + api_key=api_key, base_url=base_url, search_kwargs={"count": count}, ensure_ascii=ensure_ascii + ) results = tool._run(query) diff --git a/api/core/tools/provider/builtin/brave/tools/brave_search.yaml b/api/core/tools/provider/builtin/brave/tools/brave_search.yaml index b2a734c12d..5222a375f8 100644 --- a/api/core/tools/provider/builtin/brave/tools/brave_search.yaml +++ b/api/core/tools/provider/builtin/brave/tools/brave_search.yaml @@ -39,3 +39,15 @@ parameters: pt_BR: O número de resultados de pesquisa a serem retornados, permitindo que os usuários controlem a amplitude de sua saída de pesquisa. llm_description: Specifies the amount of search results to be displayed, offering users the ability to adjust the scope of their search findings. form: llm + - name: ensure_ascii + type: boolean + default: true + label: + en_US: Ensure ASCII + zh_Hans: 确保 ASCII + pt_BR: Ensure ASCII + human_description: + en_US: Ensure the JSON output is ASCII encoded + zh_Hans: 确保输出的 JSON 是 ASCII 编码 + pt_BR: Ensure the JSON output is ASCII encoded + form: form diff --git a/api/core/tools/provider/builtin/code/tools/simple_code.py b/api/core/tools/provider/builtin/code/tools/simple_code.py index 017fe548f7..632c9fc7f1 100644 --- a/api/core/tools/provider/builtin/code/tools/simple_code.py +++ b/api/core/tools/provider/builtin/code/tools/simple_code.py @@ -14,7 +14,7 @@ class SimpleCode(BuiltinTool): language = tool_parameters.get("language", CodeLanguage.PYTHON3) code = tool_parameters.get("code", "") - if language not in [CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT]: + if language not in {CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT}: raise ValueError(f"Only python3 and javascript are supported, not {language}") result = CodeExecutor.execute_code(language, "", code) diff --git a/api/core/tools/provider/builtin/cogview/tools/cogview3.py b/api/core/tools/provider/builtin/cogview/tools/cogview3.py index 9776bd7dd1..9039708588 100644 --- a/api/core/tools/provider/builtin/cogview/tools/cogview3.py +++ b/api/core/tools/provider/builtin/cogview/tools/cogview3.py @@ -34,11 +34,11 @@ class CogView3Tool(BuiltinTool): n = tool_parameters.get("n", 1) # get quality quality = tool_parameters.get("quality", "standard") - if quality not in ["standard", "hd"]: + if quality not in {"standard", "hd"}: return self.create_text_message("Invalid quality") # get style style = tool_parameters.get("style", "vivid") - if style not in ["natural", "vivid"]: + if style not in {"natural", "vivid"}: return self.create_text_message("Invalid style") # set extra body seed_id = tool_parameters.get("seed_id", self._generate_random_id(8)) diff --git a/api/core/tools/provider/builtin/comfyui/_assets/icon.png b/api/core/tools/provider/builtin/comfyui/_assets/icon.png new file mode 100644 index 0000000000..958ec5c5cf Binary files /dev/null and b/api/core/tools/provider/builtin/comfyui/_assets/icon.png differ diff --git a/api/core/tools/provider/builtin/comfyui/comfyui.py b/api/core/tools/provider/builtin/comfyui/comfyui.py new file mode 100644 index 0000000000..7013a0b93c --- /dev/null +++ b/api/core/tools/provider/builtin/comfyui/comfyui.py @@ -0,0 +1,17 @@ +from typing import Any + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin.comfyui.tools.comfyui_stable_diffusion import ComfyuiStableDiffusionTool +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class ComfyUIProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + try: + ComfyuiStableDiffusionTool().fork_tool_runtime( + runtime={ + "credentials": credentials, + } + ).validate_models() + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/comfyui/comfyui.yaml b/api/core/tools/provider/builtin/comfyui/comfyui.yaml new file mode 100644 index 0000000000..066fd85308 --- /dev/null +++ b/api/core/tools/provider/builtin/comfyui/comfyui.yaml @@ -0,0 +1,42 @@ +identity: + author: Qun + name: comfyui + label: + en_US: ComfyUI + zh_Hans: ComfyUI + pt_BR: ComfyUI + description: + en_US: ComfyUI is a tool for generating images which can be deployed locally. + zh_Hans: ComfyUI 是一个可以在本地部署的图片生成的工具。 + pt_BR: ComfyUI is a tool for generating images which can be deployed locally. + icon: icon.png + tags: + - image +credentials_for_provider: + base_url: + type: text-input + required: true + label: + en_US: Base URL + zh_Hans: ComfyUI服务器的Base URL + pt_BR: Base URL + placeholder: + en_US: Please input your ComfyUI server's Base URL + zh_Hans: 请输入你的 ComfyUI 服务器的 Base URL + pt_BR: Please input your ComfyUI server's Base URL + model: + type: text-input + required: true + label: + en_US: Model with suffix + zh_Hans: 模型, 需要带后缀 + pt_BR: Model with suffix + placeholder: + en_US: Please input your model + zh_Hans: 请输入你的模型名称 + pt_BR: Please input your model + help: + en_US: The checkpoint name of the ComfyUI server, e.g. xxx.safetensors + zh_Hans: ComfyUI服务器的模型名称, 比如 xxx.safetensors + pt_BR: The checkpoint name of the ComfyUI server, e.g. xxx.safetensors + url: https://docs.dify.ai/tutorials/tool-configuration/comfyui diff --git a/api/core/tools/provider/builtin/comfyui/tools/comfyui_stable_diffusion.py b/api/core/tools/provider/builtin/comfyui/tools/comfyui_stable_diffusion.py new file mode 100644 index 0000000000..b9b52c0b4d --- /dev/null +++ b/api/core/tools/provider/builtin/comfyui/tools/comfyui_stable_diffusion.py @@ -0,0 +1,475 @@ +import json +import os +import random +import uuid +from copy import deepcopy +from enum import Enum +from typing import Any, Union + +import websocket +from httpx import get, post +from yarl import URL + +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.tool.builtin_tool import BuiltinTool + +SD_TXT2IMG_OPTIONS = {} +LORA_NODE = { + "inputs": {"lora_name": "", "strength_model": 1, "strength_clip": 1, "model": ["11", 0], "clip": ["11", 1]}, + "class_type": "LoraLoader", + "_meta": {"title": "Load LoRA"}, +} +FluxGuidanceNode = { + "inputs": {"guidance": 3.5, "conditioning": ["6", 0]}, + "class_type": "FluxGuidance", + "_meta": {"title": "FluxGuidance"}, +} + + +class ModelType(Enum): + SD15 = 1 + SDXL = 2 + SD3 = 3 + FLUX = 4 + + +class ComfyuiStableDiffusionTool(BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + # base url + base_url = self.runtime.credentials.get("base_url", "") + if not base_url: + return self.create_text_message("Please input base_url") + + if tool_parameters.get("model"): + self.runtime.credentials["model"] = tool_parameters["model"] + + model = self.runtime.credentials.get("model", None) + if not model: + return self.create_text_message("Please input model") + + # prompt + prompt = tool_parameters.get("prompt", "") + if not prompt: + return self.create_text_message("Please input prompt") + + # get negative prompt + negative_prompt = tool_parameters.get("negative_prompt", "") + + # get size + width = tool_parameters.get("width", 1024) + height = tool_parameters.get("height", 1024) + + # get steps + steps = tool_parameters.get("steps", 1) + + # get sampler_name + sampler_name = tool_parameters.get("sampler_name", "euler") + + # scheduler + scheduler = tool_parameters.get("scheduler", "normal") + + # get cfg + cfg = tool_parameters.get("cfg", 7.0) + + # get model type + model_type = tool_parameters.get("model_type", ModelType.SD15.name) + + # get lora + # supports up to 3 loras + lora_list = [] + lora_strength_list = [] + if tool_parameters.get("lora_1"): + lora_list.append(tool_parameters["lora_1"]) + lora_strength_list.append(tool_parameters.get("lora_strength_1", 1)) + if tool_parameters.get("lora_2"): + lora_list.append(tool_parameters["lora_2"]) + lora_strength_list.append(tool_parameters.get("lora_strength_2", 1)) + if tool_parameters.get("lora_3"): + lora_list.append(tool_parameters["lora_3"]) + lora_strength_list.append(tool_parameters.get("lora_strength_3", 1)) + + return self.text2img( + base_url=base_url, + model=model, + model_type=model_type, + prompt=prompt, + negative_prompt=negative_prompt, + width=width, + height=height, + steps=steps, + sampler_name=sampler_name, + scheduler=scheduler, + cfg=cfg, + lora_list=lora_list, + lora_strength_list=lora_strength_list, + ) + + def get_checkpoints(self) -> list[str]: + """ + get checkpoints + """ + try: + base_url = self.runtime.credentials.get("base_url", None) + if not base_url: + return [] + api_url = str(URL(base_url) / "models" / "checkpoints") + response = get(url=api_url, timeout=(2, 10)) + if response.status_code != 200: + return [] + else: + return response.json() + except Exception as e: + return [] + + def get_loras(self) -> list[str]: + """ + get loras + """ + try: + base_url = self.runtime.credentials.get("base_url", None) + if not base_url: + return [] + api_url = str(URL(base_url) / "models" / "loras") + response = get(url=api_url, timeout=(2, 10)) + if response.status_code != 200: + return [] + else: + return response.json() + except Exception as e: + return [] + + def get_sample_methods(self) -> tuple[list[str], list[str]]: + """ + get sample method + """ + try: + base_url = self.runtime.credentials.get("base_url", None) + if not base_url: + return [], [] + api_url = str(URL(base_url) / "object_info" / "KSampler") + response = get(url=api_url, timeout=(2, 10)) + if response.status_code != 200: + return [], [] + else: + data = response.json()["KSampler"]["input"]["required"] + return data["sampler_name"][0], data["scheduler"][0] + except Exception as e: + return [], [] + + def validate_models(self) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + validate models + """ + try: + base_url = self.runtime.credentials.get("base_url", None) + if not base_url: + raise ToolProviderCredentialValidationError("Please input base_url") + model = self.runtime.credentials.get("model", None) + if not model: + raise ToolProviderCredentialValidationError("Please input model") + + api_url = str(URL(base_url) / "models" / "checkpoints") + response = get(url=api_url, timeout=(2, 10)) + if response.status_code != 200: + raise ToolProviderCredentialValidationError("Failed to get models") + else: + models = response.json() + if len([d for d in models if d == model]) > 0: + return self.create_text_message(json.dumps(models)) + else: + raise ToolProviderCredentialValidationError(f"model {model} does not exist") + except Exception as e: + raise ToolProviderCredentialValidationError(f"Failed to get models, {e}") + + def get_history(self, base_url, prompt_id): + """ + get history + """ + url = str(URL(base_url) / "history") + respond = get(url, params={"prompt_id": prompt_id}, timeout=(2, 10)) + return respond.json() + + def download_image(self, base_url, filename, subfolder, folder_type): + """ + download image + """ + url = str(URL(base_url) / "view") + response = get(url, params={"filename": filename, "subfolder": subfolder, "type": folder_type}, timeout=(2, 10)) + return response.content + + def queue_prompt_image(self, base_url, client_id, prompt): + """ + send prompt task and rotate + """ + # initiate task execution + url = str(URL(base_url) / "prompt") + respond = post(url, data=json.dumps({"client_id": client_id, "prompt": prompt}), timeout=(2, 10)) + prompt_id = respond.json()["prompt_id"] + + ws = websocket.WebSocket() + if "https" in base_url: + ws_url = base_url.replace("https", "ws") + else: + ws_url = base_url.replace("http", "ws") + ws.connect(str(URL(f"{ws_url}") / "ws") + f"?clientId={client_id}", timeout=120) + + # websocket rotate execution status + output_images = {} + while True: + out = ws.recv() + if isinstance(out, str): + message = json.loads(out) + if message["type"] == "executing": + data = message["data"] + if data["node"] is None and data["prompt_id"] == prompt_id: + break # Execution is done + elif message["type"] == "status": + data = message["data"] + if data["status"]["exec_info"]["queue_remaining"] == 0 and data.get("sid"): + break # Execution is done + else: + continue # previews are binary data + + # download image when execution finished + history = self.get_history(base_url, prompt_id)[prompt_id] + for o in history["outputs"]: + for node_id in history["outputs"]: + node_output = history["outputs"][node_id] + if "images" in node_output: + images_output = [] + for image in node_output["images"]: + image_data = self.download_image(base_url, image["filename"], image["subfolder"], image["type"]) + images_output.append(image_data) + output_images[node_id] = images_output + + ws.close() + + return output_images + + def text2img( + self, + base_url: str, + model: str, + model_type: str, + prompt: str, + negative_prompt: str, + width: int, + height: int, + steps: int, + sampler_name: str, + scheduler: str, + cfg: float, + lora_list: list, + lora_strength_list: list, + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + generate image + """ + if not SD_TXT2IMG_OPTIONS: + current_dir = os.path.dirname(os.path.realpath(__file__)) + with open(os.path.join(current_dir, "txt2img.json")) as file: + SD_TXT2IMG_OPTIONS.update(json.load(file)) + + draw_options = deepcopy(SD_TXT2IMG_OPTIONS) + draw_options["3"]["inputs"]["steps"] = steps + draw_options["3"]["inputs"]["sampler_name"] = sampler_name + draw_options["3"]["inputs"]["scheduler"] = scheduler + draw_options["3"]["inputs"]["cfg"] = cfg + # generate different image when using same prompt next time + draw_options["3"]["inputs"]["seed"] = random.randint(0, 100000000) + draw_options["4"]["inputs"]["ckpt_name"] = model + draw_options["5"]["inputs"]["width"] = width + draw_options["5"]["inputs"]["height"] = height + draw_options["6"]["inputs"]["text"] = prompt + draw_options["7"]["inputs"]["text"] = negative_prompt + # if the model is SD3 or FLUX series, the Latent class should be corresponding to SD3 Latent + if model_type in (ModelType.SD3.name, ModelType.FLUX.name): + draw_options["5"]["class_type"] = "EmptySD3LatentImage" + + if lora_list: + # last Lora node link to KSampler node + draw_options["3"]["inputs"]["model"][0] = "10" + # last Lora node link to positive and negative Clip node + draw_options["6"]["inputs"]["clip"][0] = "10" + draw_options["7"]["inputs"]["clip"][0] = "10" + # every Lora node link to next Lora node, and Checkpoints node link to first Lora node + for i, (lora, strength) in enumerate(zip(lora_list, lora_strength_list), 10): + if i - 10 == len(lora_list) - 1: + next_node_id = "4" + else: + next_node_id = str(i + 1) + lora_node = deepcopy(LORA_NODE) + lora_node["inputs"]["lora_name"] = lora + lora_node["inputs"]["strength_model"] = strength + lora_node["inputs"]["strength_clip"] = strength + lora_node["inputs"]["model"][0] = next_node_id + lora_node["inputs"]["clip"][0] = next_node_id + draw_options[str(i)] = lora_node + + # FLUX need to add FluxGuidance Node + if model_type == ModelType.FLUX.name: + last_node_id = str(10 + len(lora_list)) + draw_options[last_node_id] = deepcopy(FluxGuidanceNode) + draw_options[last_node_id]["inputs"]["conditioning"][0] = "6" + draw_options["3"]["inputs"]["positive"][0] = last_node_id + + try: + client_id = str(uuid.uuid4()) + result = self.queue_prompt_image(base_url, client_id, prompt=draw_options) + + # get first image + image = b"" + for node in result: + for img in result[node]: + if img: + image = img + break + + return self.create_blob_message( + blob=image, meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value + ) + + except Exception as e: + return self.create_text_message(f"Failed to generate image: {str(e)}") + + def get_runtime_parameters(self) -> list[ToolParameter]: + parameters = [ + ToolParameter( + name="prompt", + label=I18nObject(en_US="Prompt", zh_Hans="Prompt"), + human_description=I18nObject( + en_US="Image prompt, you can check the official documentation of Stable Diffusion", + zh_Hans="图像提示词,您可以查看 Stable Diffusion 的官方文档", + ), + type=ToolParameter.ToolParameterType.STRING, + form=ToolParameter.ToolParameterForm.LLM, + llm_description="Image prompt of Stable Diffusion, you should describe the image " + "you want to generate as a list of words as possible as detailed, " + "the prompt must be written in English.", + required=True, + ), + ] + if self.runtime.credentials: + try: + models = self.get_checkpoints() + if len(models) != 0: + parameters.append( + ToolParameter( + name="model", + label=I18nObject(en_US="Model", zh_Hans="Model"), + human_description=I18nObject( + en_US="Model of Stable Diffusion or FLUX, " + "you can check the official documentation of Stable Diffusion or FLUX", + zh_Hans="Stable Diffusion 或者 FLUX 的模型,您可以查看 Stable Diffusion 的官方文档", + ), + type=ToolParameter.ToolParameterType.SELECT, + form=ToolParameter.ToolParameterForm.FORM, + llm_description="Model of Stable Diffusion or FLUX, " + "you can check the official documentation of Stable Diffusion or FLUX", + required=True, + default=models[0], + options=[ + ToolParameterOption(value=i, label=I18nObject(en_US=i, zh_Hans=i)) for i in models + ], + ) + ) + loras = self.get_loras() + if len(loras) != 0: + for n in range(1, 4): + parameters.append( + ToolParameter( + name=f"lora_{n}", + label=I18nObject(en_US=f"Lora {n}", zh_Hans=f"Lora {n}"), + human_description=I18nObject( + en_US="Lora of Stable Diffusion, " + "you can check the official documentation of Stable Diffusion", + zh_Hans="Stable Diffusion 的 Lora 模型,您可以查看 Stable Diffusion 的官方文档", + ), + type=ToolParameter.ToolParameterType.SELECT, + form=ToolParameter.ToolParameterForm.FORM, + llm_description="Lora of Stable Diffusion, " + "you can check the official documentation of " + "Stable Diffusion", + required=False, + options=[ + ToolParameterOption(value=i, label=I18nObject(en_US=i, zh_Hans=i)) for i in loras + ], + ) + ) + sample_methods, schedulers = self.get_sample_methods() + if len(sample_methods) != 0: + parameters.append( + ToolParameter( + name="sampler_name", + label=I18nObject(en_US="Sampling method", zh_Hans="Sampling method"), + human_description=I18nObject( + en_US="Sampling method of Stable Diffusion, " + "you can check the official documentation of Stable Diffusion", + zh_Hans="Stable Diffusion 的Sampling method,您可以查看 Stable Diffusion 的官方文档", + ), + type=ToolParameter.ToolParameterType.SELECT, + form=ToolParameter.ToolParameterForm.FORM, + llm_description="Sampling method of Stable Diffusion, " + "you can check the official documentation of Stable Diffusion", + required=True, + default=sample_methods[0], + options=[ + ToolParameterOption(value=i, label=I18nObject(en_US=i, zh_Hans=i)) + for i in sample_methods + ], + ) + ) + if len(schedulers) != 0: + parameters.append( + ToolParameter( + name="scheduler", + label=I18nObject(en_US="Scheduler", zh_Hans="Scheduler"), + human_description=I18nObject( + en_US="Scheduler of Stable Diffusion, " + "you can check the official documentation of Stable Diffusion", + zh_Hans="Stable Diffusion 的Scheduler,您可以查看 Stable Diffusion 的官方文档", + ), + type=ToolParameter.ToolParameterType.SELECT, + form=ToolParameter.ToolParameterForm.FORM, + llm_description="Scheduler of Stable Diffusion, " + "you can check the official documentation of Stable Diffusion", + required=True, + default=schedulers[0], + options=[ + ToolParameterOption(value=i, label=I18nObject(en_US=i, zh_Hans=i)) for i in schedulers + ], + ) + ) + parameters.append( + ToolParameter( + name="model_type", + label=I18nObject(en_US="Model Type", zh_Hans="Model Type"), + human_description=I18nObject( + en_US="Model Type of Stable Diffusion or Flux, " + "you can check the official documentation of Stable Diffusion or Flux", + zh_Hans="Stable Diffusion 或 FLUX 的模型类型," + "您可以查看 Stable Diffusion 或 Flux 的官方文档", + ), + type=ToolParameter.ToolParameterType.SELECT, + form=ToolParameter.ToolParameterForm.FORM, + llm_description="Model Type of Stable Diffusion or Flux, " + "you can check the official documentation of Stable Diffusion or Flux", + required=True, + default=ModelType.SD15.name, + options=[ + ToolParameterOption(value=i, label=I18nObject(en_US=i, zh_Hans=i)) + for i in ModelType.__members__ + ], + ) + ) + except: + pass + + return parameters diff --git a/api/core/tools/provider/builtin/comfyui/tools/comfyui_stable_diffusion.yaml b/api/core/tools/provider/builtin/comfyui/tools/comfyui_stable_diffusion.yaml new file mode 100644 index 0000000000..4f4a6942b3 --- /dev/null +++ b/api/core/tools/provider/builtin/comfyui/tools/comfyui_stable_diffusion.yaml @@ -0,0 +1,212 @@ +identity: + name: txt2img workflow + author: Qun + label: + en_US: Txt2Img Workflow + zh_Hans: Txt2Img Workflow + pt_BR: Txt2Img Workflow +description: + human: + en_US: a pre-defined comfyui workflow that can use one model and up to 3 loras to generate images. Support SD1.5, SDXL, SD3 and FLUX which contain text encoders/clip, but does not support models that requires a triple clip loader. + zh_Hans: 一个预定义的 ComfyUI 工作流,可以使用一个模型和最多3个loras来生成图像。支持包含文本编码器/clip的SD1.5、SDXL、SD3和FLUX,但不支持需要clip加载器的模型。 + pt_BR: a pre-defined comfyui workflow that can use one model and up to 3 loras to generate images. Support SD1.5, SDXL, SD3 and FLUX which contain text encoders/clip, but does not support models that requires a triple clip loader. + llm: draw the image you want based on your prompt. +parameters: + - name: prompt + type: string + required: true + label: + en_US: Prompt + zh_Hans: 提示词 + pt_BR: Prompt + human_description: + en_US: Image prompt, you can check the official documentation of Stable Diffusion or FLUX + zh_Hans: 图像提示词,您可以查看 Stable Diffusion 或者 FLUX 的官方文档 + pt_BR: Image prompt, you can check the official documentation of Stable Diffusion or FLUX + llm_description: Image prompt of Stable Diffusion, you should describe the image you want to generate as a list of words as possible as detailed, the prompt must be written in English. + form: llm + - name: model + type: string + required: true + label: + en_US: Model Name + zh_Hans: 模型名称 + pt_BR: Model Name + human_description: + en_US: Model Name + zh_Hans: 模型名称 + pt_BR: Model Name + form: form + - name: model_type + type: string + required: true + label: + en_US: Model Type + zh_Hans: 模型类型 + pt_BR: Model Type + human_description: + en_US: Model Type + zh_Hans: 模型类型 + pt_BR: Model Type + form: form + - name: lora_1 + type: string + required: false + label: + en_US: Lora 1 + zh_Hans: Lora 1 + pt_BR: Lora 1 + human_description: + en_US: Lora 1 + zh_Hans: Lora 1 + pt_BR: Lora 1 + form: form + - name: lora_strength_1 + type: number + required: false + label: + en_US: Lora Strength 1 + zh_Hans: Lora Strength 1 + pt_BR: Lora Strength 1 + human_description: + en_US: Lora Strength 1 + zh_Hans: Lora模型的权重 + pt_BR: Lora Strength 1 + form: form + - name: steps + type: number + required: false + label: + en_US: Steps + zh_Hans: Steps + pt_BR: Steps + human_description: + en_US: Steps + zh_Hans: Steps + pt_BR: Steps + form: form + default: 20 + - name: width + type: number + required: false + label: + en_US: Width + zh_Hans: Width + pt_BR: Width + human_description: + en_US: Width + zh_Hans: Width + pt_BR: Width + form: form + default: 1024 + - name: height + type: number + required: false + label: + en_US: Height + zh_Hans: Height + pt_BR: Height + human_description: + en_US: Height + zh_Hans: Height + pt_BR: Height + form: form + default: 1024 + - name: negative_prompt + type: string + required: false + label: + en_US: Negative prompt + zh_Hans: Negative prompt + pt_BR: Negative prompt + human_description: + en_US: Negative prompt + zh_Hans: Negative prompt + pt_BR: Negative prompt + form: form + default: bad art, ugly, deformed, watermark, duplicated, discontinuous lines + - name: cfg + type: number + required: false + label: + en_US: CFG Scale + zh_Hans: CFG Scale + pt_BR: CFG Scale + human_description: + en_US: CFG Scale + zh_Hans: 提示词相关性(CFG Scale) + pt_BR: CFG Scale + form: form + default: 7.0 + - name: sampler_name + type: string + required: false + label: + en_US: Sampling method + zh_Hans: Sampling method + pt_BR: Sampling method + human_description: + en_US: Sampling method + zh_Hans: Sampling method + pt_BR: Sampling method + form: form + - name: scheduler + type: string + required: false + label: + en_US: Scheduler + zh_Hans: Scheduler + pt_BR: Scheduler + human_description: + en_US: Scheduler + zh_Hans: Scheduler + pt_BR: Scheduler + form: form + - name: lora_2 + type: string + required: false + label: + en_US: Lora 2 + zh_Hans: Lora 2 + pt_BR: Lora 2 + human_description: + en_US: Lora 2 + zh_Hans: Lora 2 + pt_BR: Lora 2 + form: form + - name: lora_strength_2 + type: number + required: false + label: + en_US: Lora Strength 2 + zh_Hans: Lora Strength 2 + pt_BR: Lora Strength 2 + human_description: + en_US: Lora Strength 2 + zh_Hans: Lora模型的权重 + pt_BR: Lora Strength 2 + form: form + - name: lora_3 + type: string + required: false + label: + en_US: Lora 3 + zh_Hans: Lora 3 + pt_BR: Lora 3 + human_description: + en_US: Lora 3 + zh_Hans: Lora 3 + pt_BR: Lora 3 + form: form + - name: lora_strength_3 + type: number + required: false + label: + en_US: Lora Strength 3 + zh_Hans: Lora Strength 3 + pt_BR: Lora Strength 3 + human_description: + en_US: Lora Strength 3 + zh_Hans: Lora模型的权重 + pt_BR: Lora Strength 3 + form: form diff --git a/api/core/tools/provider/builtin/comfyui/tools/txt2img.json b/api/core/tools/provider/builtin/comfyui/tools/txt2img.json new file mode 100644 index 0000000000..8ea869ff10 --- /dev/null +++ b/api/core/tools/provider/builtin/comfyui/tools/txt2img.json @@ -0,0 +1,107 @@ +{ + "3": { + "inputs": { + "seed": 156680208700286, + "steps": 20, + "cfg": 8, + "sampler_name": "euler", + "scheduler": "normal", + "denoise": 1, + "model": [ + "4", + 0 + ], + "positive": [ + "6", + 0 + ], + "negative": [ + "7", + 0 + ], + "latent_image": [ + "5", + 0 + ] + }, + "class_type": "KSampler", + "_meta": { + "title": "KSampler" + } + }, + "4": { + "inputs": { + "ckpt_name": "3dAnimationDiffusion_v10.safetensors" + }, + "class_type": "CheckpointLoaderSimple", + "_meta": { + "title": "Load Checkpoint" + } + }, + "5": { + "inputs": { + "width": 512, + "height": 512, + "batch_size": 1 + }, + "class_type": "EmptyLatentImage", + "_meta": { + "title": "Empty Latent Image" + } + }, + "6": { + "inputs": { + "text": "beautiful scenery nature glass bottle landscape, , purple galaxy bottle,", + "clip": [ + "4", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "7": { + "inputs": { + "text": "text, watermark", + "clip": [ + "4", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "8": { + "inputs": { + "samples": [ + "3", + 0 + ], + "vae": [ + "4", + 2 + ] + }, + "class_type": "VAEDecode", + "_meta": { + "title": "VAE Decode" + } + }, + "9": { + "inputs": { + "filename_prefix": "ComfyUI", + "images": [ + "8", + 0 + ] + }, + "class_type": "SaveImage", + "_meta": { + "title": "Save Image" + } + } +} \ No newline at end of file diff --git a/api/core/tools/provider/builtin/dalle/tools/dalle3.py b/api/core/tools/provider/builtin/dalle/tools/dalle3.py index bcfa2212b6..a8c647d71e 100644 --- a/api/core/tools/provider/builtin/dalle/tools/dalle3.py +++ b/api/core/tools/provider/builtin/dalle/tools/dalle3.py @@ -49,11 +49,11 @@ class DallE3Tool(BuiltinTool): n = tool_parameters.get("n", 1) # get quality quality = tool_parameters.get("quality", "standard") - if quality not in ["standard", "hd"]: + if quality not in {"standard", "hd"}: return self.create_text_message("Invalid quality") # get style style = tool_parameters.get("style", "vivid") - if style not in ["natural", "vivid"]: + if style not in {"natural", "vivid"}: return self.create_text_message("Invalid style") # call openapi dalle3 diff --git a/api/core/tools/provider/builtin/feishu_document/tools/get_document_raw_content.py b/api/core/tools/provider/builtin/feishu_document/tools/get_document_content.py similarity index 79% rename from api/core/tools/provider/builtin/feishu_document/tools/get_document_raw_content.py rename to api/core/tools/provider/builtin/feishu_document/tools/get_document_content.py index 83073e0822..c94a5f70ed 100644 --- a/api/core/tools/provider/builtin/feishu_document/tools/get_document_raw_content.py +++ b/api/core/tools/provider/builtin/feishu_document/tools/get_document_content.py @@ -12,6 +12,8 @@ class GetDocumentRawContentTool(BuiltinTool): client = FeishuRequest(app_id, app_secret) document_id = tool_parameters.get("document_id") + mode = tool_parameters.get("mode") + lang = tool_parameters.get("lang", 0) - res = client.get_document_raw_content(document_id) + res = client.get_document_content(document_id, mode, lang) return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_document/tools/get_document_content.yaml b/api/core/tools/provider/builtin/feishu_document/tools/get_document_content.yaml new file mode 100644 index 0000000000..51eda73a60 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_document/tools/get_document_content.yaml @@ -0,0 +1,49 @@ +identity: + name: get_document_content + author: Doug Lea + label: + en_US: Get Document Content + zh_Hans: 获取飞书云文档的内容 +description: + human: + en_US: Get document content + zh_Hans: 获取飞书云文档的内容 + llm: A tool for retrieving content from Feishu cloud documents. +parameters: + - name: document_id + type: string + required: true + label: + en_US: document_id + zh_Hans: 飞书文档的唯一标识 + human_description: + en_US: Unique identifier for a Feishu document. You can also input the document's URL. + zh_Hans: 飞书文档的唯一标识,支持输入文档的 URL。 + llm_description: 飞书文档的唯一标识,支持输入文档的 URL。 + form: llm + + - name: mode + type: string + required: false + label: + en_US: mode + zh_Hans: 文档返回格式 + human_description: + en_US: Format of the document return, optional values are text, markdown, can be empty, default is markdown. + zh_Hans: 文档返回格式,可选值有 text、markdown,可以为空,默认值为 markdown。 + llm_description: 文档返回格式,可选值有 text、markdown,可以为空,默认值为 markdown。 + form: llm + + - name: lang + type: number + required: false + default: 0 + label: + en_US: lang + zh_Hans: 指定@用户的语言 + human_description: + en_US: | + Specifies the language for MentionUser, optional values are [0, 1]. 0: User's default name, 1: User's English name, default is 0. + zh_Hans: 指定返回的 MentionUser,即 @用户 的语言,可选值有 [0,1]。0:该用户的默认名称,1:该用户的英文名称,默认值为 0。 + llm_description: 指定返回的 MentionUser,即 @用户 的语言,可选值有 [0,1]。0:该用户的默认名称,1:该用户的英文名称,默认值为 0。 + form: llm diff --git a/api/core/tools/provider/builtin/feishu_document/tools/get_document_raw_content.yaml b/api/core/tools/provider/builtin/feishu_document/tools/get_document_raw_content.yaml deleted file mode 100644 index e5b0937e03..0000000000 --- a/api/core/tools/provider/builtin/feishu_document/tools/get_document_raw_content.yaml +++ /dev/null @@ -1,23 +0,0 @@ -identity: - name: get_document_raw_content - author: Doug Lea - label: - en_US: Get Document Raw Content - zh_Hans: 获取文档纯文本内容 -description: - human: - en_US: Get document raw content - zh_Hans: 获取文档纯文本内容 - llm: A tool for getting the plain text content of Feishu documents -parameters: - - name: document_id - type: string - required: true - label: - en_US: document_id - zh_Hans: 飞书文档的唯一标识 - human_description: - en_US: Unique ID of Feishu document document_id - zh_Hans: 飞书文档的唯一标识 document_id - llm_description: 飞书文档的唯一标识 document_id - form: llm diff --git a/api/core/tools/provider/builtin/feishu_document/tools/list_document_block.yaml b/api/core/tools/provider/builtin/feishu_document/tools/list_document_block.yaml deleted file mode 100644 index d51e5a837c..0000000000 --- a/api/core/tools/provider/builtin/feishu_document/tools/list_document_block.yaml +++ /dev/null @@ -1,48 +0,0 @@ -identity: - name: list_document_block - author: Doug Lea - label: - en_US: List Document Block - zh_Hans: 获取飞书文档所有块 -description: - human: - en_US: List document block - zh_Hans: 获取飞书文档所有块的富文本内容并分页返回。 - llm: A tool to get all blocks of Feishu documents -parameters: - - name: document_id - type: string - required: true - label: - en_US: document_id - zh_Hans: 飞书文档的唯一标识 - human_description: - en_US: Unique ID of Feishu document document_id - zh_Hans: 飞书文档的唯一标识 document_id - llm_description: 飞书文档的唯一标识 document_id - form: llm - - - name: page_size - type: number - required: false - default: 500 - label: - en_US: page_size - zh_Hans: 分页大小 - human_description: - en_US: Paging size, the default and maximum value is 500. - zh_Hans: 分页大小, 默认值和最大值为 500。 - llm_description: 分页大小, 表示一次请求最多返回多少条数据,默认值和最大值为 500。 - form: llm - - - name: page_token - type: string - required: false - label: - en_US: page_token - zh_Hans: 分页标记 - human_description: - en_US: Pagination tag, used to paginate query results so that more items can be obtained in the next traversal. - zh_Hans: 分页标记,用于分页查询结果,以便下次遍历时获取更多项。 - llm_description: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。 - form: llm diff --git a/api/core/tools/provider/builtin/feishu_document/tools/list_document_block.py b/api/core/tools/provider/builtin/feishu_document/tools/list_document_blocks.py similarity index 90% rename from api/core/tools/provider/builtin/feishu_document/tools/list_document_block.py rename to api/core/tools/provider/builtin/feishu_document/tools/list_document_blocks.py index 8c0c4a3c97..572a7abf28 100644 --- a/api/core/tools/provider/builtin/feishu_document/tools/list_document_block.py +++ b/api/core/tools/provider/builtin/feishu_document/tools/list_document_blocks.py @@ -15,5 +15,5 @@ class ListDocumentBlockTool(BuiltinTool): page_size = tool_parameters.get("page_size", 500) page_token = tool_parameters.get("page_token", "") - res = client.list_document_block(document_id, page_token, page_size) + res = client.list_document_blocks(document_id, page_token, page_size) return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_document/tools/list_document_blocks.yaml b/api/core/tools/provider/builtin/feishu_document/tools/list_document_blocks.yaml new file mode 100644 index 0000000000..019ac98390 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_document/tools/list_document_blocks.yaml @@ -0,0 +1,74 @@ +identity: + name: list_document_blocks + author: Doug Lea + label: + en_US: List Document Blocks + zh_Hans: 获取飞书文档所有块 +description: + human: + en_US: List document blocks + zh_Hans: 获取飞书文档所有块的富文本内容并分页返回 + llm: A tool to get all blocks of Feishu documents +parameters: + - name: document_id + type: string + required: true + label: + en_US: document_id + zh_Hans: 飞书文档的唯一标识 + human_description: + en_US: Unique identifier for a Feishu document. You can also input the document's URL. + zh_Hans: 飞书文档的唯一标识,支持输入文档的 URL。 + llm_description: 飞书文档的唯一标识,支持输入文档的 URL。 + form: llm + + - name: user_id_type + type: select + required: false + options: + - value: open_id + label: + en_US: open_id + zh_Hans: open_id + - value: union_id + label: + en_US: union_id + zh_Hans: union_id + - value: user_id + label: + en_US: user_id + zh_Hans: user_id + default: "open_id" + label: + en_US: user_id_type + zh_Hans: 用户 ID 类型 + human_description: + en_US: User ID type, optional values are open_id, union_id, user_id, with a default value of open_id. + zh_Hans: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + llm_description: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + form: llm + + - name: page_size + type: number + required: false + default: "500" + label: + en_US: page_size + zh_Hans: 分页大小 + human_description: + en_US: Paging size, the default and maximum value is 500. + zh_Hans: 分页大小, 默认值和最大值为 500。 + llm_description: 分页大小, 表示一次请求最多返回多少条数据,默认值和最大值为 500。 + form: llm + + - name: page_token + type: string + required: false + label: + en_US: page_token + zh_Hans: 分页标记 + human_description: + en_US: Pagination token used to navigate through query results, allowing retrieval of additional items in subsequent requests. + zh_Hans: 分页标记,用于分页查询结果,以便下次遍历时获取更多项。 + llm_description: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。 + form: llm diff --git a/api/core/tools/provider/builtin/feishu_document/tools/write_document.yaml b/api/core/tools/provider/builtin/feishu_document/tools/write_document.yaml index 8ee219d4a7..4282e3dcf3 100644 --- a/api/core/tools/provider/builtin/feishu_document/tools/write_document.yaml +++ b/api/core/tools/provider/builtin/feishu_document/tools/write_document.yaml @@ -17,33 +17,35 @@ parameters: en_US: document_id zh_Hans: 飞书文档的唯一标识 human_description: - en_US: Unique ID of Feishu document document_id - zh_Hans: 飞书文档的唯一标识 document_id - llm_description: 飞书文档的唯一标识 document_id + en_US: Unique identifier for a Feishu document. You can also input the document's URL. + zh_Hans: 飞书文档的唯一标识,支持输入文档的 URL。 + llm_description: 飞书文档的唯一标识,支持输入文档的 URL。 form: llm - name: content type: string required: true label: - en_US: document content - zh_Hans: 文档内容 + en_US: Plain text or Markdown content + zh_Hans: 纯文本或 Markdown 内容 human_description: - en_US: Document content, supports markdown syntax, can be empty. - zh_Hans: 文档内容,支持 markdown 语法,可以为空。 - llm_description: + en_US: Plain text or Markdown content. Note that embedded tables in the document should not have merged cells. + zh_Hans: 纯文本或 Markdown 内容。注意文档的内嵌套表格不允许有单元格合并。 + llm_description: 纯文本或 Markdown 内容,注意文档的内嵌套表格不允许有单元格合并。 form: llm - name: position - type: select - required: true - default: start + type: string + required: false label: - en_US: Choose where to add content - zh_Hans: 选择添加内容的位置 + en_US: position + zh_Hans: 添加位置 human_description: - en_US: Please fill in start or end to add content at the beginning or end of the document respectively. - zh_Hans: 请填入 start 或 end, 分别表示在文档开头(start)或结尾(end)添加内容。 + en_US: | + Enumeration values: start or end. Use 'start' to add content at the beginning of the document, and 'end' to add content at the end. The default value is 'end'. + zh_Hans: 枚举值:start 或 end。使用 'start' 在文档开头添加内容,使用 'end' 在文档结尾添加内容,默认值为 'end'。 + llm_description: | + 枚举值 start、end,start: 在文档开头添加内容;end: 在文档结尾添加内容,默认值为 end。 form: llm options: - value: start @@ -54,3 +56,4 @@ parameters: label: en_US: end zh_Hans: 在文档结尾添加内容 + default: start diff --git a/api/core/tools/provider/builtin/firecrawl/firecrawl_appx.py b/api/core/tools/provider/builtin/firecrawl/firecrawl_appx.py index a0e4cdf933..d9fb6f04bc 100644 --- a/api/core/tools/provider/builtin/firecrawl/firecrawl_appx.py +++ b/api/core/tools/provider/builtin/firecrawl/firecrawl_appx.py @@ -37,9 +37,8 @@ class FirecrawlApp: for i in range(retries): try: response = requests.request(method, url, json=data, headers=headers) - response.raise_for_status() return response.json() - except requests.exceptions.RequestException as e: + except requests.exceptions.RequestException: if i < retries - 1: time.sleep(backoff_factor * (2**i)) else: @@ -47,7 +46,7 @@ class FirecrawlApp: return None def scrape_url(self, url: str, **kwargs): - endpoint = f"{self.base_url}/v0/scrape" + endpoint = f"{self.base_url}/v1/scrape" data = {"url": url, **kwargs} logger.debug(f"Sent request to {endpoint=} body={data}") response = self._request("POST", endpoint, data) @@ -55,39 +54,41 @@ class FirecrawlApp: raise HTTPError("Failed to scrape URL after multiple retries") return response - def search(self, query: str, **kwargs): - endpoint = f"{self.base_url}/v0/search" - data = {"query": query, **kwargs} + def map(self, url: str, **kwargs): + endpoint = f"{self.base_url}/v1/map" + data = {"url": url, **kwargs} logger.debug(f"Sent request to {endpoint=} body={data}") response = self._request("POST", endpoint, data) if response is None: - raise HTTPError("Failed to perform search after multiple retries") + raise HTTPError("Failed to perform map after multiple retries") return response def crawl_url( self, url: str, wait: bool = True, poll_interval: int = 5, idempotency_key: str | None = None, **kwargs ): - endpoint = f"{self.base_url}/v0/crawl" + endpoint = f"{self.base_url}/v1/crawl" headers = self._prepare_headers(idempotency_key) data = {"url": url, **kwargs} logger.debug(f"Sent request to {endpoint=} body={data}") response = self._request("POST", endpoint, data, headers) if response is None: raise HTTPError("Failed to initiate crawl after multiple retries") - job_id: str = response["jobId"] + elif response.get("success") == False: + raise HTTPError(f'Failed to crawl: {response.get("error")}') + job_id: str = response["id"] if wait: return self._monitor_job_status(job_id=job_id, poll_interval=poll_interval) return response def check_crawl_status(self, job_id: str): - endpoint = f"{self.base_url}/v0/crawl/status/{job_id}" + endpoint = f"{self.base_url}/v1/crawl/{job_id}" response = self._request("GET", endpoint) if response is None: raise HTTPError(f"Failed to check status for job {job_id} after multiple retries") return response def cancel_crawl_job(self, job_id: str): - endpoint = f"{self.base_url}/v0/crawl/cancel/{job_id}" + endpoint = f"{self.base_url}/v1/crawl/{job_id}" response = self._request("DELETE", endpoint) if response is None: raise HTTPError(f"Failed to cancel job {job_id} after multiple retries") @@ -116,6 +117,6 @@ def get_json_params(tool_parameters: dict[str, Any], key): # support both single quotes and double quotes param = param.replace("'", '"') param = json.loads(param) - except: + except Exception: raise ValueError(f"Invalid {key} format.") return param diff --git a/api/core/tools/provider/builtin/firecrawl/tools/crawl.py b/api/core/tools/provider/builtin/firecrawl/tools/crawl.py index 94717cbbfb..9675b8eb91 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/crawl.py +++ b/api/core/tools/provider/builtin/firecrawl/tools/crawl.py @@ -8,39 +8,38 @@ from core.tools.tool.builtin_tool import BuiltinTool class CrawlTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: """ - the crawlerOptions and pageOptions comes from doc here: + the api doc: https://docs.firecrawl.dev/api-reference/endpoint/crawl """ app = FirecrawlApp( api_key=self.runtime.credentials["firecrawl_api_key"], base_url=self.runtime.credentials["base_url"] ) - crawlerOptions = {} - pageOptions = {} + + scrapeOptions = {} + payload = {} wait_for_results = tool_parameters.get("wait_for_results", True) - crawlerOptions["excludes"] = get_array_params(tool_parameters, "excludes") - crawlerOptions["includes"] = get_array_params(tool_parameters, "includes") - crawlerOptions["returnOnlyUrls"] = tool_parameters.get("returnOnlyUrls", False) - crawlerOptions["maxDepth"] = tool_parameters.get("maxDepth") - crawlerOptions["mode"] = tool_parameters.get("mode") - crawlerOptions["ignoreSitemap"] = tool_parameters.get("ignoreSitemap", False) - crawlerOptions["limit"] = tool_parameters.get("limit", 5) - crawlerOptions["allowBackwardCrawling"] = tool_parameters.get("allowBackwardCrawling", False) - crawlerOptions["allowExternalContentLinks"] = tool_parameters.get("allowExternalContentLinks", False) + payload["excludePaths"] = get_array_params(tool_parameters, "excludePaths") + payload["includePaths"] = get_array_params(tool_parameters, "includePaths") + payload["maxDepth"] = tool_parameters.get("maxDepth") + payload["ignoreSitemap"] = tool_parameters.get("ignoreSitemap", False) + payload["limit"] = tool_parameters.get("limit", 5) + payload["allowBackwardLinks"] = tool_parameters.get("allowBackwardLinks", False) + payload["allowExternalLinks"] = tool_parameters.get("allowExternalLinks", False) + payload["webhook"] = tool_parameters.get("webhook") - pageOptions["headers"] = get_json_params(tool_parameters, "headers") - pageOptions["includeHtml"] = tool_parameters.get("includeHtml", False) - pageOptions["includeRawHtml"] = tool_parameters.get("includeRawHtml", False) - pageOptions["onlyIncludeTags"] = get_array_params(tool_parameters, "onlyIncludeTags") - pageOptions["removeTags"] = get_array_params(tool_parameters, "removeTags") - pageOptions["onlyMainContent"] = tool_parameters.get("onlyMainContent", False) - pageOptions["replaceAllPathsWithAbsolutePaths"] = tool_parameters.get("replaceAllPathsWithAbsolutePaths", False) - pageOptions["screenshot"] = tool_parameters.get("screenshot", False) - pageOptions["waitFor"] = tool_parameters.get("waitFor", 0) + scrapeOptions["formats"] = get_array_params(tool_parameters, "formats") + scrapeOptions["headers"] = get_json_params(tool_parameters, "headers") + scrapeOptions["includeTags"] = get_array_params(tool_parameters, "includeTags") + scrapeOptions["excludeTags"] = get_array_params(tool_parameters, "excludeTags") + scrapeOptions["onlyMainContent"] = tool_parameters.get("onlyMainContent", False) + scrapeOptions["waitFor"] = tool_parameters.get("waitFor", 0) + scrapeOptions = {k: v for k, v in scrapeOptions.items() if v not in {None, ""}} + payload["scrapeOptions"] = scrapeOptions or None - crawl_result = app.crawl_url( - url=tool_parameters["url"], wait=wait_for_results, crawlerOptions=crawlerOptions, pageOptions=pageOptions - ) + payload = {k: v for k, v in payload.items() if v not in {None, ""}} + + crawl_result = app.crawl_url(url=tool_parameters["url"], wait=wait_for_results, **payload) return self.create_json_message(crawl_result) diff --git a/api/core/tools/provider/builtin/firecrawl/tools/crawl.yaml b/api/core/tools/provider/builtin/firecrawl/tools/crawl.yaml index 0c5399f973..0d7dbcac20 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/crawl.yaml +++ b/api/core/tools/provider/builtin/firecrawl/tools/crawl.yaml @@ -31,8 +31,21 @@ parameters: en_US: If you choose not to wait, it will directly return a job ID. You can use this job ID to check the crawling results or cancel the crawling task, which is usually very useful for a large-scale crawling task. zh_Hans: 如果选择不等待,则会直接返回一个job_id,可以通过job_id查询爬取结果或取消爬取任务,这通常对于一个大型爬取任务来说非常有用。 form: form -############## Crawl Options ####################### - - name: includes +############## Payload ####################### + - name: excludePaths + type: string + label: + en_US: URL patterns to exclude + zh_Hans: 要排除的URL模式 + placeholder: + en_US: Use commas to separate multiple tags + zh_Hans: 多个标签时使用半角逗号分隔 + human_description: + en_US: | + Pages matching these patterns will be skipped. Example: blog/*, about/* + zh_Hans: 匹配这些模式的页面将被跳过。示例:blog/*, about/* + form: form + - name: includePaths type: string required: false label: @@ -46,30 +59,6 @@ parameters: Only pages matching these patterns will be crawled. Example: blog/*, about/* zh_Hans: 只有与这些模式匹配的页面才会被爬取。示例:blog/*, about/* form: form - - name: excludes - type: string - label: - en_US: URL patterns to exclude - zh_Hans: 要排除的URL模式 - placeholder: - en_US: Use commas to separate multiple tags - zh_Hans: 多个标签时使用半角逗号分隔 - human_description: - en_US: | - Pages matching these patterns will be skipped. Example: blog/*, about/* - zh_Hans: 匹配这些模式的页面将被跳过。示例:blog/*, about/* - form: form - - name: returnOnlyUrls - type: boolean - default: false - label: - en_US: return Only Urls - zh_Hans: 仅返回URL - human_description: - en_US: | - If true, returns only the URLs as a list on the crawl status. Attention: the return response will be a list of URLs inside the data, not a list of documents. - zh_Hans: 只返回爬取到的网页链接,而不是网页内容本身。 - form: form - name: maxDepth type: number label: @@ -80,27 +69,10 @@ parameters: zh_Hans: 相对于输入的URL,爬取的最大深度。maxDepth为0时,仅抓取输入的URL。maxDepth为1时,抓取输入的URL以及所有一级深层页面。maxDepth为2时,抓取输入的URL以及所有两级深层页面。更高值遵循相同模式。 form: form min: 0 - - name: mode - type: select - required: false - form: form - options: - - value: default - label: - en_US: default - - value: fast - label: - en_US: fast - default: default - label: - en_US: Crawl Mode - zh_Hans: 爬取模式 - human_description: - en_US: The crawling mode to use. Fast mode crawls 4x faster websites without sitemap, but may not be as accurate and shouldn't be used in heavy js-rendered websites. - zh_Hans: 使用fast模式将不会使用其站点地图,比普通模式快4倍,但是可能不够准确,也不适用于大量js渲染的网站。 + default: 2 - name: ignoreSitemap type: boolean - default: false + default: true label: en_US: ignore Sitemap zh_Hans: 忽略站点地图 @@ -120,7 +92,7 @@ parameters: form: form min: 1 default: 5 - - name: allowBackwardCrawling + - name: allowBackwardLinks type: boolean default: false label: @@ -130,7 +102,7 @@ parameters: en_US: Enables the crawler to navigate from a specific URL to previously linked pages. For instance, from 'example.com/product/123' back to 'example.com/product' zh_Hans: 使爬虫能够从特定URL导航到之前链接的页面。例如,从'example.com/product/123'返回到'example.com/product' form: form - - name: allowExternalContentLinks + - name: allowExternalLinks type: boolean default: false label: @@ -140,7 +112,30 @@ parameters: en_US: Allows the crawler to follow links to external websites. zh_Hans: form: form -############## Page Options ####################### + - name: webhook + type: string + label: + en_US: webhook + human_description: + en_US: | + The URL to send the webhook to. This will trigger for crawl started (crawl.started) ,every page crawled (crawl.page) and when the crawl is completed (crawl.completed or crawl.failed). The response will be the same as the /scrape endpoint. + zh_Hans: 发送Webhook的URL。这将在开始爬取(crawl.started)、每爬取一个页面(crawl.page)以及爬取完成(crawl.completed或crawl.failed)时触发。响应将与/scrape端点相同。 + form: form +############## Scrape Options ####################### + - name: formats + type: string + label: + en_US: Formats + zh_Hans: 结果的格式 + placeholder: + en_US: Use commas to separate multiple tags + zh_Hans: 多个标签时使用半角逗号分隔 + human_description: + en_US: | + Formats to include in the output. Available options: markdown, html, rawHtml, links, screenshot + zh_Hans: | + 输出中应包含的格式。可以填入: markdown, html, rawHtml, links, screenshot + form: form - name: headers type: string label: @@ -155,30 +150,10 @@ parameters: en_US: Please enter an object that can be serialized in JSON zh_Hans: 请输入可以json序列化的对象 form: form - - name: includeHtml - type: boolean - default: false - label: - en_US: include Html - zh_Hans: 包含HTML - human_description: - en_US: Include the HTML version of the content on page. Will output a html key in the response. - zh_Hans: 返回中包含一个HTML版本的内容,将以html键返回。 - form: form - - name: includeRawHtml - type: boolean - default: false - label: - en_US: include Raw Html - zh_Hans: 包含原始HTML - human_description: - en_US: Include the raw HTML content of the page. Will output a rawHtml key in the response. - zh_Hans: 返回中包含一个原始HTML版本的内容,将以rawHtml键返回。 - form: form - - name: onlyIncludeTags + - name: includeTags type: string label: - en_US: only Include Tags + en_US: Include Tags zh_Hans: 仅抓取这些标签 placeholder: en_US: Use commas to separate multiple tags @@ -189,6 +164,20 @@ parameters: zh_Hans: | 仅在最终输出中包含HTML页面的这些标签,可以通过标签名、类或ID来设定,使用逗号分隔值。示例:script, .ad, #footer form: form + - name: excludeTags + type: string + label: + en_US: Exclude Tags + zh_Hans: 要移除这些标签 + human_description: + en_US: | + Tags, classes and ids to remove from the page. Use comma separated values. Example: script, .ad, #footer + zh_Hans: | + 要在最终输出中移除HTML页面的这些标签,可以通过标签名、类或ID来设定,使用逗号分隔值。示例:script, .ad, #footer + placeholder: + en_US: Use commas to separate multiple tags + zh_Hans: 多个标签时使用半角逗号分隔 + form: form - name: onlyMainContent type: boolean default: false @@ -199,40 +188,6 @@ parameters: en_US: Only return the main content of the page excluding headers, navs, footers, etc. zh_Hans: 只返回页面的主要内容,不包括头部、导航栏、尾部等。 form: form - - name: removeTags - type: string - label: - en_US: remove Tags - zh_Hans: 要移除这些标签 - human_description: - en_US: | - Tags, classes and ids to remove from the page. Use comma separated values. Example: script, .ad, #footer - zh_Hans: | - 要在最终输出中移除HTML页面的这些标签,可以通过标签名、类或ID来设定,使用逗号分隔值。示例:script, .ad, #footer - placeholder: - en_US: Use commas to separate multiple tags - zh_Hans: 多个标签时使用半角逗号分隔 - form: form - - name: replaceAllPathsWithAbsolutePaths - type: boolean - default: false - label: - en_US: All AbsolutePaths - zh_Hans: 使用绝对路径 - human_description: - en_US: Replace all relative paths with absolute paths for images and links. - zh_Hans: 将所有图片和链接的相对路径替换为绝对路径。 - form: form - - name: screenshot - type: boolean - default: false - label: - en_US: screenshot - zh_Hans: 截图 - human_description: - en_US: Include a screenshot of the top of the page that you are scraping. - zh_Hans: 提供正在抓取的页面的顶部的截图。 - form: form - name: waitFor type: number min: 0 diff --git a/api/core/tools/provider/builtin/firecrawl/tools/map.py b/api/core/tools/provider/builtin/firecrawl/tools/map.py new file mode 100644 index 0000000000..bdfb5faeb8 --- /dev/null +++ b/api/core/tools/provider/builtin/firecrawl/tools/map.py @@ -0,0 +1,25 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.provider.builtin.firecrawl.firecrawl_appx import FirecrawlApp +from core.tools.tool.builtin_tool import BuiltinTool + + +class MapTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + """ + the api doc: + https://docs.firecrawl.dev/api-reference/endpoint/map + """ + app = FirecrawlApp( + api_key=self.runtime.credentials["firecrawl_api_key"], base_url=self.runtime.credentials["base_url"] + ) + payload = {} + payload["search"] = tool_parameters.get("search") + payload["ignoreSitemap"] = tool_parameters.get("ignoreSitemap", True) + payload["includeSubdomains"] = tool_parameters.get("includeSubdomains", False) + payload["limit"] = tool_parameters.get("limit", 5000) + + map_result = app.map(url=tool_parameters["url"], **payload) + + return self.create_json_message(map_result) diff --git a/api/core/tools/provider/builtin/firecrawl/tools/map.yaml b/api/core/tools/provider/builtin/firecrawl/tools/map.yaml new file mode 100644 index 0000000000..9913756983 --- /dev/null +++ b/api/core/tools/provider/builtin/firecrawl/tools/map.yaml @@ -0,0 +1,59 @@ +identity: + name: map + author: hjlarry + label: + en_US: Map + zh_Hans: 地图式快爬 +description: + human: + en_US: Input a website and get all the urls on the website - extremly fast + zh_Hans: 输入一个网站,快速获取网站上的所有网址。 + llm: Input a website and get all the urls on the website - extremly fast +parameters: + - name: url + type: string + required: true + label: + en_US: Start URL + zh_Hans: 起始URL + human_description: + en_US: The base URL to start crawling from. + zh_Hans: 要爬取网站的起始URL。 + llm_description: The URL of the website that needs to be crawled. This is a required parameter. + form: llm + - name: search + type: string + label: + en_US: search + zh_Hans: 搜索查询 + human_description: + en_US: Search query to use for mapping. During the Alpha phase, the 'smart' part of the search functionality is limited to 100 search results. However, if map finds more results, there is no limit applied. + zh_Hans: 用于映射的搜索查询。在Alpha阶段,搜索功能的“智能”部分限制为最多100个搜索结果。然而,如果地图找到了更多结果,则不施加任何限制。 + llm_description: Search query to use for mapping. During the Alpha phase, the 'smart' part of the search functionality is limited to 100 search results. However, if map finds more results, there is no limit applied. + form: llm +############## Page Options ####################### + - name: ignoreSitemap + type: boolean + default: true + label: + en_US: ignore Sitemap + zh_Hans: 忽略站点地图 + human_description: + en_US: Ignore the website sitemap when crawling. + zh_Hans: 爬取时忽略网站站点地图。 + form: form + - name: includeSubdomains + type: boolean + default: false + label: + en_US: include Subdomains + zh_Hans: 包含子域名 + form: form + - name: limit + type: number + min: 0 + default: 5000 + label: + en_US: Maximum results + zh_Hans: 最大结果数量 + form: form diff --git a/api/core/tools/provider/builtin/firecrawl/tools/scrape.py b/api/core/tools/provider/builtin/firecrawl/tools/scrape.py index 962570bf73..538b4a1fcb 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/scrape.py +++ b/api/core/tools/provider/builtin/firecrawl/tools/scrape.py @@ -6,34 +6,34 @@ from core.tools.tool.builtin_tool import BuiltinTool class ScrapeTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]: """ - the pageOptions and extractorOptions comes from doc here: + the api doc: https://docs.firecrawl.dev/api-reference/endpoint/scrape """ app = FirecrawlApp( api_key=self.runtime.credentials["firecrawl_api_key"], base_url=self.runtime.credentials["base_url"] ) - pageOptions = {} - extractorOptions = {} + payload = {} + extract = {} - pageOptions["headers"] = get_json_params(tool_parameters, "headers") - pageOptions["includeHtml"] = tool_parameters.get("includeHtml", False) - pageOptions["includeRawHtml"] = tool_parameters.get("includeRawHtml", False) - pageOptions["onlyIncludeTags"] = get_array_params(tool_parameters, "onlyIncludeTags") - pageOptions["removeTags"] = get_array_params(tool_parameters, "removeTags") - pageOptions["onlyMainContent"] = tool_parameters.get("onlyMainContent", False) - pageOptions["replaceAllPathsWithAbsolutePaths"] = tool_parameters.get("replaceAllPathsWithAbsolutePaths", False) - pageOptions["screenshot"] = tool_parameters.get("screenshot", False) - pageOptions["waitFor"] = tool_parameters.get("waitFor", 0) + payload["formats"] = get_array_params(tool_parameters, "formats") + payload["onlyMainContent"] = tool_parameters.get("onlyMainContent", True) + payload["includeTags"] = get_array_params(tool_parameters, "includeTags") + payload["excludeTags"] = get_array_params(tool_parameters, "excludeTags") + payload["headers"] = get_json_params(tool_parameters, "headers") + payload["waitFor"] = tool_parameters.get("waitFor", 0) + payload["timeout"] = tool_parameters.get("timeout", 30000) - extractorOptions["mode"] = tool_parameters.get("mode", "") - extractorOptions["extractionPrompt"] = tool_parameters.get("extractionPrompt", "") - extractorOptions["extractionSchema"] = get_json_params(tool_parameters, "extractionSchema") + extract["schema"] = get_json_params(tool_parameters, "schema") + extract["systemPrompt"] = tool_parameters.get("systemPrompt") + extract["prompt"] = tool_parameters.get("prompt") + extract = {k: v for k, v in extract.items() if v not in {None, ""}} + payload["extract"] = extract or None - crawl_result = app.scrape_url( - url=tool_parameters["url"], pageOptions=pageOptions, extractorOptions=extractorOptions - ) + payload = {k: v for k, v in payload.items() if v not in {None, ""}} - return self.create_json_message(crawl_result) + crawl_result = app.scrape_url(url=tool_parameters["url"], **payload) + markdown_result = crawl_result.get("data", {}).get("markdown", "") + return [self.create_text_message(markdown_result), self.create_json_message(crawl_result)] diff --git a/api/core/tools/provider/builtin/firecrawl/tools/scrape.yaml b/api/core/tools/provider/builtin/firecrawl/tools/scrape.yaml index 598429de5e..8f1f1348a4 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/scrape.yaml +++ b/api/core/tools/provider/builtin/firecrawl/tools/scrape.yaml @@ -6,8 +6,8 @@ identity: zh_Hans: 单页面抓取 description: human: - en_US: Extract data from a single URL. - zh_Hans: 从单个URL抓取数据。 + en_US: Turn any url into clean data. + zh_Hans: 将任何网址转换为干净的数据。 llm: This tool is designed to scrape URL and output the content in Markdown format. parameters: - name: url @@ -21,7 +21,59 @@ parameters: zh_Hans: 要抓取并提取数据的网站URL。 llm_description: The URL of the website that needs to be crawled. This is a required parameter. form: llm -############## Page Options ####################### +############## Payload ####################### + - name: formats + type: string + label: + en_US: Formats + zh_Hans: 结果的格式 + placeholder: + en_US: Use commas to separate multiple tags + zh_Hans: 多个标签时使用半角逗号分隔 + human_description: + en_US: | + Formats to include in the output. Available options: markdown, html, rawHtml, links, screenshot, extract, screenshot@fullPage + zh_Hans: | + 输出中应包含的格式。可以填入: markdown, html, rawHtml, links, screenshot, extract, screenshot@fullPage + form: form + - name: onlyMainContent + type: boolean + default: false + label: + en_US: only Main Content + zh_Hans: 仅抓取主要内容 + human_description: + en_US: Only return the main content of the page excluding headers, navs, footers, etc. + zh_Hans: 只返回页面的主要内容,不包括头部、导航栏、尾部等。 + form: form + - name: includeTags + type: string + label: + en_US: Include Tags + zh_Hans: 仅抓取这些标签 + placeholder: + en_US: Use commas to separate multiple tags + zh_Hans: 多个标签时使用半角逗号分隔 + human_description: + en_US: | + Only include tags, classes and ids from the page in the final output. Use comma separated values. Example: script, .ad, #footer + zh_Hans: | + 仅在最终输出中包含HTML页面的这些标签,可以通过标签名、类或ID来设定,使用逗号分隔值。示例:script, .ad, #footer + form: form + - name: excludeTags + type: string + label: + en_US: Exclude Tags + zh_Hans: 要移除这些标签 + human_description: + en_US: | + Tags, classes and ids to remove from the page. Use comma separated values. Example: script, .ad, #footer + zh_Hans: | + 要在最终输出中移除HTML页面的这些标签,可以通过标签名、类或ID来设定,使用逗号分隔值。示例:script, .ad, #footer + placeholder: + en_US: Use commas to separate multiple tags + zh_Hans: 多个标签时使用半角逗号分隔 + form: form - name: headers type: string label: @@ -36,87 +88,10 @@ parameters: en_US: Please enter an object that can be serialized in JSON zh_Hans: 请输入可以json序列化的对象 form: form - - name: includeHtml - type: boolean - default: false - label: - en_US: include Html - zh_Hans: 包含HTML - human_description: - en_US: Include the HTML version of the content on page. Will output a html key in the response. - zh_Hans: 返回中包含一个HTML版本的内容,将以html键返回。 - form: form - - name: includeRawHtml - type: boolean - default: false - label: - en_US: include Raw Html - zh_Hans: 包含原始HTML - human_description: - en_US: Include the raw HTML content of the page. Will output a rawHtml key in the response. - zh_Hans: 返回中包含一个原始HTML版本的内容,将以rawHtml键返回。 - form: form - - name: onlyIncludeTags - type: string - label: - en_US: only Include Tags - zh_Hans: 仅抓取这些标签 - placeholder: - en_US: Use commas to separate multiple tags - zh_Hans: 多个标签时使用半角逗号分隔 - human_description: - en_US: | - Only include tags, classes and ids from the page in the final output. Use comma separated values. Example: script, .ad, #footer - zh_Hans: | - 仅在最终输出中包含HTML页面的这些标签,可以通过标签名、类或ID来设定,使用逗号分隔值。示例:script, .ad, #footer - form: form - - name: onlyMainContent - type: boolean - default: false - label: - en_US: only Main Content - zh_Hans: 仅抓取主要内容 - human_description: - en_US: Only return the main content of the page excluding headers, navs, footers, etc. - zh_Hans: 只返回页面的主要内容,不包括头部、导航栏、尾部等。 - form: form - - name: removeTags - type: string - label: - en_US: remove Tags - zh_Hans: 要移除这些标签 - human_description: - en_US: | - Tags, classes and ids to remove from the page. Use comma separated values. Example: script, .ad, #footer - zh_Hans: | - 要在最终输出中移除HTML页面的这些标签,可以通过标签名、类或ID来设定,使用逗号分隔值。示例:script, .ad, #footer - placeholder: - en_US: Use commas to separate multiple tags - zh_Hans: 多个标签时使用半角逗号分隔 - form: form - - name: replaceAllPathsWithAbsolutePaths - type: boolean - default: false - label: - en_US: All AbsolutePaths - zh_Hans: 使用绝对路径 - human_description: - en_US: Replace all relative paths with absolute paths for images and links. - zh_Hans: 将所有图片和链接的相对路径替换为绝对路径。 - form: form - - name: screenshot - type: boolean - default: false - label: - en_US: screenshot - zh_Hans: 截图 - human_description: - en_US: Include a screenshot of the top of the page that you are scraping. - zh_Hans: 提供正在抓取的页面的顶部的截图。 - form: form - name: waitFor type: number min: 0 + default: 0 label: en_US: wait For zh_Hans: 等待时间 @@ -124,57 +99,54 @@ parameters: en_US: Wait x amount of milliseconds for the page to load to fetch content. zh_Hans: 等待x毫秒以使页面加载并获取内容。 form: form + - name: timeout + type: number + min: 0 + default: 30000 + label: + en_US: Timeout + human_description: + en_US: Timeout in milliseconds for the request. + zh_Hans: 请求的超时时间(以毫秒为单位)。 + form: form ############## Extractor Options ####################### - - name: mode - type: select - options: - - value: markdown - label: - en_US: markdown - - value: llm-extraction - label: - en_US: llm-extraction - - value: llm-extraction-from-raw-html - label: - en_US: llm-extraction-from-raw-html - - value: llm-extraction-from-markdown - label: - en_US: llm-extraction-from-markdown - label: - en_US: Extractor Mode - zh_Hans: 提取模式 - human_description: - en_US: | - The extraction mode to use. 'markdown': Returns the scraped markdown content, does not perform LLM extraction. 'llm-extraction': Extracts information from the cleaned and parsed content using LLM. - zh_Hans: 使用的提取模式。“markdown”:返回抓取的markdown内容,不执行LLM提取。“llm-extractioin”:使用LLM按Extractor Schema从内容中提取信息。 - form: form - - name: extractionPrompt - type: string - label: - en_US: Extractor Prompt - zh_Hans: 提取时的提示词 - human_description: - en_US: A prompt describing what information to extract from the page, applicable for LLM extraction modes. - zh_Hans: 当使用LLM提取模式时,用于给LLM描述提取规则。 - form: form - - name: extractionSchema + - name: schema type: string label: en_US: Extractor Schema zh_Hans: 提取时的结构 placeholder: en_US: Please enter an object that can be serialized in JSON + zh_Hans: 请输入可以json序列化的对象 human_description: en_US: | - The schema for the data to be extracted, required only for LLM extraction modes. Example: { + The schema for the data to be extracted. Example: { "type": "object", "properties": {"company_mission": {"type": "string"}}, "required": ["company_mission"] } zh_Hans: | - 当使用LLM提取模式时,使用该结构去提取,示例:{ + 使用该结构去提取,示例:{ "type": "object", "properties": {"company_mission": {"type": "string"}}, "required": ["company_mission"] } form: form + - name: systemPrompt + type: string + label: + en_US: Extractor System Prompt + zh_Hans: 提取时的系统提示词 + human_description: + en_US: The system prompt to use for the extraction. + zh_Hans: 用于提取的系统提示。 + form: form + - name: prompt + type: string + label: + en_US: Extractor Prompt + zh_Hans: 提取时的提示词 + human_description: + en_US: The prompt to use for the extraction without a schema. + zh_Hans: 用于无schema时提取的提示词 + form: form diff --git a/api/core/tools/provider/builtin/firecrawl/tools/search.py b/api/core/tools/provider/builtin/firecrawl/tools/search.py deleted file mode 100644 index f077e7d8ea..0000000000 --- a/api/core/tools/provider/builtin/firecrawl/tools/search.py +++ /dev/null @@ -1,27 +0,0 @@ -from typing import Any - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.provider.builtin.firecrawl.firecrawl_appx import FirecrawlApp -from core.tools.tool.builtin_tool import BuiltinTool - - -class SearchTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: - """ - the pageOptions and searchOptions comes from doc here: - https://docs.firecrawl.dev/api-reference/endpoint/search - """ - app = FirecrawlApp( - api_key=self.runtime.credentials["firecrawl_api_key"], base_url=self.runtime.credentials["base_url"] - ) - pageOptions = {} - pageOptions["onlyMainContent"] = tool_parameters.get("onlyMainContent", False) - pageOptions["fetchPageContent"] = tool_parameters.get("fetchPageContent", True) - pageOptions["includeHtml"] = tool_parameters.get("includeHtml", False) - pageOptions["includeRawHtml"] = tool_parameters.get("includeRawHtml", False) - searchOptions = {"limit": tool_parameters.get("limit")} - search_result = app.search( - query=tool_parameters["keyword"], pageOptions=pageOptions, searchOptions=searchOptions - ) - - return self.create_json_message(search_result) diff --git a/api/core/tools/provider/builtin/firecrawl/tools/search.yaml b/api/core/tools/provider/builtin/firecrawl/tools/search.yaml deleted file mode 100644 index 29df0cfaaa..0000000000 --- a/api/core/tools/provider/builtin/firecrawl/tools/search.yaml +++ /dev/null @@ -1,75 +0,0 @@ -identity: - name: search - author: ahasasjeb - label: - en_US: Search - zh_Hans: 搜索 -description: - human: - en_US: Search, and output in Markdown format - zh_Hans: 搜索,并且以Markdown格式输出 - llm: This tool can perform online searches and convert the results to Markdown format. -parameters: - - name: keyword - type: string - required: true - label: - en_US: keyword - zh_Hans: 关键词 - human_description: - en_US: Input keywords to use Firecrawl API for search. - zh_Hans: 输入关键词即可使用Firecrawl API进行搜索。 - llm_description: Efficiently extract keywords from user text. - form: llm -############## Page Options ####################### - - name: onlyMainContent - type: boolean - default: false - label: - en_US: only Main Content - zh_Hans: 仅抓取主要内容 - human_description: - en_US: Only return the main content of the page excluding headers, navs, footers, etc. - zh_Hans: 只返回页面的主要内容,不包括头部、导航栏、尾部等。 - form: form - - name: fetchPageContent - type: boolean - default: true - label: - en_US: fetch Page Content - zh_Hans: 抓取页面内容 - human_description: - en_US: Fetch the content of each page. If false, defaults to a basic fast serp API. - zh_Hans: 获取每个页面的内容。如果为否,则使用基本的快速搜索结果页面API。 - form: form - - name: includeHtml - type: boolean - default: false - label: - en_US: include Html - zh_Hans: 包含HTML - human_description: - en_US: Include the HTML version of the content on page. Will output a html key in the response. - zh_Hans: 返回中包含一个HTML版本的内容,将以html键返回。 - form: form - - name: includeRawHtml - type: boolean - default: false - label: - en_US: include Raw Html - zh_Hans: 包含原始HTML - human_description: - en_US: Include the raw HTML content of the page. Will output a rawHtml key in the response. - zh_Hans: 返回中包含一个原始HTML版本的内容,将以rawHtml键返回。 - form: form -############## Search Options ####################### - - name: limit - type: number - min: 0 - label: - en_US: Maximum results - zh_Hans: 最大结果数量 - human_description: - en_US: Maximum number of results. Max is 20 during beta. - zh_Hans: 最大结果数量。在测试阶段,最大为20。 - form: form diff --git a/api/core/tools/provider/builtin/hap/tools/add_worksheet_record.py b/api/core/tools/provider/builtin/hap/tools/add_worksheet_record.py index f2288ed81c..597adc91db 100644 --- a/api/core/tools/provider/builtin/hap/tools/add_worksheet_record.py +++ b/api/core/tools/provider/builtin/hap/tools/add_worksheet_record.py @@ -30,7 +30,7 @@ class AddWorksheetRecordTool(BuiltinTool): elif not host.startswith(("http://", "https://")): return self.create_text_message("Invalid parameter Host Address") else: - host = f"{host[:-1] if host.endswith('/') else host}/api" + host = f"{host.removesuffix('/')}/api" url = f"{host}/v2/open/worksheet/addRow" headers = {"Content-Type": "application/json"} diff --git a/api/core/tools/provider/builtin/hap/tools/delete_worksheet_record.py b/api/core/tools/provider/builtin/hap/tools/delete_worksheet_record.py index 1df5f6d5cf..5d42af4c49 100644 --- a/api/core/tools/provider/builtin/hap/tools/delete_worksheet_record.py +++ b/api/core/tools/provider/builtin/hap/tools/delete_worksheet_record.py @@ -29,7 +29,7 @@ class DeleteWorksheetRecordTool(BuiltinTool): elif not host.startswith(("http://", "https://")): return self.create_text_message("Invalid parameter Host Address") else: - host = f"{host[:-1] if host.endswith('/') else host}/api" + host = f"{host.removesuffix('/')}/api" url = f"{host}/v2/open/worksheet/deleteRow" headers = {"Content-Type": "application/json"} diff --git a/api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py b/api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py index 40e1af043b..6887b8b4e9 100644 --- a/api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py +++ b/api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py @@ -27,7 +27,7 @@ class GetWorksheetFieldsTool(BuiltinTool): elif not host.startswith(("http://", "https://")): return self.create_text_message("Invalid parameter Host Address") else: - host = f"{host[:-1] if host.endswith('/') else host}/api" + host = f"{host.removesuffix('/')}/api" url = f"{host}/v2/open/worksheet/getWorksheetInfo" headers = {"Content-Type": "application/json"} @@ -133,9 +133,9 @@ class GetWorksheetFieldsTool(BuiltinTool): def _extract_options(self, control: dict) -> list: options = [] - if control["type"] in [9, 10, 11]: + if control["type"] in {9, 10, 11}: options.extend([{"key": opt["key"], "value": opt["value"]} for opt in control.get("options", [])]) - elif control["type"] in [28, 36]: + elif control["type"] in {28, 36}: itemnames = control["advancedSetting"].get("itemnames") if itemnames and itemnames.startswith("[{"): try: diff --git a/api/core/tools/provider/builtin/hap/tools/get_worksheet_pivot_data.py b/api/core/tools/provider/builtin/hap/tools/get_worksheet_pivot_data.py index 01c1af9b3e..26d7116869 100644 --- a/api/core/tools/provider/builtin/hap/tools/get_worksheet_pivot_data.py +++ b/api/core/tools/provider/builtin/hap/tools/get_worksheet_pivot_data.py @@ -38,7 +38,7 @@ class GetWorksheetPivotDataTool(BuiltinTool): elif not host.startswith(("http://", "https://")): return self.create_text_message("Invalid parameter Host Address") else: - host = f"{host[:-1] if host.endswith('/') else host}/api" + host = f"{host.removesuffix('/')}/api" url = f"{host}/report/getPivotData" headers = {"Content-Type": "application/json"} diff --git a/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py b/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py index 171895a306..d6ac3688b7 100644 --- a/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py +++ b/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py @@ -30,7 +30,7 @@ class ListWorksheetRecordsTool(BuiltinTool): elif not (host.startswith("http://") or host.startswith("https://")): return self.create_text_message("Invalid parameter Host Address") else: - host = f"{host[:-1] if host.endswith('/') else host}/api" + host = f"{host.removesuffix('/')}/api" url_fields = f"{host}/v2/open/worksheet/getWorksheetInfo" headers = {"Content-Type": "application/json"} @@ -183,11 +183,11 @@ class ListWorksheetRecordsTool(BuiltinTool): type_id = field.get("typeId") if type_id == 10: value = value if isinstance(value, str) else "、".join(value) - elif type_id in [28, 36]: + elif type_id in {28, 36}: value = field.get("options", {}).get(value, value) - elif type_id in [26, 27, 48, 14]: + elif type_id in {26, 27, 48, 14}: value = self.process_value(value) - elif type_id in [35, 29]: + elif type_id in {35, 29}: value = self.parse_cascade_or_associated(field, value) elif type_id == 40: value = self.parse_location(value) diff --git a/api/core/tools/provider/builtin/hap/tools/list_worksheets.py b/api/core/tools/provider/builtin/hap/tools/list_worksheets.py index 4dba2df1f1..4e852c0028 100644 --- a/api/core/tools/provider/builtin/hap/tools/list_worksheets.py +++ b/api/core/tools/provider/builtin/hap/tools/list_worksheets.py @@ -24,7 +24,7 @@ class ListWorksheetsTool(BuiltinTool): elif not (host.startswith("http://") or host.startswith("https://")): return self.create_text_message("Invalid parameter Host Address") else: - host = f"{host[:-1] if host.endswith('/') else host}/api" + host = f"{host.removesuffix('/')}/api" url = f"{host}/v1/open/app/get" result_type = tool_parameters.get("result_type", "") diff --git a/api/core/tools/provider/builtin/hap/tools/update_worksheet_record.py b/api/core/tools/provider/builtin/hap/tools/update_worksheet_record.py index 32abb18f9a..971f3d37f6 100644 --- a/api/core/tools/provider/builtin/hap/tools/update_worksheet_record.py +++ b/api/core/tools/provider/builtin/hap/tools/update_worksheet_record.py @@ -33,7 +33,7 @@ class UpdateWorksheetRecordTool(BuiltinTool): elif not host.startswith(("http://", "https://")): return self.create_text_message("Invalid parameter Host Address") else: - host = f"{host[:-1] if host.endswith('/') else host}/api" + host = f"{host.removesuffix('/')}/api" url = f"{host}/v2/open/worksheet/editRow" headers = {"Content-Type": "application/json"} diff --git a/api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py b/api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py index 9ca14b327c..a200ee8123 100644 --- a/api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py +++ b/api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py @@ -35,7 +35,7 @@ class NovitaAiModelQueryTool(BuiltinTool): models_data=[], headers=headers, params=params, - recursive=not (result_type == "first sd_name" or result_type == "first name sd_name pair"), + recursive=result_type not in {"first sd_name", "first name sd_name pair"}, ) result_str = "" diff --git a/api/core/tools/provider/builtin/searchapi/tools/google.py b/api/core/tools/provider/builtin/searchapi/tools/google.py index 16ae14549d..17e2978194 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/google.py +++ b/api/core/tools/provider/builtin/searchapi/tools/google.py @@ -38,7 +38,7 @@ class SearchAPI: return { "engine": "google", "q": query, - **{key: value for key, value in kwargs.items() if value not in [None, ""]}, + **{key: value for key, value in kwargs.items() if value not in {None, ""}}, } @staticmethod diff --git a/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py b/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py index d29cb0ae3f..c478bc108b 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py +++ b/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py @@ -38,7 +38,7 @@ class SearchAPI: return { "engine": "google_jobs", "q": query, - **{key: value for key, value in kwargs.items() if value not in [None, ""]}, + **{key: value for key, value in kwargs.items() if value not in {None, ""}}, } @staticmethod diff --git a/api/core/tools/provider/builtin/searchapi/tools/google_news.py b/api/core/tools/provider/builtin/searchapi/tools/google_news.py index 8458c8c958..562bc01964 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/google_news.py +++ b/api/core/tools/provider/builtin/searchapi/tools/google_news.py @@ -38,7 +38,7 @@ class SearchAPI: return { "engine": "google_news", "q": query, - **{key: value for key, value in kwargs.items() if value not in [None, ""]}, + **{key: value for key, value in kwargs.items() if value not in {None, ""}}, } @staticmethod diff --git a/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py b/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py index 09725cf8a2..1867cf7be7 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py +++ b/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py @@ -38,7 +38,7 @@ class SearchAPI: "engine": "youtube_transcripts", "video_id": video_id, "lang": language or "en", - **{key: value for key, value in kwargs.items() if value not in [None, ""]}, + **{key: value for key, value in kwargs.items() if value not in {None, ""}}, } @staticmethod diff --git a/api/core/tools/provider/builtin/siliconflow/tools/flux.py b/api/core/tools/provider/builtin/siliconflow/tools/flux.py index 1b846624bd..0d16ff385e 100644 --- a/api/core/tools/provider/builtin/siliconflow/tools/flux.py +++ b/api/core/tools/provider/builtin/siliconflow/tools/flux.py @@ -5,7 +5,10 @@ import requests from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool -FLUX_URL = "https://api.siliconflow.cn/v1/black-forest-labs/FLUX.1-schnell/text-to-image" +FLUX_URL = { + "schnell": "https://api.siliconflow.cn/v1/black-forest-labs/FLUX.1-schnell/text-to-image", + "dev": "https://api.siliconflow.cn/v1/image/generations", +} class FluxTool(BuiltinTool): @@ -24,8 +27,12 @@ class FluxTool(BuiltinTool): "seed": tool_parameters.get("seed"), "num_inference_steps": tool_parameters.get("num_inference_steps", 20), } + model = tool_parameters.get("model", "schnell") + url = FLUX_URL.get(model) + if model == "dev": + payload["model"] = "black-forest-labs/FLUX.1-dev" - response = requests.post(FLUX_URL, json=payload, headers=headers) + response = requests.post(url, json=payload, headers=headers) if response.status_code != 200: return self.create_text_message(f"Got Error Response:{response.text}") diff --git a/api/core/tools/provider/builtin/siliconflow/tools/flux.yaml b/api/core/tools/provider/builtin/siliconflow/tools/flux.yaml index 2a0698700c..d06b9bf3e1 100644 --- a/api/core/tools/provider/builtin/siliconflow/tools/flux.yaml +++ b/api/core/tools/provider/builtin/siliconflow/tools/flux.yaml @@ -6,8 +6,8 @@ identity: icon: icon.svg description: human: - en_US: Generate image via SiliconFlow's flux schnell. - llm: This tool is used to generate image from prompt via SiliconFlow's flux schnell model. + en_US: Generate image via SiliconFlow's flux model. + llm: This tool is used to generate image from prompt via SiliconFlow's flux model. parameters: - name: prompt type: string @@ -17,9 +17,24 @@ parameters: zh_Hans: 提示词 human_description: en_US: The text prompt used to generate the image. - zh_Hans: 用于生成图片的文字提示词 + zh_Hans: 建议用英文的生成图片提示词以获得更好的生成效果。 llm_description: this prompt text will be used to generate image. form: llm + - name: model + type: select + required: true + options: + - value: schnell + label: + en_US: Flux.1-schnell + - value: dev + label: + en_US: Flux.1-dev + default: schnell + label: + en_US: Choose Image Model + zh_Hans: 选择生成图片的模型 + form: form - name: image_size type: select required: true diff --git a/api/core/tools/provider/builtin/spider/spiderApp.py b/api/core/tools/provider/builtin/spider/spiderApp.py index 3972e560c4..4bc446a1a0 100644 --- a/api/core/tools/provider/builtin/spider/spiderApp.py +++ b/api/core/tools/provider/builtin/spider/spiderApp.py @@ -214,7 +214,7 @@ class Spider: return requests.delete(url, headers=headers, stream=stream) def _handle_error(self, response, action): - if response.status_code in [402, 409, 500]: + if response.status_code in {402, 409, 500}: error_message = response.json().get("error", "Unknown error occurred") raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}") else: diff --git a/api/core/tools/provider/builtin/stability/tools/text2image.py b/api/core/tools/provider/builtin/stability/tools/text2image.py index 9f415ceb55..6bcf315484 100644 --- a/api/core/tools/provider/builtin/stability/tools/text2image.py +++ b/api/core/tools/provider/builtin/stability/tools/text2image.py @@ -32,7 +32,7 @@ class StableDiffusionTool(BuiltinTool, BaseStabilityAuthorization): model = tool_parameters.get("model", "core") - if model in ["sd3", "sd3-turbo"]: + if model in {"sd3", "sd3-turbo"}: payload["model"] = tool_parameters.get("model") if model != "sd3-turbo": diff --git a/api/core/tools/provider/builtin/vanna/tools/vanna.py b/api/core/tools/provider/builtin/vanna/tools/vanna.py index a6efb0f79a..c90d766e48 100644 --- a/api/core/tools/provider/builtin/vanna/tools/vanna.py +++ b/api/core/tools/provider/builtin/vanna/tools/vanna.py @@ -38,7 +38,7 @@ class VannaTool(BuiltinTool): vn = VannaDefault(model=model, api_key=api_key) db_type = tool_parameters.get("db_type", "") - if db_type in ["Postgres", "MySQL", "Hive", "ClickHouse"]: + if db_type in {"Postgres", "MySQL", "Hive", "ClickHouse"}: if not db_name: return self.create_text_message("Please input database name") if not username: diff --git a/api/core/tools/provider/builtin_tool_provider.py b/api/core/tools/provider/builtin_tool_provider.py index 6b64dd1b4e..ff022812ef 100644 --- a/api/core/tools/provider/builtin_tool_provider.py +++ b/api/core/tools/provider/builtin_tool_provider.py @@ -19,7 +19,7 @@ from core.tools.utils.yaml_utils import load_yaml_file class BuiltinToolProviderController(ToolProviderController): def __init__(self, **data: Any) -> None: - if self.provider_type == ToolProviderType.API or self.provider_type == ToolProviderType.APP: + if self.provider_type in {ToolProviderType.API, ToolProviderType.APP}: super().__init__(**data) return diff --git a/api/core/tools/provider/tool_provider.py b/api/core/tools/provider/tool_provider.py index f4008eedce..05c88b904e 100644 --- a/api/core/tools/provider/tool_provider.py +++ b/api/core/tools/provider/tool_provider.py @@ -153,10 +153,10 @@ class ToolProviderController(BaseModel, ABC): # check type credential_schema = credentials_need_to_validate[credential_name] - if ( - credential_schema == ToolProviderCredentials.CredentialsType.SECRET_INPUT - or credential_schema == ToolProviderCredentials.CredentialsType.TEXT_INPUT - ): + if credential_schema.type in { + ToolProviderCredentials.CredentialsType.SECRET_INPUT, + ToolProviderCredentials.CredentialsType.TEXT_INPUT, + }: if not isinstance(credentials[credential_name], str): raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string") @@ -184,11 +184,11 @@ class ToolProviderController(BaseModel, ABC): if credential_schema.default is not None: default_value = credential_schema.default # parse default value into the correct type - if ( - credential_schema.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT - or credential_schema.type == ToolProviderCredentials.CredentialsType.TEXT_INPUT - or credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT - ): + if credential_schema.type in { + ToolProviderCredentials.CredentialsType.SECRET_INPUT, + ToolProviderCredentials.CredentialsType.TEXT_INPUT, + ToolProviderCredentials.CredentialsType.SELECT, + }: default_value = str(default_value) credentials[credential_name] = default_value diff --git a/api/core/tools/tool/api_tool.py b/api/core/tools/tool/api_tool.py index bf336b48f3..c779d704c3 100644 --- a/api/core/tools/tool/api_tool.py +++ b/api/core/tools/tool/api_tool.py @@ -5,7 +5,7 @@ from urllib.parse import urlencode import httpx -import core.helper.ssrf_proxy as ssrf_proxy +from core.helper import ssrf_proxy from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError @@ -191,7 +191,7 @@ class ApiTool(Tool): else: body = body - if method in ("get", "head", "post", "put", "delete", "patch"): + if method in {"get", "head", "post", "put", "delete", "patch"}: response = getattr(ssrf_proxy, method)( url, params=params, @@ -224,9 +224,9 @@ class ApiTool(Tool): elif option["type"] == "string": return str(value) elif option["type"] == "boolean": - if str(value).lower() in ["true", "1"]: + if str(value).lower() in {"true", "1"}: return True - elif str(value).lower() in ["false", "0"]: + elif str(value).lower() in {"false", "0"}: return False else: continue # Not a boolean, try next option diff --git a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py index 6073b8e92e..ab7b40a253 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py @@ -165,7 +165,10 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): if dataset.indexing_technique == "economy": # use keyword table query documents = RetrievalService.retrieve( - retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k + retrieval_method="keyword_search", + dataset_id=dataset.id, + query=query, + top_k=retrieval_model.get("top_k") or 2, ) if documents: all_documents.extend(documents) @@ -176,7 +179,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): retrieval_method=retrieval_model["search_method"], dataset_id=dataset.id, query=query, - top_k=self.top_k, + top_k=retrieval_model.get("top_k") or 2, score_threshold=retrieval_model.get("score_threshold", 0.0) if retrieval_model["score_threshold_enabled"] else 0.0, diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 645f0861fa..9912114dd6 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -189,10 +189,7 @@ class ToolEngine: result += response.message elif response.type == ToolInvokeMessage.MessageType.LINK: result += f"result link: {response.message}. please tell user to check it." - elif ( - response.type == ToolInvokeMessage.MessageType.IMAGE_LINK - or response.type == ToolInvokeMessage.MessageType.IMAGE - ): + elif response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}: result += ( "image has been created and sent to user already, you do not need to create it," " just tell the user to check it now." @@ -212,10 +209,7 @@ class ToolEngine: result = [] for response in tool_response: - if ( - response.type == ToolInvokeMessage.MessageType.IMAGE_LINK - or response.type == ToolInvokeMessage.MessageType.IMAGE - ): + if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}: mimetype = None if response.meta.get("mime_type"): mimetype = response.meta.get("mime_type") @@ -297,7 +291,7 @@ class ToolEngine: belongs_to="assistant", url=message.url, upload_file_id=None, - created_by_role=("account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user"), + created_by_role=("account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"), created_by=user_id, ) diff --git a/api/core/tools/utils/feishu_api_utils.py b/api/core/tools/utils/feishu_api_utils.py index 44803d7d65..ffdb06498f 100644 --- a/api/core/tools/utils/feishu_api_utils.py +++ b/api/core/tools/utils/feishu_api_utils.py @@ -76,9 +76,9 @@ class FeishuRequest: url = "https://lark-plugin-api.solutionsuite.cn/lark-plugin/document/write_document" payload = {"document_id": document_id, "content": content, "position": position} res = self._send_request(url, payload=payload) - return res.get("data") + return res - def get_document_raw_content(self, document_id: str) -> dict: + def get_document_content(self, document_id: str, mode: str, lang: int = 0) -> dict: """ API url: https://open.larkoffice.com/document/server-docs/docs/docs/docx-v1/document/raw_content Example Response: @@ -92,16 +92,18 @@ class FeishuRequest: """ # noqa: E501 params = { "document_id": document_id, + "mode": mode, + "lang": lang, } - url = "https://lark-plugin-api.solutionsuite.cn/lark-plugin/document/get_document_raw_content" + url = "https://lark-plugin-api.solutionsuite.cn/lark-plugin/document/get_document_content" res = self._send_request(url, method="get", params=params) return res.get("data").get("content") - def list_document_block(self, document_id: str, page_token: str, page_size: int = 500) -> dict: + def list_document_blocks(self, document_id: str, page_token: str, page_size: int = 500) -> dict: """ API url: https://open.larkoffice.com/document/server-docs/docs/docs/docx-v1/document/list """ - url = "https://lark-plugin-api.solutionsuite.cn/lark-plugin/document/list_document_block" + url = "https://lark-plugin-api.solutionsuite.cn/lark-plugin/document/list_document_blocks" params = { "document_id": document_id, "page_size": page_size, diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index bf040d91d3..3cfab207ba 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -19,7 +19,7 @@ class ToolFileMessageTransformer: result = [] for message in messages: - if message.type == ToolInvokeMessage.MessageType.TEXT or message.type == ToolInvokeMessage.MessageType.LINK: + if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK}: result.append(message) elif message.type == ToolInvokeMessage.MessageType.IMAGE: # try to download image diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index 210b84b29a..9ead4f8e5c 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -165,7 +165,7 @@ class ApiBasedToolSchemaParser: elif "schema" in parameter and "type" in parameter["schema"]: typ = parameter["schema"]["type"] - if typ == "integer" or typ == "number": + if typ in {"integer", "number"}: return ToolParameter.ToolParameterType.NUMBER elif typ == "boolean": return ToolParameter.ToolParameterType.BOOLEAN diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py index e57cae9f16..1ced7d0488 100644 --- a/api/core/tools/utils/web_reader_tool.py +++ b/api/core/tools/utils/web_reader_tool.py @@ -313,7 +313,7 @@ def normalize_whitespace(text): def is_leaf(element): - return element.name in ["p", "li"] + return element.name in {"p", "li"} def is_text(element): diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index 20efe91a59..1175f4af2a 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -405,21 +405,22 @@ class Graph(BaseModel): if condition_edge_mappings: for condition_hash, graph_edges in condition_edge_mappings.items(): - current_parallel = cls._get_current_parallel( - parallel_mapping=parallel_mapping, - graph_edge=graph_edge, - parallel=condition_parallels.get(condition_hash), - parent_parallel=parent_parallel, - ) + for graph_edge in graph_edges: + current_parallel: GraphParallel | None = cls._get_current_parallel( + parallel_mapping=parallel_mapping, + graph_edge=graph_edge, + parallel=condition_parallels.get(condition_hash), + parent_parallel=parent_parallel, + ) - cls._recursively_add_parallels( - edge_mapping=edge_mapping, - reverse_edge_mapping=reverse_edge_mapping, - start_node_id=graph_edge.target_node_id, - parallel_mapping=parallel_mapping, - node_parallel_mapping=node_parallel_mapping, - parent_parallel=current_parallel, - ) + cls._recursively_add_parallels( + edge_mapping=edge_mapping, + reverse_edge_mapping=reverse_edge_mapping, + start_node_id=graph_edge.target_node_id, + parallel_mapping=parallel_mapping, + node_parallel_mapping=node_parallel_mapping, + parent_parallel=current_parallel, + ) else: for graph_edge in target_node_edges: current_parallel = cls._get_current_parallel( @@ -688,23 +689,11 @@ class Graph(BaseModel): parallel_start_node_ids[graph_edge.source_node_id].append(branch_node_id) - parallel_start_node_id = None - for p_start_node_id, branch_node_ids in parallel_start_node_ids.items(): + for _, branch_node_ids in parallel_start_node_ids.items(): if set(branch_node_ids) == set(routes_node_ids.keys()): - parallel_start_node_id = p_start_node_id return True - if not parallel_start_node_id: - raise Exception("Parallel start node id not found") - - for graph_edge in reverse_edge_mapping[start_node_id]: - if ( - graph_edge.source_node_id not in all_routes_node_ids - or graph_edge.source_node_id != parallel_start_node_id - ): - return False - - return True + return False @classmethod def _is_node2_after_node1(cls, node1_id: str, node2_id: str, edge_mapping: dict[str, list[GraphEdge]]) -> bool: diff --git a/api/core/workflow/graph_engine/entities/runtime_route_state.py b/api/core/workflow/graph_engine/entities/runtime_route_state.py index 8fc8047426..bb24b51112 100644 --- a/api/core/workflow/graph_engine/entities/runtime_route_state.py +++ b/api/core/workflow/graph_engine/entities/runtime_route_state.py @@ -51,7 +51,7 @@ class RouteNodeState(BaseModel): :param run_result: run result """ - if self.status in [RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED]: + if self.status in {RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED}: raise Exception(f"Route state {self.id} already finished") if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: diff --git a/api/core/workflow/nodes/answer/answer_stream_generate_router.py b/api/core/workflow/nodes/answer/answer_stream_generate_router.py index e31a1479a8..bbd1f88867 100644 --- a/api/core/workflow/nodes/answer/answer_stream_generate_router.py +++ b/api/core/workflow/nodes/answer/answer_stream_generate_router.py @@ -148,11 +148,12 @@ class AnswerStreamGeneratorRouter: for edge in reverse_edges: source_node_id = edge.source_node_id source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type") - if source_node_type in ( + if source_node_type in { NodeType.ANSWER.value, NodeType.IF_ELSE.value, NodeType.QUESTION_CLASSIFIER.value, - ): + NodeType.ITERATION.value, + }: answer_dependencies[answer_node_id].append(source_node_id) else: cls._recursive_fetch_answer_dependencies( diff --git a/api/core/workflow/nodes/end/end_stream_generate_router.py b/api/core/workflow/nodes/end/end_stream_generate_router.py index a38d982393..9a7d2ecde3 100644 --- a/api/core/workflow/nodes/end/end_stream_generate_router.py +++ b/api/core/workflow/nodes/end/end_stream_generate_router.py @@ -136,10 +136,10 @@ class EndStreamGeneratorRouter: for edge in reverse_edges: source_node_id = edge.source_node_id source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type") - if source_node_type in ( + if source_node_type in { NodeType.IF_ELSE.value, NodeType.QUESTION_CLASSIFIER, - ): + }: end_dependencies[end_node_id].append(source_node_id) else: cls._recursive_fetch_end_dependencies( diff --git a/api/core/workflow/nodes/http_request/http_executor.py b/api/core/workflow/nodes/http_request/http_executor.py index 49102dc3ab..f8ab4e3132 100644 --- a/api/core/workflow/nodes/http_request/http_executor.py +++ b/api/core/workflow/nodes/http_request/http_executor.py @@ -6,8 +6,8 @@ from urllib.parse import urlencode import httpx -import core.helper.ssrf_proxy as ssrf_proxy from configs import dify_config +from core.helper import ssrf_proxy from core.workflow.entities.variable_entities import VariableSelector from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.http_request.entities import ( @@ -176,7 +176,7 @@ class HttpExecutor: elif node_data.body.type == "x-www-form-urlencoded" and not content_type_is_set: self.headers["Content-Type"] = "application/x-www-form-urlencoded" - if node_data.body.type in ["form-data", "x-www-form-urlencoded"]: + if node_data.body.type in {"form-data", "x-www-form-urlencoded"}: body = self._to_dict(body_data) if node_data.body.type == "form-data": @@ -187,7 +187,7 @@ class HttpExecutor: self.headers["Content-Type"] = f"multipart/form-data; boundary={self.boundary}" else: self.body = urlencode(body) - elif node_data.body.type in ["json", "raw-text"]: + elif node_data.body.type in {"json", "raw-text"}: self.body = body_data elif node_data.body.type == "none": self.body = "" @@ -258,7 +258,7 @@ class HttpExecutor: "follow_redirects": True, } - if self.method in ("get", "head", "post", "put", "delete", "patch"): + if self.method in {"get", "head", "post", "put", "delete", "patch"}: response = getattr(ssrf_proxy, self.method)(data=self.body, files=self.files, **kwargs) else: raise ValueError(f"Invalid http method {self.method}") diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 4d944e93db..6f20745daf 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -20,11 +20,9 @@ from core.workflow.graph_engine.entities.event import ( NodeRunSucceededEvent, ) from core.workflow.graph_engine.entities.graph import Graph -from core.workflow.graph_engine.entities.run_condition import RunCondition from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.event import RunCompletedEvent, RunEvent from core.workflow.nodes.iteration.entities import IterationNodeData -from core.workflow.utils.condition.entities import Condition from models.workflow import WorkflowNodeExecutionStatus logger = logging.getLogger(__name__) @@ -68,38 +66,6 @@ class IterationNode(BaseNode): if not iteration_graph: raise ValueError("iteration graph not found") - leaf_node_ids = iteration_graph.get_leaf_node_ids() - iteration_leaf_node_ids = [] - for leaf_node_id in leaf_node_ids: - node_config = iteration_graph.node_id_config_mapping.get(leaf_node_id) - if not node_config: - continue - - leaf_node_iteration_id = node_config.get("data", {}).get("iteration_id") - if not leaf_node_iteration_id: - continue - - if leaf_node_iteration_id != self.node_id: - continue - - iteration_leaf_node_ids.append(leaf_node_id) - - # add condition of end nodes to root node - iteration_graph.add_extra_edge( - source_node_id=leaf_node_id, - target_node_id=root_node_id, - run_condition=RunCondition( - type="condition", - conditions=[ - Condition( - variable_selector=[self.node_id, "index"], - comparison_operator="<", - value=str(len(iterator_list_value)), - ) - ], - ), - ) - variable_pool = self.graph_runtime_state.variable_pool # append iteration variable (item, index) to variable pool @@ -149,91 +115,90 @@ class IterationNode(BaseNode): outputs: list[Any] = [] try: - # run workflow - rst = graph_engine.run() - for event in rst: - if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id: - event.in_iteration_id = self.node_id + for _ in range(len(iterator_list_value)): + # run workflow + rst = graph_engine.run() + for event in rst: + if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id: + event.in_iteration_id = self.node_id - if ( - isinstance(event, BaseNodeEvent) - and event.node_type == NodeType.ITERATION_START - and not isinstance(event, NodeRunStreamChunkEvent) - ): - continue + if ( + isinstance(event, BaseNodeEvent) + and event.node_type == NodeType.ITERATION_START + and not isinstance(event, NodeRunStreamChunkEvent) + ): + continue - if isinstance(event, NodeRunSucceededEvent): - if event.route_node_state.node_run_result: - metadata = event.route_node_state.node_run_result.metadata - if not metadata: - metadata = {} + if isinstance(event, NodeRunSucceededEvent): + if event.route_node_state.node_run_result: + metadata = event.route_node_state.node_run_result.metadata + if not metadata: + metadata = {} - if NodeRunMetadataKey.ITERATION_ID not in metadata: - metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id - metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any( - [self.node_id, "index"] - ) - event.route_node_state.node_run_result.metadata = metadata + if NodeRunMetadataKey.ITERATION_ID not in metadata: + metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id + metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any( + [self.node_id, "index"] + ) + event.route_node_state.node_run_result.metadata = metadata - yield event - - # handle iteration run result - if event.route_node_state.node_id in iteration_leaf_node_ids: - # append to iteration output variable list - current_iteration_output = variable_pool.get_any(self.node_data.output_selector) - outputs.append(current_iteration_output) - - # remove all nodes outputs from variable pool - for node_id in iteration_graph.node_ids: - variable_pool.remove_node(node_id) - - # move to next iteration - current_index = variable_pool.get([self.node_id, "index"]) - if current_index is None: - raise ValueError(f"iteration {self.node_id} current index not found") - - next_index = int(current_index.to_object()) + 1 - variable_pool.add([self.node_id, "index"], next_index) - - if next_index < len(iterator_list_value): - variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) - - yield IterationRunNextEvent( - iteration_id=self.id, - iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, - index=next_index, - pre_iteration_output=jsonable_encoder(current_iteration_output) - if current_iteration_output - else None, - ) - elif isinstance(event, BaseGraphEvent): - if isinstance(event, GraphRunFailedEvent): - # iteration run failed - yield IterationRunFailedEvent( - iteration_id=self.id, - iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, - start_at=start_at, - inputs=inputs, - outputs={"output": jsonable_encoder(outputs)}, - steps=len(iterator_list_value), - metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, - error=event.error, - ) - - yield RunCompletedEvent( - run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, + yield event + elif isinstance(event, BaseGraphEvent): + if isinstance(event, GraphRunFailedEvent): + # iteration run failed + yield IterationRunFailedEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + start_at=start_at, + inputs=inputs, + outputs={"output": jsonable_encoder(outputs)}, + steps=len(iterator_list_value), + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, error=event.error, ) - ) - break - else: - event = cast(InNodeEvent, event) - yield event + + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=event.error, + ) + ) + return + else: + event = cast(InNodeEvent, event) + yield event + + # append to iteration output variable list + current_iteration_output = variable_pool.get_any(self.node_data.output_selector) + outputs.append(current_iteration_output) + + # remove all nodes outputs from variable pool + for node_id in iteration_graph.node_ids: + variable_pool.remove_node(node_id) + + # move to next iteration + current_index = variable_pool.get([self.node_id, "index"]) + if current_index is None: + raise ValueError(f"iteration {self.node_id} current index not found") + + next_index = int(current_index.to_object()) + 1 + variable_pool.add([self.node_id, "index"], next_index) + + if next_index < len(iterator_list_value): + variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) + + yield IterationRunNextEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + index=next_index, + pre_iteration_output=jsonable_encoder(current_iteration_output) + if current_iteration_output + else None, + ) yield IterationRunSucceededEvent( iteration_id=self.id, diff --git a/api/core/workflow/nodes/parameter_extractor/entities.py b/api/core/workflow/nodes/parameter_extractor/entities.py index 802ed31e27..5697d7c049 100644 --- a/api/core/workflow/nodes/parameter_extractor/entities.py +++ b/api/core/workflow/nodes/parameter_extractor/entities.py @@ -33,7 +33,7 @@ class ParameterConfig(BaseModel): def validate_name(cls, value) -> str: if not value: raise ValueError("Parameter name is required") - if value in ["__reason", "__is_success"]: + if value in {"__reason", "__is_success"}: raise ValueError("Invalid parameter name, __reason and __is_success are reserved") return value @@ -66,7 +66,7 @@ class ParameterExtractorNodeData(BaseNodeData): for parameter in self.parameters: parameter_schema = {"description": parameter.description} - if parameter.type in ["string", "select"]: + if parameter.type in {"string", "select"}: parameter_schema["type"] = "string" elif parameter.type.startswith("array"): parameter_schema["type"] = "array" diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 131d26b19e..a6454bd1cd 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -467,7 +467,7 @@ class ParameterExtractorNode(LLMNode): # transformed_result[parameter.name] = bool(result[parameter.name].lower() == 'true') # elif isinstance(result[parameter.name], int): # transformed_result[parameter.name] = bool(result[parameter.name]) - elif parameter.type in ["string", "select"]: + elif parameter.type in {"string", "select"}: if isinstance(result[parameter.name], str): transformed_result[parameter.name] = result[parameter.name] elif parameter.type.startswith("array"): @@ -498,7 +498,7 @@ class ParameterExtractorNode(LLMNode): transformed_result[parameter.name] = 0 elif parameter.type == "bool": transformed_result[parameter.name] = False - elif parameter.type in ["string", "select"]: + elif parameter.type in {"string", "select"}: transformed_result[parameter.name] = "" elif parameter.type.startswith("array"): transformed_result[parameter.name] = [] @@ -516,9 +516,9 @@ class ParameterExtractorNode(LLMNode): """ stack = [] for i, c in enumerate(text): - if c == "{" or c == "[": + if c in {"{", "["}: stack.append(c) - elif c == "}" or c == "]": + elif c in {"}", "]"}: # check if stack is empty if not stack: return text[:i] @@ -560,7 +560,7 @@ class ParameterExtractorNode(LLMNode): result[parameter.name] = 0 elif parameter.type == "bool": result[parameter.name] = False - elif parameter.type in ["string", "select"]: + elif parameter.type in {"string", "select"}: result[parameter.name] = "" return result diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index e55adfc1f4..3b86b29cf8 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -163,10 +163,7 @@ class ToolNode(BaseNode): result = [] for response in tool_response: - if ( - response.type == ToolInvokeMessage.MessageType.IMAGE_LINK - or response.type == ToolInvokeMessage.MessageType.IMAGE - ): + if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}: url = response.message ext = path.splitext(url)[1] mimetype = response.meta.get("mime_type", "image/jpeg") diff --git a/api/extensions/ext_sentry.py b/api/extensions/ext_sentry.py index 86e4720b3f..c2dc736038 100644 --- a/api/extensions/ext_sentry.py +++ b/api/extensions/ext_sentry.py @@ -6,6 +6,15 @@ from sentry_sdk.integrations.flask import FlaskIntegration from werkzeug.exceptions import HTTPException +def before_send(event, hint): + if "exc_info" in hint: + exc_type, exc_value, tb = hint["exc_info"] + if parse_error.defaultErrorResponse in str(exc_value): + return None + + return event + + def init_app(app): if app.config.get("SENTRY_DSN"): sentry_sdk.init( @@ -16,4 +25,5 @@ def init_app(app): profiles_sample_rate=app.config.get("SENTRY_PROFILES_SAMPLE_RATE", 1.0), environment=app.config.get("DEPLOY_ENV"), release=f"dify-{app.config.get('CURRENT_VERSION')}-{app.config.get('COMMIT_SHA')}", + before_send=before_send, ) diff --git a/api/extensions/storage/aliyun_storage.py b/api/extensions/storage/aliyun_storage.py index bee237fc17..2677912aa9 100644 --- a/api/extensions/storage/aliyun_storage.py +++ b/api/extensions/storage/aliyun_storage.py @@ -31,54 +31,34 @@ class AliyunStorage(BaseStorage): ) def save(self, filename, data): - if not self.folder or self.folder.endswith("/"): - filename = self.folder + filename - else: - filename = self.folder + "/" + filename - self.client.put_object(filename, data) + self.client.put_object(self.__wrapper_folder_filename(filename), data) def load_once(self, filename: str) -> bytes: - if not self.folder or self.folder.endswith("/"): - filename = self.folder + filename - else: - filename = self.folder + "/" + filename - - with closing(self.client.get_object(filename)) as obj: + with closing(self.client.get_object(self.__wrapper_folder_filename(filename))) as obj: data = obj.read() return data def load_stream(self, filename: str) -> Generator: def generate(filename: str = filename) -> Generator: - if not self.folder or self.folder.endswith("/"): - filename = self.folder + filename - else: - filename = self.folder + "/" + filename - - with closing(self.client.get_object(filename)) as obj: + with closing(self.client.get_object(self.__wrapper_folder_filename(filename))) as obj: while chunk := obj.read(4096): yield chunk return generate() def download(self, filename, target_filepath): - if not self.folder or self.folder.endswith("/"): - filename = self.folder + filename - else: - filename = self.folder + "/" + filename - - self.client.get_object_to_file(filename, target_filepath) + self.client.get_object_to_file(self.__wrapper_folder_filename(filename), target_filepath) def exists(self, filename): - if not self.folder or self.folder.endswith("/"): - filename = self.folder + filename - else: - filename = self.folder + "/" + filename - - return self.client.object_exists(filename) + return self.client.object_exists(self.__wrapper_folder_filename(filename)) def delete(self, filename): - if not self.folder or self.folder.endswith("/"): - filename = self.folder + filename - else: - filename = self.folder + "/" + filename - self.client.delete_object(filename) + self.client.delete_object(self.__wrapper_folder_filename(filename)) + + def __wrapper_folder_filename(self, filename) -> str: + if self.folder: + if self.folder.endswith("/"): + filename = self.folder + filename + else: + filename = self.folder + "/" + filename + return filename diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py index 6da1a6d39b..05a73b09b7 100644 --- a/api/libs/oauth_data_source.py +++ b/api/libs/oauth_data_source.py @@ -158,7 +158,7 @@ class NotionOAuth(OAuthDataSource): page_icon = page_result["icon"] if page_icon: icon_type = page_icon["type"] - if icon_type == "external" or icon_type == "file": + if icon_type in {"external", "file"}: url = page_icon[icon_type]["url"] icon = {"type": "url", "url": url if url.startswith("http") else f"https://www.notion.so{url}"} else: @@ -191,7 +191,7 @@ class NotionOAuth(OAuthDataSource): page_icon = database_result["icon"] if page_icon: icon_type = page_icon["type"] - if icon_type == "external" or icon_type == "file": + if icon_type in {"external", "file"}: url = page_icon[icon_type]["url"] icon = {"type": "url", "url": url if url.startswith("http") else f"https://www.notion.so{url}"} else: diff --git a/api/libs/rsa.py b/api/libs/rsa.py index a578bf3e56..637bcc4a1d 100644 --- a/api/libs/rsa.py +++ b/api/libs/rsa.py @@ -4,9 +4,9 @@ from Crypto.Cipher import AES from Crypto.PublicKey import RSA from Crypto.Random import get_random_bytes -import libs.gmpy2_pkcs10aep_cipher as gmpy2_pkcs10aep_cipher from extensions.ext_redis import redis_client from extensions.ext_storage import storage +from libs import gmpy2_pkcs10aep_cipher def generate_key_pair(tenant_id): diff --git a/api/models/dataset.py b/api/models/dataset.py index 0da35910cd..a2d2a3454d 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -284,9 +284,9 @@ class Document(db.Model): status = None if self.indexing_status == "waiting": status = "queuing" - elif self.indexing_status not in ["completed", "error", "waiting"] and self.is_paused: + elif self.indexing_status not in {"completed", "error", "waiting"} and self.is_paused: status = "paused" - elif self.indexing_status in ["parsing", "cleaning", "splitting", "indexing"]: + elif self.indexing_status in {"parsing", "cleaning", "splitting", "indexing"}: status = "indexing" elif self.indexing_status == "error": status = "error" @@ -331,7 +331,7 @@ class Document(db.Model): "created_at": file_detail.created_at.timestamp(), } } - elif self.data_source_type == "notion_import" or self.data_source_type == "website_crawl": + elif self.data_source_type in {"notion_import", "website_crawl"}: return json.loads(self.data_source_info) return {} diff --git a/api/models/model.py b/api/models/model.py index a8b2e00ee4..ae0bc3210b 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -134,7 +134,7 @@ class App(db.Model): return False if self.app_model_config.agent_mode_dict.get("enabled", False) and self.app_model_config.agent_mode_dict.get( "strategy", "" - ) in ["function_call", "react"]: + ) in {"function_call", "react"}: self.mode = AppMode.AGENT_CHAT.value db.session.commit() return True @@ -1501,6 +1501,6 @@ class TraceAppConfig(db.Model): "tracing_provider": self.tracing_provider, "tracing_config": self.tracing_config_dict, "is_active": self.is_active, - "created_at": self.created_at.__str__() if self.created_at else None, - "updated_at": self.updated_at.__str__() if self.updated_at else None, + "created_at": str(self.created_at) if self.created_at else None, + "updated_at": str(self.updated_at) if self.updated_at else None, } diff --git a/api/poetry.lock b/api/poetry.lock index 36b52f68be..191db600e4 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -8003,29 +8003,29 @@ pyasn1 = ">=0.1.3" [[package]] name = "ruff" -version = "0.6.4" +version = "0.6.5" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.6.4-py3-none-linux_armv6l.whl", hash = "sha256:c4b153fc152af51855458e79e835fb6b933032921756cec9af7d0ba2aa01a258"}, - {file = "ruff-0.6.4-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:bedff9e4f004dad5f7f76a9d39c4ca98af526c9b1695068198b3bda8c085ef60"}, - {file = "ruff-0.6.4-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d02a4127a86de23002e694d7ff19f905c51e338c72d8e09b56bfb60e1681724f"}, - {file = "ruff-0.6.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7862f42fc1a4aca1ea3ffe8a11f67819d183a5693b228f0bb3a531f5e40336fc"}, - {file = "ruff-0.6.4-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eebe4ff1967c838a1a9618a5a59a3b0a00406f8d7eefee97c70411fefc353617"}, - {file = "ruff-0.6.4-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:932063a03bac394866683e15710c25b8690ccdca1cf192b9a98260332ca93408"}, - {file = "ruff-0.6.4-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:50e30b437cebef547bd5c3edf9ce81343e5dd7c737cb36ccb4fe83573f3d392e"}, - {file = "ruff-0.6.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c44536df7b93a587de690e124b89bd47306fddd59398a0fb12afd6133c7b3818"}, - {file = "ruff-0.6.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0ea086601b22dc5e7693a78f3fcfc460cceabfdf3bdc36dc898792aba48fbad6"}, - {file = "ruff-0.6.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b52387d3289ccd227b62102c24714ed75fbba0b16ecc69a923a37e3b5e0aaaa"}, - {file = "ruff-0.6.4-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:0308610470fcc82969082fc83c76c0d362f562e2f0cdab0586516f03a4e06ec6"}, - {file = "ruff-0.6.4-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:803b96dea21795a6c9d5bfa9e96127cc9c31a1987802ca68f35e5c95aed3fc0d"}, - {file = "ruff-0.6.4-py3-none-musllinux_1_2_i686.whl", hash = "sha256:66dbfea86b663baab8fcae56c59f190caba9398df1488164e2df53e216248baa"}, - {file = "ruff-0.6.4-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:34d5efad480193c046c86608dbba2bccdc1c5fd11950fb271f8086e0c763a5d1"}, - {file = "ruff-0.6.4-py3-none-win32.whl", hash = "sha256:f0f8968feea5ce3777c0d8365653d5e91c40c31a81d95824ba61d871a11b8523"}, - {file = "ruff-0.6.4-py3-none-win_amd64.whl", hash = "sha256:549daccee5227282289390b0222d0fbee0275d1db6d514550d65420053021a58"}, - {file = "ruff-0.6.4-py3-none-win_arm64.whl", hash = "sha256:ac4b75e898ed189b3708c9ab3fc70b79a433219e1e87193b4f2b77251d058d14"}, - {file = "ruff-0.6.4.tar.gz", hash = "sha256:ac3b5bfbee99973f80aa1b7cbd1c9cbce200883bdd067300c22a6cc1c7fba212"}, + {file = "ruff-0.6.5-py3-none-linux_armv6l.whl", hash = "sha256:7e4e308f16e07c95fc7753fc1aaac690a323b2bb9f4ec5e844a97bb7fbebd748"}, + {file = "ruff-0.6.5-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:932cd69eefe4daf8c7d92bd6689f7e8182571cb934ea720af218929da7bd7d69"}, + {file = "ruff-0.6.5-py3-none-macosx_11_0_arm64.whl", hash = "sha256:3a8d42d11fff8d3143ff4da41742a98f8f233bf8890e9fe23077826818f8d680"}, + {file = "ruff-0.6.5-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a50af6e828ee692fb10ff2dfe53f05caecf077f4210fae9677e06a808275754f"}, + {file = "ruff-0.6.5-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:794ada3400a0d0b89e3015f1a7e01f4c97320ac665b7bc3ade24b50b54cb2972"}, + {file = "ruff-0.6.5-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:381413ec47f71ce1d1c614f7779d88886f406f1fd53d289c77e4e533dc6ea200"}, + {file = "ruff-0.6.5-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:52e75a82bbc9b42e63c08d22ad0ac525117e72aee9729a069d7c4f235fc4d276"}, + {file = "ruff-0.6.5-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:09c72a833fd3551135ceddcba5ebdb68ff89225d30758027280968c9acdc7810"}, + {file = "ruff-0.6.5-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:800c50371bdcb99b3c1551d5691e14d16d6f07063a518770254227f7f6e8c178"}, + {file = "ruff-0.6.5-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e25ddd9cd63ba1f3bd51c1f09903904a6adf8429df34f17d728a8fa11174253"}, + {file = "ruff-0.6.5-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:7291e64d7129f24d1b0c947ec3ec4c0076e958d1475c61202497c6aced35dd19"}, + {file = "ruff-0.6.5-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:9ad7dfbd138d09d9a7e6931e6a7e797651ce29becd688be8a0d4d5f8177b4b0c"}, + {file = "ruff-0.6.5-py3-none-musllinux_1_2_i686.whl", hash = "sha256:005256d977021790cc52aa23d78f06bb5090dc0bfbd42de46d49c201533982ae"}, + {file = "ruff-0.6.5-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:482c1e6bfeb615eafc5899127b805d28e387bd87db38b2c0c41d271f5e58d8cc"}, + {file = "ruff-0.6.5-py3-none-win32.whl", hash = "sha256:cf4d3fa53644137f6a4a27a2b397381d16454a1566ae5335855c187fbf67e4f5"}, + {file = "ruff-0.6.5-py3-none-win_amd64.whl", hash = "sha256:3e42a57b58e3612051a636bc1ac4e6b838679530235520e8f095f7c44f706ff9"}, + {file = "ruff-0.6.5-py3-none-win_arm64.whl", hash = "sha256:51935067740773afdf97493ba9b8231279e9beef0f2a8079188c4776c25688e0"}, + {file = "ruff-0.6.5.tar.gz", hash = "sha256:4d32d87fab433c0cf285c3683dd4dae63be05fd7a1d65b3f5bf7cdd05a6b96fb"}, ] [[package]] @@ -10416,4 +10416,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "726af69ca5a577808dfe76dbce098de77ce358bf64862a4d27309cb1900cea0c" +content-hash = "9173a56b2efea12804c980511e1465fba43c7a3d83b1ad284ee149851ed67fc5" diff --git a/api/pyproject.toml b/api/pyproject.toml index 83aa35c542..8c10f1dad9 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -6,6 +6,9 @@ requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" [tool.ruff] +exclude=[ + "migrations/*", +] line-length = 120 [tool.ruff.lint] @@ -19,6 +22,13 @@ select = [ "I", # isort rules "N", # pep8-naming "PT", # flake8-pytest-style rules + "PLC0208", # iteration-over-set + "PLC2801", # unnecessary-dunder-call + "PLC0414", # useless-import-alias + "PLR0402", # manual-from-import + "PLR1711", # useless-return + "PLR1714", # repeated-equality-comparison + "PLR6201", # literal-membership "RUF019", # unnecessary-key-check "RUF100", # unused-noqa "RUF101", # redirected-noqa @@ -78,9 +88,6 @@ ignore = [ "libs/gmpy2_pkcs10aep_cipher.py" = [ "N803", # invalid-argument-name ] -"migrations/versions/*" = [ - "E501", # line-too-long -] "tests/*" = [ "F401", # unused-import "F811", # redefined-while-unused @@ -88,7 +95,6 @@ ignore = [ [tool.ruff.format] exclude = [ - "migrations/**/*", ] [tool.pytest_env] @@ -203,7 +209,7 @@ zhipuai = "1.0.7" # Before adding new dependency, consider place it in alphabet order (a-z) and suitable group. ############################################################ -# Related transparent dependencies with pinned verion +# Related transparent dependencies with pinned version # required by main implementations ############################################################ azure-ai-ml = "^1.19.0" @@ -277,4 +283,4 @@ optional = true [tool.poetry.group.lint.dependencies] dotenv-linter = "~0.5.0" -ruff = "~0.6.4" +ruff = "~0.6.5" diff --git a/api/services/account_service.py b/api/services/account_service.py index e839ae54ba..66ff5d2b7c 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -47,7 +47,7 @@ class AccountService: if not account: return None - if account.status in [AccountStatus.BANNED.value, AccountStatus.CLOSED.value]: + if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}: raise Unauthorized("Account is banned or closed.") current_tenant: TenantAccountJoin = TenantAccountJoin.query.filter_by( @@ -92,7 +92,7 @@ class AccountService: if not account: raise AccountLoginError("Invalid email or password.") - if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value: + if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}: raise AccountLoginError("Account is banned or closed.") if account.status == AccountStatus.PENDING.value: @@ -427,7 +427,7 @@ class TenantService: "remove": [TenantAccountRole.OWNER], "update": [TenantAccountRole.OWNER], } - if action not in ["add", "remove", "update"]: + if action not in {"add", "remove", "update"}: raise InvalidActionError("Invalid action.") if member: diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 2fe39b5224..54594e1175 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -90,7 +90,7 @@ class AppDslService: # import dsl and create app app_mode = AppMode.value_of(app_data.get("mode")) - if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: + if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: app = cls._import_and_create_new_workflow_based_app( tenant_id=tenant_id, app_mode=app_mode, @@ -103,7 +103,7 @@ class AppDslService: icon_background=icon_background, use_icon_as_answer_icon=use_icon_as_answer_icon, ) - elif app_mode in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION]: + elif app_mode in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION}: app = cls._import_and_create_new_model_config_based_app( tenant_id=tenant_id, app_mode=app_mode, @@ -143,7 +143,7 @@ class AppDslService: # import dsl and overwrite app app_mode = AppMode.value_of(app_data.get("mode")) - if app_mode not in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: + if app_mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: raise ValueError("Only support import workflow in advanced-chat or workflow app.") if app_data.get("mode") != app_model.mode: @@ -177,7 +177,7 @@ class AppDslService: }, } - if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: + if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: cls._append_workflow_export_data( export_data=export_data, app_model=app_model, include_secret=include_secret ) diff --git a/api/services/app_service.py b/api/services/app_service.py index 1dacfea246..ac45d623e8 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -316,7 +316,7 @@ class AppService: meta = {"tool_icons": {}} - if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: + if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: workflow = app_model.workflow if workflow is None: return meta diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 05cd1c96a1..7a0cd5725b 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) class AudioService: @classmethod def transcript_asr(cls, app_model: App, file: FileStorage, end_user: Optional[str] = None): - if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: workflow = app_model.workflow if workflow is None: raise ValueError("Speech to text is not enabled") @@ -83,7 +83,7 @@ class AudioService: def invoke_tts(text_content: str, app_model, voice: Optional[str] = None): with app.app_context(): - if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: workflow = app_model.workflow if workflow is None: raise ValueError("TTS is not enabled") diff --git a/api/services/auth/firecrawl.py b/api/services/auth/firecrawl.py index 30e4ee57c0..afc491398f 100644 --- a/api/services/auth/firecrawl.py +++ b/api/services/auth/firecrawl.py @@ -37,7 +37,7 @@ class FirecrawlAuth(ApiKeyAuthBase): return requests.post(url, headers=headers, json=data) def _handle_error(self, response): - if response.status_code in [402, 409, 500]: + if response.status_code in {402, 409, 500}: error_message = response.json().get("error", "Unknown error occurred") raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}") else: diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index fa017bfa42..30c010ef29 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -544,7 +544,7 @@ class DocumentService: @staticmethod def pause_document(document): - if document.indexing_status not in ["waiting", "parsing", "cleaning", "splitting", "indexing"]: + if document.indexing_status not in {"waiting", "parsing", "cleaning", "splitting", "indexing"}: raise DocumentIndexingError() # update document to be paused document.is_paused = True diff --git a/api/services/ops_service.py b/api/services/ops_service.py index d8e2b1689a..1160a1f275 100644 --- a/api/services/ops_service.py +++ b/api/services/ops_service.py @@ -31,16 +31,28 @@ class OpsService: if tracing_provider == "langfuse" and ( "project_key" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_key") ): - project_key = OpsTraceManager.get_trace_config_project_key(decrypt_tracing_config, tracing_provider) - new_decrypt_tracing_config.update( - {"project_url": "{host}/project/{key}".format(host=decrypt_tracing_config.get("host"), key=project_key)} - ) + try: + project_key = OpsTraceManager.get_trace_config_project_key(decrypt_tracing_config, tracing_provider) + new_decrypt_tracing_config.update( + { + "project_url": "{host}/project/{key}".format( + host=decrypt_tracing_config.get("host"), key=project_key + ) + } + ) + except Exception: + new_decrypt_tracing_config.update( + {"project_url": "{host}/".format(host=decrypt_tracing_config.get("host"))} + ) if tracing_provider == "langsmith" and ( "project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url") ): - project_url = OpsTraceManager.get_trace_config_project_url(decrypt_tracing_config, tracing_provider) - new_decrypt_tracing_config.update({"project_url": project_url}) + try: + project_url = OpsTraceManager.get_trace_config_project_url(decrypt_tracing_config, tracing_provider) + new_decrypt_tracing_config.update({"project_url": project_url}) + except Exception: + new_decrypt_tracing_config.update({"project_url": "https://smith.langchain.com/"}) trace_config_data.tracing_config = new_decrypt_tracing_config return trace_config_data.to_dict() diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 6fb0f2f517..7ae1b9f231 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -33,7 +33,7 @@ class ToolTransformService: if provider_type == ToolProviderType.BUILT_IN.value: return url_prefix + "builtin/" + provider_name + "/icon" - elif provider_type in [ToolProviderType.API.value, ToolProviderType.WORKFLOW.value]: + elif provider_type in {ToolProviderType.API.value, ToolProviderType.WORKFLOW.value}: try: return json.loads(icon) except: diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 357ffd41c1..0ff81f1f7e 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -295,7 +295,7 @@ class WorkflowService: # chatbot convert to workflow mode workflow_converter = WorkflowConverter() - if app_model.mode not in [AppMode.CHAT.value, AppMode.COMPLETION.value]: + if app_model.mode not in {AppMode.CHAT.value, AppMode.COMPLETION.value}: raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.") # convert to workflow diff --git a/api/tasks/recover_document_indexing_task.py b/api/tasks/recover_document_indexing_task.py index 21ea11d4dd..934eb7430c 100644 --- a/api/tasks/recover_document_indexing_task.py +++ b/api/tasks/recover_document_indexing_task.py @@ -29,7 +29,7 @@ def recover_document_indexing_task(dataset_id: str, document_id: str): try: indexing_runner = IndexingRunner() - if document.indexing_status in ["waiting", "parsing", "cleaning"]: + if document.indexing_status in {"waiting", "parsing", "cleaning"}: indexing_runner.run([document]) elif document.indexing_status == "splitting": indexing_runner.run_in_splitting_status(document) diff --git a/api/tests/integration_tests/model_runtime/__mock/google.py b/api/tests/integration_tests/model_runtime/__mock/google.py index bc0684086f..402bd9c2c2 100644 --- a/api/tests/integration_tests/model_runtime/__mock/google.py +++ b/api/tests/integration_tests/model_runtime/__mock/google.py @@ -1,15 +1,13 @@ from collections.abc import Generator -import google.generativeai.types.content_types as content_types import google.generativeai.types.generation_types as generation_config_types -import google.generativeai.types.safety_types as safety_types import pytest from _pytest.monkeypatch import MonkeyPatch from google.ai import generativelanguage as glm from google.ai.generativelanguage_v1beta.types import content as gag_content from google.generativeai import GenerativeModel from google.generativeai.client import _ClientManager, configure -from google.generativeai.types import GenerateContentResponse +from google.generativeai.types import GenerateContentResponse, content_types, safety_types from google.generativeai.types.generation_types import BaseGenerateContentResponse current_api_key = "" diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_chat.py b/api/tests/integration_tests/model_runtime/__mock/openai_chat.py index d9cd7b046e..439f7d56e9 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_chat.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_chat.py @@ -6,7 +6,6 @@ from time import time # import monkeypatch from typing import Any, Literal, Optional, Union -import openai.types.chat.completion_create_params as completion_create_params from openai import AzureOpenAI, OpenAI from openai._types import NOT_GIVEN, NotGiven from openai.resources.chat.completions import Completions @@ -18,6 +17,7 @@ from openai.types.chat import ( ChatCompletionMessageToolCall, ChatCompletionToolChoiceOptionParam, ChatCompletionToolParam, + completion_create_params, ) from openai.types.chat.chat_completion import ChatCompletion as _ChatCompletion from openai.types.chat.chat_completion import Choice as _ChatCompletionChoice @@ -254,7 +254,7 @@ class MockChatClass: "gpt-3.5-turbo-16k-0613", ] azure_openai_models = ["gpt35", "gpt-4v", "gpt-35-turbo"] - if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()): + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", str(self._client.base_url)): raise InvokeAuthorizationError("Invalid base url") if model in openai_models + azure_openai_models: if not re.match(r"sk-[a-zA-Z0-9]{24,}$", self._client.api_key) and type(self._client) == OpenAI: diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_completion.py b/api/tests/integration_tests/model_runtime/__mock/openai_completion.py index c27e89248f..14223668e0 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_completion.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_completion.py @@ -112,7 +112,7 @@ class MockCompletionsClass: ] azure_openai_models = ["gpt-35-turbo-instruct"] - if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()): + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", str(self._client.base_url)): raise InvokeAuthorizationError("Invalid base url") if model in openai_models + azure_openai_models: if not re.match(r"sk-[a-zA-Z0-9]{24,}$", self._client.api_key) and type(self._client) == OpenAI: diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py b/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py index 025913cb17..e27b9891f5 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py @@ -22,7 +22,7 @@ class MockEmbeddingsClass: if isinstance(input, str): input = [input] - if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()): + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", str(self._client.base_url)): raise InvokeAuthorizationError("Invalid base url") if len(self._client.api_key) < 18: diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py b/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py index 270a88e85f..4262d40f3e 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py @@ -20,7 +20,7 @@ class MockModerationClass: if isinstance(input, str): input = [input] - if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()): + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", str(self._client.base_url)): raise InvokeAuthorizationError("Invalid base url") if len(self._client.api_key) < 18: diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py b/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py index ef361e8613..a51dcab4be 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py @@ -20,7 +20,7 @@ class MockSpeech2TextClass: temperature: float | NotGiven = NOT_GIVEN, **kwargs: Any, ) -> Transcription: - if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()): + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", str(self._client.base_url)): raise InvokeAuthorizationError("Invalid base url") if len(self._client.api_key) < 18: diff --git a/api/tests/integration_tests/model_runtime/__mock/xinference.py b/api/tests/integration_tests/model_runtime/__mock/xinference.py index 777737187e..299523f4f5 100644 --- a/api/tests/integration_tests/model_runtime/__mock/xinference.py +++ b/api/tests/integration_tests/model_runtime/__mock/xinference.py @@ -42,7 +42,7 @@ class MockXinferenceClass: model_uid = url.split("/")[-1] or "" if not re.match( r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", model_uid - ) and model_uid not in ["generate", "chat", "embedding", "rerank"]: + ) and model_uid not in {"generate", "chat", "embedding", "rerank"}: response.status_code = 404 response._content = b"{}" return response @@ -53,7 +53,7 @@ class MockXinferenceClass: response._content = b"{}" return response - if model_uid in ["generate", "chat"]: + if model_uid in {"generate", "chat"}: response.status_code = 200 response._content = b"""{ "model_type": "LLM", diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py index cbe4a5d335..25af312afa 100644 --- a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py +++ b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py @@ -1,4 +1,3 @@ -import json from textwrap import dedent from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index cbe9c5914f..88435c4022 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -411,5 +411,5 @@ def test_chat_parameter_extractor_with_memory(setup_anthropic_mock): if latest_role is not None: assert latest_role != prompt.get("role") - if prompt.get("role") in ["user", "assistant"]: + if prompt.get("role") in {"user", "assistant"}: latest_role = prompt.get("role") diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index a2d71d61fc..197288adba 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -210,7 +210,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove): assert not isinstance(item, NodeRunFailedEvent) assert not isinstance(item, GraphRunFailedEvent) - if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in ["llm2", "llm3", "end1", "end2"]: + if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in {"llm2", "llm3", "end1", "end2"}: assert item.parallel_id is not None assert len(items) == 18 @@ -315,12 +315,12 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove): assert not isinstance(item, NodeRunFailedEvent) assert not isinstance(item, GraphRunFailedEvent) - if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in [ + if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in { "answer2", "answer3", "answer4", "answer5", - ]: + }: assert item.parallel_id is not None assert len(items) == 23 diff --git a/docker/.env.example b/docker/.env.example index 7e4430a37d..c892c15636 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -568,6 +568,13 @@ SSRF_PROXY_HTTP_URL=http://ssrf_proxy:3128 # SSRF Proxy server HTTPS URL SSRF_PROXY_HTTPS_URL=http://ssrf_proxy:3128 +# ------------------------------ +# Environment Variables for web Service +# ------------------------------ + +# The timeout for the text generation in millisecond +TEXT_GENERATION_TIMEOUT_MS=60000 + # ------------------------------ # Environment Variables for db Service # ------------------------------ diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index dbfc1ea531..251c62fee1 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -41,7 +41,7 @@ services: # The DifySandbox sandbox: - image: langgenius/dify-sandbox:0.2.7 + image: langgenius/dify-sandbox:0.2.9 restart: always environment: # The DifySandbox configurations @@ -89,6 +89,7 @@ services: weaviate: image: semitechnologies/weaviate:1.19.0 profiles: + - "" - weaviate restart: always volumes: diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index b8e068fde0..0fbc695177 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -75,7 +75,7 @@ x-shared-env: &shared-api-worker-env ALIYUN_OSS_ENDPOINT: ${ALIYUN_OSS_ENDPOINT:-} ALIYUN_OSS_REGION: ${ALIYUN_OSS_REGION:-} ALIYUN_OSS_AUTH_VERSION: ${ALIYUN_OSS_AUTH_VERSION:-v4} - ALIYUN_OSS_PATHS: ${ALIYUN_OSS_PATH:-} + ALIYUN_OSS_PATH: ${ALIYUN_OSS_PATH:-} TENCENT_COS_BUCKET_NAME: ${TENCENT_COS_BUCKET_NAME:-} TENCENT_COS_SECRET_KEY: ${TENCENT_COS_SECRET_KEY:-} TENCENT_COS_SECRET_ID: ${TENCENT_COS_SECRET_ID:-} @@ -254,6 +254,7 @@ services: APP_API_URL: ${APP_API_URL:-} SENTRY_DSN: ${WEB_SENTRY_DSN:-} NEXT_TELEMETRY_DISABLED: ${NEXT_TELEMETRY_DISABLED:-0} + TEXT_GENERATION_TIMEOUT_MS: ${TEXT_GENERATION_TIMEOUT_MS:-60000} # The postgres database. db: @@ -292,7 +293,7 @@ services: # The DifySandbox sandbox: - image: langgenius/dify-sandbox:0.2.7 + image: langgenius/dify-sandbox:0.2.9 restart: always environment: # The DifySandbox configurations diff --git a/sdks/python-client/dify_client/client.py b/sdks/python-client/dify_client/client.py index a8da1d7cae..2be079bdf3 100644 --- a/sdks/python-client/dify_client/client.py +++ b/sdks/python-client/dify_client/client.py @@ -1,3 +1,4 @@ +import json import requests @@ -131,3 +132,284 @@ class WorkflowClient(DifyClient): def stop(self, task_id, user): data = {"user": user} return self._send_request("POST", f"/workflows/tasks/{task_id}/stop", data) + + def get_result(self, workflow_run_id): + return self._send_request("GET", f"/workflows/run/{workflow_run_id}") + + + +class KnowledgeBaseClient(DifyClient): + + def __init__(self, api_key, base_url: str = 'https://api.dify.ai/v1', dataset_id: str = None): + """ + Construct a KnowledgeBaseClient object. + + Args: + api_key (str): API key of Dify. + base_url (str, optional): Base URL of Dify API. Defaults to 'https://api.dify.ai/v1'. + dataset_id (str, optional): ID of the dataset. Defaults to None. You don't need this if you just want to + create a new dataset. or list datasets. otherwise you need to set this. + """ + super().__init__( + api_key=api_key, + base_url=base_url + ) + self.dataset_id = dataset_id + + def _get_dataset_id(self): + if self.dataset_id is None: + raise ValueError("dataset_id is not set") + return self.dataset_id + + def create_dataset(self, name: str, **kwargs): + return self._send_request('POST', '/datasets', {'name': name}, **kwargs) + + def list_datasets(self, page: int = 1, page_size: int = 20, **kwargs): + return self._send_request('GET', f'/datasets?page={page}&limit={page_size}', **kwargs) + + def create_document_by_text(self, name, text, extra_params: dict = None, **kwargs): + """ + Create a document by text. + + :param name: Name of the document + :param text: Text content of the document + :param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional) + e.g. + { + 'indexing_technique': 'high_quality', + 'process_rule': { + 'rules': { + 'pre_processing_rules': [ + {'id': 'remove_extra_spaces', 'enabled': True}, + {'id': 'remove_urls_emails', 'enabled': True} + ], + 'segmentation': { + 'separator': '\n', + 'max_tokens': 500 + } + }, + 'mode': 'custom' + } + } + :return: Response from the API + """ + data = { + 'indexing_technique': 'high_quality', + 'process_rule': { + 'mode': 'automatic' + }, + 'name': name, + 'text': text + } + if extra_params is not None and isinstance(extra_params, dict): + data.update(extra_params) + url = f"/datasets/{self._get_dataset_id()}/document/create_by_text" + return self._send_request("POST", url, json=data, **kwargs) + + def update_document_by_text(self, document_id, name, text, extra_params: dict = None, **kwargs): + """ + Update a document by text. + + :param document_id: ID of the document + :param name: Name of the document + :param text: Text content of the document + :param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional) + e.g. + { + 'indexing_technique': 'high_quality', + 'process_rule': { + 'rules': { + 'pre_processing_rules': [ + {'id': 'remove_extra_spaces', 'enabled': True}, + {'id': 'remove_urls_emails', 'enabled': True} + ], + 'segmentation': { + 'separator': '\n', + 'max_tokens': 500 + } + }, + 'mode': 'custom' + } + } + :return: Response from the API + """ + data = { + 'name': name, + 'text': text + } + if extra_params is not None and isinstance(extra_params, dict): + data.update(extra_params) + url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_text" + return self._send_request("POST", url, json=data, **kwargs) + + def create_document_by_file(self, file_path, original_document_id=None, extra_params: dict = None): + """ + Create a document by file. + + :param file_path: Path to the file + :param original_document_id: pass this ID if you want to replace the original document (optional) + :param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional) + e.g. + { + 'indexing_technique': 'high_quality', + 'process_rule': { + 'rules': { + 'pre_processing_rules': [ + {'id': 'remove_extra_spaces', 'enabled': True}, + {'id': 'remove_urls_emails', 'enabled': True} + ], + 'segmentation': { + 'separator': '\n', + 'max_tokens': 500 + } + }, + 'mode': 'custom' + } + } + :return: Response from the API + """ + files = {"file": open(file_path, "rb")} + data = { + 'process_rule': { + 'mode': 'automatic' + }, + 'indexing_technique': 'high_quality' + } + if extra_params is not None and isinstance(extra_params, dict): + data.update(extra_params) + if original_document_id is not None: + data['original_document_id'] = original_document_id + url = f"/datasets/{self._get_dataset_id()}/document/create_by_file" + return self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files) + + def update_document_by_file(self, document_id, file_path, extra_params: dict = None): + """ + Update a document by file. + + :param document_id: ID of the document + :param file_path: Path to the file + :param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional) + e.g. + { + 'indexing_technique': 'high_quality', + 'process_rule': { + 'rules': { + 'pre_processing_rules': [ + {'id': 'remove_extra_spaces', 'enabled': True}, + {'id': 'remove_urls_emails', 'enabled': True} + ], + 'segmentation': { + 'separator': '\n', + 'max_tokens': 500 + } + }, + 'mode': 'custom' + } + } + :return: + """ + files = {"file": open(file_path, "rb")} + data = {} + if extra_params is not None and isinstance(extra_params, dict): + data.update(extra_params) + url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_file" + return self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files) + + def batch_indexing_status(self, batch_id: str, **kwargs): + """ + Get the status of the batch indexing. + + :param batch_id: ID of the batch uploading + :return: Response from the API + """ + url = f"/datasets/{self._get_dataset_id()}/documents/{batch_id}/indexing-status" + return self._send_request("GET", url, **kwargs) + + def delete_dataset(self): + """ + Delete this dataset. + + :return: Response from the API + """ + url = f"/datasets/{self._get_dataset_id()}" + return self._send_request("DELETE", url) + + def delete_document(self, document_id): + """ + Delete a document. + + :param document_id: ID of the document + :return: Response from the API + """ + url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}" + return self._send_request("DELETE", url) + + def list_documents(self, page: int = None, page_size: int = None, keyword: str = None, **kwargs): + """ + Get a list of documents in this dataset. + + :return: Response from the API + """ + params = {} + if page is not None: + params['page'] = page + if page_size is not None: + params['limit'] = page_size + if keyword is not None: + params['keyword'] = keyword + url = f"/datasets/{self._get_dataset_id()}/documents" + return self._send_request("GET", url, params=params, **kwargs) + + def add_segments(self, document_id, segments, **kwargs): + """ + Add segments to a document. + + :param document_id: ID of the document + :param segments: List of segments to add, example: [{"content": "1", "answer": "1", "keyword": ["a"]}] + :return: Response from the API + """ + data = {"segments": segments} + url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments" + return self._send_request("POST", url, json=data, **kwargs) + + def query_segments(self, document_id, keyword: str = None, status: str = None, **kwargs): + """ + Query segments in this document. + + :param document_id: ID of the document + :param keyword: query keyword, optional + :param status: status of the segment, optional, e.g. completed + """ + url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments" + params = {} + if keyword is not None: + params['keyword'] = keyword + if status is not None: + params['status'] = status + if "params" in kwargs: + params.update(kwargs["params"]) + return self._send_request("GET", url, params=params, **kwargs) + + def delete_document_segment(self, document_id, segment_id): + """ + Delete a segment from a document. + + :param document_id: ID of the document + :param segment_id: ID of the segment + :return: Response from the API + """ + url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments/{segment_id}" + return self._send_request("DELETE", url) + + def update_document_segment(self, document_id, segment_id, segment_data, **kwargs): + """ + Update a segment in a document. + + :param document_id: ID of the document + :param segment_id: ID of the segment + :param segment_data: Data of the segment, example: {"content": "1", "answer": "1", "keyword": ["a"], "enabled": True} + :return: Response from the API + """ + data = {"segment": segment_data} + url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments/{segment_id}" + return self._send_request("POST", url, json=data, **kwargs) diff --git a/sdks/python-client/setup.py b/sdks/python-client/setup.py index e74748377e..bb8ca46d97 100644 --- a/sdks/python-client/setup.py +++ b/sdks/python-client/setup.py @@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh: setup( name="dify-client", - version="0.1.10", + version="0.1.12", author="Dify", author_email="hello@dify.ai", description="A package for interacting with the Dify Service-API", diff --git a/sdks/python-client/tests/test_client.py b/sdks/python-client/tests/test_client.py index 5259d082ca..301e733b6b 100644 --- a/sdks/python-client/tests/test_client.py +++ b/sdks/python-client/tests/test_client.py @@ -1,10 +1,157 @@ import os +import time import unittest -from dify_client.client import ChatClient, CompletionClient, DifyClient +from dify_client.client import ChatClient, CompletionClient, DifyClient, KnowledgeBaseClient API_KEY = os.environ.get("API_KEY") APP_ID = os.environ.get("APP_ID") +API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.dify.ai/v1") +FILE_PATH_BASE = os.path.dirname(__file__) + + +class TestKnowledgeBaseClient(unittest.TestCase): + def setUp(self): + self.knowledge_base_client = KnowledgeBaseClient(API_KEY, base_url=API_BASE_URL) + self.README_FILE_PATH = os.path.abspath(os.path.join(FILE_PATH_BASE, "../README.md")) + self.dataset_id = None + self.document_id = None + self.segment_id = None + self.batch_id = None + + def _get_dataset_kb_client(self): + self.assertIsNotNone(self.dataset_id) + return KnowledgeBaseClient(API_KEY, base_url=API_BASE_URL, dataset_id=self.dataset_id) + + def test_001_create_dataset(self): + response = self.knowledge_base_client.create_dataset(name="test_dataset") + data = response.json() + self.assertIn("id", data) + self.dataset_id = data["id"] + self.assertEqual("test_dataset", data["name"]) + + # the following tests require to be executed in order because they use + # the dataset/document/segment ids from the previous test + self._test_002_list_datasets() + self._test_003_create_document_by_text() + time.sleep(1) + self._test_004_update_document_by_text() + # self._test_005_batch_indexing_status() + time.sleep(1) + self._test_006_update_document_by_file() + time.sleep(1) + self._test_007_list_documents() + self._test_008_delete_document() + self._test_009_create_document_by_file() + time.sleep(1) + self._test_010_add_segments() + self._test_011_query_segments() + self._test_012_update_document_segment() + self._test_013_delete_document_segment() + self._test_014_delete_dataset() + + def _test_002_list_datasets(self): + response = self.knowledge_base_client.list_datasets() + data = response.json() + self.assertIn("data", data) + self.assertIn("total", data) + + def _test_003_create_document_by_text(self): + client = self._get_dataset_kb_client() + response = client.create_document_by_text("test_document", "test_text") + data = response.json() + self.assertIn("document", data) + self.document_id = data["document"]["id"] + self.batch_id = data["batch"] + + def _test_004_update_document_by_text(self): + client = self._get_dataset_kb_client() + self.assertIsNotNone(self.document_id) + response = client.update_document_by_text(self.document_id, "test_document_updated", "test_text_updated") + data = response.json() + self.assertIn("document", data) + self.assertIn("batch", data) + self.batch_id = data["batch"] + + def _test_005_batch_indexing_status(self): + client = self._get_dataset_kb_client() + response = client.batch_indexing_status(self.batch_id) + data = response.json() + self.assertEqual(response.status_code, 200) + + def _test_006_update_document_by_file(self): + client = self._get_dataset_kb_client() + self.assertIsNotNone(self.document_id) + response = client.update_document_by_file(self.document_id, self.README_FILE_PATH) + data = response.json() + self.assertIn("document", data) + self.assertIn("batch", data) + self.batch_id = data["batch"] + + def _test_007_list_documents(self): + client = self._get_dataset_kb_client() + response = client.list_documents() + data = response.json() + self.assertIn("data", data) + + def _test_008_delete_document(self): + client = self._get_dataset_kb_client() + self.assertIsNotNone(self.document_id) + response = client.delete_document(self.document_id) + data = response.json() + self.assertIn("result", data) + self.assertEqual("success", data["result"]) + + def _test_009_create_document_by_file(self): + client = self._get_dataset_kb_client() + response = client.create_document_by_file(self.README_FILE_PATH) + data = response.json() + self.assertIn("document", data) + self.document_id = data["document"]["id"] + self.batch_id = data["batch"] + + def _test_010_add_segments(self): + client = self._get_dataset_kb_client() + response = client.add_segments(self.document_id, [ + {"content": "test text segment 1"} + ]) + data = response.json() + self.assertIn("data", data) + self.assertGreater(len(data["data"]), 0) + segment = data["data"][0] + self.segment_id = segment["id"] + + def _test_011_query_segments(self): + client = self._get_dataset_kb_client() + response = client.query_segments(self.document_id) + data = response.json() + self.assertIn("data", data) + self.assertGreater(len(data["data"]), 0) + + def _test_012_update_document_segment(self): + client = self._get_dataset_kb_client() + self.assertIsNotNone(self.segment_id) + response = client.update_document_segment(self.document_id, self.segment_id, + {"content": "test text segment 1 updated"} + ) + data = response.json() + self.assertIn("data", data) + self.assertGreater(len(data["data"]), 0) + segment = data["data"] + self.assertEqual("test text segment 1 updated", segment["content"]) + + def _test_013_delete_document_segment(self): + client = self._get_dataset_kb_client() + self.assertIsNotNone(self.segment_id) + response = client.delete_document_segment(self.document_id, self.segment_id) + data = response.json() + self.assertIn("result", data) + self.assertEqual("success", data["result"]) + + def _test_014_delete_dataset(self): + client = self._get_dataset_kb_client() + response = client.delete_dataset() + self.assertEqual(204, response.status_code) class TestChatClient(unittest.TestCase): diff --git a/web/.env.example b/web/.env.example index 3045cef2f9..8e254082b3 100644 --- a/web/.env.example +++ b/web/.env.example @@ -19,3 +19,6 @@ NEXT_TELEMETRY_DISABLED=1 # Disable Upload Image as WebApp icon default is false NEXT_PUBLIC_UPLOAD_IMAGE_AS_ICON=false + +# The timeout for the text generation in millisecond +NEXT_PUBLIC_TEXT_GENERATION_TIMEOUT_MS=60000 diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout.tsx index e728749b85..96ee874d53 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout.tsx @@ -109,6 +109,11 @@ const AppDetailLayout: FC = (props) => { setAppDetail() fetchAppDetail({ url: '/apps', id: appId }).then((res) => { // redirection + const canIEditApp = isCurrentWorkspaceEditor + if (!canIEditApp && (pathname.endsWith('configuration') || pathname.endsWith('workflow') || pathname.endsWith('logs'))) { + router.replace(`/app/${appId}/overview`) + return + } if ((res.mode === 'workflow' || res.mode === 'advanced-chat') && (pathname).endsWith('configuration')) { router.replace(`/app/${appId}/workflow`) } @@ -118,7 +123,7 @@ const AppDetailLayout: FC = (props) => { else { setAppDetail({ ...res, enable_sso: false }) setNavigation(getNavigations(appId, isCurrentWorkspaceEditor, res.mode)) - if (systemFeatures.enable_web_sso_switch_component) { + if (systemFeatures.enable_web_sso_switch_component && canIEditApp) { fetchAppSSO({ appId }).then((ssoRes) => { setAppDetail({ ...res, enable_sso: ssoRes.enabled }) }) @@ -128,7 +133,7 @@ const AppDetailLayout: FC = (props) => { if (e.status === 404) router.replace('/apps') }) - }, [appId, isCurrentWorkspaceEditor, systemFeatures]) + }, [appId, isCurrentWorkspaceEditor, systemFeatures, getNavigations, pathname, router, setAppDetail]) useUnmount(() => { setAppDetail() diff --git a/web/app/components/app/configuration/toolbox/annotation/annotation-ctrl-btn/index.tsx b/web/app/components/app/configuration/toolbox/annotation/annotation-ctrl-btn/index.tsx index 111c380afc..809b907d62 100644 --- a/web/app/components/app/configuration/toolbox/annotation/annotation-ctrl-btn/index.tsx +++ b/web/app/components/app/configuration/toolbox/annotation/annotation-ctrl-btn/index.tsx @@ -73,7 +73,7 @@ const CacheCtrlBtn: FC = ({ setShowModal(false) } return ( -
+
{cached ? ( @@ -101,7 +101,6 @@ const CacheCtrlBtn: FC = ({ ? (
= ({ }
{/* Panel Header */} -
+
{isChatMode ? t('appLog.detail.conversationId') : t('appLog.detail.time')}
{isChatMode && ( @@ -725,7 +725,7 @@ const ConversationList: FC = ({ logs, appDetail, onRefresh }) onClose={onCloseDrawer} mask={isMobile} footer={null} - panelClassname='mt-16 mx-2 sm:mr-2 mb-4 !p-0 !max-w-[640px] rounded-xl' + panelClassname='mt-16 mx-2 sm:mr-2 mb-4 !p-0 !max-w-[640px] rounded-xl bg-background-gradient-bg-fill-chat-bg-1' > = ({ onSave, }) => { const systemFeatures = useContextSelector(AppContext, state => state.systemFeatures) + const { isCurrentWorkspaceEditor } = useAppContext() const { notify } = useToastContext() const [isShowMore, setIsShowMore] = useState(false) const { @@ -265,7 +266,7 @@ const SettingsModal: FC = ({ } asChild={false} > - setInputInfo({ ...inputInfo, enable_sso: v })}> + setInputInfo({ ...inputInfo, enable_sso: v })}>

{t(`${prefixSettings}.sso.description`)}

diff --git a/web/app/components/base/chat/chat/answer/index.tsx b/web/app/components/base/chat/chat/answer/index.tsx index 270cd553c2..5fe2a7bad5 100644 --- a/web/app/components/base/chat/chat/answer/index.tsx +++ b/web/app/components/base/chat/chat/answer/index.tsx @@ -8,7 +8,6 @@ import type { ChatConfig, ChatItem, } from '../../types' -import { useChatContext } from '../context' import Operation from './operation' import AgentContent from './agent-content' import BasicContent from './basic-content' @@ -16,13 +15,13 @@ import SuggestedQuestions from './suggested-questions' import More from './more' import WorkflowProcess from './workflow-process' import { AnswerTriangle } from '@/app/components/base/icons/src/vender/solid/general' -import { MessageFast } from '@/app/components/base/icons/src/vender/solid/communication' import LoadingAnim from '@/app/components/base/chat/chat/loading-anim' import Citation from '@/app/components/base/chat/chat/citation' import { EditTitle } from '@/app/components/app/annotation/edit-annotation-modal/edit-item' import type { Emoji } from '@/app/components/tools/types' import type { AppData } from '@/models/share' import AnswerIcon from '@/app/components/base/answer-icon' +import cn from '@/utils/classnames' type AnswerProps = { item: ChatItem @@ -61,26 +60,24 @@ const Answer: FC = ({ } = item const hasAgentThoughts = !!agent_thoughts?.length - const [containerWidth] = useState(0) + const [containerWidth, setContainerWidth] = useState(0) const [contentWidth, setContentWidth] = useState(0) const containerRef = useRef(null) const contentRef = useRef(null) - const { - config: chatContextConfig, - } = useChatContext() + const getContainerWidth = () => { + if (containerRef.current) + setContainerWidth(containerRef.current?.clientWidth + 16) + } + useEffect(() => { + getContainerWidth() + }, []) - const voiceRef = useRef(chatContextConfig?.text_to_speech?.voice) const getContentWidth = () => { if (contentRef.current) setContentWidth(contentRef.current?.clientWidth) } - useEffect(() => { - voiceRef.current = chatContextConfig?.text_to_speech?.voice - } - , [chatContextConfig?.text_to_speech?.voice]) - useEffect(() => { if (!responding) getContentWidth() @@ -89,36 +86,20 @@ const Answer: FC = ({ return (
- { - answerIcon || - } - { - responding && ( -
- -
- ) - } + {answerIcon || } + {responding && ( +
+ +
+ )}
-
+
- {annotation?.id && ( -
-
- -
-
- )} { !responding && ( = ({ /> )} { - !positionRight && annotation?.id && ( + annotation?.id && (
diff --git a/web/app/components/base/chat/chat/index.tsx b/web/app/components/base/chat/chat/index.tsx index 65e49eff67..68194193c4 100644 --- a/web/app/components/base/chat/chat/index.tsx +++ b/web/app/components/base/chat/chat/index.tsx @@ -109,9 +109,9 @@ const Chat: FC = ({ const userScrolledRef = useRef(false) const handleScrollToBottom = useCallback(() => { - if (chatContainerRef.current && !userScrolledRef.current) + if (chatList.length > 1 && chatContainerRef.current && !userScrolledRef.current) chatContainerRef.current.scrollTop = chatContainerRef.current.scrollHeight - }, []) + }, [chatList.length]) const handleWindowResize = useCallback(() => { if (chatContainerRef.current) diff --git a/web/app/components/base/image-uploader/image-preview.tsx b/web/app/components/base/image-uploader/image-preview.tsx index 41f29fda2e..e5bd4c1bbc 100644 --- a/web/app/components/base/image-uploader/image-preview.tsx +++ b/web/app/components/base/image-uploader/image-preview.tsx @@ -1,26 +1,42 @@ import type { FC } from 'react' -import { useRef } from 'react' +import React, { useCallback, useEffect, useRef, useState } from 'react' import { t } from 'i18next' import { createPortal } from 'react-dom' -import { RiCloseLine, RiExternalLinkLine } from '@remixicon/react' +import { RiAddBoxLine, RiCloseLine, RiDownloadCloud2Line, RiFileCopyLine, RiZoomInLine, RiZoomOutLine } from '@remixicon/react' import Tooltip from '@/app/components/base/tooltip' -import { randomString } from '@/utils' +import Toast from '@/app/components/base/toast' type ImagePreviewProps = { url: string title: string onCancel: () => void } + +const isBase64 = (str: string): boolean => { + try { + return btoa(atob(str)) === str + } + catch (err) { + return false + } +} + const ImagePreview: FC = ({ url, title, onCancel, }) => { - const selector = useRef(`copy-tooltip-${randomString(4)}`) + const [scale, setScale] = useState(1) + const [position, setPosition] = useState({ x: 0, y: 0 }) + const [isDragging, setIsDragging] = useState(false) + const imgRef = useRef(null) + const dragStartRef = useRef({ x: 0, y: 0 }) + const [isCopied, setIsCopied] = useState(false) + const containerRef = useRef(null) const openInNewTab = () => { // Open in a new window, considering the case when the page is inside an iframe - if (url.startsWith('http')) { + if (url.startsWith('http') || url.startsWith('https')) { window.open(url, '_blank') } else if (url.startsWith('data:image')) { @@ -29,34 +45,224 @@ const ImagePreview: FC = ({ win?.document.write(`${title}`) } else { - console.error('Unable to open image', url) + Toast.notify({ + type: 'error', + message: `Unable to open image: ${url}`, + }) + } + } + const downloadImage = () => { + // Open in a new window, considering the case when the page is inside an iframe + if (url.startsWith('http') || url.startsWith('https')) { + const a = document.createElement('a') + a.href = url + a.download = title + a.click() + } + else if (url.startsWith('data:image')) { + // Base64 image + const a = document.createElement('a') + a.href = url + a.download = title + a.click() + } + else { + Toast.notify({ + type: 'error', + message: `Unable to open image: ${url}`, + }) } } + const zoomIn = () => { + setScale(prevScale => Math.min(prevScale * 1.2, 15)) + } + + const zoomOut = () => { + setScale((prevScale) => { + const newScale = Math.max(prevScale / 1.2, 0.5) + if (newScale === 1) + setPosition({ x: 0, y: 0 }) // Reset position when fully zoomed out + + return newScale + }) + } + + const imageTobase64ToBlob = (base64: string, type = 'image/png'): Blob => { + const byteCharacters = atob(base64) + const byteArrays = [] + + for (let offset = 0; offset < byteCharacters.length; offset += 512) { + const slice = byteCharacters.slice(offset, offset + 512) + const byteNumbers = new Array(slice.length) + for (let i = 0; i < slice.length; i++) + byteNumbers[i] = slice.charCodeAt(i) + + const byteArray = new Uint8Array(byteNumbers) + byteArrays.push(byteArray) + } + + return new Blob(byteArrays, { type }) + } + + const imageCopy = useCallback(() => { + const shareImage = async () => { + try { + const base64Data = url.split(',')[1] + const blob = imageTobase64ToBlob(base64Data, 'image/png') + + await navigator.clipboard.write([ + new ClipboardItem({ + [blob.type]: blob, + }), + ]) + setIsCopied(true) + + Toast.notify({ + type: 'success', + message: t('common.operation.imageCopied'), + }) + } + catch (err) { + console.error('Failed to copy image:', err) + + const link = document.createElement('a') + link.href = url + link.download = `${title}.png` + document.body.appendChild(link) + link.click() + document.body.removeChild(link) + + Toast.notify({ + type: 'info', + message: t('common.operation.imageDownloaded'), + }) + } + } + shareImage() + }, [title, url]) + + const handleWheel = useCallback((e: React.WheelEvent) => { + if (e.deltaY < 0) + zoomIn() + else + zoomOut() + }, []) + + const handleMouseDown = useCallback((e: React.MouseEvent) => { + if (scale > 1) { + setIsDragging(true) + dragStartRef.current = { x: e.clientX - position.x, y: e.clientY - position.y } + } + }, [scale, position]) + + const handleMouseMove = useCallback((e: React.MouseEvent) => { + if (isDragging && scale > 1) { + const deltaX = e.clientX - dragStartRef.current.x + const deltaY = e.clientY - dragStartRef.current.y + + // Calculate boundaries + const imgRect = imgRef.current?.getBoundingClientRect() + const containerRect = imgRef.current?.parentElement?.getBoundingClientRect() + + if (imgRect && containerRect) { + const maxX = (imgRect.width * scale - containerRect.width) / 2 + const maxY = (imgRect.height * scale - containerRect.height) / 2 + + setPosition({ + x: Math.max(-maxX, Math.min(maxX, deltaX)), + y: Math.max(-maxY, Math.min(maxY, deltaY)), + }) + } + } + }, [isDragging, scale]) + + const handleMouseUp = useCallback(() => { + setIsDragging(false) + }, []) + + useEffect(() => { + document.addEventListener('mouseup', handleMouseUp) + return () => { + document.removeEventListener('mouseup', handleMouseUp) + } + }, [handleMouseUp]) + + useEffect(() => { + const handleKeyDown = (event: KeyboardEvent) => { + if (event.key === 'Escape') + onCancel() + } + + window.addEventListener('keydown', handleKeyDown) + + // Set focus to the container element + if (containerRef.current) + containerRef.current.focus() + + // Cleanup function + return () => { + window.removeEventListener('keydown', handleKeyDown) + } + }, [onCancel]) + return createPortal( -
e.stopPropagation()}> +
e.stopPropagation()} + onWheel={handleWheel} + onMouseDown={handleMouseDown} + onMouseMove={handleMouseMove} + onMouseUp={handleMouseUp} + style={{ cursor: scale > 1 ? 'move' : 'default' }} + tabIndex={-1}> {/* eslint-disable-next-line @next/next/no-img-element */} {title} -
- -
- + +
+ {isCopied + ? + : } +
+
+ +
+ +
+
+ +
+ +
+
+ +
+ +
+
+ +
+ +
+
+
- + className='absolute top-6 right-6 flex items-center justify-center w-8 h-8 bg-white/8 rounded-lg backdrop-blur-[2px] cursor-pointer' + onClick={onCancel}> +
, diff --git a/web/app/components/base/markdown.tsx b/web/app/components/base/markdown.tsx index d4e7dac4ae..443ee3410c 100644 --- a/web/app/components/base/markdown.tsx +++ b/web/app/components/base/markdown.tsx @@ -5,6 +5,7 @@ import RemarkMath from 'remark-math' import RemarkBreaks from 'remark-breaks' import RehypeKatex from 'rehype-katex' import RemarkGfm from 'remark-gfm' +import RehypeRaw from 'rehype-raw' import SyntaxHighlighter from 'react-syntax-highlighter' import { atelierHeathLight } from 'react-syntax-highlighter/dist/esm/styles/hljs' import type { RefObject } from 'react' @@ -18,6 +19,7 @@ import ImageGallery from '@/app/components/base/image-gallery' import { useChatContext } from '@/app/components/base/chat/chat/context' import VideoGallery from '@/app/components/base/video-gallery' import AudioGallery from '@/app/components/base/audio-gallery' +import SVGRenderer from '@/app/components/base/svg-gallery' // Available language https://github.com/react-syntax-highlighter/react-syntax-highlighter/blob/master/AVAILABLE_LANGUAGES_HLJS.MD const capitalizationLanguageNameMap: Record = { @@ -40,6 +42,7 @@ const capitalizationLanguageNameMap: Record = { powershell: 'PowerShell', json: 'JSON', latex: 'Latex', + svg: 'SVG', } const getCorrectCapitalizationLanguageName = (language: string) => { if (!language) @@ -107,6 +110,7 @@ const useLazyLoad = (ref: RefObject): boolean => { // Error: Minified React error 185; // visit https://reactjs.org/docs/error-decoder.html?invariant=185 for the full message // or use the non-minified dev environment for full errors and additional helpful warnings. + const CodeBlock: CodeComponent = memo(({ inline, className, children, ...props }) => { const [isSVG, setIsSVG] = useState(true) const match = /language-(\w+)/.exec(className || '') @@ -134,7 +138,7 @@ const CodeBlock: CodeComponent = memo(({ inline, className, children, ...props } >
{languageShowName}
- {language === 'mermaid' && } + {language === 'mermaid' && } {(language === 'mermaid' && isSVG) ? () - : ( - (language === 'echarts') - ? (
-
) + : (language === 'echarts' + ? (
) + : (language === 'svg' + ? () : ( {String(children).replace(/\n$/, '')} - ))} + )))}
) - : ( - - {children} - - ) + : ({children}) }, [chartData, children, className, inline, isSVG, language, languageShowName, match, props]) }) - CodeBlock.displayName = 'CodeBlock' const VideoBlock: CodeComponent = memo(({ node }) => { @@ -230,6 +227,7 @@ export function Markdown(props: { content: string; className?: string }) { remarkPlugins={[[RemarkGfm, RemarkMath, { singleDollarTextMath: false }], RemarkBreaks]} rehypePlugins={[ RehypeKatex, + RehypeRaw as any, // The Rehype plug-in is used to remove the ref attribute of an element () => { return (tree) => { @@ -244,6 +242,7 @@ export function Markdown(props: { content: string; className?: string }) { } }, ]} + disallowedElements={['script', 'iframe', 'head', 'html', 'meta', 'link', 'style', 'body']} components={{ code: CodeBlock, img: Img, @@ -266,19 +265,23 @@ export function Markdown(props: { content: string; className?: string }) { // This can happen when a component attempts to access an undefined object that references an unregistered map, causing the program to crash. export default class ErrorBoundary extends Component { - constructor(props) { + constructor(props: any) { super(props) this.state = { hasError: false } } - componentDidCatch(error, errorInfo) { + componentDidCatch(error: any, errorInfo: any) { this.setState({ hasError: true }) console.error(error, errorInfo) } render() { + // eslint-disable-next-line @typescript-eslint/ban-ts-comment + // @ts-expect-error if (this.state.hasError) - return
Oops! ECharts reported a runtime error.
(see the browser console for more information)
+ return
Oops! An error occurred. This could be due to an ECharts runtime error or invalid SVG content.
(see the browser console for more information)
+ // eslint-disable-next-line @typescript-eslint/ban-ts-comment + // @ts-expect-error return this.props.children } } diff --git a/web/app/components/base/svg-gallery/index.tsx b/web/app/components/base/svg-gallery/index.tsx new file mode 100644 index 0000000000..81e8e87655 --- /dev/null +++ b/web/app/components/base/svg-gallery/index.tsx @@ -0,0 +1,79 @@ +import { useEffect, useRef, useState } from 'react' +import { SVG } from '@svgdotjs/svg.js' +import ImagePreview from '@/app/components/base/image-uploader/image-preview' + +export const SVGRenderer = ({ content }: { content: string }) => { + const svgRef = useRef(null) + const [imagePreview, setImagePreview] = useState('') + const [windowSize, setWindowSize] = useState({ + width: typeof window !== 'undefined' ? window.innerWidth : 0, + height: typeof window !== 'undefined' ? window.innerHeight : 0, + }) + + const svgToDataURL = (svgElement: Element): string => { + const svgString = new XMLSerializer().serializeToString(svgElement) + const base64String = Buffer.from(svgString).toString('base64') + return `data:image/svg+xml;base64,${base64String}` + } + + useEffect(() => { + const handleResize = () => { + setWindowSize({ width: window.innerWidth, height: window.innerHeight }) + } + + window.addEventListener('resize', handleResize) + return () => window.removeEventListener('resize', handleResize) + }, []) + + useEffect(() => { + if (svgRef.current) { + try { + svgRef.current.innerHTML = '' + const draw = SVG().addTo(svgRef.current).size('100%', '100%') + + const parser = new DOMParser() + const svgDoc = parser.parseFromString(content, 'image/svg+xml') + const svgElement = svgDoc.documentElement + + if (!(svgElement instanceof SVGElement)) + throw new Error('Invalid SVG content') + + const originalWidth = parseInt(svgElement.getAttribute('width') || '400', 10) + const originalHeight = parseInt(svgElement.getAttribute('height') || '600', 10) + const scale = Math.min(windowSize.width / originalWidth, windowSize.height / originalHeight, 1) + const scaledWidth = originalWidth * scale + const scaledHeight = originalHeight * scale + draw.size(scaledWidth, scaledHeight) + + const rootElement = draw.svg(content) + rootElement.scale(scale) + + rootElement.click(() => { + setImagePreview(svgToDataURL(svgElement as Element)) + }) + } + catch (error) { + if (svgRef.current) + svgRef.current.innerHTML = 'Error rendering SVG. Wait for the image content to complete.' + } + } + }, [content, windowSize]) + + return ( + <> +
+ {imagePreview && ( setImagePreview('')} />)} + + ) +} + +export default SVGRenderer diff --git a/web/app/components/header/account-setting/model-provider-page/model-modal/model-load-balancing-entry-modal.tsx b/web/app/components/header/account-setting/model-provider-page/model-modal/model-load-balancing-entry-modal.tsx index 86857f1ab2..1c318b9baf 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-modal/model-load-balancing-entry-modal.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-modal/model-load-balancing-entry-modal.tsx @@ -192,12 +192,12 @@ const ModelLoadBalancingEntryModal: FC = ({ }) const getSecretValues = useCallback((v: FormValue) => { return secretFormSchemas.reduce((prev, next) => { - if (v[next.variable] === initialFormSchemasValue[next.variable]) + if (isEditMode && v[next.variable] && v[next.variable] === initialFormSchemasValue[next.variable]) prev[next.variable] = '[__HIDDEN__]' return prev }, {} as Record) - }, [initialFormSchemasValue, secretFormSchemas]) + }, [initialFormSchemasValue, isEditMode, secretFormSchemas]) // const handleValueChange = ({ __model_type, __model_name, ...v }: FormValue) => { const handleValueChange = (v: FormValue) => { @@ -214,6 +214,7 @@ const ModelLoadBalancingEntryModal: FC = ({ ...value, ...getSecretValues(value), }, + entry?.id, ) if (res.status === ValidatedStatus.Success) { // notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) diff --git a/web/app/components/header/account-setting/model-provider-page/utils.ts b/web/app/components/header/account-setting/model-provider-page/utils.ts index 8cad399763..165926b2bb 100644 --- a/web/app/components/header/account-setting/model-provider-page/utils.ts +++ b/web/app/components/header/account-setting/model-provider-page/utils.ts @@ -56,14 +56,14 @@ export const validateCredentials = async (predefined: boolean, provider: string, } } -export const validateLoadBalancingCredentials = async (predefined: boolean, provider: string, v: FormValue): Promise<{ +export const validateLoadBalancingCredentials = async (predefined: boolean, provider: string, v: FormValue, id?: string): Promise<{ status: ValidatedStatus message?: string }> => { const { __model_name, __model_type, ...credentials } = v try { const res = await validateModelLoadBalancingCredentials({ - url: `/workspaces/current/model-providers/${provider}/models/load-balancing-configs/credentials-validate`, + url: `/workspaces/current/model-providers/${provider}/models/load-balancing-configs/${id ? `${id}/` : ''}credentials-validate`, body: { model: __model_name, model_type: __model_type, diff --git a/web/app/components/share/text-generation/result/index.tsx b/web/app/components/share/text-generation/result/index.tsx index 2d5546f9b4..a61302fc98 100644 --- a/web/app/components/share/text-generation/result/index.tsx +++ b/web/app/components/share/text-generation/result/index.tsx @@ -269,8 +269,10 @@ const Result: FC = ({ })) }, onWorkflowFinished: ({ data }) => { - if (isTimeout) + if (isTimeout) { + notify({ type: 'warning', message: t('appDebug.warningMessage.timeoutExceeded') }) return + } if (data.error) { notify({ type: 'error', message: data.error }) setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => { @@ -326,8 +328,10 @@ const Result: FC = ({ setCompletionRes(res.join('')) }, onCompleted: () => { - if (isTimeout) + if (isTimeout) { + notify({ type: 'warning', message: t('appDebug.warningMessage.timeoutExceeded') }) return + } setRespondingFalse() setMessageId(tempMessageId) onCompleted(getCompletionRes(), taskId, true) @@ -338,8 +342,10 @@ const Result: FC = ({ setCompletionRes(res.join('')) }, onError() { - if (isTimeout) + if (isTimeout) { + notify({ type: 'warning', message: t('appDebug.warningMessage.timeoutExceeded') }) return + } setRespondingFalse() onCompleted(getCompletionRes(), taskId, false) isEnd = true diff --git a/web/app/components/share/text-generation/run-once/index.tsx b/web/app/components/share/text-generation/run-once/index.tsx index e3b8d6a0a4..e74616b08d 100644 --- a/web/app/components/share/text-generation/run-once/index.tsx +++ b/web/app/components/share/text-generation/run-once/index.tsx @@ -1,4 +1,4 @@ -import type { FC } from 'react' +import type { FC, FormEvent } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' import { @@ -39,11 +39,16 @@ const RunOnce: FC = ({ onInputsChange(newInputs) } + const onSubmit = (e: FormEvent) => { + e.preventDefault() + onSend() + } + return (
{/* input form */} -
+ {promptConfig.prompt_variables.map(item => (
@@ -65,12 +70,6 @@ const RunOnce: FC = ({ placeholder={`${item.name}${!item.required ? `(${t('appDebug.variableTable.optional')})` : ''}`} value={inputs[item.key]} onChange={(e) => { onInputsChange({ ...inputs, [item.key]: e.target.value }) }} - onKeyDown={(e) => { - if (e.key === 'Enter') { - e.preventDefault() - onSend() - } - }} maxLength={item.max_length || DEFAULT_VALUE_MAX_LEN} /> )} @@ -124,8 +123,8 @@ const RunOnce: FC = ({ {t('common.operation.clear')}
)} - needsDelay >
{ const nodes = useNodes() const node = useMemo(() => { - if (isSystemVar(valueSelector)) - return nodes.find(node => node.data.type === BlockEnum.Start) + if (isSystemVar(valueSelector)) { + const startNode = availableNodes?.find(n => n.data.type === BlockEnum.Start) + if (startNode) + return startNode + } + return getNodeInfoById(availableNodes || nodes, valueSelector[0]) + }, [nodes, valueSelector, availableNodes]) - return nodes.find(node => node.id === valueSelector[0]) - }, [nodes, valueSelector]) const isEnv = isENV(valueSelector) const isChatVar = isConversationVar(valueSelector) + const isValid = Boolean(node) || isEnv || isChatVar const variableName = isSystemVar(valueSelector) ? valueSelector.slice(0).join('.') : valueSelector.slice(1).join('.') + const { t } = useTranslation() return ( -
- {!isEnv && !isChatVar && ( - <> + +
+ {(!isEnv && !isChatVar && <> {node && ( - + <> + +
+ {node?.data.title} +
+ )} -
- {node?.data.title} -
- - )} - {isEnv && } - {isChatVar && } -
- {variableName} + )} + {isEnv && } + {isChatVar && } +
+ {variableName} +
+ { + varType && ( +
{capitalize(varType)}
+ ) + } + {!isValid && }
- { - varType && ( -
{capitalize(varType)}
- ) - } -
+
) } diff --git a/web/app/components/workflow/nodes/_base/components/variable/var-reference-picker.tsx b/web/app/components/workflow/nodes/_base/components/variable/var-reference-picker.tsx index 7fb4ad68d8..5e51173673 100644 --- a/web/app/components/workflow/nodes/_base/components/variable/var-reference-picker.tsx +++ b/web/app/components/workflow/nodes/_base/components/variable/var-reference-picker.tsx @@ -5,9 +5,10 @@ import { useTranslation } from 'react-i18next' import { RiArrowDownSLine, RiCloseLine, + RiErrorWarningFill, } from '@remixicon/react' import produce from 'immer' -import { useStoreApi } from 'reactflow' +import { useEdges, useStoreApi } from 'reactflow' import useAvailableVarList from '../../hooks/use-available-var-list' import VarReferencePopup from './var-reference-popup' import { getNodeInfoById, isConversationVar, isENV, isSystemVar } from './utils' @@ -33,6 +34,8 @@ import { VarType as VarKindType } from '@/app/components/workflow/nodes/tool/typ import TypeSelector from '@/app/components/workflow/nodes/_base/components/selector' import AddButton from '@/app/components/base/button/add-button' import Badge from '@/app/components/base/badge' +import Tooltip from '@/app/components/base/tooltip' + const TRIGGER_DEFAULT_WIDTH = 227 type Props = { @@ -77,6 +80,7 @@ const VarReferencePicker: FC = ({ const { getNodes, } = store.getState() + const edges = useEdges() const isChatMode = useIsChatMode() const { getCurrentVariableType } = useWorkflowVariables() @@ -179,7 +183,7 @@ const VarReferencePicker: FC = ({ const handleVarReferenceChange = useCallback((value: ValueSelector, varInfo: Var) => { // sys var not passed to backend const newValue = produce(value, (draft) => { - if (draft[1] && draft[1].startsWith('sys')) { + if (draft[1] && draft[1].startsWith('sys.')) { draft.shift() const paths = draft[0].split('.') paths.forEach((p, i) => { @@ -206,8 +210,16 @@ const VarReferencePicker: FC = ({ isConstant: !!isConstant, }) - const isEnv = isENV(value as ValueSelector) - const isChatVar = isConversationVar(value as ValueSelector) + const { isEnv, isChatVar, isValidVar } = useMemo(() => { + const isEnv = isENV(value as ValueSelector) + const isChatVar = isConversationVar(value as ValueSelector) + const isValidVar = Boolean(outputVarNode) || isEnv || isChatVar + return { + isEnv, + isChatVar, + isValidVar, + } + }, [value, edges, outputVarNode]) // 8(left/right-padding) + 14(icon) + 4 + 14 + 2 = 42 + 17 buff const availableWidth = triggerWidth - 56 @@ -285,39 +297,44 @@ const VarReferencePicker: FC = ({ className='grow h-full' >
-
- {hasValue - ? ( - <> - {isShowNodeName && !isEnv && !isChatVar && ( -
-
- + +
+ {hasValue + ? ( + <> + {isShowNodeName && !isEnv && !isChatVar && ( +
+
+ {outputVarNode?.type && } +
+
{outputVarNode?.title}
+
-
{outputVarNode?.title}
- + )} +
+ {!hasValue && } + {isEnv && } + {isChatVar && } +
{varName}
- )} -
- {!hasValue && } - {isEnv && } - {isChatVar && } -
{varName}
-
-
{type}
- - ) - :
{t('workflow.common.setVarValuePlaceholder')}
} -
+
{type}
+ {!isValidVar && } + + ) + :
{t('workflow.common.setVarValuePlaceholder')}
} +
+
diff --git a/web/app/components/workflow/nodes/if-else/components/condition-list/condition-item.tsx b/web/app/components/workflow/nodes/if-else/components/condition-list/condition-item.tsx index c6cb580118..c96f6b3ef6 100644 --- a/web/app/components/workflow/nodes/if-else/components/condition-list/condition-item.tsx +++ b/web/app/components/workflow/nodes/if-else/components/condition-list/condition-item.tsx @@ -80,6 +80,7 @@ const ConditionItem = ({
diff --git a/web/app/components/workflow/run/tracing-panel.tsx b/web/app/components/workflow/run/tracing-panel.tsx index 7684dc84b6..613c10198d 100644 --- a/web/app/components/workflow/run/tracing-panel.tsx +++ b/web/app/components/workflow/run/tracing-panel.tsx @@ -11,6 +11,7 @@ import { RiArrowDownSLine, RiMenu4Line, } from '@remixicon/react' +import { useTranslation } from 'react-i18next' import NodePanel from './node' import { BlockEnum, @@ -37,7 +38,7 @@ type TracingNodeProps = { hideNodeProcessDetail?: boolean } -function buildLogTree(nodes: NodeTracing[]): TracingNodeProps[] { +function buildLogTree(nodes: NodeTracing[], t: (key: string) => string): TracingNodeProps[] { const rootNodes: TracingNodeProps[] = [] const parallelStacks: { [key: string]: TracingNodeProps } = {} const levelCounts: { [key: string]: number } = {} @@ -58,7 +59,7 @@ function buildLogTree(nodes: NodeTracing[]): TracingNodeProps[] { const parentTitle = parentId ? parallelStacks[parentId]?.parallelTitle : '' const levelNumber = parentTitle ? parseInt(parentTitle.split('-')[1]) + 1 : 1 const letter = parallelChildCounts[levelKey]?.size > 1 ? String.fromCharCode(64 + levelCounts[levelKey]) : '' - return `PARALLEL-${levelNumber}${letter}` + return `${t('workflow.common.parallel')}-${levelNumber}${letter}` } const getBranchTitle = (parentId: string | null, branchNum: number): string => { @@ -67,7 +68,7 @@ function buildLogTree(nodes: NodeTracing[]): TracingNodeProps[] { const levelNumber = parentTitle ? parseInt(parentTitle.split('-')[1]) + 1 : 1 const letter = parallelChildCounts[levelKey]?.size > 1 ? String.fromCharCode(64 + levelCounts[levelKey]) : '' const branchLetter = String.fromCharCode(64 + branchNum) - return `BRANCH-${levelNumber}${letter}-${branchLetter}` + return `${t('workflow.common.branch')}-${levelNumber}${letter}-${branchLetter}` } // Count parallel children (for figuring out if we need to use letters) @@ -163,7 +164,8 @@ const TracingPanel: FC = ({ hideNodeInfo = false, hideNodeProcessDetail = false, }) => { - const treeNodes = buildLogTree(list) + const { t } = useTranslation() + const treeNodes = buildLogTree(list, t) const [collapsedNodes, setCollapsedNodes] = useState>(new Set()) const [hoveredParallel, setHoveredParallel] = useState(null) diff --git a/web/app/layout.tsx b/web/app/layout.tsx index 008fdefa0b..e9242edfad 100644 --- a/web/app/layout.tsx +++ b/web/app/layout.tsx @@ -43,6 +43,7 @@ const LocaleLayout = ({ data-public-sentry-dsn={process.env.NEXT_PUBLIC_SENTRY_DSN} data-public-maintenance-notice={process.env.NEXT_PUBLIC_MAINTENANCE_NOTICE} data-public-site-about={process.env.NEXT_PUBLIC_SITE_ABOUT} + data-public-text-generation-timeout-ms={process.env.NEXT_PUBLIC_TEXT_GENERATION_TIMEOUT_MS} > diff --git a/web/config/index.ts b/web/config/index.ts index 71edf8f939..21fc2f211c 100644 --- a/web/config/index.ts +++ b/web/config/index.ts @@ -246,6 +246,13 @@ Thought: {{agent_scratchpad}} export const VAR_REGEX = /\{\{(#[a-zA-Z0-9_-]{1,50}(\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}/gi -export const TEXT_GENERATION_TIMEOUT_MS = 60000 +export let textGenerationTimeoutMs = 60000 + +if (process.env.NEXT_PUBLIC_TEXT_GENERATION_TIMEOUT_MS && process.env.NEXT_PUBLIC_TEXT_GENERATION_TIMEOUT_MS !== '') + textGenerationTimeoutMs = parseInt(process.env.NEXT_PUBLIC_TEXT_GENERATION_TIMEOUT_MS) +else if (globalThis.document?.body?.getAttribute('data-public-text-generation-timeout-ms') && globalThis.document.body.getAttribute('data-public-text-generation-timeout-ms') !== '') + textGenerationTimeoutMs = parseInt(globalThis.document.body.getAttribute('data-public-text-generation-timeout-ms') as string) + +export const TEXT_GENERATION_TIMEOUT_MS = textGenerationTimeoutMs export const DISABLE_UPLOAD_IMAGE_AS_ICON = process.env.NEXT_PUBLIC_DISABLE_UPLOAD_IMAGE_AS_ICON === 'true' diff --git a/web/docker/entrypoint.sh b/web/docker/entrypoint.sh index a19c543d68..fc4a8f45bc 100755 --- a/web/docker/entrypoint.sh +++ b/web/docker/entrypoint.sh @@ -21,4 +21,6 @@ export NEXT_PUBLIC_SENTRY_DSN=${SENTRY_DSN} export NEXT_PUBLIC_SITE_ABOUT=${SITE_ABOUT} export NEXT_TELEMETRY_DISABLED=${NEXT_TELEMETRY_DISABLED} +export NEXT_PUBLIC_TEXT_GENERATION_TIMEOUT_MS=${TEXT_GENERATION_TIMEOUT_MS} + pm2 start ./pm2.json --no-daemon diff --git a/web/i18n/de-DE/workflow.ts b/web/i18n/de-DE/workflow.ts index 5e154aeca5..c01d0e6f99 100644 --- a/web/i18n/de-DE/workflow.ts +++ b/web/i18n/de-DE/workflow.ts @@ -93,6 +93,8 @@ const translation = { disconnect: 'Trennen', jumpToNode: 'Zu diesem Knoten springen', addParallelNode: 'Parallelen Knoten hinzufügen', + parallel: 'PARALLEL', + branch: 'ZWEIG', }, env: { envPanelTitle: 'Umgebungsvariablen', diff --git a/web/i18n/en-US/app-debug.ts b/web/i18n/en-US/app-debug.ts index b1f3f33cd8..3156c9f6cc 100644 --- a/web/i18n/en-US/app-debug.ts +++ b/web/i18n/en-US/app-debug.ts @@ -268,6 +268,9 @@ const translation = { notSelectModel: 'Please choose a model', waitForImgUpload: 'Please wait for the image to upload', }, + warningMessage: { + timeoutExceeded: 'Results are not displayed due to timeout. Please refer to the logs to gather complete results.', + }, chatSubTitle: 'Instructions', completionSubTitle: 'Prefix Prompt', promptTip: diff --git a/web/i18n/en-US/workflow.ts b/web/i18n/en-US/workflow.ts index b83d213cb8..a7e768911f 100644 --- a/web/i18n/en-US/workflow.ts +++ b/web/i18n/en-US/workflow.ts @@ -93,6 +93,8 @@ const translation = { disconnect: 'Disconnect', jumpToNode: 'Jump to this node', addParallelNode: 'Add Parallel Node', + parallel: 'PARALLEL', + branch: 'BRANCH', }, env: { envPanelTitle: 'Environment Variables', diff --git a/web/i18n/es-ES/workflow.ts b/web/i18n/es-ES/workflow.ts index 0efc996f91..2260631d0f 100644 --- a/web/i18n/es-ES/workflow.ts +++ b/web/i18n/es-ES/workflow.ts @@ -93,6 +93,8 @@ const translation = { disconnect: 'Desconectar', jumpToNode: 'Saltar a este nodo', addParallelNode: 'Agregar nodo paralelo', + parallel: 'PARALELO', + branch: 'RAMA', }, env: { envPanelTitle: 'Variables de Entorno', diff --git a/web/i18n/fa-IR/workflow.ts b/web/i18n/fa-IR/workflow.ts index 67020f5025..eb36dfdc88 100644 --- a/web/i18n/fa-IR/workflow.ts +++ b/web/i18n/fa-IR/workflow.ts @@ -93,6 +93,8 @@ const translation = { jumpToNode: 'پرش به این گره', parallelRun: 'اجرای موازی', addParallelNode: 'افزودن گره موازی', + parallel: 'موازی', + branch: 'شاخه', }, env: { envPanelTitle: 'متغیرهای محیطی', diff --git a/web/i18n/fr-FR/workflow.ts b/web/i18n/fr-FR/workflow.ts index 3e56246e0c..878d25804e 100644 --- a/web/i18n/fr-FR/workflow.ts +++ b/web/i18n/fr-FR/workflow.ts @@ -93,6 +93,8 @@ const translation = { disconnect: 'Déconnecter', jumpToNode: 'Aller à ce nœud', addParallelNode: 'Ajouter un nœud parallèle', + parallel: 'PARALLÈLE', + branch: 'BRANCHE', }, env: { envPanelTitle: 'Variables d\'Environnement', diff --git a/web/i18n/hi-IN/workflow.ts b/web/i18n/hi-IN/workflow.ts index 072e4874e3..ac356c2067 100644 --- a/web/i18n/hi-IN/workflow.ts +++ b/web/i18n/hi-IN/workflow.ts @@ -96,6 +96,8 @@ const translation = { parallelRun: 'समानांतर रन', jumpToNode: 'इस नोड पर जाएं', addParallelNode: 'समानांतर नोड जोड़ें', + parallel: 'समानांतर', + branch: 'शाखा', }, env: { envPanelTitle: 'पर्यावरण चर', diff --git a/web/i18n/it-IT/workflow.ts b/web/i18n/it-IT/workflow.ts index f5d6fc8bf5..0427a45cd9 100644 --- a/web/i18n/it-IT/workflow.ts +++ b/web/i18n/it-IT/workflow.ts @@ -97,6 +97,8 @@ const translation = { disconnect: 'Disconnettere', jumpToNode: 'Vai a questo nodo', addParallelNode: 'Aggiungi nodo parallelo', + parallel: 'PARALLELO', + branch: 'RAMO', }, env: { envPanelTitle: 'Variabili d\'Ambiente', diff --git a/web/i18n/ja-JP/workflow.ts b/web/i18n/ja-JP/workflow.ts index 755061e8f6..48c2019601 100644 --- a/web/i18n/ja-JP/workflow.ts +++ b/web/i18n/ja-JP/workflow.ts @@ -93,6 +93,8 @@ const translation = { disconnect: '切る', jumpToNode: 'このノードにジャンプします', addParallelNode: '並列ノードを追加', + parallel: '並列', + branch: 'ブランチ', }, env: { envPanelTitle: '環境変数', diff --git a/web/i18n/ko-KR/workflow.ts b/web/i18n/ko-KR/workflow.ts index 8fed0e0417..4a97943790 100644 --- a/web/i18n/ko-KR/workflow.ts +++ b/web/i18n/ko-KR/workflow.ts @@ -93,6 +93,8 @@ const translation = { disconnect: '분리하다', jumpToNode: '이 노드로 이동', addParallelNode: '병렬 노드 추가', + parallel: '병렬', + branch: '브랜치', }, env: { envPanelTitle: '환경 변수', diff --git a/web/i18n/pl-PL/workflow.ts b/web/i18n/pl-PL/workflow.ts index de05ee7169..41927668f7 100644 --- a/web/i18n/pl-PL/workflow.ts +++ b/web/i18n/pl-PL/workflow.ts @@ -93,6 +93,8 @@ const translation = { jumpToNode: 'Przejdź do tego węzła', disconnect: 'Odłączyć', addParallelNode: 'Dodaj węzeł równoległy', + parallel: 'RÓWNOLEGŁY', + branch: 'GAŁĄŹ', }, env: { envPanelTitle: 'Zmienne Środowiskowe', diff --git a/web/i18n/pt-BR/workflow.ts b/web/i18n/pt-BR/workflow.ts index ef589c0bde..222fc788bf 100644 --- a/web/i18n/pt-BR/workflow.ts +++ b/web/i18n/pt-BR/workflow.ts @@ -93,6 +93,8 @@ const translation = { disconnect: 'Desligar', jumpToNode: 'Ir para este nó', addParallelNode: 'Adicionar nó paralelo', + parallel: 'PARALELO', + branch: 'RAMIFICAÇÃO', }, env: { envPanelTitle: 'Variáveis de Ambiente', diff --git a/web/i18n/ro-RO/workflow.ts b/web/i18n/ro-RO/workflow.ts index 689ebdead9..ac4b718b07 100644 --- a/web/i18n/ro-RO/workflow.ts +++ b/web/i18n/ro-RO/workflow.ts @@ -93,6 +93,8 @@ const translation = { disconnect: 'Deconecta', jumpToNode: 'Sari la acest nod', addParallelNode: 'Adăugare nod paralel', + parallel: 'PARALEL', + branch: 'RAMURĂ', }, env: { envPanelTitle: 'Variabile de Mediu', diff --git a/web/i18n/ru-RU/workflow.ts b/web/i18n/ru-RU/workflow.ts index 9d3ce1235c..1931863895 100644 --- a/web/i18n/ru-RU/workflow.ts +++ b/web/i18n/ru-RU/workflow.ts @@ -93,6 +93,8 @@ const translation = { disconnect: 'Разъединять', jumpToNode: 'Перейти к этому узлу', addParallelNode: 'Добавить параллельный узел', + parallel: 'ПАРАЛЛЕЛЬНЫЙ', + branch: 'ВЕТКА', }, env: { envPanelTitle: 'Переменные среды', diff --git a/web/i18n/tr-TR/workflow.ts b/web/i18n/tr-TR/workflow.ts index 96313d6d6b..a33a3724ad 100644 --- a/web/i18n/tr-TR/workflow.ts +++ b/web/i18n/tr-TR/workflow.ts @@ -93,6 +93,8 @@ const translation = { addParallelNode: 'Paralel Düğüm Ekle', disconnect: 'Ayırmak', parallelRun: 'Paralel Koşu', + parallel: 'PARALEL', + branch: 'DAL', }, env: { envPanelTitle: 'Çevre Değişkenleri', diff --git a/web/i18n/uk-UA/workflow.ts b/web/i18n/uk-UA/workflow.ts index 03471348c8..e1bea99bcd 100644 --- a/web/i18n/uk-UA/workflow.ts +++ b/web/i18n/uk-UA/workflow.ts @@ -93,6 +93,8 @@ const translation = { parallelRun: 'Паралельний біг', jumpToNode: 'Перейти до цього вузла', addParallelNode: 'Додати паралельний вузол', + parallel: 'ПАРАЛЕЛЬНИЙ', + branch: 'ГІЛКА', }, env: { envPanelTitle: 'Змінні середовища', diff --git a/web/i18n/vi-VN/workflow.ts b/web/i18n/vi-VN/workflow.ts index 5be19ab7fd..5e6f0e0063 100644 --- a/web/i18n/vi-VN/workflow.ts +++ b/web/i18n/vi-VN/workflow.ts @@ -93,6 +93,8 @@ const translation = { disconnect: 'Ngắt kết nối', jumpToNode: 'Chuyển đến nút này', addParallelNode: 'Thêm nút song song', + parallel: 'SONG SONG', + branch: 'NHÁNH', }, env: { envPanelTitle: 'Biến Môi Trường', diff --git a/web/i18n/zh-Hans/workflow.ts b/web/i18n/zh-Hans/workflow.ts index 39311649f4..3579ec5df3 100644 --- a/web/i18n/zh-Hans/workflow.ts +++ b/web/i18n/zh-Hans/workflow.ts @@ -93,6 +93,8 @@ const translation = { disconnect: '断开连接', jumpToNode: '跳转到节点', addParallelNode: '添加并行节点', + parallel: '并行', + branch: '分支', }, env: { envPanelTitle: '环境变量', diff --git a/web/i18n/zh-Hant/workflow.ts b/web/i18n/zh-Hant/workflow.ts index eef3ffaebd..8e1b7529fe 100644 --- a/web/i18n/zh-Hant/workflow.ts +++ b/web/i18n/zh-Hant/workflow.ts @@ -93,6 +93,8 @@ const translation = { disconnect: '斷開', jumpToNode: '跳轉到此節點', addParallelNode: '添加並行節點', + parallel: '並行', + branch: '分支', }, env: { envPanelTitle: '環境變數', diff --git a/web/package.json b/web/package.json index 197a1e7e05..bc532fb242 100644 --- a/web/package.json +++ b/web/package.json @@ -44,6 +44,7 @@ "classnames": "^2.3.2", "copy-to-clipboard": "^3.3.3", "crypto-js": "^4.2.0", + "@svgdotjs/svg.js": "^3.2.4", "dayjs": "^1.11.7", "echarts": "^5.4.1", "echarts-for-react": "^3.0.2", diff --git a/web/yarn.lock b/web/yarn.lock index 3c020c9664..bec2059a47 100644 --- a/web/yarn.lock +++ b/web/yarn.lock @@ -1489,6 +1489,11 @@ dependencies: "@sinonjs/commons" "^3.0.0" +"@svgdotjs/svg.js@^3.2.4": + version "3.2.4" + resolved "https://registry.yarnpkg.com/@svgdotjs/svg.js/-/svg.js-3.2.4.tgz#4716be92a64c66b29921b63f7235fcfb953fb13a" + integrity sha512-BjJ/7vWNowlX3Z8O4ywT58DqbNRyYlkk6Yz/D13aB7hGmfQTvGX4Tkgtm/ApYlu9M7lCQi15xUEidqMUmdMYwg== + "@swc/counter@^0.1.3": version "0.1.3" resolved "https://registry.npmjs.org/@swc/counter/-/counter-0.1.3.tgz"