From bd6175157cb79bc731b876f3aba3a5073b9ec24f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?= Date: Thu, 31 Oct 2024 10:00:22 +0800 Subject: [PATCH] feat: enhance comfyui workflow (#10085) --- .../builtin/comfyui/tools/comfyui_client.py | 31 +++++++------- .../builtin/comfyui/tools/comfyui_workflow.py | 41 ++++++++++++++++--- .../comfyui/tools/comfyui_workflow.yaml | 20 +++++++-- 3 files changed, 68 insertions(+), 24 deletions(-) diff --git a/api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py b/api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py index d4bf713441..1aae7b2442 100644 --- a/api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py +++ b/api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py @@ -1,5 +1,3 @@ -import base64 -import io import json import random import uuid @@ -8,7 +6,7 @@ import httpx from websocket import WebSocket from yarl import URL -from core.file.file_manager import _get_encoded_string +from core.file.file_manager import download from core.file.models import File @@ -29,8 +27,7 @@ class ComfyUiClient: return response.content def upload_image(self, image_file: File) -> dict: - image_content = base64.b64decode(_get_encoded_string(image_file)) - file = io.BytesIO(image_content) + file = download(image_file) files = {"image": (image_file.filename, file, image_file.mime_type), "overwrite": "true"} res = httpx.post(str(self.base_url / "upload/image"), files=files) return res.json() @@ -47,12 +44,7 @@ class ComfyUiClient: ws.connect(ws_address) return ws, client_id - def set_prompt( - self, origin_prompt: dict, positive_prompt: str, negative_prompt: str = "", image_name: str = "" - ) -> dict: - """ - find the first KSampler, then can find the prompt node through it. - """ + def set_prompt_by_ksampler(self, origin_prompt: dict, positive_prompt: str, negative_prompt: str = "") -> dict: prompt = origin_prompt.copy() id_to_class_type = {id: details["class_type"] for id, details in prompt.items()} k_sampler = [key for key, value in id_to_class_type.items() if value == "KSampler"][0] @@ -64,9 +56,20 @@ class ComfyUiClient: negative_input_id = prompt.get(k_sampler)["inputs"]["negative"][0] prompt.get(negative_input_id)["inputs"]["text"] = negative_prompt - if image_name != "": - image_loader = [key for key, value in id_to_class_type.items() if value == "LoadImage"][0] - prompt.get(image_loader)["inputs"]["image"] = image_name + return prompt + + def set_prompt_images_by_ids(self, origin_prompt: dict, image_names: list[str], image_ids: list[str]) -> dict: + prompt = origin_prompt.copy() + for index, image_node_id in enumerate(image_ids): + prompt[image_node_id]["inputs"]["image"] = image_names[index] + return prompt + + def set_prompt_images_by_default(self, origin_prompt: dict, image_names: list[str]) -> dict: + prompt = origin_prompt.copy() + id_to_class_type = {id: details["class_type"] for id, details in prompt.items()} + load_image_nodes = [key for key, value in id_to_class_type.items() if value == "LoadImage"] + for load_image, image_name in zip(load_image_nodes, image_names): + prompt.get(load_image)["inputs"]["image"] = image_name return prompt def track_progress(self, prompt: dict, ws: WebSocket, prompt_id: str): diff --git a/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.py b/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.py index 11320d5d0f..79fe08a86b 100644 --- a/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.py +++ b/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.py @@ -1,7 +1,9 @@ import json from typing import Any +from core.file import FileType from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.errors import ToolParameterValidationError from core.tools.provider.builtin.comfyui.tools.comfyui_client import ComfyUiClient from core.tools.tool.builtin_tool import BuiltinTool @@ -10,19 +12,46 @@ class ComfyUIWorkflowTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: comfyui = ComfyUiClient(self.runtime.credentials["base_url"]) - positive_prompt = tool_parameters.get("positive_prompt") - negative_prompt = tool_parameters.get("negative_prompt") + positive_prompt = tool_parameters.get("positive_prompt", "") + negative_prompt = tool_parameters.get("negative_prompt", "") + images = tool_parameters.get("images") or [] workflow = tool_parameters.get("workflow_json") - image_name = "" - if image := tool_parameters.get("image"): + image_names = [] + for image in images: + if image.type != FileType.IMAGE: + continue image_name = comfyui.upload_image(image).get("name") + image_names.append(image_name) + + set_prompt_with_ksampler = True + if "{{positive_prompt}}" in workflow: + set_prompt_with_ksampler = False + workflow = workflow.replace("{{positive_prompt}}", positive_prompt) + workflow = workflow.replace("{{negative_prompt}}", negative_prompt) try: - origin_prompt = json.loads(workflow) + prompt = json.loads(workflow) except: return self.create_text_message("the Workflow JSON is not correct") - prompt = comfyui.set_prompt(origin_prompt, positive_prompt, negative_prompt, image_name) + if set_prompt_with_ksampler: + try: + prompt = comfyui.set_prompt_by_ksampler(prompt, positive_prompt, negative_prompt) + except: + raise ToolParameterValidationError( + "Failed set prompt with KSampler, try replace prompt to {{positive_prompt}} in your workflow json" + ) + + if image_names: + if image_ids := tool_parameters.get("image_ids"): + image_ids = image_ids.split(",") + try: + prompt = comfyui.set_prompt_images_by_ids(prompt, image_names, image_ids) + except: + raise ToolParameterValidationError("the Image Node ID List not match your upload image files.") + else: + prompt = comfyui.set_prompt_images_by_default(prompt, image_names) + images = comfyui.generate_image_by_prompt(prompt) result = [] for img in images: diff --git a/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.yaml b/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.yaml index 55fcdad825..dc4e0d77b2 100644 --- a/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.yaml +++ b/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.yaml @@ -24,12 +24,12 @@ parameters: zh_Hans: 负面提示词 llm_description: Negative prompt, you should describe the image you don't want to generate as a list of words as possible as detailed, the prompt must be written in English. form: llm - - name: image - type: file + - name: images + type: files label: - en_US: Input Image + en_US: Input Images zh_Hans: 输入的图片 - llm_description: The input image, used to transfer to the comfyui workflow to generate another image. + llm_description: The input images, used to transfer to the comfyui workflow to generate another image. form: llm - name: workflow_json type: string @@ -40,3 +40,15 @@ parameters: en_US: exported from ComfyUI workflow zh_Hans: 从ComfyUI的工作流中导出 form: form + - name: image_ids + type: string + label: + en_US: Image Node ID List + zh_Hans: 图片节点ID列表 + placeholder: + en_US: Use commas to separate multiple node ID + zh_Hans: 多个节点ID时使用半角逗号分隔 + human_description: + en_US: When the workflow has multiple image nodes, enter the ID list of these nodes, and the images will be passed to ComfyUI in the order of the list. + zh_Hans: 当工作流有多个图片节点时,输入这些节点的ID列表,图片将按列表顺序传给ComfyUI + form: form