diff --git a/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py b/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py index e449062718..4c022f983f 100644 --- a/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py +++ b/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py @@ -131,7 +131,8 @@ class StableDiffusionTool(BuiltinTool): negative_prompt=negative_prompt, width=width, height=height, - steps=steps) + steps=steps, + model=model) return self.text2img(base_url=base_url, lora=lora, @@ -139,7 +140,8 @@ class StableDiffusionTool(BuiltinTool): negative_prompt=negative_prompt, width=width, height=height, - steps=steps) + steps=steps, + model=model) def validate_models(self) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ @@ -197,7 +199,7 @@ class StableDiffusionTool(BuiltinTool): def img2img(self, base_url: str, lora: str, image_binary: bytes, prompt: str, negative_prompt: str, - width: int, height: int, steps: int) \ + width: int, height: int, steps: int, model: str) \ -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ generate image @@ -213,7 +215,8 @@ class StableDiffusionTool(BuiltinTool): "sampler_name": "Euler a", "restore_faces": False, "steps": steps, - "script_args": ["outpainting mk2"] + "script_args": ["outpainting mk2"], + "override_settings": {"sd_model_checkpoint": model} } if lora: @@ -236,7 +239,7 @@ class StableDiffusionTool(BuiltinTool): except Exception as e: return self.create_text_message('Failed to generate image') - def text2img(self, base_url: str, lora: str, prompt: str, negative_prompt: str, width: int, height: int, steps: int) \ + def text2img(self, base_url: str, lora: str, prompt: str, negative_prompt: str, width: int, height: int, steps: int, model: str) \ -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ generate image @@ -253,6 +256,7 @@ class StableDiffusionTool(BuiltinTool): draw_options['height'] = height draw_options['steps'] = steps draw_options['negative_prompt'] = negative_prompt + draw_options['override_settings']['sd_model_checkpoint'] = model try: url = str(URL(base_url) / 'sdapi' / 'v1' / 'txt2img')