diff --git a/api/core/tools/provider/builtin/comfyui/comfyui.py b/api/core/tools/provider/builtin/comfyui/comfyui.py index 7013a0b93c..bab690af82 100644 --- a/api/core/tools/provider/builtin/comfyui/comfyui.py +++ b/api/core/tools/provider/builtin/comfyui/comfyui.py @@ -1,17 +1,21 @@ from typing import Any +import websocket +from yarl import URL + from core.tools.errors import ToolProviderCredentialValidationError -from core.tools.provider.builtin.comfyui.tools.comfyui_stable_diffusion import ComfyuiStableDiffusionTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController class ComfyUIProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: + ws = websocket.WebSocket() + base_url = URL(credentials.get("base_url")) + ws_address = f"ws://{base_url.authority}/ws?clientId=test123" + try: - ComfyuiStableDiffusionTool().fork_tool_runtime( - runtime={ - "credentials": credentials, - } - ).validate_models() + ws.connect(ws_address) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) + raise ToolProviderCredentialValidationError(f"can not connect to {ws_address}") + finally: + ws.close() diff --git a/api/core/tools/provider/builtin/comfyui/comfyui.yaml b/api/core/tools/provider/builtin/comfyui/comfyui.yaml index 3891eebf3a..24ae43cd44 100644 --- a/api/core/tools/provider/builtin/comfyui/comfyui.yaml +++ b/api/core/tools/provider/builtin/comfyui/comfyui.yaml @@ -4,11 +4,9 @@ identity: label: en_US: ComfyUI zh_Hans: ComfyUI - pt_BR: ComfyUI description: en_US: ComfyUI is a tool for generating images which can be deployed locally. zh_Hans: ComfyUI 是一个可以在本地部署的图片生成的工具。 - pt_BR: ComfyUI is a tool for generating images which can be deployed locally. icon: icon.png tags: - image @@ -17,26 +15,9 @@ credentials_for_provider: type: text-input required: true label: - en_US: Base URL - zh_Hans: ComfyUI服务器的Base URL - pt_BR: Base URL + en_US: The URL of ComfyUI Server + zh_Hans: ComfyUI服务器的URL placeholder: en_US: Please input your ComfyUI server's Base URL zh_Hans: 请输入你的 ComfyUI 服务器的 Base URL - pt_BR: Please input your ComfyUI server's Base URL - model: - type: text-input - required: true - label: - en_US: Model with suffix - zh_Hans: 模型, 需要带后缀 - pt_BR: Model with suffix - placeholder: - en_US: Please input your model - zh_Hans: 请输入你的模型名称 - pt_BR: Please input your model - help: - en_US: The checkpoint name of the ComfyUI server, e.g. xxx.safetensors - zh_Hans: ComfyUI服务器的模型名称, 比如 xxx.safetensors - pt_BR: The checkpoint name of the ComfyUI server, e.g. xxx.safetensors - url: https://github.com/comfyanonymous/ComfyUI#installing + url: https://docs.dify.ai/guides/tools/tool-configuration/comfyui diff --git a/api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py b/api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py new file mode 100644 index 0000000000..a41d34d40f --- /dev/null +++ b/api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py @@ -0,0 +1,105 @@ +import json +import random +import uuid + +import httpx +from websocket import WebSocket +from yarl import URL + + +class ComfyUiClient: + def __init__(self, base_url: str): + self.base_url = URL(base_url) + + def get_history(self, prompt_id: str): + res = httpx.get(str(self.base_url / "history"), params={"prompt_id": prompt_id}) + history = res.json()[prompt_id] + return history + + def get_image(self, filename: str, subfolder: str, folder_type: str): + response = httpx.get( + str(self.base_url / "view"), + params={"filename": filename, "subfolder": subfolder, "type": folder_type}, + ) + return response.content + + def upload_image(self, input_path: str, name: str, image_type: str = "input", overwrite: bool = False): + # plan to support img2img in dify 0.10.0 + with open(input_path, "rb") as file: + files = {"image": (name, file, "image/png")} + data = {"type": image_type, "overwrite": str(overwrite).lower()} + + res = httpx.post(str(self.base_url / "upload/image"), data=data, files=files) + return res + + def queue_prompt(self, client_id: str, prompt: dict): + res = httpx.post(str(self.base_url / "prompt"), json={"client_id": client_id, "prompt": prompt}) + prompt_id = res.json()["prompt_id"] + return prompt_id + + def open_websocket_connection(self): + client_id = str(uuid.uuid4()) + ws = WebSocket() + ws_address = f"ws://{self.base_url.authority}/ws?clientId={client_id}" + ws.connect(ws_address) + return ws, client_id + + def set_prompt(self, origin_prompt: dict, positive_prompt: str, negative_prompt: str = ""): + """ + find the first KSampler, then can find the prompt node through it. + """ + prompt = origin_prompt.copy() + id_to_class_type = {id: details["class_type"] for id, details in prompt.items()} + k_sampler = [key for key, value in id_to_class_type.items() if value == "KSampler"][0] + prompt.get(k_sampler)["inputs"]["seed"] = random.randint(10**14, 10**15 - 1) + positive_input_id = prompt.get(k_sampler)["inputs"]["positive"][0] + prompt.get(positive_input_id)["inputs"]["text"] = positive_prompt + + if negative_prompt != "": + negative_input_id = prompt.get(k_sampler)["inputs"]["negative"][0] + prompt.get(negative_input_id)["inputs"]["text"] = negative_prompt + return prompt + + def track_progress(self, prompt: dict, ws: WebSocket, prompt_id: str): + node_ids = list(prompt.keys()) + finished_nodes = [] + + while True: + out = ws.recv() + if isinstance(out, str): + message = json.loads(out) + if message["type"] == "progress": + data = message["data"] + current_step = data["value"] + print("In K-Sampler -> Step: ", current_step, " of: ", data["max"]) + if message["type"] == "execution_cached": + data = message["data"] + for itm in data["nodes"]: + if itm not in finished_nodes: + finished_nodes.append(itm) + print("Progress: ", len(finished_nodes), "/", len(node_ids), " Tasks done") + if message["type"] == "executing": + data = message["data"] + if data["node"] not in finished_nodes: + finished_nodes.append(data["node"]) + print("Progress: ", len(finished_nodes), "/", len(node_ids), " Tasks done") + + if data["node"] is None and data["prompt_id"] == prompt_id: + break # Execution is done + else: + continue + + def generate_image_by_prompt(self, prompt: dict): + try: + ws, client_id = self.open_websocket_connection() + prompt_id = self.queue_prompt(client_id, prompt) + self.track_progress(prompt, ws, prompt_id) + history = self.get_history(prompt_id) + images = [] + for output in history["outputs"].values(): + for img in output.get("images", []): + image_data = self.get_image(img["filename"], img["subfolder"], img["type"]) + images.append(image_data) + return images + finally: + ws.close() diff --git a/api/core/tools/provider/builtin/comfyui/tools/comfyui_stable_diffusion.yaml b/api/core/tools/provider/builtin/comfyui/tools/comfyui_stable_diffusion.yaml index 4f4a6942b3..75fe746965 100644 --- a/api/core/tools/provider/builtin/comfyui/tools/comfyui_stable_diffusion.yaml +++ b/api/core/tools/provider/builtin/comfyui/tools/comfyui_stable_diffusion.yaml @@ -1,10 +1,10 @@ identity: - name: txt2img workflow + name: txt2img author: Qun label: - en_US: Txt2Img Workflow - zh_Hans: Txt2Img Workflow - pt_BR: Txt2Img Workflow + en_US: Txt2Img + zh_Hans: Txt2Img + pt_BR: Txt2Img description: human: en_US: a pre-defined comfyui workflow that can use one model and up to 3 loras to generate images. Support SD1.5, SDXL, SD3 and FLUX which contain text encoders/clip, but does not support models that requires a triple clip loader. diff --git a/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.py b/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.py new file mode 100644 index 0000000000..e4df9f8c3b --- /dev/null +++ b/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.py @@ -0,0 +1,32 @@ +import json +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + +from .comfyui_client import ComfyUiClient + + +class ComfyUIWorkflowTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: + comfyui = ComfyUiClient(self.runtime.credentials["base_url"]) + + positive_prompt = tool_parameters.get("positive_prompt") + negative_prompt = tool_parameters.get("negative_prompt") + workflow = tool_parameters.get("workflow_json") + + try: + origin_prompt = json.loads(workflow) + except: + return self.create_text_message("the Workflow JSON is not correct") + + prompt = comfyui.set_prompt(origin_prompt, positive_prompt, negative_prompt) + images = comfyui.generate_image_by_prompt(prompt) + result = [] + for img in images: + result.append( + self.create_blob_message( + blob=img, meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value + ) + ) + return result diff --git a/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.yaml b/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.yaml new file mode 100644 index 0000000000..6342d6d468 --- /dev/null +++ b/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.yaml @@ -0,0 +1,35 @@ +identity: + name: workflow + author: hjlarry + label: + en_US: workflow + zh_Hans: 工作流 +description: + human: + en_US: Run ComfyUI workflow. + zh_Hans: 运行ComfyUI工作流。 + llm: Run ComfyUI workflow. +parameters: + - name: positive_prompt + type: string + label: + en_US: Prompt + zh_Hans: 提示词 + llm_description: Image prompt, you should describe the image you want to generate as a list of words as possible as detailed, the prompt must be written in English. + form: llm + - name: negative_prompt + type: string + label: + en_US: Negative Prompt + zh_Hans: 负面提示词 + llm_description: Negative prompt, you should describe the image you don't want to generate as a list of words as possible as detailed, the prompt must be written in English. + form: llm + - name: workflow_json + type: string + required: true + label: + en_US: Workflow JSON + human_description: + en_US: exported from ComfyUI workflow + zh_Hans: 从ComfyUI的工作流中导出 + form: form