feat: Add ComfyUI tool for Stable Diffusion (#8160)

This commit is contained in:
Qun 2024-09-18 10:56:29 +08:00 committed by GitHub
parent e896d1e9d7
commit cf645c3ba1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 853 additions and 0 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 209 KiB

View File

@ -0,0 +1,17 @@
from typing import Any
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.comfyui.tools.comfyui_stable_diffusion import ComfyuiStableDiffusionTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
class ComfyUIProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
ComfyuiStableDiffusionTool().fork_tool_runtime(
runtime={
"credentials": credentials,
}
).validate_models()
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@ -0,0 +1,42 @@
identity:
author: Qun
name: comfyui
label:
en_US: ComfyUI
zh_Hans: ComfyUI
pt_BR: ComfyUI
description:
en_US: ComfyUI is a tool for generating images which can be deployed locally.
zh_Hans: ComfyUI 是一个可以在本地部署的图片生成的工具。
pt_BR: ComfyUI is a tool for generating images which can be deployed locally.
icon: icon.png
tags:
- image
credentials_for_provider:
base_url:
type: text-input
required: true
label:
en_US: Base URL
zh_Hans: ComfyUI服务器的Base URL
pt_BR: Base URL
placeholder:
en_US: Please input your ComfyUI server's Base URL
zh_Hans: 请输入你的 ComfyUI 服务器的 Base URL
pt_BR: Please input your ComfyUI server's Base URL
model:
type: text-input
required: true
label:
en_US: Model with suffix
zh_Hans: 模型, 需要带后缀
pt_BR: Model with suffix
placeholder:
en_US: Please input your model
zh_Hans: 请输入你的模型名称
pt_BR: Please input your model
help:
en_US: The checkpoint name of the ComfyUI server, e.g. xxx.safetensors
zh_Hans: ComfyUI服务器的模型名称, 比如 xxx.safetensors
pt_BR: The checkpoint name of the ComfyUI server, e.g. xxx.safetensors
url: https://docs.dify.ai/tutorials/tool-configuration/comfyui

View File

@ -0,0 +1,475 @@
import json
import os
import random
import uuid
from copy import deepcopy
from enum import Enum
from typing import Any, Union
import websocket
from httpx import get, post
from yarl import URL
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.tool.builtin_tool import BuiltinTool
SD_TXT2IMG_OPTIONS = {}
LORA_NODE = {
"inputs": {"lora_name": "", "strength_model": 1, "strength_clip": 1, "model": ["11", 0], "clip": ["11", 1]},
"class_type": "LoraLoader",
"_meta": {"title": "Load LoRA"},
}
FluxGuidanceNode = {
"inputs": {"guidance": 3.5, "conditioning": ["6", 0]},
"class_type": "FluxGuidance",
"_meta": {"title": "FluxGuidance"},
}
class ModelType(Enum):
SD15 = 1
SDXL = 2
SD3 = 3
FLUX = 4
class ComfyuiStableDiffusionTool(BuiltinTool):
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
"""
# base url
base_url = self.runtime.credentials.get("base_url", "")
if not base_url:
return self.create_text_message("Please input base_url")
if tool_parameters.get("model"):
self.runtime.credentials["model"] = tool_parameters["model"]
model = self.runtime.credentials.get("model", None)
if not model:
return self.create_text_message("Please input model")
# prompt
prompt = tool_parameters.get("prompt", "")
if not prompt:
return self.create_text_message("Please input prompt")
# get negative prompt
negative_prompt = tool_parameters.get("negative_prompt", "")
# get size
width = tool_parameters.get("width", 1024)
height = tool_parameters.get("height", 1024)
# get steps
steps = tool_parameters.get("steps", 1)
# get sampler_name
sampler_name = tool_parameters.get("sampler_name", "euler")
# scheduler
scheduler = tool_parameters.get("scheduler", "normal")
# get cfg
cfg = tool_parameters.get("cfg", 7.0)
# get model type
model_type = tool_parameters.get("model_type", ModelType.SD15.name)
# get lora
# supports up to 3 loras
lora_list = []
lora_strength_list = []
if tool_parameters.get("lora_1"):
lora_list.append(tool_parameters["lora_1"])
lora_strength_list.append(tool_parameters.get("lora_strength_1", 1))
if tool_parameters.get("lora_2"):
lora_list.append(tool_parameters["lora_2"])
lora_strength_list.append(tool_parameters.get("lora_strength_2", 1))
if tool_parameters.get("lora_3"):
lora_list.append(tool_parameters["lora_3"])
lora_strength_list.append(tool_parameters.get("lora_strength_3", 1))
return self.text2img(
base_url=base_url,
model=model,
model_type=model_type,
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
steps=steps,
sampler_name=sampler_name,
scheduler=scheduler,
cfg=cfg,
lora_list=lora_list,
lora_strength_list=lora_strength_list,
)
def get_checkpoints(self) -> list[str]:
"""
get checkpoints
"""
try:
base_url = self.runtime.credentials.get("base_url", None)
if not base_url:
return []
api_url = str(URL(base_url) / "models" / "checkpoints")
response = get(url=api_url, timeout=(2, 10))
if response.status_code != 200:
return []
else:
return response.json()
except Exception as e:
return []
def get_loras(self) -> list[str]:
"""
get loras
"""
try:
base_url = self.runtime.credentials.get("base_url", None)
if not base_url:
return []
api_url = str(URL(base_url) / "models" / "loras")
response = get(url=api_url, timeout=(2, 10))
if response.status_code != 200:
return []
else:
return response.json()
except Exception as e:
return []
def get_sample_methods(self) -> tuple[list[str], list[str]]:
"""
get sample method
"""
try:
base_url = self.runtime.credentials.get("base_url", None)
if not base_url:
return [], []
api_url = str(URL(base_url) / "object_info" / "KSampler")
response = get(url=api_url, timeout=(2, 10))
if response.status_code != 200:
return [], []
else:
data = response.json()["KSampler"]["input"]["required"]
return data["sampler_name"][0], data["scheduler"][0]
except Exception as e:
return [], []
def validate_models(self) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
validate models
"""
try:
base_url = self.runtime.credentials.get("base_url", None)
if not base_url:
raise ToolProviderCredentialValidationError("Please input base_url")
model = self.runtime.credentials.get("model", None)
if not model:
raise ToolProviderCredentialValidationError("Please input model")
api_url = str(URL(base_url) / "models" / "checkpoints")
response = get(url=api_url, timeout=(2, 10))
if response.status_code != 200:
raise ToolProviderCredentialValidationError("Failed to get models")
else:
models = response.json()
if len([d for d in models if d == model]) > 0:
return self.create_text_message(json.dumps(models))
else:
raise ToolProviderCredentialValidationError(f"model {model} does not exist")
except Exception as e:
raise ToolProviderCredentialValidationError(f"Failed to get models, {e}")
def get_history(self, base_url, prompt_id):
"""
get history
"""
url = str(URL(base_url) / "history")
respond = get(url, params={"prompt_id": prompt_id}, timeout=(2, 10))
return respond.json()
def download_image(self, base_url, filename, subfolder, folder_type):
"""
download image
"""
url = str(URL(base_url) / "view")
response = get(url, params={"filename": filename, "subfolder": subfolder, "type": folder_type}, timeout=(2, 10))
return response.content
def queue_prompt_image(self, base_url, client_id, prompt):
"""
send prompt task and rotate
"""
# initiate task execution
url = str(URL(base_url) / "prompt")
respond = post(url, data=json.dumps({"client_id": client_id, "prompt": prompt}), timeout=(2, 10))
prompt_id = respond.json()["prompt_id"]
ws = websocket.WebSocket()
if "https" in base_url:
ws_url = base_url.replace("https", "ws")
else:
ws_url = base_url.replace("http", "ws")
ws.connect(str(URL(f"{ws_url}") / "ws") + f"?clientId={client_id}", timeout=120)
# websocket rotate execution status
output_images = {}
while True:
out = ws.recv()
if isinstance(out, str):
message = json.loads(out)
if message["type"] == "executing":
data = message["data"]
if data["node"] is None and data["prompt_id"] == prompt_id:
break # Execution is done
elif message["type"] == "status":
data = message["data"]
if data["status"]["exec_info"]["queue_remaining"] == 0 and data.get("sid"):
break # Execution is done
else:
continue # previews are binary data
# download image when execution finished
history = self.get_history(base_url, prompt_id)[prompt_id]
for o in history["outputs"]:
for node_id in history["outputs"]:
node_output = history["outputs"][node_id]
if "images" in node_output:
images_output = []
for image in node_output["images"]:
image_data = self.download_image(base_url, image["filename"], image["subfolder"], image["type"])
images_output.append(image_data)
output_images[node_id] = images_output
ws.close()
return output_images
def text2img(
self,
base_url: str,
model: str,
model_type: str,
prompt: str,
negative_prompt: str,
width: int,
height: int,
steps: int,
sampler_name: str,
scheduler: str,
cfg: float,
lora_list: list,
lora_strength_list: list,
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
generate image
"""
if not SD_TXT2IMG_OPTIONS:
current_dir = os.path.dirname(os.path.realpath(__file__))
with open(os.path.join(current_dir, "txt2img.json")) as file:
SD_TXT2IMG_OPTIONS.update(json.load(file))
draw_options = deepcopy(SD_TXT2IMG_OPTIONS)
draw_options["3"]["inputs"]["steps"] = steps
draw_options["3"]["inputs"]["sampler_name"] = sampler_name
draw_options["3"]["inputs"]["scheduler"] = scheduler
draw_options["3"]["inputs"]["cfg"] = cfg
# generate different image when using same prompt next time
draw_options["3"]["inputs"]["seed"] = random.randint(0, 100000000)
draw_options["4"]["inputs"]["ckpt_name"] = model
draw_options["5"]["inputs"]["width"] = width
draw_options["5"]["inputs"]["height"] = height
draw_options["6"]["inputs"]["text"] = prompt
draw_options["7"]["inputs"]["text"] = negative_prompt
# if the model is SD3 or FLUX series, the Latent class should be corresponding to SD3 Latent
if model_type in (ModelType.SD3.name, ModelType.FLUX.name):
draw_options["5"]["class_type"] = "EmptySD3LatentImage"
if lora_list:
# last Lora node link to KSampler node
draw_options["3"]["inputs"]["model"][0] = "10"
# last Lora node link to positive and negative Clip node
draw_options["6"]["inputs"]["clip"][0] = "10"
draw_options["7"]["inputs"]["clip"][0] = "10"
# every Lora node link to next Lora node, and Checkpoints node link to first Lora node
for i, (lora, strength) in enumerate(zip(lora_list, lora_strength_list), 10):
if i - 10 == len(lora_list) - 1:
next_node_id = "4"
else:
next_node_id = str(i + 1)
lora_node = deepcopy(LORA_NODE)
lora_node["inputs"]["lora_name"] = lora
lora_node["inputs"]["strength_model"] = strength
lora_node["inputs"]["strength_clip"] = strength
lora_node["inputs"]["model"][0] = next_node_id
lora_node["inputs"]["clip"][0] = next_node_id
draw_options[str(i)] = lora_node
# FLUX need to add FluxGuidance Node
if model_type == ModelType.FLUX.name:
last_node_id = str(10 + len(lora_list))
draw_options[last_node_id] = deepcopy(FluxGuidanceNode)
draw_options[last_node_id]["inputs"]["conditioning"][0] = "6"
draw_options["3"]["inputs"]["positive"][0] = last_node_id
try:
client_id = str(uuid.uuid4())
result = self.queue_prompt_image(base_url, client_id, prompt=draw_options)
# get first image
image = b""
for node in result:
for img in result[node]:
if img:
image = img
break
return self.create_blob_message(
blob=image, meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value
)
except Exception as e:
return self.create_text_message(f"Failed to generate image: {str(e)}")
def get_runtime_parameters(self) -> list[ToolParameter]:
parameters = [
ToolParameter(
name="prompt",
label=I18nObject(en_US="Prompt", zh_Hans="Prompt"),
human_description=I18nObject(
en_US="Image prompt, you can check the official documentation of Stable Diffusion",
zh_Hans="图像提示词,您可以查看 Stable Diffusion 的官方文档",
),
type=ToolParameter.ToolParameterType.STRING,
form=ToolParameter.ToolParameterForm.LLM,
llm_description="Image prompt of Stable Diffusion, you should describe the image "
"you want to generate as a list of words as possible as detailed, "
"the prompt must be written in English.",
required=True,
),
]
if self.runtime.credentials:
try:
models = self.get_checkpoints()
if len(models) != 0:
parameters.append(
ToolParameter(
name="model",
label=I18nObject(en_US="Model", zh_Hans="Model"),
human_description=I18nObject(
en_US="Model of Stable Diffusion or FLUX, "
"you can check the official documentation of Stable Diffusion or FLUX",
zh_Hans="Stable Diffusion 或者 FLUX 的模型,您可以查看 Stable Diffusion 的官方文档",
),
type=ToolParameter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
llm_description="Model of Stable Diffusion or FLUX, "
"you can check the official documentation of Stable Diffusion or FLUX",
required=True,
default=models[0],
options=[
ToolParameterOption(value=i, label=I18nObject(en_US=i, zh_Hans=i)) for i in models
],
)
)
loras = self.get_loras()
if len(loras) != 0:
for n in range(1, 4):
parameters.append(
ToolParameter(
name=f"lora_{n}",
label=I18nObject(en_US=f"Lora {n}", zh_Hans=f"Lora {n}"),
human_description=I18nObject(
en_US="Lora of Stable Diffusion, "
"you can check the official documentation of Stable Diffusion",
zh_Hans="Stable Diffusion 的 Lora 模型,您可以查看 Stable Diffusion 的官方文档",
),
type=ToolParameter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
llm_description="Lora of Stable Diffusion, "
"you can check the official documentation of "
"Stable Diffusion",
required=False,
options=[
ToolParameterOption(value=i, label=I18nObject(en_US=i, zh_Hans=i)) for i in loras
],
)
)
sample_methods, schedulers = self.get_sample_methods()
if len(sample_methods) != 0:
parameters.append(
ToolParameter(
name="sampler_name",
label=I18nObject(en_US="Sampling method", zh_Hans="Sampling method"),
human_description=I18nObject(
en_US="Sampling method of Stable Diffusion, "
"you can check the official documentation of Stable Diffusion",
zh_Hans="Stable Diffusion 的Sampling method您可以查看 Stable Diffusion 的官方文档",
),
type=ToolParameter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
llm_description="Sampling method of Stable Diffusion, "
"you can check the official documentation of Stable Diffusion",
required=True,
default=sample_methods[0],
options=[
ToolParameterOption(value=i, label=I18nObject(en_US=i, zh_Hans=i))
for i in sample_methods
],
)
)
if len(schedulers) != 0:
parameters.append(
ToolParameter(
name="scheduler",
label=I18nObject(en_US="Scheduler", zh_Hans="Scheduler"),
human_description=I18nObject(
en_US="Scheduler of Stable Diffusion, "
"you can check the official documentation of Stable Diffusion",
zh_Hans="Stable Diffusion 的Scheduler您可以查看 Stable Diffusion 的官方文档",
),
type=ToolParameter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
llm_description="Scheduler of Stable Diffusion, "
"you can check the official documentation of Stable Diffusion",
required=True,
default=schedulers[0],
options=[
ToolParameterOption(value=i, label=I18nObject(en_US=i, zh_Hans=i)) for i in schedulers
],
)
)
parameters.append(
ToolParameter(
name="model_type",
label=I18nObject(en_US="Model Type", zh_Hans="Model Type"),
human_description=I18nObject(
en_US="Model Type of Stable Diffusion or Flux, "
"you can check the official documentation of Stable Diffusion or Flux",
zh_Hans="Stable Diffusion 或 FLUX 的模型类型,"
"您可以查看 Stable Diffusion 或 Flux 的官方文档",
),
type=ToolParameter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
llm_description="Model Type of Stable Diffusion or Flux, "
"you can check the official documentation of Stable Diffusion or Flux",
required=True,
default=ModelType.SD15.name,
options=[
ToolParameterOption(value=i, label=I18nObject(en_US=i, zh_Hans=i))
for i in ModelType.__members__
],
)
)
except:
pass
return parameters

View File

@ -0,0 +1,212 @@
identity:
name: txt2img workflow
author: Qun
label:
en_US: Txt2Img Workflow
zh_Hans: Txt2Img Workflow
pt_BR: Txt2Img Workflow
description:
human:
en_US: a pre-defined comfyui workflow that can use one model and up to 3 loras to generate images. Support SD1.5, SDXL, SD3 and FLUX which contain text encoders/clip, but does not support models that requires a triple clip loader.
zh_Hans: 一个预定义的 ComfyUI 工作流可以使用一个模型和最多3个loras来生成图像。支持包含文本编码器/clip的SD1.5、SDXL、SD3和FLUX但不支持需要clip加载器的模型。
pt_BR: a pre-defined comfyui workflow that can use one model and up to 3 loras to generate images. Support SD1.5, SDXL, SD3 and FLUX which contain text encoders/clip, but does not support models that requires a triple clip loader.
llm: draw the image you want based on your prompt.
parameters:
- name: prompt
type: string
required: true
label:
en_US: Prompt
zh_Hans: 提示词
pt_BR: Prompt
human_description:
en_US: Image prompt, you can check the official documentation of Stable Diffusion or FLUX
zh_Hans: 图像提示词,您可以查看 Stable Diffusion 或者 FLUX 的官方文档
pt_BR: Image prompt, you can check the official documentation of Stable Diffusion or FLUX
llm_description: Image prompt of Stable Diffusion, you should describe the image you want to generate as a list of words as possible as detailed, the prompt must be written in English.
form: llm
- name: model
type: string
required: true
label:
en_US: Model Name
zh_Hans: 模型名称
pt_BR: Model Name
human_description:
en_US: Model Name
zh_Hans: 模型名称
pt_BR: Model Name
form: form
- name: model_type
type: string
required: true
label:
en_US: Model Type
zh_Hans: 模型类型
pt_BR: Model Type
human_description:
en_US: Model Type
zh_Hans: 模型类型
pt_BR: Model Type
form: form
- name: lora_1
type: string
required: false
label:
en_US: Lora 1
zh_Hans: Lora 1
pt_BR: Lora 1
human_description:
en_US: Lora 1
zh_Hans: Lora 1
pt_BR: Lora 1
form: form
- name: lora_strength_1
type: number
required: false
label:
en_US: Lora Strength 1
zh_Hans: Lora Strength 1
pt_BR: Lora Strength 1
human_description:
en_US: Lora Strength 1
zh_Hans: Lora模型的权重
pt_BR: Lora Strength 1
form: form
- name: steps
type: number
required: false
label:
en_US: Steps
zh_Hans: Steps
pt_BR: Steps
human_description:
en_US: Steps
zh_Hans: Steps
pt_BR: Steps
form: form
default: 20
- name: width
type: number
required: false
label:
en_US: Width
zh_Hans: Width
pt_BR: Width
human_description:
en_US: Width
zh_Hans: Width
pt_BR: Width
form: form
default: 1024
- name: height
type: number
required: false
label:
en_US: Height
zh_Hans: Height
pt_BR: Height
human_description:
en_US: Height
zh_Hans: Height
pt_BR: Height
form: form
default: 1024
- name: negative_prompt
type: string
required: false
label:
en_US: Negative prompt
zh_Hans: Negative prompt
pt_BR: Negative prompt
human_description:
en_US: Negative prompt
zh_Hans: Negative prompt
pt_BR: Negative prompt
form: form
default: bad art, ugly, deformed, watermark, duplicated, discontinuous lines
- name: cfg
type: number
required: false
label:
en_US: CFG Scale
zh_Hans: CFG Scale
pt_BR: CFG Scale
human_description:
en_US: CFG Scale
zh_Hans: 提示词相关性(CFG Scale)
pt_BR: CFG Scale
form: form
default: 7.0
- name: sampler_name
type: string
required: false
label:
en_US: Sampling method
zh_Hans: Sampling method
pt_BR: Sampling method
human_description:
en_US: Sampling method
zh_Hans: Sampling method
pt_BR: Sampling method
form: form
- name: scheduler
type: string
required: false
label:
en_US: Scheduler
zh_Hans: Scheduler
pt_BR: Scheduler
human_description:
en_US: Scheduler
zh_Hans: Scheduler
pt_BR: Scheduler
form: form
- name: lora_2
type: string
required: false
label:
en_US: Lora 2
zh_Hans: Lora 2
pt_BR: Lora 2
human_description:
en_US: Lora 2
zh_Hans: Lora 2
pt_BR: Lora 2
form: form
- name: lora_strength_2
type: number
required: false
label:
en_US: Lora Strength 2
zh_Hans: Lora Strength 2
pt_BR: Lora Strength 2
human_description:
en_US: Lora Strength 2
zh_Hans: Lora模型的权重
pt_BR: Lora Strength 2
form: form
- name: lora_3
type: string
required: false
label:
en_US: Lora 3
zh_Hans: Lora 3
pt_BR: Lora 3
human_description:
en_US: Lora 3
zh_Hans: Lora 3
pt_BR: Lora 3
form: form
- name: lora_strength_3
type: number
required: false
label:
en_US: Lora Strength 3
zh_Hans: Lora Strength 3
pt_BR: Lora Strength 3
human_description:
en_US: Lora Strength 3
zh_Hans: Lora模型的权重
pt_BR: Lora Strength 3
form: form

View File

@ -0,0 +1,107 @@
{
"3": {
"inputs": {
"seed": 156680208700286,
"steps": 20,
"cfg": 8,
"sampler_name": "euler",
"scheduler": "normal",
"denoise": 1,
"model": [
"4",
0
],
"positive": [
"6",
0
],
"negative": [
"7",
0
],
"latent_image": [
"5",
0
]
},
"class_type": "KSampler",
"_meta": {
"title": "KSampler"
}
},
"4": {
"inputs": {
"ckpt_name": "3dAnimationDiffusion_v10.safetensors"
},
"class_type": "CheckpointLoaderSimple",
"_meta": {
"title": "Load Checkpoint"
}
},
"5": {
"inputs": {
"width": 512,
"height": 512,
"batch_size": 1
},
"class_type": "EmptyLatentImage",
"_meta": {
"title": "Empty Latent Image"
}
},
"6": {
"inputs": {
"text": "beautiful scenery nature glass bottle landscape, , purple galaxy bottle,",
"clip": [
"4",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"7": {
"inputs": {
"text": "text, watermark",
"clip": [
"4",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"8": {
"inputs": {
"samples": [
"3",
0
],
"vae": [
"4",
2
]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"9": {
"inputs": {
"filename_prefix": "ComfyUI",
"images": [
"8",
0
]
},
"class_type": "SaveImage",
"_meta": {
"title": "Save Image"
}
}
}