mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-11 16:28:58 +08:00
Optimization stable diffusion verify (#2322)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM> Co-authored-by: crazywoola <427733928@qq.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
parent
83063532a0
commit
5929e84036
@ -5,6 +5,7 @@ from core.tools.provider.builtin.stablediffusion.tools.stable_diffusion import S
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
class StableDiffusionProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
|
||||
try:
|
||||
@ -12,15 +13,6 @@ class StableDiffusionProvider(BuiltinToolProviderController):
|
||||
meta={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
tool_parameters={
|
||||
"prompt": "cat",
|
||||
"lora": "",
|
||||
"steps": 1,
|
||||
"width": 512,
|
||||
"height": 512,
|
||||
},
|
||||
)
|
||||
).validate_models()
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
@ -4,7 +4,7 @@ from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
from httpx import post
|
||||
from httpx import post, get
|
||||
from os.path import join
|
||||
from base64 import b64decode, b64encode
|
||||
from PIL import Image
|
||||
@ -59,6 +59,7 @@ DRAW_TEXT_OPTIONS = {
|
||||
"alwayson_scripts": {}
|
||||
}
|
||||
|
||||
|
||||
class StableDiffusionTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) \
|
||||
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||
@ -137,6 +138,30 @@ class StableDiffusionTool(BuiltinTool):
|
||||
height=height,
|
||||
steps=steps)
|
||||
|
||||
def validate_models(self) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||
"""
|
||||
validate models
|
||||
"""
|
||||
try:
|
||||
base_url = self.runtime.credentials.get('base_url', None)
|
||||
if not base_url:
|
||||
raise ToolProviderCredentialValidationError('Please input base_url')
|
||||
model = self.runtime.credentials.get('model', None)
|
||||
if not model:
|
||||
raise ToolProviderCredentialValidationError('Please input model')
|
||||
|
||||
response = get(url=f'{base_url}/sdapi/v1/sd-models', timeout=120)
|
||||
if response.status_code != 200:
|
||||
raise ToolProviderCredentialValidationError('Failed to get models')
|
||||
else:
|
||||
models = [d['model_name'] for d in response.json()]
|
||||
if len([d for d in models if d == model]) > 0:
|
||||
return self.create_text_message(json.dumps(models))
|
||||
else:
|
||||
raise ToolProviderCredentialValidationError(f'model {model} does not exist')
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(f'Failed to get models, {e}')
|
||||
|
||||
def img2img(self, base_url: str, lora: str, image_binary: bytes,
|
||||
prompt: str, negative_prompt: str,
|
||||
width: int, height: int, steps: int) \
|
||||
@ -211,7 +236,6 @@ class StableDiffusionTool(BuiltinTool):
|
||||
except Exception as e:
|
||||
return self.create_text_message('Failed to generate image')
|
||||
|
||||
|
||||
def get_runtime_parameters(self) -> List[ToolParameter]:
|
||||
parameters = [
|
||||
ToolParameter(name='prompt',
|
||||
|
Loading…
x
Reference in New Issue
Block a user