mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 18:35:58 +08:00
Update stable_diffusion.py (#7536)
This commit is contained in:
parent
e42848f4b7
commit
70d6ab0bf5
@ -27,7 +27,7 @@ DRAW_TEXT_OPTIONS = {
|
|||||||
"seed_resize_from_w": -1,
|
"seed_resize_from_w": -1,
|
||||||
|
|
||||||
# Samplers
|
# Samplers
|
||||||
# "sampler_name": "DPM++ 2M",
|
"sampler_name": "DPM++ 2M",
|
||||||
# "scheduler": "",
|
# "scheduler": "",
|
||||||
# "sampler_index": "Automatic",
|
# "sampler_index": "Automatic",
|
||||||
|
|
||||||
@ -178,6 +178,23 @@ class StableDiffusionTool(BuiltinTool):
|
|||||||
return [d['model_name'] for d in response.json()]
|
return [d['model_name'] for d in response.json()]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
def get_sample_methods(self) -> list[str]:
|
||||||
|
"""
|
||||||
|
get sample method
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
base_url = self.runtime.credentials.get('base_url', None)
|
||||||
|
if not base_url:
|
||||||
|
return []
|
||||||
|
api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'samplers')
|
||||||
|
response = get(url=api_url, timeout=(2, 10))
|
||||||
|
if response.status_code != 200:
|
||||||
|
return []
|
||||||
|
else:
|
||||||
|
return [d['name'] for d in response.json()]
|
||||||
|
except Exception as e:
|
||||||
|
return []
|
||||||
|
|
||||||
def img2img(self, base_url: str, tool_parameters: dict[str, Any]) \
|
def img2img(self, base_url: str, tool_parameters: dict[str, Any]) \
|
||||||
-> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
-> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||||
@ -339,7 +356,27 @@ class StableDiffusionTool(BuiltinTool):
|
|||||||
label=I18nObject(en_US=i, zh_Hans=i)
|
label=I18nObject(en_US=i, zh_Hans=i)
|
||||||
) for i in models])
|
) for i in models])
|
||||||
)
|
)
|
||||||
|
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
sample_methods = self.get_sample_methods()
|
||||||
|
if len(sample_methods) != 0:
|
||||||
|
parameters.append(
|
||||||
|
ToolParameter(name='sampler_name',
|
||||||
|
label=I18nObject(en_US='Sampling method', zh_Hans='Sampling method'),
|
||||||
|
human_description=I18nObject(
|
||||||
|
en_US='Sampling method of Stable Diffusion, you can check the official documentation of Stable Diffusion',
|
||||||
|
zh_Hans='Stable Diffusion 的Sampling method,您可以查看 Stable Diffusion 的官方文档',
|
||||||
|
),
|
||||||
|
type=ToolParameter.ToolParameterType.SELECT,
|
||||||
|
form=ToolParameter.ToolParameterForm.FORM,
|
||||||
|
llm_description='Sampling method of Stable Diffusion, you can check the official documentation of Stable Diffusion',
|
||||||
|
required=True,
|
||||||
|
default=sample_methods[0],
|
||||||
|
options=[ToolParameterOption(
|
||||||
|
value=i,
|
||||||
|
label=I18nObject(en_US=i, zh_Hans=i)
|
||||||
|
) for i in sample_methods])
|
||||||
|
)
|
||||||
return parameters
|
return parameters
|
||||||
|
Loading…
x
Reference in New Issue
Block a user