mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 04:58:59 +08:00
feat: enhance comfyui workflow (#10085)
This commit is contained in:
parent
6692e8c508
commit
bd6175157c
@ -1,5 +1,3 @@
|
|||||||
import base64
|
|
||||||
import io
|
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
import uuid
|
import uuid
|
||||||
@ -8,7 +6,7 @@ import httpx
|
|||||||
from websocket import WebSocket
|
from websocket import WebSocket
|
||||||
from yarl import URL
|
from yarl import URL
|
||||||
|
|
||||||
from core.file.file_manager import _get_encoded_string
|
from core.file.file_manager import download
|
||||||
from core.file.models import File
|
from core.file.models import File
|
||||||
|
|
||||||
|
|
||||||
@ -29,8 +27,7 @@ class ComfyUiClient:
|
|||||||
return response.content
|
return response.content
|
||||||
|
|
||||||
def upload_image(self, image_file: File) -> dict:
|
def upload_image(self, image_file: File) -> dict:
|
||||||
image_content = base64.b64decode(_get_encoded_string(image_file))
|
file = download(image_file)
|
||||||
file = io.BytesIO(image_content)
|
|
||||||
files = {"image": (image_file.filename, file, image_file.mime_type), "overwrite": "true"}
|
files = {"image": (image_file.filename, file, image_file.mime_type), "overwrite": "true"}
|
||||||
res = httpx.post(str(self.base_url / "upload/image"), files=files)
|
res = httpx.post(str(self.base_url / "upload/image"), files=files)
|
||||||
return res.json()
|
return res.json()
|
||||||
@ -47,12 +44,7 @@ class ComfyUiClient:
|
|||||||
ws.connect(ws_address)
|
ws.connect(ws_address)
|
||||||
return ws, client_id
|
return ws, client_id
|
||||||
|
|
||||||
def set_prompt(
|
def set_prompt_by_ksampler(self, origin_prompt: dict, positive_prompt: str, negative_prompt: str = "") -> dict:
|
||||||
self, origin_prompt: dict, positive_prompt: str, negative_prompt: str = "", image_name: str = ""
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
find the first KSampler, then can find the prompt node through it.
|
|
||||||
"""
|
|
||||||
prompt = origin_prompt.copy()
|
prompt = origin_prompt.copy()
|
||||||
id_to_class_type = {id: details["class_type"] for id, details in prompt.items()}
|
id_to_class_type = {id: details["class_type"] for id, details in prompt.items()}
|
||||||
k_sampler = [key for key, value in id_to_class_type.items() if value == "KSampler"][0]
|
k_sampler = [key for key, value in id_to_class_type.items() if value == "KSampler"][0]
|
||||||
@ -64,9 +56,20 @@ class ComfyUiClient:
|
|||||||
negative_input_id = prompt.get(k_sampler)["inputs"]["negative"][0]
|
negative_input_id = prompt.get(k_sampler)["inputs"]["negative"][0]
|
||||||
prompt.get(negative_input_id)["inputs"]["text"] = negative_prompt
|
prompt.get(negative_input_id)["inputs"]["text"] = negative_prompt
|
||||||
|
|
||||||
if image_name != "":
|
return prompt
|
||||||
image_loader = [key for key, value in id_to_class_type.items() if value == "LoadImage"][0]
|
|
||||||
prompt.get(image_loader)["inputs"]["image"] = image_name
|
def set_prompt_images_by_ids(self, origin_prompt: dict, image_names: list[str], image_ids: list[str]) -> dict:
|
||||||
|
prompt = origin_prompt.copy()
|
||||||
|
for index, image_node_id in enumerate(image_ids):
|
||||||
|
prompt[image_node_id]["inputs"]["image"] = image_names[index]
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def set_prompt_images_by_default(self, origin_prompt: dict, image_names: list[str]) -> dict:
|
||||||
|
prompt = origin_prompt.copy()
|
||||||
|
id_to_class_type = {id: details["class_type"] for id, details in prompt.items()}
|
||||||
|
load_image_nodes = [key for key, value in id_to_class_type.items() if value == "LoadImage"]
|
||||||
|
for load_image, image_name in zip(load_image_nodes, image_names):
|
||||||
|
prompt.get(load_image)["inputs"]["image"] = image_name
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
def track_progress(self, prompt: dict, ws: WebSocket, prompt_id: str):
|
def track_progress(self, prompt: dict, ws: WebSocket, prompt_id: str):
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from core.file import FileType
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||||
|
from core.tools.errors import ToolParameterValidationError
|
||||||
from core.tools.provider.builtin.comfyui.tools.comfyui_client import ComfyUiClient
|
from core.tools.provider.builtin.comfyui.tools.comfyui_client import ComfyUiClient
|
||||||
from core.tools.tool.builtin_tool import BuiltinTool
|
from core.tools.tool.builtin_tool import BuiltinTool
|
||||||
|
|
||||||
@ -10,19 +12,46 @@ class ComfyUIWorkflowTool(BuiltinTool):
|
|||||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||||
comfyui = ComfyUiClient(self.runtime.credentials["base_url"])
|
comfyui = ComfyUiClient(self.runtime.credentials["base_url"])
|
||||||
|
|
||||||
positive_prompt = tool_parameters.get("positive_prompt")
|
positive_prompt = tool_parameters.get("positive_prompt", "")
|
||||||
negative_prompt = tool_parameters.get("negative_prompt")
|
negative_prompt = tool_parameters.get("negative_prompt", "")
|
||||||
|
images = tool_parameters.get("images") or []
|
||||||
workflow = tool_parameters.get("workflow_json")
|
workflow = tool_parameters.get("workflow_json")
|
||||||
image_name = ""
|
image_names = []
|
||||||
if image := tool_parameters.get("image"):
|
for image in images:
|
||||||
|
if image.type != FileType.IMAGE:
|
||||||
|
continue
|
||||||
image_name = comfyui.upload_image(image).get("name")
|
image_name = comfyui.upload_image(image).get("name")
|
||||||
|
image_names.append(image_name)
|
||||||
|
|
||||||
|
set_prompt_with_ksampler = True
|
||||||
|
if "{{positive_prompt}}" in workflow:
|
||||||
|
set_prompt_with_ksampler = False
|
||||||
|
workflow = workflow.replace("{{positive_prompt}}", positive_prompt)
|
||||||
|
workflow = workflow.replace("{{negative_prompt}}", negative_prompt)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
origin_prompt = json.loads(workflow)
|
prompt = json.loads(workflow)
|
||||||
except:
|
except:
|
||||||
return self.create_text_message("the Workflow JSON is not correct")
|
return self.create_text_message("the Workflow JSON is not correct")
|
||||||
|
|
||||||
prompt = comfyui.set_prompt(origin_prompt, positive_prompt, negative_prompt, image_name)
|
if set_prompt_with_ksampler:
|
||||||
|
try:
|
||||||
|
prompt = comfyui.set_prompt_by_ksampler(prompt, positive_prompt, negative_prompt)
|
||||||
|
except:
|
||||||
|
raise ToolParameterValidationError(
|
||||||
|
"Failed set prompt with KSampler, try replace prompt to {{positive_prompt}} in your workflow json"
|
||||||
|
)
|
||||||
|
|
||||||
|
if image_names:
|
||||||
|
if image_ids := tool_parameters.get("image_ids"):
|
||||||
|
image_ids = image_ids.split(",")
|
||||||
|
try:
|
||||||
|
prompt = comfyui.set_prompt_images_by_ids(prompt, image_names, image_ids)
|
||||||
|
except:
|
||||||
|
raise ToolParameterValidationError("the Image Node ID List not match your upload image files.")
|
||||||
|
else:
|
||||||
|
prompt = comfyui.set_prompt_images_by_default(prompt, image_names)
|
||||||
|
|
||||||
images = comfyui.generate_image_by_prompt(prompt)
|
images = comfyui.generate_image_by_prompt(prompt)
|
||||||
result = []
|
result = []
|
||||||
for img in images:
|
for img in images:
|
||||||
|
@ -24,12 +24,12 @@ parameters:
|
|||||||
zh_Hans: 负面提示词
|
zh_Hans: 负面提示词
|
||||||
llm_description: Negative prompt, you should describe the image you don't want to generate as a list of words as possible as detailed, the prompt must be written in English.
|
llm_description: Negative prompt, you should describe the image you don't want to generate as a list of words as possible as detailed, the prompt must be written in English.
|
||||||
form: llm
|
form: llm
|
||||||
- name: image
|
- name: images
|
||||||
type: file
|
type: files
|
||||||
label:
|
label:
|
||||||
en_US: Input Image
|
en_US: Input Images
|
||||||
zh_Hans: 输入的图片
|
zh_Hans: 输入的图片
|
||||||
llm_description: The input image, used to transfer to the comfyui workflow to generate another image.
|
llm_description: The input images, used to transfer to the comfyui workflow to generate another image.
|
||||||
form: llm
|
form: llm
|
||||||
- name: workflow_json
|
- name: workflow_json
|
||||||
type: string
|
type: string
|
||||||
@ -40,3 +40,15 @@ parameters:
|
|||||||
en_US: exported from ComfyUI workflow
|
en_US: exported from ComfyUI workflow
|
||||||
zh_Hans: 从ComfyUI的工作流中导出
|
zh_Hans: 从ComfyUI的工作流中导出
|
||||||
form: form
|
form: form
|
||||||
|
- name: image_ids
|
||||||
|
type: string
|
||||||
|
label:
|
||||||
|
en_US: Image Node ID List
|
||||||
|
zh_Hans: 图片节点ID列表
|
||||||
|
placeholder:
|
||||||
|
en_US: Use commas to separate multiple node ID
|
||||||
|
zh_Hans: 多个节点ID时使用半角逗号分隔
|
||||||
|
human_description:
|
||||||
|
en_US: When the workflow has multiple image nodes, enter the ID list of these nodes, and the images will be passed to ComfyUI in the order of the list.
|
||||||
|
zh_Hans: 当工作流有多个图片节点时,输入这些节点的ID列表,图片将按列表顺序传给ComfyUI
|
||||||
|
form: form
|
||||||
|
Loading…
x
Reference in New Issue
Block a user