From 5929e8403624aca993e6e98e28f7cf134708b93d Mon Sep 17 00:00:00 2001 From: "Charlie.Wei" Date: Thu, 1 Feb 2024 12:05:09 +0800 Subject: [PATCH] Optimization stable diffusion verify (#2322) Co-authored-by: luowei Co-authored-by: crazywoola <427733928@qq.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> --- .../stablediffusion/stablediffusion.py | 12 ++----- .../stablediffusion/tools/stable_diffusion.py | 34 ++++++++++++++++--- 2 files changed, 31 insertions(+), 15 deletions(-) diff --git a/api/core/tools/provider/builtin/stablediffusion/stablediffusion.py b/api/core/tools/provider/builtin/stablediffusion/stablediffusion.py index 5820d0f62b..2e79df4ff3 100644 --- a/api/core/tools/provider/builtin/stablediffusion/stablediffusion.py +++ b/api/core/tools/provider/builtin/stablediffusion/stablediffusion.py @@ -5,6 +5,7 @@ from core.tools.provider.builtin.stablediffusion.tools.stable_diffusion import S from typing import Any, Dict + class StableDiffusionProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: Dict[str, Any]) -> None: try: @@ -12,15 +13,6 @@ class StableDiffusionProvider(BuiltinToolProviderController): meta={ "credentials": credentials, } - ).invoke( - user_id='', - tool_parameters={ - "prompt": "cat", - "lora": "", - "steps": 1, - "width": 512, - "height": 512, - }, - ) + ).validate_models() except Exception as e: raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py b/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py index 1a52ba0f2f..dbf2fd749f 100644 --- a/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py +++ b/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py @@ -4,7 +4,7 @@ from core.tools.entities.common_entities import I18nObject from core.tools.errors import ToolProviderCredentialValidationError from typing import Any, Dict, List, Union -from httpx import post +from httpx import post, get from os.path import join from base64 import b64decode, b64encode from PIL import Image @@ -59,6 +59,7 @@ DRAW_TEXT_OPTIONS = { "alwayson_scripts": {} } + class StableDiffusionTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) \ -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: @@ -136,7 +137,31 @@ class StableDiffusionTool(BuiltinTool): width=width, height=height, steps=steps) - + + def validate_models(self) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: + """ + validate models + """ + try: + base_url = self.runtime.credentials.get('base_url', None) + if not base_url: + raise ToolProviderCredentialValidationError('Please input base_url') + model = self.runtime.credentials.get('model', None) + if not model: + raise ToolProviderCredentialValidationError('Please input model') + + response = get(url=f'{base_url}/sdapi/v1/sd-models', timeout=120) + if response.status_code != 200: + raise ToolProviderCredentialValidationError('Failed to get models') + else: + models = [d['model_name'] for d in response.json()] + if len([d for d in models if d == model]) > 0: + return self.create_text_message(json.dumps(models)) + else: + raise ToolProviderCredentialValidationError(f'model {model} does not exist') + except Exception as e: + raise ToolProviderCredentialValidationError(f'Failed to get models, {e}') + def img2img(self, base_url: str, lora: str, image_binary: bytes, prompt: str, negative_prompt: str, width: int, height: int, steps: int) \ @@ -211,10 +236,9 @@ class StableDiffusionTool(BuiltinTool): except Exception as e: return self.create_text_message('Failed to generate image') - def get_runtime_parameters(self) -> List[ToolParameter]: parameters = [ - ToolParameter(name='prompt', + ToolParameter(name='prompt', label=I18nObject(en_US='Prompt', zh_Hans='Prompt'), human_description=I18nObject( en_US='Image prompt, you can check the official documentation of Stable Diffusion', @@ -227,7 +251,7 @@ class StableDiffusionTool(BuiltinTool): ] if len(self.list_default_image_variables()) != 0: parameters.append( - ToolParameter(name='image_id', + ToolParameter(name='image_id', label=I18nObject(en_US='image_id', zh_Hans='image_id'), human_description=I18nObject( en_US='Image id of the image you want to generate based on, if you want to generate image based on the default image, you can leave this field empty.',