mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 06:05:51 +08:00
dalle3 add style consistency parameter (#5067)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM> Co-authored-by: crazywoola <427733928@qq.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
parent
9f7b38c068
commit
b7c72f7a97
@ -8,10 +8,10 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class DallE3Tool(BuiltinTool):
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
@ -43,14 +43,18 @@ class DallE3Tool(BuiltinTool):
|
||||
style = tool_parameters.get('style', 'vivid')
|
||||
if style not in ['natural', 'vivid']:
|
||||
return self.create_text_message('Invalid style')
|
||||
# set extra body
|
||||
seed_id = tool_parameters.get('seed_id', self._generate_random_id(8))
|
||||
extra_body = {'seed': seed_id}
|
||||
|
||||
# call openapi dalle3
|
||||
model=self.runtime.credentials['azure_openai_api_model_name']
|
||||
model = self.runtime.credentials['azure_openai_api_model_name']
|
||||
response = client.images.generate(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
size=size,
|
||||
n=n,
|
||||
extra_body=extra_body,
|
||||
style=style,
|
||||
quality=quality,
|
||||
response_format='b64_json'
|
||||
@ -59,8 +63,15 @@ class DallE3Tool(BuiltinTool):
|
||||
result = []
|
||||
|
||||
for image in response.data:
|
||||
result.append(self.create_blob_message(blob=b64decode(image.b64_json),
|
||||
meta={ 'mime_type': 'image/png' },
|
||||
save_as=self.VARIABLE_KEY.IMAGE.value))
|
||||
result.append(self.create_blob_message(blob=b64decode(image.b64_json),
|
||||
meta={'mime_type': 'image/png'},
|
||||
save_as=self.VARIABLE_KEY.IMAGE.value))
|
||||
result.append(self.create_text_message(f'\nGenerate image source to Seed ID: {seed_id}'))
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _generate_random_id(length=8):
|
||||
characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
|
||||
random_id = ''.join(random.choices(characters, k=length))
|
||||
return random_id
|
||||
|
@ -29,6 +29,19 @@ parameters:
|
||||
pt_BR: Imagem prompt, você pode verificar a documentação oficial do DallE 3
|
||||
llm_description: Image prompt of DallE 3, you should describe the image you want to generate as a list of words as possible as detailed
|
||||
form: llm
|
||||
- name: seed_id
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Seed ID
|
||||
zh_Hans: 种子ID
|
||||
pt_BR: ID da semente
|
||||
human_description:
|
||||
en_US: Image generation seed ID to ensure consistency of series generated images
|
||||
zh_Hans: 图像生成种子ID,确保系列生成图像的一致性
|
||||
pt_BR: ID de semente de geração de imagem para garantir a consistência das imagens geradas em série
|
||||
llm_description: If the user requests image consistency, extract the seed ID from the user's question or context.The seed id consists of an 8-bit string containing uppercase and lowercase letters and numbers
|
||||
form: llm
|
||||
- name: size
|
||||
type: select
|
||||
required: true
|
||||
|
@ -1,3 +1,4 @@
|
||||
import random
|
||||
from base64 import b64decode
|
||||
from typing import Any, Union
|
||||
|
||||
@ -9,10 +10,10 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class DallE3Tool(BuiltinTool):
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
@ -53,6 +54,9 @@ class DallE3Tool(BuiltinTool):
|
||||
style = tool_parameters.get('style', 'vivid')
|
||||
if style not in ['natural', 'vivid']:
|
||||
return self.create_text_message('Invalid style')
|
||||
# set extra body
|
||||
seed_id = tool_parameters.get('seed_id', self._generate_random_id(8))
|
||||
extra_body = {'seed': seed_id}
|
||||
|
||||
# call openapi dalle3
|
||||
response = client.images.generate(
|
||||
@ -60,6 +64,7 @@ class DallE3Tool(BuiltinTool):
|
||||
model='dall-e-3',
|
||||
size=size,
|
||||
n=n,
|
||||
extra_body=extra_body,
|
||||
style=style,
|
||||
quality=quality,
|
||||
response_format='b64_json'
|
||||
@ -68,8 +73,15 @@ class DallE3Tool(BuiltinTool):
|
||||
result = []
|
||||
|
||||
for image in response.data:
|
||||
result.append(self.create_blob_message(blob=b64decode(image.b64_json),
|
||||
meta={ 'mime_type': 'image/png' },
|
||||
save_as=self.VARIABLE_KEY.IMAGE.value))
|
||||
result.append(self.create_blob_message(blob=b64decode(image.b64_json),
|
||||
meta={'mime_type': 'image/png'},
|
||||
save_as=self.VARIABLE_KEY.IMAGE.value))
|
||||
result.append(self.create_text_message(f'\nGenerate image source to Seed ID: {seed_id}'))
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _generate_random_id(length=8):
|
||||
characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
|
||||
random_id = ''.join(random.choices(characters, k=length))
|
||||
return random_id
|
||||
|
@ -29,6 +29,19 @@ parameters:
|
||||
pt_BR: Image prompt, you can check the official documentation of DallE 3
|
||||
llm_description: Image prompt of DallE 3, you should describe the image you want to generate as a list of words as possible as detailed
|
||||
form: llm
|
||||
- name: seed_id
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Seed ID
|
||||
zh_Hans: 种子ID
|
||||
pt_BR: ID da semente
|
||||
human_description:
|
||||
en_US: Image generation seed ID to ensure consistency of series generated images
|
||||
zh_Hans: 图像生成种子ID,确保系列生成图像的一致性
|
||||
pt_BR: ID de semente de geração de imagem para garantir a consistência das imagens geradas em série
|
||||
llm_description: If the user requests image consistency, extract the seed ID from the user's question or context.The seed id consists of an 8-bit string containing uppercase and lowercase letters and numbers
|
||||
form: llm
|
||||
- name: size
|
||||
type: select
|
||||
required: true
|
||||
|
Loading…
x
Reference in New Issue
Block a user