mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 22:05:59 +08:00
dalle3 add style consistency parameter (#5067)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM> Co-authored-by: crazywoola <427733928@qq.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
parent
9f7b38c068
commit
b7c72f7a97
@ -10,8 +10,8 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
|||||||
class DallE3Tool(BuiltinTool):
|
class DallE3Tool(BuiltinTool):
|
||||||
def _invoke(self,
|
def _invoke(self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
tool_parameters: dict[str, Any],
|
tool_parameters: dict[str, Any],
|
||||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||||
"""
|
"""
|
||||||
invoke tools
|
invoke tools
|
||||||
"""
|
"""
|
||||||
@ -43,14 +43,18 @@ class DallE3Tool(BuiltinTool):
|
|||||||
style = tool_parameters.get('style', 'vivid')
|
style = tool_parameters.get('style', 'vivid')
|
||||||
if style not in ['natural', 'vivid']:
|
if style not in ['natural', 'vivid']:
|
||||||
return self.create_text_message('Invalid style')
|
return self.create_text_message('Invalid style')
|
||||||
|
# set extra body
|
||||||
|
seed_id = tool_parameters.get('seed_id', self._generate_random_id(8))
|
||||||
|
extra_body = {'seed': seed_id}
|
||||||
|
|
||||||
# call openapi dalle3
|
# call openapi dalle3
|
||||||
model=self.runtime.credentials['azure_openai_api_model_name']
|
model = self.runtime.credentials['azure_openai_api_model_name']
|
||||||
response = client.images.generate(
|
response = client.images.generate(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
model=model,
|
model=model,
|
||||||
size=size,
|
size=size,
|
||||||
n=n,
|
n=n,
|
||||||
|
extra_body=extra_body,
|
||||||
style=style,
|
style=style,
|
||||||
quality=quality,
|
quality=quality,
|
||||||
response_format='b64_json'
|
response_format='b64_json'
|
||||||
@ -60,7 +64,14 @@ class DallE3Tool(BuiltinTool):
|
|||||||
|
|
||||||
for image in response.data:
|
for image in response.data:
|
||||||
result.append(self.create_blob_message(blob=b64decode(image.b64_json),
|
result.append(self.create_blob_message(blob=b64decode(image.b64_json),
|
||||||
meta={ 'mime_type': 'image/png' },
|
meta={'mime_type': 'image/png'},
|
||||||
save_as=self.VARIABLE_KEY.IMAGE.value))
|
save_as=self.VARIABLE_KEY.IMAGE.value))
|
||||||
|
result.append(self.create_text_message(f'\nGenerate image source to Seed ID: {seed_id}'))
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _generate_random_id(length=8):
|
||||||
|
characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
|
||||||
|
random_id = ''.join(random.choices(characters, k=length))
|
||||||
|
return random_id
|
||||||
|
@ -29,6 +29,19 @@ parameters:
|
|||||||
pt_BR: Imagem prompt, você pode verificar a documentação oficial do DallE 3
|
pt_BR: Imagem prompt, você pode verificar a documentação oficial do DallE 3
|
||||||
llm_description: Image prompt of DallE 3, you should describe the image you want to generate as a list of words as possible as detailed
|
llm_description: Image prompt of DallE 3, you should describe the image you want to generate as a list of words as possible as detailed
|
||||||
form: llm
|
form: llm
|
||||||
|
- name: seed_id
|
||||||
|
type: string
|
||||||
|
required: false
|
||||||
|
label:
|
||||||
|
en_US: Seed ID
|
||||||
|
zh_Hans: 种子ID
|
||||||
|
pt_BR: ID da semente
|
||||||
|
human_description:
|
||||||
|
en_US: Image generation seed ID to ensure consistency of series generated images
|
||||||
|
zh_Hans: 图像生成种子ID,确保系列生成图像的一致性
|
||||||
|
pt_BR: ID de semente de geração de imagem para garantir a consistência das imagens geradas em série
|
||||||
|
llm_description: If the user requests image consistency, extract the seed ID from the user's question or context.The seed id consists of an 8-bit string containing uppercase and lowercase letters and numbers
|
||||||
|
form: llm
|
||||||
- name: size
|
- name: size
|
||||||
type: select
|
type: select
|
||||||
required: true
|
required: true
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import random
|
||||||
from base64 import b64decode
|
from base64 import b64decode
|
||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
|
|
||||||
@ -11,8 +12,8 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
|||||||
class DallE3Tool(BuiltinTool):
|
class DallE3Tool(BuiltinTool):
|
||||||
def _invoke(self,
|
def _invoke(self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
tool_parameters: dict[str, Any],
|
tool_parameters: dict[str, Any],
|
||||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||||
"""
|
"""
|
||||||
invoke tools
|
invoke tools
|
||||||
"""
|
"""
|
||||||
@ -53,6 +54,9 @@ class DallE3Tool(BuiltinTool):
|
|||||||
style = tool_parameters.get('style', 'vivid')
|
style = tool_parameters.get('style', 'vivid')
|
||||||
if style not in ['natural', 'vivid']:
|
if style not in ['natural', 'vivid']:
|
||||||
return self.create_text_message('Invalid style')
|
return self.create_text_message('Invalid style')
|
||||||
|
# set extra body
|
||||||
|
seed_id = tool_parameters.get('seed_id', self._generate_random_id(8))
|
||||||
|
extra_body = {'seed': seed_id}
|
||||||
|
|
||||||
# call openapi dalle3
|
# call openapi dalle3
|
||||||
response = client.images.generate(
|
response = client.images.generate(
|
||||||
@ -60,6 +64,7 @@ class DallE3Tool(BuiltinTool):
|
|||||||
model='dall-e-3',
|
model='dall-e-3',
|
||||||
size=size,
|
size=size,
|
||||||
n=n,
|
n=n,
|
||||||
|
extra_body=extra_body,
|
||||||
style=style,
|
style=style,
|
||||||
quality=quality,
|
quality=quality,
|
||||||
response_format='b64_json'
|
response_format='b64_json'
|
||||||
@ -69,7 +74,14 @@ class DallE3Tool(BuiltinTool):
|
|||||||
|
|
||||||
for image in response.data:
|
for image in response.data:
|
||||||
result.append(self.create_blob_message(blob=b64decode(image.b64_json),
|
result.append(self.create_blob_message(blob=b64decode(image.b64_json),
|
||||||
meta={ 'mime_type': 'image/png' },
|
meta={'mime_type': 'image/png'},
|
||||||
save_as=self.VARIABLE_KEY.IMAGE.value))
|
save_as=self.VARIABLE_KEY.IMAGE.value))
|
||||||
|
result.append(self.create_text_message(f'\nGenerate image source to Seed ID: {seed_id}'))
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _generate_random_id(length=8):
|
||||||
|
characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
|
||||||
|
random_id = ''.join(random.choices(characters, k=length))
|
||||||
|
return random_id
|
||||||
|
@ -29,6 +29,19 @@ parameters:
|
|||||||
pt_BR: Image prompt, you can check the official documentation of DallE 3
|
pt_BR: Image prompt, you can check the official documentation of DallE 3
|
||||||
llm_description: Image prompt of DallE 3, you should describe the image you want to generate as a list of words as possible as detailed
|
llm_description: Image prompt of DallE 3, you should describe the image you want to generate as a list of words as possible as detailed
|
||||||
form: llm
|
form: llm
|
||||||
|
- name: seed_id
|
||||||
|
type: string
|
||||||
|
required: false
|
||||||
|
label:
|
||||||
|
en_US: Seed ID
|
||||||
|
zh_Hans: 种子ID
|
||||||
|
pt_BR: ID da semente
|
||||||
|
human_description:
|
||||||
|
en_US: Image generation seed ID to ensure consistency of series generated images
|
||||||
|
zh_Hans: 图像生成种子ID,确保系列生成图像的一致性
|
||||||
|
pt_BR: ID de semente de geração de imagem para garantir a consistência das imagens geradas em série
|
||||||
|
llm_description: If the user requests image consistency, extract the seed ID from the user's question or context.The seed id consists of an 8-bit string containing uppercase and lowercase letters and numbers
|
||||||
|
form: llm
|
||||||
- name: size
|
- name: size
|
||||||
type: select
|
type: select
|
||||||
required: true
|
required: true
|
||||||
|
Loading…
x
Reference in New Issue
Block a user