mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 20:19:12 +08:00
feat: add the workflow tool of comfyUI (#9447)
This commit is contained in:
parent
f447ee7b9d
commit
d3c06a3f76
@ -1,17 +1,21 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import websocket
|
||||||
|
from yarl import URL
|
||||||
|
|
||||||
from core.tools.errors import ToolProviderCredentialValidationError
|
from core.tools.errors import ToolProviderCredentialValidationError
|
||||||
from core.tools.provider.builtin.comfyui.tools.comfyui_stable_diffusion import ComfyuiStableDiffusionTool
|
|
||||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||||
|
|
||||||
|
|
||||||
class ComfyUIProvider(BuiltinToolProviderController):
|
class ComfyUIProvider(BuiltinToolProviderController):
|
||||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||||
|
ws = websocket.WebSocket()
|
||||||
|
base_url = URL(credentials.get("base_url"))
|
||||||
|
ws_address = f"ws://{base_url.authority}/ws?clientId=test123"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ComfyuiStableDiffusionTool().fork_tool_runtime(
|
ws.connect(ws_address)
|
||||||
runtime={
|
|
||||||
"credentials": credentials,
|
|
||||||
}
|
|
||||||
).validate_models()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ToolProviderCredentialValidationError(str(e))
|
raise ToolProviderCredentialValidationError(f"can not connect to {ws_address}")
|
||||||
|
finally:
|
||||||
|
ws.close()
|
||||||
|
@ -4,11 +4,9 @@ identity:
|
|||||||
label:
|
label:
|
||||||
en_US: ComfyUI
|
en_US: ComfyUI
|
||||||
zh_Hans: ComfyUI
|
zh_Hans: ComfyUI
|
||||||
pt_BR: ComfyUI
|
|
||||||
description:
|
description:
|
||||||
en_US: ComfyUI is a tool for generating images which can be deployed locally.
|
en_US: ComfyUI is a tool for generating images which can be deployed locally.
|
||||||
zh_Hans: ComfyUI 是一个可以在本地部署的图片生成的工具。
|
zh_Hans: ComfyUI 是一个可以在本地部署的图片生成的工具。
|
||||||
pt_BR: ComfyUI is a tool for generating images which can be deployed locally.
|
|
||||||
icon: icon.png
|
icon: icon.png
|
||||||
tags:
|
tags:
|
||||||
- image
|
- image
|
||||||
@ -17,26 +15,9 @@ credentials_for_provider:
|
|||||||
type: text-input
|
type: text-input
|
||||||
required: true
|
required: true
|
||||||
label:
|
label:
|
||||||
en_US: Base URL
|
en_US: The URL of ComfyUI Server
|
||||||
zh_Hans: ComfyUI服务器的Base URL
|
zh_Hans: ComfyUI服务器的URL
|
||||||
pt_BR: Base URL
|
|
||||||
placeholder:
|
placeholder:
|
||||||
en_US: Please input your ComfyUI server's Base URL
|
en_US: Please input your ComfyUI server's Base URL
|
||||||
zh_Hans: 请输入你的 ComfyUI 服务器的 Base URL
|
zh_Hans: 请输入你的 ComfyUI 服务器的 Base URL
|
||||||
pt_BR: Please input your ComfyUI server's Base URL
|
url: https://docs.dify.ai/guides/tools/tool-configuration/comfyui
|
||||||
model:
|
|
||||||
type: text-input
|
|
||||||
required: true
|
|
||||||
label:
|
|
||||||
en_US: Model with suffix
|
|
||||||
zh_Hans: 模型, 需要带后缀
|
|
||||||
pt_BR: Model with suffix
|
|
||||||
placeholder:
|
|
||||||
en_US: Please input your model
|
|
||||||
zh_Hans: 请输入你的模型名称
|
|
||||||
pt_BR: Please input your model
|
|
||||||
help:
|
|
||||||
en_US: The checkpoint name of the ComfyUI server, e.g. xxx.safetensors
|
|
||||||
zh_Hans: ComfyUI服务器的模型名称, 比如 xxx.safetensors
|
|
||||||
pt_BR: The checkpoint name of the ComfyUI server, e.g. xxx.safetensors
|
|
||||||
url: https://github.com/comfyanonymous/ComfyUI#installing
|
|
||||||
|
105
api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py
Normal file
105
api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
import json
|
||||||
|
import random
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from websocket import WebSocket
|
||||||
|
from yarl import URL
|
||||||
|
|
||||||
|
|
||||||
|
class ComfyUiClient:
|
||||||
|
def __init__(self, base_url: str):
|
||||||
|
self.base_url = URL(base_url)
|
||||||
|
|
||||||
|
def get_history(self, prompt_id: str):
|
||||||
|
res = httpx.get(str(self.base_url / "history"), params={"prompt_id": prompt_id})
|
||||||
|
history = res.json()[prompt_id]
|
||||||
|
return history
|
||||||
|
|
||||||
|
def get_image(self, filename: str, subfolder: str, folder_type: str):
|
||||||
|
response = httpx.get(
|
||||||
|
str(self.base_url / "view"),
|
||||||
|
params={"filename": filename, "subfolder": subfolder, "type": folder_type},
|
||||||
|
)
|
||||||
|
return response.content
|
||||||
|
|
||||||
|
def upload_image(self, input_path: str, name: str, image_type: str = "input", overwrite: bool = False):
|
||||||
|
# plan to support img2img in dify 0.10.0
|
||||||
|
with open(input_path, "rb") as file:
|
||||||
|
files = {"image": (name, file, "image/png")}
|
||||||
|
data = {"type": image_type, "overwrite": str(overwrite).lower()}
|
||||||
|
|
||||||
|
res = httpx.post(str(self.base_url / "upload/image"), data=data, files=files)
|
||||||
|
return res
|
||||||
|
|
||||||
|
def queue_prompt(self, client_id: str, prompt: dict):
|
||||||
|
res = httpx.post(str(self.base_url / "prompt"), json={"client_id": client_id, "prompt": prompt})
|
||||||
|
prompt_id = res.json()["prompt_id"]
|
||||||
|
return prompt_id
|
||||||
|
|
||||||
|
def open_websocket_connection(self):
|
||||||
|
client_id = str(uuid.uuid4())
|
||||||
|
ws = WebSocket()
|
||||||
|
ws_address = f"ws://{self.base_url.authority}/ws?clientId={client_id}"
|
||||||
|
ws.connect(ws_address)
|
||||||
|
return ws, client_id
|
||||||
|
|
||||||
|
def set_prompt(self, origin_prompt: dict, positive_prompt: str, negative_prompt: str = ""):
|
||||||
|
"""
|
||||||
|
find the first KSampler, then can find the prompt node through it.
|
||||||
|
"""
|
||||||
|
prompt = origin_prompt.copy()
|
||||||
|
id_to_class_type = {id: details["class_type"] for id, details in prompt.items()}
|
||||||
|
k_sampler = [key for key, value in id_to_class_type.items() if value == "KSampler"][0]
|
||||||
|
prompt.get(k_sampler)["inputs"]["seed"] = random.randint(10**14, 10**15 - 1)
|
||||||
|
positive_input_id = prompt.get(k_sampler)["inputs"]["positive"][0]
|
||||||
|
prompt.get(positive_input_id)["inputs"]["text"] = positive_prompt
|
||||||
|
|
||||||
|
if negative_prompt != "":
|
||||||
|
negative_input_id = prompt.get(k_sampler)["inputs"]["negative"][0]
|
||||||
|
prompt.get(negative_input_id)["inputs"]["text"] = negative_prompt
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def track_progress(self, prompt: dict, ws: WebSocket, prompt_id: str):
|
||||||
|
node_ids = list(prompt.keys())
|
||||||
|
finished_nodes = []
|
||||||
|
|
||||||
|
while True:
|
||||||
|
out = ws.recv()
|
||||||
|
if isinstance(out, str):
|
||||||
|
message = json.loads(out)
|
||||||
|
if message["type"] == "progress":
|
||||||
|
data = message["data"]
|
||||||
|
current_step = data["value"]
|
||||||
|
print("In K-Sampler -> Step: ", current_step, " of: ", data["max"])
|
||||||
|
if message["type"] == "execution_cached":
|
||||||
|
data = message["data"]
|
||||||
|
for itm in data["nodes"]:
|
||||||
|
if itm not in finished_nodes:
|
||||||
|
finished_nodes.append(itm)
|
||||||
|
print("Progress: ", len(finished_nodes), "/", len(node_ids), " Tasks done")
|
||||||
|
if message["type"] == "executing":
|
||||||
|
data = message["data"]
|
||||||
|
if data["node"] not in finished_nodes:
|
||||||
|
finished_nodes.append(data["node"])
|
||||||
|
print("Progress: ", len(finished_nodes), "/", len(node_ids), " Tasks done")
|
||||||
|
|
||||||
|
if data["node"] is None and data["prompt_id"] == prompt_id:
|
||||||
|
break # Execution is done
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
def generate_image_by_prompt(self, prompt: dict):
|
||||||
|
try:
|
||||||
|
ws, client_id = self.open_websocket_connection()
|
||||||
|
prompt_id = self.queue_prompt(client_id, prompt)
|
||||||
|
self.track_progress(prompt, ws, prompt_id)
|
||||||
|
history = self.get_history(prompt_id)
|
||||||
|
images = []
|
||||||
|
for output in history["outputs"].values():
|
||||||
|
for img in output.get("images", []):
|
||||||
|
image_data = self.get_image(img["filename"], img["subfolder"], img["type"])
|
||||||
|
images.append(image_data)
|
||||||
|
return images
|
||||||
|
finally:
|
||||||
|
ws.close()
|
@ -1,10 +1,10 @@
|
|||||||
identity:
|
identity:
|
||||||
name: txt2img workflow
|
name: txt2img
|
||||||
author: Qun
|
author: Qun
|
||||||
label:
|
label:
|
||||||
en_US: Txt2Img Workflow
|
en_US: Txt2Img
|
||||||
zh_Hans: Txt2Img Workflow
|
zh_Hans: Txt2Img
|
||||||
pt_BR: Txt2Img Workflow
|
pt_BR: Txt2Img
|
||||||
description:
|
description:
|
||||||
human:
|
human:
|
||||||
en_US: a pre-defined comfyui workflow that can use one model and up to 3 loras to generate images. Support SD1.5, SDXL, SD3 and FLUX which contain text encoders/clip, but does not support models that requires a triple clip loader.
|
en_US: a pre-defined comfyui workflow that can use one model and up to 3 loras to generate images. Support SD1.5, SDXL, SD3 and FLUX which contain text encoders/clip, but does not support models that requires a triple clip loader.
|
||||||
|
@ -0,0 +1,32 @@
|
|||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||||
|
from core.tools.tool.builtin_tool import BuiltinTool
|
||||||
|
|
||||||
|
from .comfyui_client import ComfyUiClient
|
||||||
|
|
||||||
|
|
||||||
|
class ComfyUIWorkflowTool(BuiltinTool):
|
||||||
|
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||||
|
comfyui = ComfyUiClient(self.runtime.credentials["base_url"])
|
||||||
|
|
||||||
|
positive_prompt = tool_parameters.get("positive_prompt")
|
||||||
|
negative_prompt = tool_parameters.get("negative_prompt")
|
||||||
|
workflow = tool_parameters.get("workflow_json")
|
||||||
|
|
||||||
|
try:
|
||||||
|
origin_prompt = json.loads(workflow)
|
||||||
|
except:
|
||||||
|
return self.create_text_message("the Workflow JSON is not correct")
|
||||||
|
|
||||||
|
prompt = comfyui.set_prompt(origin_prompt, positive_prompt, negative_prompt)
|
||||||
|
images = comfyui.generate_image_by_prompt(prompt)
|
||||||
|
result = []
|
||||||
|
for img in images:
|
||||||
|
result.append(
|
||||||
|
self.create_blob_message(
|
||||||
|
blob=img, meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result
|
@ -0,0 +1,35 @@
|
|||||||
|
identity:
|
||||||
|
name: workflow
|
||||||
|
author: hjlarry
|
||||||
|
label:
|
||||||
|
en_US: workflow
|
||||||
|
zh_Hans: 工作流
|
||||||
|
description:
|
||||||
|
human:
|
||||||
|
en_US: Run ComfyUI workflow.
|
||||||
|
zh_Hans: 运行ComfyUI工作流。
|
||||||
|
llm: Run ComfyUI workflow.
|
||||||
|
parameters:
|
||||||
|
- name: positive_prompt
|
||||||
|
type: string
|
||||||
|
label:
|
||||||
|
en_US: Prompt
|
||||||
|
zh_Hans: 提示词
|
||||||
|
llm_description: Image prompt, you should describe the image you want to generate as a list of words as possible as detailed, the prompt must be written in English.
|
||||||
|
form: llm
|
||||||
|
- name: negative_prompt
|
||||||
|
type: string
|
||||||
|
label:
|
||||||
|
en_US: Negative Prompt
|
||||||
|
zh_Hans: 负面提示词
|
||||||
|
llm_description: Negative prompt, you should describe the image you don't want to generate as a list of words as possible as detailed, the prompt must be written in English.
|
||||||
|
form: llm
|
||||||
|
- name: workflow_json
|
||||||
|
type: string
|
||||||
|
required: true
|
||||||
|
label:
|
||||||
|
en_US: Workflow JSON
|
||||||
|
human_description:
|
||||||
|
en_US: exported from ComfyUI workflow
|
||||||
|
zh_Hans: 从ComfyUI的工作流中导出
|
||||||
|
form: form
|
Loading…
x
Reference in New Issue
Block a user