Optimization stable diffusion verify (#2322)

Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
Charlie.Wei 2024-02-01 12:05:09 +08:00 committed by GitHub
parent 83063532a0
commit 5929e84036
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 31 additions and 15 deletions

View File

@ -5,6 +5,7 @@ from core.tools.provider.builtin.stablediffusion.tools.stable_diffusion import S
from typing import Any, Dict from typing import Any, Dict
class StableDiffusionProvider(BuiltinToolProviderController): class StableDiffusionProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: Dict[str, Any]) -> None: def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
try: try:
@ -12,15 +13,6 @@ class StableDiffusionProvider(BuiltinToolProviderController):
meta={ meta={
"credentials": credentials, "credentials": credentials,
} }
).invoke( ).validate_models()
user_id='',
tool_parameters={
"prompt": "cat",
"lora": "",
"steps": 1,
"width": 512,
"height": 512,
},
)
except Exception as e: except Exception as e:
raise ToolProviderCredentialValidationError(str(e)) raise ToolProviderCredentialValidationError(str(e))

View File

@ -4,7 +4,7 @@ from core.tools.entities.common_entities import I18nObject
from core.tools.errors import ToolProviderCredentialValidationError from core.tools.errors import ToolProviderCredentialValidationError
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Union
from httpx import post from httpx import post, get
from os.path import join from os.path import join
from base64 import b64decode, b64encode from base64 import b64decode, b64encode
from PIL import Image from PIL import Image
@ -59,6 +59,7 @@ DRAW_TEXT_OPTIONS = {
"alwayson_scripts": {} "alwayson_scripts": {}
} }
class StableDiffusionTool(BuiltinTool): class StableDiffusionTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) \ def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) \
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
@ -137,6 +138,30 @@ class StableDiffusionTool(BuiltinTool):
height=height, height=height,
steps=steps) steps=steps)
def validate_models(self) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
validate models
"""
try:
base_url = self.runtime.credentials.get('base_url', None)
if not base_url:
raise ToolProviderCredentialValidationError('Please input base_url')
model = self.runtime.credentials.get('model', None)
if not model:
raise ToolProviderCredentialValidationError('Please input model')
response = get(url=f'{base_url}/sdapi/v1/sd-models', timeout=120)
if response.status_code != 200:
raise ToolProviderCredentialValidationError('Failed to get models')
else:
models = [d['model_name'] for d in response.json()]
if len([d for d in models if d == model]) > 0:
return self.create_text_message(json.dumps(models))
else:
raise ToolProviderCredentialValidationError(f'model {model} does not exist')
except Exception as e:
raise ToolProviderCredentialValidationError(f'Failed to get models, {e}')
def img2img(self, base_url: str, lora: str, image_binary: bytes, def img2img(self, base_url: str, lora: str, image_binary: bytes,
prompt: str, negative_prompt: str, prompt: str, negative_prompt: str,
width: int, height: int, steps: int) \ width: int, height: int, steps: int) \
@ -211,7 +236,6 @@ class StableDiffusionTool(BuiltinTool):
except Exception as e: except Exception as e:
return self.create_text_message('Failed to generate image') return self.create_text_message('Failed to generate image')
def get_runtime_parameters(self) -> List[ToolParameter]: def get_runtime_parameters(self) -> List[ToolParameter]:
parameters = [ parameters = [
ToolParameter(name='prompt', ToolParameter(name='prompt',