feat: Add draft hash check in workflow (#4251)

This commit is contained in:
takatost 2024-05-10 14:48:29 +08:00 committed by GitHub
parent a1ab87107b
commit 8f3042e5b3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 53 additions and 8 deletions

View File

@ -91,3 +91,9 @@ class DraftWorkflowNotExist(BaseHTTPException):
error_code = 'draft_workflow_not_exist' error_code = 'draft_workflow_not_exist'
description = "Draft workflow need to be initialized." description = "Draft workflow need to be initialized."
code = 400 code = 400
class DraftWorkflowNotSync(BaseHTTPException):
error_code = 'draft_workflow_not_sync'
description = "Workflow graph might have been modified, please refresh and resubmit."
code = 400

View File

@ -7,7 +7,7 @@ from werkzeug.exceptions import InternalServerError, NotFound
import services import services
from controllers.console import api from controllers.console import api
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
@ -20,6 +20,7 @@ from libs.helper import TimestampField, uuid_value
from libs.login import current_user, login_required from libs.login import current_user, login_required
from models.model import App, AppMode from models.model import App, AppMode
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
from services.errors.app import WorkflowHashNotEqualError
from services.workflow_service import WorkflowService from services.workflow_service import WorkflowService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -59,6 +60,7 @@ class DraftWorkflowApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('graph', type=dict, required=True, nullable=False, location='json') parser.add_argument('graph', type=dict, required=True, nullable=False, location='json')
parser.add_argument('features', type=dict, required=True, nullable=False, location='json') parser.add_argument('features', type=dict, required=True, nullable=False, location='json')
parser.add_argument('hash', type=str, required=False, location='json')
args = parser.parse_args() args = parser.parse_args()
elif 'text/plain' in content_type: elif 'text/plain' in content_type:
try: try:
@ -71,7 +73,8 @@ class DraftWorkflowApi(Resource):
args = { args = {
'graph': data.get('graph'), 'graph': data.get('graph'),
'features': data.get('features') 'features': data.get('features'),
'hash': data.get('hash')
} }
except json.JSONDecodeError: except json.JSONDecodeError:
return {'message': 'Invalid JSON data'}, 400 return {'message': 'Invalid JSON data'}, 400
@ -79,15 +82,21 @@ class DraftWorkflowApi(Resource):
abort(415) abort(415)
workflow_service = WorkflowService() workflow_service = WorkflowService()
try:
workflow = workflow_service.sync_draft_workflow( workflow = workflow_service.sync_draft_workflow(
app_model=app_model, app_model=app_model,
graph=args.get('graph'), graph=args.get('graph'),
features=args.get('features'), features=args.get('features'),
unique_hash=args.get('hash'),
account=current_user account=current_user
) )
except WorkflowHashNotEqualError:
raise DraftWorkflowNotSync()
return { return {
"result": "success", "result": "success",
"hash": workflow.unique_hash,
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at) "updated_at": TimestampField().format(workflow.updated_at or workflow.created_at)
} }

View File

@ -7,6 +7,7 @@ workflow_fields = {
'id': fields.String, 'id': fields.String,
'graph': fields.Raw(attribute='graph_dict'), 'graph': fields.Raw(attribute='graph_dict'),
'features': fields.Raw(attribute='features_dict'), 'features': fields.Raw(attribute='features_dict'),
'hash': fields.String(attribute='unique_hash'),
'created_by': fields.Nested(simple_account_fields, attribute='created_by_account'), 'created_by': fields.Nested(simple_account_fields, attribute='created_by_account'),
'created_at': TimestampField, 'created_at': TimestampField,
'updated_by': fields.Nested(simple_account_fields, attribute='updated_by_account', allow_null=True), 'updated_by': fields.Nested(simple_account_fields, attribute='updated_by_account', allow_null=True),

View File

@ -4,6 +4,7 @@ from typing import Optional, Union
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from extensions.ext_database import db from extensions.ext_database import db
from libs import helper
from models import StringUUID from models import StringUUID
from models.account import Account from models.account import Account
@ -156,6 +157,21 @@ class Workflow(db.Model):
return variables return variables
@property
def unique_hash(self) -> str:
"""
Get hash of workflow.
:return: hash
"""
entity = {
'graph': self.graph_dict,
'features': self.features_dict
}
return helper.generate_text_hash(json.dumps(entity, sort_keys=True))
class WorkflowRunTriggeredFrom(Enum): class WorkflowRunTriggeredFrom(Enum):
""" """
Workflow Run Triggered From Enum Workflow Run Triggered From Enum

View File

@ -196,6 +196,7 @@ class AppService:
app_model=app, app_model=app,
graph=workflow.get('graph'), graph=workflow.get('graph'),
features=workflow.get('features'), features=workflow.get('features'),
unique_hash=None,
account=account account=account
) )
workflow_service.publish_workflow( workflow_service.publish_workflow(

View File

@ -1,2 +1,6 @@
class MoreLikeThisDisabledError(Exception): class MoreLikeThisDisabledError(Exception):
pass pass
class WorkflowHashNotEqualError(Exception):
pass

View File

@ -21,6 +21,7 @@ from models.workflow import (
WorkflowNodeExecutionTriggeredFrom, WorkflowNodeExecutionTriggeredFrom,
WorkflowType, WorkflowType,
) )
from services.errors.app import WorkflowHashNotEqualError
from services.workflow.workflow_converter import WorkflowConverter from services.workflow.workflow_converter import WorkflowConverter
@ -63,13 +64,20 @@ class WorkflowService:
def sync_draft_workflow(self, app_model: App, def sync_draft_workflow(self, app_model: App,
graph: dict, graph: dict,
features: dict, features: dict,
unique_hash: Optional[str],
account: Account) -> Workflow: account: Account) -> Workflow:
""" """
Sync draft workflow Sync draft workflow
@throws WorkflowHashNotEqualError
""" """
# fetch draft workflow by app_model # fetch draft workflow by app_model
workflow = self.get_draft_workflow(app_model=app_model) workflow = self.get_draft_workflow(app_model=app_model)
if workflow:
# validate unique hash
if workflow.unique_hash != unique_hash:
raise WorkflowHashNotEqualError()
# validate features structure # validate features structure
self.validate_features_structure( self.validate_features_structure(
app_model=app_model, app_model=app_model,