From 063e006446082c899488f8f5ecc68aa29b3c278a Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Wed, 21 Aug 2024 14:44:47 +0200 Subject: [PATCH] feat: custom comfyui workflow Co-Authored-By: John Karabudak --- backend/apps/images/main.py | 41 ++++++++++++++-- .../components/admin/Settings/Images.svelte | 48 +++++++++++++++---- 2 files changed, 75 insertions(+), 14 deletions(-) diff --git a/backend/apps/images/main.py b/backend/apps/images/main.py index 35b84d75e..8bd88308b 100644 --- a/backend/apps/images/main.py +++ b/backend/apps/images/main.py @@ -268,12 +268,43 @@ def get_models(user=Depends(get_verified_user)): r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info") info = r.json() - return list( - map( - lambda model: {"id": model, "name": model}, - info["CheckpointLoaderSimple"]["input"]["required"]["ckpt_name"][0], + workflow = json.loads(app.state.config.COMFYUI_WORKFLOW) + model_node_id = None + + for node in app.state.config.COMFYUI_WORKFLOW_NODES: + if node["type"] == "model": + model_node_id = node["node_ids"][0] + break + + if model_node_id: + model_list_key = None + + print(workflow[model_node_id]["class_type"]) + for key in info[workflow[model_node_id]["class_type"]]["input"][ + "required" + ]: + if "_name" in key: + model_list_key = key + break + + if model_list_key: + return list( + map( + lambda model: {"id": model, "name": model}, + info[workflow[model_node_id]["class_type"]]["input"][ + "required" + ][model_list_key][0], + ) + ) + else: + return list( + map( + lambda model: {"id": model, "name": model}, + info["CheckpointLoaderSimple"]["input"]["required"][ + "ckpt_name" + ][0], + ) ) - ) elif ( app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == "" ): diff --git a/src/lib/components/admin/Settings/Images.svelte b/src/lib/components/admin/Settings/Images.svelte index 3808755fe..1a792f764 100644 --- a/src/lib/components/admin/Settings/Images.svelte +++ b/src/lib/components/admin/Settings/Images.svelte @@ -30,27 +30,27 @@ { type: 'prompt', key: 'text', - node_ids: [] + node_ids: '' }, { type: 'model', key: 'ckpt_name', - node_ids: [] + node_ids: '' }, { type: 'width', key: 'width', - node_ids: [] + node_ids: '' }, { type: 'height', key: 'height', - node_ids: [] + node_ids: '' }, { type: 'steps', key: 'steps', - node_ids: [] + node_ids: '' } ]; @@ -99,6 +99,16 @@ } } + if (config?.comfyui?.COMFYUI_WORKFLOW) { + config.comfyui.COMFYUI_WORKFLOW_NODES = workflowNodes.map((node) => { + return { + type: node.type, + key: node.key, + node_ids: node.node_ids.split(',').map((id) => id.trim()) + }; + }); + } + await updateConfig(localStorage.token, config).catch((error) => { toast.error(error); loading = false; @@ -111,6 +121,7 @@ return null; }); + getModels(); dispatch('save'); loading = false; }; @@ -130,6 +141,24 @@ getModels(); } + if (config.comfyui.COMFYUI_WORKFLOW) { + config.comfyui.COMFYUI_WORKFLOW = JSON.stringify( + JSON.parse(config.comfyui.COMFYUI_WORKFLOW), + null, + 2 + ); + } + + if ((config?.comfyui?.COMFYUI_WORKFLOW_NODES ?? []).length >= 5) { + workflowNodes = config.comfyui.COMFYUI_WORKFLOW_NODES.map((node) => { + return { + type: node.type, + key: node.key, + node_ids: node.node_ids.join(',') + }; + }); + } + const imageConfigRes = await getImageGenerationConfig(localStorage.token).catch((error) => { toast.error(error); return null; @@ -321,7 +350,8 @@