From b3fdd618a1271014357aa0f0f1ae42413c7af6b9 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Tue, 15 Oct 2024 11:04:28 +0800 Subject: [PATCH] refactor(core): simplify role handling and improve usability - Replaced explicit string usage with `CreatedByRole` enum for better maintainability. - Removed duplicate `CreatedByRole` class definition, improving codebase consistency. - Increased file number limits from 6 to 10 to allow more file uploads. - Transitioned `AppMode` to a string enum for consistent type usage. - Refactored `extract_thread_messages` function argument for flexibility. - Removed file extension limitation in file service to support custom extensions. - Improved enum import statements across multiple modules for clarity and consistency. --- .../advanced_chat/generate_task_pipeline.py | 5 +-- .../task_pipeline/workflow_cycle_manage.py | 3 +- api/core/memory/token_buffer_memory.py | 2 +- .../prompt/utils/extract_thread_messages.py | 4 ++- api/models/model.py | 4 +-- api/models/workflow.py | 35 ++++--------------- api/services/file_service.py | 13 ------- api/services/workflow_service.py | 3 +- 8 files changed, 18 insertions(+), 51 deletions(-) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index a71e6153dd..71bab4e7da 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -51,6 +51,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.ops_trace_manager import TraceQueueManager from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from enums import CreatedByRole from enums.workflow_nodes import NodeType from events.message_event import message_was_created from extensions.ext_database import db @@ -512,9 +513,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc url=file["remote_url"], belongs_to="assistant", upload_file_id=file["related_id"], - created_by_role="account" + created_by_role=CreatedByRole.ACCOUNT if self._message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} - else "end_user", + else CreatedByRole.END_USER, created_by=self._message.from_account_id or self._message.from_end_user_id or "", ) for file in self._recorded_files diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 1a50f291c2..a70b1bbed3 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -36,12 +36,11 @@ from core.tools.tool_manager import ToolManager from core.workflow.enums import SystemVariableKey from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.workflow_entry import WorkflowEntry -from enums import NodeType, WorkflowRunTriggeredFrom +from enums import CreatedByRole, NodeType, WorkflowRunTriggeredFrom from extensions.ext_database import db from models.account import Account from models.model import EndUser from models.workflow import ( - CreatedByRole, Workflow, WorkflowNodeExecution, WorkflowNodeExecutionStatus, diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 297a17c788..98704a9f43 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -72,7 +72,7 @@ class TokenBufferMemory: files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() if files: file_extra_config = None - if self.conversation.mode not in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: + if self.conversation.mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config) else: if message.workflow_run_id: diff --git a/api/core/prompt/utils/extract_thread_messages.py b/api/core/prompt/utils/extract_thread_messages.py index e8b626499f..f7aef76c87 100644 --- a/api/core/prompt/utils/extract_thread_messages.py +++ b/api/core/prompt/utils/extract_thread_messages.py @@ -1,7 +1,9 @@ +from typing import Any + from constants import UUID_NIL -def extract_thread_messages(messages: list[dict]) -> list[dict]: +def extract_thread_messages(messages: list[Any]): thread_messages = [] next_message = None diff --git a/api/models/model.py b/api/models/model.py index 2e20542e38..4557a3cf58 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -29,7 +29,7 @@ class FileUploadConfig(BaseModel): allowed_file_types: Sequence[FileType] = Field(default_factory=list) allowed_extensions: Sequence[str] = Field(default_factory=list) allowed_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list) - number_limits: int = Field(default=0, gt=0, le=6) + number_limits: int = Field(default=0, gt=0, le=10) class DifySetup(db.Model): @@ -40,7 +40,7 @@ class DifySetup(db.Model): setup_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) -class AppMode(Enum): +class AppMode(str, Enum): COMPLETION = "completion" WORKFLOW = "workflow" CHAT = "chat" diff --git a/api/models/workflow.py b/api/models/workflow.py index c20df075ce..a858a034e3 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -11,6 +11,7 @@ import contexts from constants import HIDDEN_VALUE from core.helper import encrypter from core.variables import SecretVariable, Variable +from enums import CreatedByRole from extensions.ext_database import db from factories import variable_factory from libs import helper @@ -19,28 +20,6 @@ from .account import Account from .types import StringUUID -class CreatedByRole(Enum): - """ - Created By Role Enum - """ - - ACCOUNT = "account" - END_USER = "end_user" - - @classmethod - def value_of(cls, value: str) -> "CreatedByRole": - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f"invalid created by role value {value}") - - class WorkflowType(Enum): """ Workflow Type Enum @@ -424,14 +403,14 @@ class WorkflowRun(db.Model): @property def created_by_account(self): - created_by_role = CreatedByRole.value_of(self.created_by_role) + created_by_role = CreatedByRole(self.created_by_role) return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None @property def created_by_end_user(self): from models.model import EndUser - created_by_role = CreatedByRole.value_of(self.created_by_role) + created_by_role = CreatedByRole(self.created_by_role) return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None @property @@ -651,14 +630,14 @@ class WorkflowNodeExecution(db.Model): @property def created_by_account(self): - created_by_role = CreatedByRole.value_of(self.created_by_role) + created_by_role = CreatedByRole(self.created_by_role) return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None @property def created_by_end_user(self): from models.model import EndUser - created_by_role = CreatedByRole.value_of(self.created_by_role) + created_by_role = CreatedByRole(self.created_by_role) return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None @property @@ -770,14 +749,14 @@ class WorkflowAppLog(db.Model): @property def created_by_account(self): - created_by_role = CreatedByRole.value_of(self.created_by_role) + created_by_role = CreatedByRole(self.created_by_role) return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None @property def created_by_end_user(self): from models.model import EndUser - created_by_role = CreatedByRole.value_of(self.created_by_role) + created_by_role = CreatedByRole(self.created_by_role) return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None diff --git a/api/services/file_service.py b/api/services/file_service.py index 8772c60aae..0b35561600 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -36,19 +36,6 @@ class FileService: extension = filename.split(".")[-1] if len(filename) > 200: filename = filename.split(".")[0][:200] + "." + extension - - # Cancel the limitation of file extension cause we need to support custom extensions in multi-modal feature. - # - # allowed_extensions = ( - # UNSTRUCTURED_ALLOWED_EXTENSIONS if dify_config.ETL_TYPE == "Unstructured" else ALLOWED_EXTENSIONS - # ) - # allowed_extensions = ( - # allowed_extensions + IMAGE_EXTENSIONS + VIDEO_EXTENSIONS + AUDIO_EXTENSIONS + DOCUMENT_EXTENSIONS - # ) - - # if extension not in allowed_extensions: - # raise UnsupportedFileTypeError() - # read file content file_content = file.read() diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 1d0e939016..374bcc5bcf 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -13,13 +13,12 @@ from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.nodes.event import RunCompletedEvent from core.workflow.nodes.node_mapping import node_classes from core.workflow.workflow_entry import WorkflowEntry -from enums import NodeType +from enums import CreatedByRole, NodeType from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from extensions.ext_database import db from models.account import Account from models.model import App, AppMode from models.workflow import ( - CreatedByRole, Workflow, WorkflowNodeExecution, WorkflowNodeExecutionStatus,