mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-11 18:49:02 +08:00
Optimization stable diffusion verify (#2322)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM> Co-authored-by: crazywoola <427733928@qq.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
parent
83063532a0
commit
5929e84036
@ -5,6 +5,7 @@ from core.tools.provider.builtin.stablediffusion.tools.stable_diffusion import S
|
|||||||
|
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionProvider(BuiltinToolProviderController):
|
class StableDiffusionProvider(BuiltinToolProviderController):
|
||||||
def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
|
def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
|
||||||
try:
|
try:
|
||||||
@ -12,15 +13,6 @@ class StableDiffusionProvider(BuiltinToolProviderController):
|
|||||||
meta={
|
meta={
|
||||||
"credentials": credentials,
|
"credentials": credentials,
|
||||||
}
|
}
|
||||||
).invoke(
|
).validate_models()
|
||||||
user_id='',
|
|
||||||
tool_parameters={
|
|
||||||
"prompt": "cat",
|
|
||||||
"lora": "",
|
|
||||||
"steps": 1,
|
|
||||||
"width": 512,
|
|
||||||
"height": 512,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ToolProviderCredentialValidationError(str(e))
|
raise ToolProviderCredentialValidationError(str(e))
|
@ -4,7 +4,7 @@ from core.tools.entities.common_entities import I18nObject
|
|||||||
from core.tools.errors import ToolProviderCredentialValidationError
|
from core.tools.errors import ToolProviderCredentialValidationError
|
||||||
|
|
||||||
from typing import Any, Dict, List, Union
|
from typing import Any, Dict, List, Union
|
||||||
from httpx import post
|
from httpx import post, get
|
||||||
from os.path import join
|
from os.path import join
|
||||||
from base64 import b64decode, b64encode
|
from base64 import b64decode, b64encode
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@ -59,6 +59,7 @@ DRAW_TEXT_OPTIONS = {
|
|||||||
"alwayson_scripts": {}
|
"alwayson_scripts": {}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionTool(BuiltinTool):
|
class StableDiffusionTool(BuiltinTool):
|
||||||
def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) \
|
def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) \
|
||||||
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||||
@ -137,6 +138,30 @@ class StableDiffusionTool(BuiltinTool):
|
|||||||
height=height,
|
height=height,
|
||||||
steps=steps)
|
steps=steps)
|
||||||
|
|
||||||
|
def validate_models(self) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||||
|
"""
|
||||||
|
validate models
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
base_url = self.runtime.credentials.get('base_url', None)
|
||||||
|
if not base_url:
|
||||||
|
raise ToolProviderCredentialValidationError('Please input base_url')
|
||||||
|
model = self.runtime.credentials.get('model', None)
|
||||||
|
if not model:
|
||||||
|
raise ToolProviderCredentialValidationError('Please input model')
|
||||||
|
|
||||||
|
response = get(url=f'{base_url}/sdapi/v1/sd-models', timeout=120)
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise ToolProviderCredentialValidationError('Failed to get models')
|
||||||
|
else:
|
||||||
|
models = [d['model_name'] for d in response.json()]
|
||||||
|
if len([d for d in models if d == model]) > 0:
|
||||||
|
return self.create_text_message(json.dumps(models))
|
||||||
|
else:
|
||||||
|
raise ToolProviderCredentialValidationError(f'model {model} does not exist')
|
||||||
|
except Exception as e:
|
||||||
|
raise ToolProviderCredentialValidationError(f'Failed to get models, {e}')
|
||||||
|
|
||||||
def img2img(self, base_url: str, lora: str, image_binary: bytes,
|
def img2img(self, base_url: str, lora: str, image_binary: bytes,
|
||||||
prompt: str, negative_prompt: str,
|
prompt: str, negative_prompt: str,
|
||||||
width: int, height: int, steps: int) \
|
width: int, height: int, steps: int) \
|
||||||
@ -211,7 +236,6 @@ class StableDiffusionTool(BuiltinTool):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return self.create_text_message('Failed to generate image')
|
return self.create_text_message('Failed to generate image')
|
||||||
|
|
||||||
|
|
||||||
def get_runtime_parameters(self) -> List[ToolParameter]:
|
def get_runtime_parameters(self) -> List[ToolParameter]:
|
||||||
parameters = [
|
parameters = [
|
||||||
ToolParameter(name='prompt',
|
ToolParameter(name='prompt',
|
||||||
|
Loading…
x
Reference in New Issue
Block a user