diff --git a/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py b/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py index 3c4e6ee9a5..45cb9b1fe5 100644 --- a/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py +++ b/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py @@ -8,10 +8,10 @@ from core.tools.tool.builtin_tool import BuiltinTool class DallE3Tool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke(self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ invoke tools """ @@ -43,14 +43,18 @@ class DallE3Tool(BuiltinTool): style = tool_parameters.get('style', 'vivid') if style not in ['natural', 'vivid']: return self.create_text_message('Invalid style') + # set extra body + seed_id = tool_parameters.get('seed_id', self._generate_random_id(8)) + extra_body = {'seed': seed_id} # call openapi dalle3 - model=self.runtime.credentials['azure_openai_api_model_name'] + model = self.runtime.credentials['azure_openai_api_model_name'] response = client.images.generate( prompt=prompt, model=model, size=size, n=n, + extra_body=extra_body, style=style, quality=quality, response_format='b64_json' @@ -59,8 +63,15 @@ class DallE3Tool(BuiltinTool): result = [] for image in response.data: - result.append(self.create_blob_message(blob=b64decode(image.b64_json), - meta={ 'mime_type': 'image/png' }, - save_as=self.VARIABLE_KEY.IMAGE.value)) + result.append(self.create_blob_message(blob=b64decode(image.b64_json), + meta={'mime_type': 'image/png'}, + save_as=self.VARIABLE_KEY.IMAGE.value)) + result.append(self.create_text_message(f'\nGenerate image source to Seed ID: {seed_id}')) return result + + @staticmethod + def _generate_random_id(length=8): + characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' + random_id = ''.join(random.choices(characters, k=length)) + return random_id diff --git a/api/core/tools/provider/builtin/azuredalle/tools/dalle3.yaml b/api/core/tools/provider/builtin/azuredalle/tools/dalle3.yaml index 5553984e1e..63a8c99d97 100644 --- a/api/core/tools/provider/builtin/azuredalle/tools/dalle3.yaml +++ b/api/core/tools/provider/builtin/azuredalle/tools/dalle3.yaml @@ -29,6 +29,19 @@ parameters: pt_BR: Imagem prompt, você pode verificar a documentação oficial do DallE 3 llm_description: Image prompt of DallE 3, you should describe the image you want to generate as a list of words as possible as detailed form: llm + - name: seed_id + type: string + required: false + label: + en_US: Seed ID + zh_Hans: 种子ID + pt_BR: ID da semente + human_description: + en_US: Image generation seed ID to ensure consistency of series generated images + zh_Hans: 图像生成种子ID,确保系列生成图像的一致性 + pt_BR: ID de semente de geração de imagem para garantir a consistência das imagens geradas em série + llm_description: If the user requests image consistency, extract the seed ID from the user's question or context.The seed id consists of an 8-bit string containing uppercase and lowercase letters and numbers + form: llm - name: size type: select required: true diff --git a/api/core/tools/provider/builtin/dalle/tools/dalle3.py b/api/core/tools/provider/builtin/dalle/tools/dalle3.py index 87d18f68e0..45a289ddf8 100644 --- a/api/core/tools/provider/builtin/dalle/tools/dalle3.py +++ b/api/core/tools/provider/builtin/dalle/tools/dalle3.py @@ -1,3 +1,4 @@ +import random from base64 import b64decode from typing import Any, Union @@ -9,10 +10,10 @@ from core.tools.tool.builtin_tool import BuiltinTool class DallE3Tool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke(self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ invoke tools """ @@ -53,6 +54,9 @@ class DallE3Tool(BuiltinTool): style = tool_parameters.get('style', 'vivid') if style not in ['natural', 'vivid']: return self.create_text_message('Invalid style') + # set extra body + seed_id = tool_parameters.get('seed_id', self._generate_random_id(8)) + extra_body = {'seed': seed_id} # call openapi dalle3 response = client.images.generate( @@ -60,6 +64,7 @@ class DallE3Tool(BuiltinTool): model='dall-e-3', size=size, n=n, + extra_body=extra_body, style=style, quality=quality, response_format='b64_json' @@ -68,8 +73,15 @@ class DallE3Tool(BuiltinTool): result = [] for image in response.data: - result.append(self.create_blob_message(blob=b64decode(image.b64_json), - meta={ 'mime_type': 'image/png' }, - save_as=self.VARIABLE_KEY.IMAGE.value)) + result.append(self.create_blob_message(blob=b64decode(image.b64_json), + meta={'mime_type': 'image/png'}, + save_as=self.VARIABLE_KEY.IMAGE.value)) + result.append(self.create_text_message(f'\nGenerate image source to Seed ID: {seed_id}')) return result + + @staticmethod + def _generate_random_id(length=8): + characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' + random_id = ''.join(random.choices(characters, k=length)) + return random_id diff --git a/api/core/tools/provider/builtin/dalle/tools/dalle3.yaml b/api/core/tools/provider/builtin/dalle/tools/dalle3.yaml index 7ba5c56889..b07a17212e 100644 --- a/api/core/tools/provider/builtin/dalle/tools/dalle3.yaml +++ b/api/core/tools/provider/builtin/dalle/tools/dalle3.yaml @@ -29,6 +29,19 @@ parameters: pt_BR: Image prompt, you can check the official documentation of DallE 3 llm_description: Image prompt of DallE 3, you should describe the image you want to generate as a list of words as possible as detailed form: llm + - name: seed_id + type: string + required: false + label: + en_US: Seed ID + zh_Hans: 种子ID + pt_BR: ID da semente + human_description: + en_US: Image generation seed ID to ensure consistency of series generated images + zh_Hans: 图像生成种子ID,确保系列生成图像的一致性 + pt_BR: ID de semente de geração de imagem para garantir a consistência das imagens geradas em série + llm_description: If the user requests image consistency, extract the seed ID from the user's question or context.The seed id consists of an 8-bit string containing uppercase and lowercase letters and numbers + form: llm - name: size type: select required: true