add three aws tools (#11905)

Co-authored-by: Yuanbo Li <ybalbert@amazon.com>
This commit is contained in:
ybalbert001 2024-12-21 13:36:13 +08:00 committed by GitHub
parent 7a00798027
commit de8800f41a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 1407 additions and 0 deletions

View File

@ -0,0 +1,115 @@
import json
import operator
from typing import Any, Optional, Union
import boto3
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class BedrockRetrieveTool(BuiltinTool):
bedrock_client: Any = None
knowledge_base_id: str = None
topk: int = None
def _bedrock_retrieve(
self, query_input: str, knowledge_base_id: str, num_results: int, metadata_filter: Optional[dict] = None
):
try:
retrieval_query = {"text": query_input}
retrieval_configuration = {"vectorSearchConfiguration": {"numberOfResults": num_results}}
# 如果有元数据过滤条件,则添加到检索配置中
if metadata_filter:
retrieval_configuration["vectorSearchConfiguration"]["filter"] = metadata_filter
response = self.bedrock_client.retrieve(
knowledgeBaseId=knowledge_base_id,
retrievalQuery=retrieval_query,
retrievalConfiguration=retrieval_configuration,
)
results = []
for result in response.get("retrievalResults", []):
results.append(
{
"content": result.get("content", {}).get("text", ""),
"score": result.get("score", 0.0),
"metadata": result.get("metadata", {}),
}
)
return results
except Exception as e:
raise Exception(f"Error retrieving from knowledge base: {str(e)}")
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
"""
line = 0
try:
if not self.bedrock_client:
aws_region = tool_parameters.get("aws_region")
if aws_region:
self.bedrock_client = boto3.client("bedrock-agent-runtime", region_name=aws_region)
else:
self.bedrock_client = boto3.client("bedrock-agent-runtime")
line = 1
if not self.knowledge_base_id:
self.knowledge_base_id = tool_parameters.get("knowledge_base_id")
if not self.knowledge_base_id:
return self.create_text_message("Please provide knowledge_base_id")
line = 2
if not self.topk:
self.topk = tool_parameters.get("topk", 5)
line = 3
query = tool_parameters.get("query", "")
if not query:
return self.create_text_message("Please input query")
# 获取元数据过滤条件(如果存在)
metadata_filter_str = tool_parameters.get("metadata_filter")
metadata_filter = json.loads(metadata_filter_str) if metadata_filter_str else None
line = 4
retrieved_docs = self._bedrock_retrieve(
query_input=query,
knowledge_base_id=self.knowledge_base_id,
num_results=self.topk,
metadata_filter=metadata_filter, # 将元数据过滤条件传递给检索方法
)
line = 5
# Sort results by score in descending order
sorted_docs = sorted(retrieved_docs, key=operator.itemgetter("score"), reverse=True)
line = 6
return [self.create_json_message(res) for res in sorted_docs]
except Exception as e:
return self.create_text_message(f"Exception {str(e)}, line : {line}")
def validate_parameters(self, parameters: dict[str, Any]) -> None:
"""
Validate the parameters
"""
if not parameters.get("knowledge_base_id"):
raise ValueError("knowledge_base_id is required")
if not parameters.get("query"):
raise ValueError("query is required")
# 可选:可以验证元数据过滤条件是否为有效的 JSON 字符串(如果提供)
metadata_filter_str = parameters.get("metadata_filter")
if metadata_filter_str and not isinstance(json.loads(metadata_filter_str), dict):
raise ValueError("metadata_filter must be a valid JSON object")

View File

@ -0,0 +1,87 @@
identity:
name: bedrock_retrieve
author: AWS
label:
en_US: Bedrock Retrieve
zh_Hans: Bedrock检索
pt_BR: Bedrock Retrieve
icon: icon.svg
description:
human:
en_US: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base. You can find deploy instructions on Github Repo - https://github.com/aws-samples/dify-aws-tool
zh_Hans: Amazon Bedrock知识库检索工具, 请参考 Github Repo - https://github.com/aws-samples/dify-aws-tool上的部署说明
pt_BR: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base.
llm: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base. You can find deploy instructions on Github Repo - https://github.com/aws-samples/dify-aws-tool
parameters:
- name: knowledge_base_id
type: string
required: true
label:
en_US: Bedrock Knowledge Base ID
zh_Hans: Bedrock知识库ID
pt_BR: Bedrock Knowledge Base ID
human_description:
en_US: ID of the Bedrock Knowledge Base to retrieve from
zh_Hans: 用于检索的Bedrock知识库ID
pt_BR: ID of the Bedrock Knowledge Base to retrieve from
llm_description: ID of the Bedrock Knowledge Base to retrieve from
form: form
- name: query
type: string
required: true
label:
en_US: Query string
zh_Hans: 查询语句
pt_BR: Query string
human_description:
en_US: The search query to retrieve relevant information
zh_Hans: 用于检索相关信息的查询语句
pt_BR: The search query to retrieve relevant information
llm_description: The search query to retrieve relevant information
form: llm
- name: topk
type: number
required: false
form: form
label:
en_US: Limit for results count
zh_Hans: 返回结果数量限制
pt_BR: Limit for results count
human_description:
en_US: Maximum number of results to return
zh_Hans: 最大返回结果数量
pt_BR: Maximum number of results to return
min: 1
max: 10
default: 5
- name: aws_region
type: string
required: false
label:
en_US: AWS Region
zh_Hans: AWS 区域
pt_BR: AWS Region
human_description:
en_US: AWS region where the Bedrock Knowledge Base is located
zh_Hans: Bedrock知识库所在的AWS区域
pt_BR: AWS region where the Bedrock Knowledge Base is located
llm_description: AWS region where the Bedrock Knowledge Base is located
form: form
- name: metadata_filter
type: string
required: false
label:
en_US: Metadata Filter
zh_Hans: 元数据过滤器
pt_BR: Metadata Filter
human_description:
en_US: 'JSON formatted filter conditions for metadata (e.g., {"greaterThan": {"key: "aaa", "value": 10}})'
zh_Hans: '元数据的JSON格式过滤条件例如{{"greaterThan": {"key: "aaa", "value": 10}}'
pt_BR: 'JSON formatted filter conditions for metadata (e.g., {"greaterThan": {"key: "aaa", "value": 10}})'
form: form

View File

@ -0,0 +1,357 @@
import base64
import json
import logging
import re
from datetime import datetime
from typing import Any, Union
from urllib.parse import urlparse
import boto3
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.tool.builtin_tool import BuiltinTool
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class NovaCanvasTool(BuiltinTool):
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
Invoke AWS Bedrock Nova Canvas model for image generation
"""
# Get common parameters
prompt = tool_parameters.get("prompt", "")
image_output_s3uri = tool_parameters.get("image_output_s3uri", "").strip()
if not prompt:
return self.create_text_message("Please provide a text prompt for image generation.")
if not image_output_s3uri or urlparse(image_output_s3uri).scheme != "s3":
return self.create_text_message("Please provide an valid S3 URI for image output.")
task_type = tool_parameters.get("task_type", "TEXT_IMAGE")
aws_region = tool_parameters.get("aws_region", "us-east-1")
# Get common image generation config parameters
width = tool_parameters.get("width", 1024)
height = tool_parameters.get("height", 1024)
cfg_scale = tool_parameters.get("cfg_scale", 8.0)
negative_prompt = tool_parameters.get("negative_prompt", "")
seed = tool_parameters.get("seed", 0)
quality = tool_parameters.get("quality", "standard")
# Handle S3 image if provided
image_input_s3uri = tool_parameters.get("image_input_s3uri", "")
if task_type != "TEXT_IMAGE":
if not image_input_s3uri or urlparse(image_input_s3uri).scheme != "s3":
return self.create_text_message("Please provide a valid S3 URI for image to image generation.")
# Parse S3 URI
parsed_uri = urlparse(image_input_s3uri)
bucket = parsed_uri.netloc
key = parsed_uri.path.lstrip("/")
# Initialize S3 client and download image
s3_client = boto3.client("s3")
response = s3_client.get_object(Bucket=bucket, Key=key)
image_data = response["Body"].read()
# Base64 encode the image
input_image = base64.b64encode(image_data).decode("utf-8")
try:
# Initialize Bedrock client
bedrock = boto3.client(service_name="bedrock-runtime", region_name=aws_region)
# Base image generation config
image_generation_config = {
"width": width,
"height": height,
"cfgScale": cfg_scale,
"seed": seed,
"numberOfImages": 1,
"quality": quality,
}
# Prepare request body based on task type
body = {"imageGenerationConfig": image_generation_config}
if task_type == "TEXT_IMAGE":
body["taskType"] = "TEXT_IMAGE"
body["textToImageParams"] = {"text": prompt}
if negative_prompt:
body["textToImageParams"]["negativeText"] = negative_prompt
elif task_type == "COLOR_GUIDED_GENERATION":
colors = tool_parameters.get("colors", "#ff8080-#ffb280-#ffe680-#ffe680")
if not self._validate_color_string(colors):
return self.create_text_message("Please provide valid colors in hexadecimal format.")
body["taskType"] = "COLOR_GUIDED_GENERATION"
body["colorGuidedGenerationParams"] = {
"colors": colors.split("-"),
"referenceImage": input_image,
"text": prompt,
}
if negative_prompt:
body["colorGuidedGenerationParams"]["negativeText"] = negative_prompt
elif task_type == "IMAGE_VARIATION":
similarity_strength = tool_parameters.get("similarity_strength", 0.5)
body["taskType"] = "IMAGE_VARIATION"
body["imageVariationParams"] = {
"images": [input_image],
"similarityStrength": similarity_strength,
"text": prompt,
}
if negative_prompt:
body["imageVariationParams"]["negativeText"] = negative_prompt
elif task_type == "INPAINTING":
mask_prompt = tool_parameters.get("mask_prompt")
if not mask_prompt:
return self.create_text_message("Please provide a mask prompt for image inpainting.")
body["taskType"] = "INPAINTING"
body["inPaintingParams"] = {"image": input_image, "maskPrompt": mask_prompt, "text": prompt}
if negative_prompt:
body["inPaintingParams"]["negativeText"] = negative_prompt
elif task_type == "OUTPAINTING":
mask_prompt = tool_parameters.get("mask_prompt")
if not mask_prompt:
return self.create_text_message("Please provide a mask prompt for image outpainting.")
outpainting_mode = tool_parameters.get("outpainting_mode", "DEFAULT")
body["taskType"] = "OUTPAINTING"
body["outPaintingParams"] = {
"image": input_image,
"maskPrompt": mask_prompt,
"outPaintingMode": outpainting_mode,
"text": prompt,
}
if negative_prompt:
body["outPaintingParams"]["negativeText"] = negative_prompt
elif task_type == "BACKGROUND_REMOVAL":
body["taskType"] = "BACKGROUND_REMOVAL"
body["backgroundRemovalParams"] = {"image": input_image}
else:
return self.create_text_message(f"Unsupported task type: {task_type}")
# Call Nova Canvas model
response = bedrock.invoke_model(
body=json.dumps(body),
modelId="amazon.nova-canvas-v1:0",
accept="application/json",
contentType="application/json",
)
# Process response
response_body = json.loads(response.get("body").read())
if response_body.get("error"):
raise Exception(f"Error in model response: {response_body.get('error')}")
base64_image = response_body.get("images")[0]
# Upload to S3 if image_output_s3uri is provided
try:
# Parse S3 URI for output
parsed_uri = urlparse(image_output_s3uri)
output_bucket = parsed_uri.netloc
output_base_path = parsed_uri.path.lstrip("/")
# Generate filename with timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_key = f"{output_base_path}/canvas-output-{timestamp}.png"
# Initialize S3 client if not already done
s3_client = boto3.client("s3", region_name=aws_region)
# Decode base64 image and upload to S3
image_data = base64.b64decode(base64_image)
s3_client.put_object(Bucket=output_bucket, Key=output_key, Body=image_data, ContentType="image/png")
logger.info(f"Image uploaded to s3://{output_bucket}/{output_key}")
except Exception as e:
logger.exception("Failed to upload image to S3")
# Return image
return [
self.create_text_message(f"Image is available at: s3://{output_bucket}/{output_key}"),
self.create_blob_message(
blob=base64.b64decode(base64_image),
meta={"mime_type": "image/png"},
save_as=self.VariableKey.IMAGE.value,
),
]
except Exception as e:
return self.create_text_message(f"Failed to generate image: {str(e)}")
def _validate_color_string(self, color_string) -> bool:
color_pattern = r"^#[0-9a-fA-F]{6}(?:-#[0-9a-fA-F]{6})*$"
if re.match(color_pattern, color_string):
return True
return False
def get_runtime_parameters(self) -> list[ToolParameter]:
parameters = [
ToolParameter(
name="prompt",
label=I18nObject(en_US="Prompt", zh_Hans="提示词"),
type=ToolParameter.ToolParameterType.STRING,
required=True,
form=ToolParameter.ToolParameterForm.LLM,
human_description=I18nObject(
en_US="Text description of the image you want to generate or modify",
zh_Hans="您想要生成或修改的图像的文本描述",
),
llm_description="Describe the image you want to generate or how you want to modify the input image",
),
ToolParameter(
name="image_input_s3uri",
label=I18nObject(en_US="Input image s3 uri", zh_Hans="输入图片的s3 uri"),
type=ToolParameter.ToolParameterType.STRING,
required=False,
form=ToolParameter.ToolParameterForm.LLM,
human_description=I18nObject(en_US="Image to be modified", zh_Hans="想要修改的图片"),
),
ToolParameter(
name="image_output_s3uri",
label=I18nObject(en_US="Output Image S3 URI", zh_Hans="输出图片的S3 URI目录"),
type=ToolParameter.ToolParameterType.STRING,
required=True,
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(
en_US="S3 URI where the generated image should be uploaded", zh_Hans="生成的图像应该上传到的S3 URI"
),
),
ToolParameter(
name="width",
label=I18nObject(en_US="Width", zh_Hans="宽度"),
type=ToolParameter.ToolParameterType.NUMBER,
required=False,
default=1024,
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(en_US="Width of the generated image", zh_Hans="生成图像的宽度"),
),
ToolParameter(
name="height",
label=I18nObject(en_US="Height", zh_Hans="高度"),
type=ToolParameter.ToolParameterType.NUMBER,
required=False,
default=1024,
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(en_US="Height of the generated image", zh_Hans="生成图像的高度"),
),
ToolParameter(
name="cfg_scale",
label=I18nObject(en_US="CFG Scale", zh_Hans="CFG比例"),
type=ToolParameter.ToolParameterType.NUMBER,
required=False,
default=8.0,
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(
en_US="How strongly the image should conform to the prompt", zh_Hans="图像应该多大程度上符合提示词"
),
),
ToolParameter(
name="negative_prompt",
label=I18nObject(en_US="Negative Prompt", zh_Hans="负面提示词"),
type=ToolParameter.ToolParameterType.STRING,
required=False,
default="",
form=ToolParameter.ToolParameterForm.LLM,
human_description=I18nObject(
en_US="Things you don't want in the generated image", zh_Hans="您不想在生成的图像中出现的内容"
),
),
ToolParameter(
name="seed",
label=I18nObject(en_US="Seed", zh_Hans="种子值"),
type=ToolParameter.ToolParameterType.NUMBER,
required=False,
default=0,
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(en_US="Random seed for image generation", zh_Hans="图像生成的随机种子"),
),
ToolParameter(
name="aws_region",
label=I18nObject(en_US="AWS Region", zh_Hans="AWS 区域"),
type=ToolParameter.ToolParameterType.STRING,
required=False,
default="us-east-1",
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(en_US="AWS region for Bedrock service", zh_Hans="Bedrock 服务的 AWS 区域"),
),
ToolParameter(
name="task_type",
label=I18nObject(en_US="Task Type", zh_Hans="任务类型"),
type=ToolParameter.ToolParameterType.STRING,
required=False,
default="TEXT_IMAGE",
form=ToolParameter.ToolParameterForm.LLM,
human_description=I18nObject(en_US="Type of image generation task", zh_Hans="图像生成任务的类型"),
),
ToolParameter(
name="quality",
label=I18nObject(en_US="Quality", zh_Hans="质量"),
type=ToolParameter.ToolParameterType.STRING,
required=False,
default="standard",
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(
en_US="Quality of the generated image (standard or premium)", zh_Hans="生成图像的质量(标准或高级)"
),
),
ToolParameter(
name="colors",
label=I18nObject(en_US="Colors", zh_Hans="颜色"),
type=ToolParameter.ToolParameterType.STRING,
required=False,
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(
en_US="List of colors for color-guided generation, example: #ff8080-#ffb280-#ffe680-#ffe680",
zh_Hans="颜色引导生成的颜色列表, 例子: #ff8080-#ffb280-#ffe680-#ffe680",
),
),
ToolParameter(
name="similarity_strength",
label=I18nObject(en_US="Similarity Strength", zh_Hans="相似度强度"),
type=ToolParameter.ToolParameterType.NUMBER,
required=False,
default=0.5,
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(
en_US="How similar the generated image should be to the input image (0.0 to 1.0)",
zh_Hans="生成的图像应该与输入图像的相似程度0.0到1.0",
),
),
ToolParameter(
name="mask_prompt",
label=I18nObject(en_US="Mask Prompt", zh_Hans="蒙版提示词"),
type=ToolParameter.ToolParameterType.STRING,
required=False,
form=ToolParameter.ToolParameterForm.LLM,
human_description=I18nObject(
en_US="Text description to generate mask for inpainting/outpainting",
zh_Hans="用于生成内补绘制/外补绘制蒙版的文本描述",
),
),
ToolParameter(
name="outpainting_mode",
label=I18nObject(en_US="Outpainting Mode", zh_Hans="外补绘制模式"),
type=ToolParameter.ToolParameterType.STRING,
required=False,
default="DEFAULT",
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(
en_US="Mode for outpainting (DEFAULT or other supported modes)",
zh_Hans="外补绘制的模式DEFAULT或其他支持的模式",
),
),
]
return parameters

View File

@ -0,0 +1,175 @@
identity:
name: nova_canvas
author: AWS
label:
en_US: AWS Bedrock Nova Canvas
zh_Hans: AWS Bedrock Nova Canvas
icon: icon.svg
description:
human:
en_US: A tool for generating and modifying images using AWS Bedrock's Nova Canvas model. Supports text-to-image, color-guided generation, image variation, inpainting, outpainting, and background removal. Input parameters reference https://docs.aws.amazon.com/nova/latest/userguide/image-gen-req-resp-structure.html
zh_Hans: 使用 AWS Bedrock 的 Nova Canvas 模型生成和修改图像的工具。支持文生图、颜色引导生成、图像变体、内补绘制、外补绘制和背景移除功能, 输入参数参考 https://docs.aws.amazon.com/nova/latest/userguide/image-gen-req-resp-structure.html。
llm: Generate or modify images using AWS Bedrock's Nova Canvas model with multiple task types including text-to-image, color-guided generation, image variation, inpainting, outpainting, and background removal.
parameters:
- name: task_type
type: string
required: false
default: TEXT_IMAGE
label:
en_US: Task Type
zh_Hans: 任务类型
human_description:
en_US: Type of image generation task (TEXT_IMAGE, COLOR_GUIDED_GENERATION, IMAGE_VARIATION, INPAINTING, OUTPAINTING, BACKGROUND_REMOVAL)
zh_Hans: 图像生成任务的类型(文生图、颜色引导生成、图像变体、内补绘制、外补绘制、背景移除)
form: llm
- name: prompt
type: string
required: true
label:
en_US: Prompt
zh_Hans: 提示词
human_description:
en_US: Text description of the image you want to generate or modify
zh_Hans: 您想要生成或修改的图像的文本描述
llm_description: Describe the image you want to generate or how you want to modify the input image
form: llm
- name: image_input_s3uri
type: string
required: false
label:
en_US: Input image s3 uri
zh_Hans: 输入图片的s3 uri
human_description:
en_US: The input image to modify (required for all modes except TEXT_IMAGE)
zh_Hans: 要修改的输入图像(除文生图外的所有模式都需要)
llm_description: The input image you want to modify. Required for all modes except TEXT_IMAGE.
form: llm
- name: image_output_s3uri
type: string
required: true
label:
en_US: Output S3 URI
zh_Hans: 输出S3 URI
human_description:
en_US: The S3 URI where the generated image will be saved. If provided, the image will be uploaded with name format canvas-output-{timestamp}.png
zh_Hans: 生成的图像将保存到的S3 URI。如果提供图像将以canvas-output-{timestamp}.png的格式上传
llm_description: Optional S3 URI where the generated image will be uploaded. The image will be saved with a timestamp-based filename.
form: form
- name: negative_prompt
type: string
required: false
label:
en_US: Negative Prompt
zh_Hans: 负面提示词
human_description:
en_US: Things you don't want in the generated image
zh_Hans: 您不想在生成的图像中出现的内容
form: llm
- name: width
type: number
required: false
label:
en_US: Width
zh_Hans: 宽度
human_description:
en_US: Width of the generated image
zh_Hans: 生成图像的宽度
form: form
default: 1024
- name: height
type: number
required: false
label:
en_US: Height
zh_Hans: 高度
human_description:
en_US: Height of the generated image
zh_Hans: 生成图像的高度
form: form
default: 1024
- name: cfg_scale
type: number
required: false
label:
en_US: CFG Scale
zh_Hans: CFG比例
human_description:
en_US: How strongly the image should conform to the prompt
zh_Hans: 图像应该多大程度上符合提示词
form: form
default: 8.0
- name: seed
type: number
required: false
label:
en_US: Seed
zh_Hans: 种子值
human_description:
en_US: Random seed for image generation
zh_Hans: 图像生成的随机种子
form: form
default: 0
- name: aws_region
type: string
required: false
default: us-east-1
label:
en_US: AWS Region
zh_Hans: AWS 区域
human_description:
en_US: AWS region for Bedrock service
zh_Hans: Bedrock 服务的 AWS 区域
form: form
- name: quality
type: string
required: false
default: standard
label:
en_US: Quality
zh_Hans: 质量
human_description:
en_US: Quality of the generated image (standard or premium)
zh_Hans: 生成图像的质量(标准或高级)
form: form
- name: colors
type: string
required: false
label:
en_US: Colors
zh_Hans: 颜色
human_description:
en_US: List of colors for color-guided generation
zh_Hans: 颜色引导生成的颜色列表
form: form
- name: similarity_strength
type: number
required: false
default: 0.5
label:
en_US: Similarity Strength
zh_Hans: 相似度强度
human_description:
en_US: How similar the generated image should be to the input image (0.0 to 1.0)
zh_Hans: 生成的图像应该与输入图像的相似程度0.0到1.0
form: form
- name: mask_prompt
type: string
required: false
label:
en_US: Mask Prompt
zh_Hans: 蒙版提示词
human_description:
en_US: Text description to generate mask for inpainting/outpainting
zh_Hans: 用于生成内补绘制/外补绘制蒙版的文本描述
form: llm
- name: outpainting_mode
type: string
required: false
default: DEFAULT
label:
en_US: Outpainting Mode
zh_Hans: 外补绘制模式
human_description:
en_US: Mode for outpainting (DEFAULT or other supported modes)
zh_Hans: 外补绘制的模式DEFAULT或其他支持的模式
form: form

View File

@ -0,0 +1,371 @@
import base64
import logging
import time
from io import BytesIO
from typing import Any, Optional, Union
from urllib.parse import urlparse
import boto3
from botocore.exceptions import ClientError
from PIL import Image
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.tool.builtin_tool import BuiltinTool
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
NOVA_REEL_DEFAULT_REGION = "us-east-1"
NOVA_REEL_DEFAULT_DIMENSION = "1280x720"
NOVA_REEL_DEFAULT_FPS = 24
NOVA_REEL_DEFAULT_DURATION = 6
NOVA_REEL_MODEL_ID = "amazon.nova-reel-v1:0"
NOVA_REEL_STATUS_CHECK_INTERVAL = 5
# Image requirements
NOVA_REEL_REQUIRED_IMAGE_WIDTH = 1280
NOVA_REEL_REQUIRED_IMAGE_HEIGHT = 720
NOVA_REEL_REQUIRED_IMAGE_MODE = "RGB"
class NovaReelTool(BuiltinTool):
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
Invoke AWS Bedrock Nova Reel model for video generation.
Args:
user_id: The ID of the user making the request
tool_parameters: Dictionary containing the tool parameters
Returns:
ToolInvokeMessage containing either the video content or status information
"""
try:
# Validate and extract parameters
params = self._validate_and_extract_parameters(tool_parameters)
if isinstance(params, ToolInvokeMessage):
return params
# Initialize AWS clients
bedrock, s3_client = self._initialize_aws_clients(params["aws_region"])
# Prepare model input
model_input = self._prepare_model_input(params, s3_client)
if isinstance(model_input, ToolInvokeMessage):
return model_input
# Start video generation
invocation = self._start_video_generation(bedrock, model_input, params["video_output_s3uri"])
invocation_arn = invocation["invocationArn"]
# Handle async/sync mode
return self._handle_generation_mode(bedrock, s3_client, invocation_arn, params["async_mode"])
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code", "Unknown")
error_message = e.response.get("Error", {}).get("Message", str(e))
logger.exception(f"AWS API error: {error_code} - {error_message}")
return self.create_text_message(f"AWS service error: {error_code} - {error_message}")
except Exception as e:
logger.error(f"Unexpected error in video generation: {str(e)}", exc_info=True)
return self.create_text_message(f"Failed to generate video: {str(e)}")
def _validate_and_extract_parameters(
self, tool_parameters: dict[str, Any]
) -> Union[dict[str, Any], ToolInvokeMessage]:
"""Validate and extract parameters from the input dictionary."""
prompt = tool_parameters.get("prompt", "")
video_output_s3uri = tool_parameters.get("video_output_s3uri", "").strip()
# Validate required parameters
if not prompt:
return self.create_text_message("Please provide a text prompt for video generation.")
if not video_output_s3uri:
return self.create_text_message("Please provide an S3 URI for video output.")
# Validate S3 URI format
if not video_output_s3uri.startswith("s3://"):
return self.create_text_message("Invalid S3 URI format. Must start with 's3://'")
# Ensure S3 URI ends with '/'
video_output_s3uri = video_output_s3uri if video_output_s3uri.endswith("/") else video_output_s3uri + "/"
return {
"prompt": prompt,
"video_output_s3uri": video_output_s3uri,
"image_input_s3uri": tool_parameters.get("image_input_s3uri", "").strip(),
"aws_region": tool_parameters.get("aws_region", NOVA_REEL_DEFAULT_REGION),
"dimension": tool_parameters.get("dimension", NOVA_REEL_DEFAULT_DIMENSION),
"seed": int(tool_parameters.get("seed", 0)),
"fps": int(tool_parameters.get("fps", NOVA_REEL_DEFAULT_FPS)),
"duration": int(tool_parameters.get("duration", NOVA_REEL_DEFAULT_DURATION)),
"async_mode": bool(tool_parameters.get("async", True)),
}
def _initialize_aws_clients(self, region: str) -> tuple[Any, Any]:
"""Initialize AWS Bedrock and S3 clients."""
bedrock = boto3.client(service_name="bedrock-runtime", region_name=region)
s3_client = boto3.client("s3", region_name=region)
return bedrock, s3_client
def _prepare_model_input(self, params: dict[str, Any], s3_client: Any) -> Union[dict[str, Any], ToolInvokeMessage]:
"""Prepare the input for the Nova Reel model."""
model_input = {
"taskType": "TEXT_VIDEO",
"textToVideoParams": {"text": params["prompt"]},
"videoGenerationConfig": {
"durationSeconds": params["duration"],
"fps": params["fps"],
"dimension": params["dimension"],
"seed": params["seed"],
},
}
# Add image if provided
if params["image_input_s3uri"]:
try:
image_data = self._get_image_from_s3(s3_client, params["image_input_s3uri"])
if not image_data:
return self.create_text_message("Failed to retrieve image from S3")
# Process and validate image
processed_image = self._process_and_validate_image(image_data)
if isinstance(processed_image, ToolInvokeMessage):
return processed_image
# Convert processed image to base64
img_buffer = BytesIO()
processed_image.save(img_buffer, format="PNG")
img_buffer.seek(0)
input_image_base64 = base64.b64encode(img_buffer.getvalue()).decode("utf-8")
model_input["textToVideoParams"]["images"] = [
{"format": "png", "source": {"bytes": input_image_base64}}
]
except Exception as e:
logger.error(f"Error processing input image: {str(e)}", exc_info=True)
return self.create_text_message(f"Failed to process input image: {str(e)}")
return model_input
def _process_and_validate_image(self, image_data: bytes) -> Union[Image.Image, ToolInvokeMessage]:
"""
Process and validate the input image according to Nova Reel requirements.
Requirements:
- Must be 1280x720 pixels
- Must be RGB format (8 bits per channel)
- If PNG, alpha channel must not have transparent/translucent pixels
"""
try:
# Open image
img = Image.open(BytesIO(image_data))
# Convert RGBA to RGB if needed, ensuring no transparency
if img.mode == "RGBA":
# Check for transparency
if img.getchannel("A").getextrema()[0] < 255:
return self.create_text_message(
"PNG image contains transparent or translucent pixels, which is not supported. "
"Please provide an image without transparency."
)
# Convert to RGB
img = img.convert("RGB")
elif img.mode != "RGB":
# Convert any other mode to RGB
img = img.convert("RGB")
# Validate/adjust dimensions
if img.size != (NOVA_REEL_REQUIRED_IMAGE_WIDTH, NOVA_REEL_REQUIRED_IMAGE_HEIGHT):
logger.warning(
f"Image dimensions {img.size} do not match required dimensions "
f"({NOVA_REEL_REQUIRED_IMAGE_WIDTH}x{NOVA_REEL_REQUIRED_IMAGE_HEIGHT}). Resizing..."
)
img = img.resize(
(NOVA_REEL_REQUIRED_IMAGE_WIDTH, NOVA_REEL_REQUIRED_IMAGE_HEIGHT), Image.Resampling.LANCZOS
)
# Validate bit depth
if img.mode != NOVA_REEL_REQUIRED_IMAGE_MODE:
return self.create_text_message(
f"Image must be in {NOVA_REEL_REQUIRED_IMAGE_MODE} mode with 8 bits per channel"
)
return img
except Exception as e:
logger.error(f"Error processing image: {str(e)}", exc_info=True)
return self.create_text_message(
"Failed to process image. Please ensure the image is a valid JPEG or PNG file."
)
def _get_image_from_s3(self, s3_client: Any, s3_uri: str) -> Optional[bytes]:
"""Download and return image data from S3."""
parsed_uri = urlparse(s3_uri)
bucket = parsed_uri.netloc
key = parsed_uri.path.lstrip("/")
response = s3_client.get_object(Bucket=bucket, Key=key)
return response["Body"].read()
def _start_video_generation(self, bedrock: Any, model_input: dict[str, Any], output_s3uri: str) -> dict[str, Any]:
"""Start the async video generation process."""
return bedrock.start_async_invoke(
modelId=NOVA_REEL_MODEL_ID,
modelInput=model_input,
outputDataConfig={"s3OutputDataConfig": {"s3Uri": output_s3uri}},
)
def _handle_generation_mode(
self, bedrock: Any, s3_client: Any, invocation_arn: str, async_mode: bool
) -> ToolInvokeMessage:
"""Handle async or sync video generation mode."""
invocation_response = bedrock.get_async_invoke(invocationArn=invocation_arn)
video_path = invocation_response["outputDataConfig"]["s3OutputDataConfig"]["s3Uri"]
video_uri = f"{video_path}/output.mp4"
if async_mode:
return self.create_text_message(
f"Video generation started.\nInvocation ARN: {invocation_arn}\n"
f"Video will be available at: {video_uri}"
)
return self._wait_for_completion(bedrock, s3_client, invocation_arn)
def _wait_for_completion(self, bedrock: Any, s3_client: Any, invocation_arn: str) -> ToolInvokeMessage:
"""Wait for video generation completion and handle the result."""
while True:
status_response = bedrock.get_async_invoke(invocationArn=invocation_arn)
status = status_response["status"]
video_path = status_response["outputDataConfig"]["s3OutputDataConfig"]["s3Uri"]
if status == "Completed":
return self._handle_completed_video(s3_client, video_path)
elif status == "Failed":
failure_message = status_response.get("failureMessage", "Unknown error")
return self.create_text_message(f"Video generation failed.\nError: {failure_message}")
elif status == "InProgress":
time.sleep(NOVA_REEL_STATUS_CHECK_INTERVAL)
else:
return self.create_text_message(f"Unexpected status: {status}")
def _handle_completed_video(self, s3_client: Any, video_path: str) -> ToolInvokeMessage:
"""Handle completed video generation and return the result."""
parsed_uri = urlparse(video_path)
bucket = parsed_uri.netloc
key = parsed_uri.path.lstrip("/") + "/output.mp4"
try:
response = s3_client.get_object(Bucket=bucket, Key=key)
video_content = response["Body"].read()
return [
self.create_text_message(f"Video is available at: {video_path}/output.mp4"),
self.create_blob_message(blob=video_content, meta={"mime_type": "video/mp4"}, save_as="output.mp4"),
]
except Exception as e:
logger.error(f"Error downloading video: {str(e)}", exc_info=True)
return self.create_text_message(
f"Video generation completed but failed to download video: {str(e)}\n"
f"Video is available at: s3://{bucket}/{key}"
)
def get_runtime_parameters(self) -> list[ToolParameter]:
"""Define the tool's runtime parameters."""
parameters = [
ToolParameter(
name="prompt",
label=I18nObject(en_US="Prompt", zh_Hans="提示词"),
type=ToolParameter.ToolParameterType.STRING,
required=True,
form=ToolParameter.ToolParameterForm.LLM,
human_description=I18nObject(
en_US="Text description of the video you want to generate", zh_Hans="您想要生成的视频的文本描述"
),
llm_description="Describe the video you want to generate",
),
ToolParameter(
name="video_output_s3uri",
label=I18nObject(en_US="Output S3 URI", zh_Hans="输出S3 URI"),
type=ToolParameter.ToolParameterType.STRING,
required=True,
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(
en_US="S3 URI where the generated video will be stored", zh_Hans="生成的视频将存储的S3 URI"
),
),
ToolParameter(
name="dimension",
label=I18nObject(en_US="Dimension", zh_Hans="尺寸"),
type=ToolParameter.ToolParameterType.STRING,
required=False,
default=NOVA_REEL_DEFAULT_DIMENSION,
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(en_US="Video dimensions (width x height)", zh_Hans="视频尺寸(宽 x 高)"),
),
ToolParameter(
name="duration",
label=I18nObject(en_US="Duration", zh_Hans="时长"),
type=ToolParameter.ToolParameterType.NUMBER,
required=False,
default=NOVA_REEL_DEFAULT_DURATION,
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(en_US="Video duration in seconds", zh_Hans="视频时长(秒)"),
),
ToolParameter(
name="seed",
label=I18nObject(en_US="Seed", zh_Hans="种子值"),
type=ToolParameter.ToolParameterType.NUMBER,
required=False,
default=0,
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(en_US="Random seed for video generation", zh_Hans="视频生成的随机种子"),
),
ToolParameter(
name="fps",
label=I18nObject(en_US="FPS", zh_Hans="帧率"),
type=ToolParameter.ToolParameterType.NUMBER,
required=False,
default=NOVA_REEL_DEFAULT_FPS,
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(
en_US="Frames per second for the generated video", zh_Hans="生成视频的每秒帧数"
),
),
ToolParameter(
name="async",
label=I18nObject(en_US="Async Mode", zh_Hans="异步模式"),
type=ToolParameter.ToolParameterType.BOOLEAN,
required=False,
default=True,
form=ToolParameter.ToolParameterForm.LLM,
human_description=I18nObject(
en_US="Whether to run in async mode (return immediately) or sync mode (wait for completion)",
zh_Hans="是否以异步模式运行(立即返回)或同步模式(等待完成)",
),
),
ToolParameter(
name="aws_region",
label=I18nObject(en_US="AWS Region", zh_Hans="AWS 区域"),
type=ToolParameter.ToolParameterType.STRING,
required=False,
default=NOVA_REEL_DEFAULT_REGION,
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(en_US="AWS region for Bedrock service", zh_Hans="Bedrock 服务的 AWS 区域"),
),
ToolParameter(
name="image_input_s3uri",
label=I18nObject(en_US="Input Image S3 URI", zh_Hans="输入图像S3 URI"),
type=ToolParameter.ToolParameterType.STRING,
required=False,
form=ToolParameter.ToolParameterForm.LLM,
human_description=I18nObject(
en_US="S3 URI of the input image (1280x720 JPEG/PNG) to use as first frame",
zh_Hans="用作第一帧的输入图像1280x720 JPEG/PNG的S3 URI",
),
),
]
return parameters

View File

@ -0,0 +1,124 @@
identity:
name: nova_reel
author: AWS
label:
en_US: AWS Bedrock Nova Reel
zh_Hans: AWS Bedrock Nova Reel
icon: icon.svg
description:
human:
en_US: A tool for generating videos using AWS Bedrock's Nova Reel model. Supports text-to-video generation and image-to-video generation with customizable parameters like duration, FPS, and dimensions. Input parameters reference https://docs.aws.amazon.com/nova/latest/userguide/video-generation.html
zh_Hans: 使用 AWS Bedrock 的 Nova Reel 模型生成视频的工具。支持文本生成视频和图像生成视频功能,可自定义持续时间、帧率和尺寸等参数。输入参数参考 https://docs.aws.amazon.com/nova/latest/userguide/video-generation.html
llm: Generate videos using AWS Bedrock's Nova Reel model with support for both text-to-video and image-to-video generation, allowing customization of video properties like duration, frame rate, and resolution.
parameters:
- name: prompt
type: string
required: true
label:
en_US: Prompt
zh_Hans: 提示词
human_description:
en_US: Text description of the video you want to generate
zh_Hans: 您想要生成的视频的文本描述
llm_description: Describe the video you want to generate
form: llm
- name: video_output_s3uri
type: string
required: true
label:
en_US: Output S3 URI
zh_Hans: 输出S3 URI
human_description:
en_US: S3 URI where the generated video will be stored
zh_Hans: 生成的视频将存储的S3 URI
form: form
- name: dimension
type: string
required: false
default: 1280x720
label:
en_US: Dimension
zh_Hans: 尺寸
human_description:
en_US: Video dimensions (width x height)
zh_Hans: 视频尺寸(宽 x 高)
form: form
- name: duration
type: number
required: false
default: 6
label:
en_US: Duration
zh_Hans: 时长
human_description:
en_US: Video duration in seconds
zh_Hans: 视频时长(秒)
form: form
- name: seed
type: number
required: false
default: 0
label:
en_US: Seed
zh_Hans: 种子值
human_description:
en_US: Random seed for video generation
zh_Hans: 视频生成的随机种子
form: form
- name: fps
type: number
required: false
default: 24
label:
en_US: FPS
zh_Hans: 帧率
human_description:
en_US: Frames per second for the generated video
zh_Hans: 生成视频的每秒帧数
form: form
- name: async
type: boolean
required: false
default: true
label:
en_US: Async Mode
zh_Hans: 异步模式
human_description:
en_US: Whether to run in async mode (return immediately) or sync mode (wait for completion)
zh_Hans: 是否以异步模式运行(立即返回)或同步模式(等待完成)
form: llm
- name: aws_region
type: string
required: false
default: us-east-1
label:
en_US: AWS Region
zh_Hans: AWS 区域
human_description:
en_US: AWS region for Bedrock service
zh_Hans: Bedrock 服务的 AWS 区域
form: form
- name: image_input_s3uri
type: string
required: false
label:
en_US: Input Image S3 URI
zh_Hans: 输入图像S3 URI
human_description:
en_US: S3 URI of the input image (1280x720 JPEG/PNG) to use as first frame
zh_Hans: 用作第一帧的输入图像1280x720 JPEG/PNG的S3 URI
form: llm
development:
dependencies:
- boto3
- pillow

View File

@ -0,0 +1,80 @@
from typing import Any, Union
from urllib.parse import urlparse
import boto3
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class S3Operator(BuiltinTool):
s3_client: Any = None
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
"""
try:
# Initialize S3 client if not already done
if not self.s3_client:
aws_region = tool_parameters.get("aws_region")
if aws_region:
self.s3_client = boto3.client("s3", region_name=aws_region)
else:
self.s3_client = boto3.client("s3")
# Parse S3 URI
s3_uri = tool_parameters.get("s3_uri")
if not s3_uri:
return self.create_text_message("s3_uri parameter is required")
parsed_uri = urlparse(s3_uri)
if parsed_uri.scheme != "s3":
return self.create_text_message("Invalid S3 URI format. Must start with 's3://'")
bucket = parsed_uri.netloc
# Remove leading slash from key
key = parsed_uri.path.lstrip("/")
operation_type = tool_parameters.get("operation_type", "read")
generate_presign_url = tool_parameters.get("generate_presign_url", False)
presign_expiry = int(tool_parameters.get("presign_expiry", 3600)) # default 1 hour
if operation_type == "write":
text_content = tool_parameters.get("text_content")
if not text_content:
return self.create_text_message("text_content parameter is required for write operation")
# Write content to S3
self.s3_client.put_object(Bucket=bucket, Key=key, Body=text_content.encode("utf-8"))
result = f"s3://{bucket}/{key}"
# Generate presigned URL for the written object if requested
if generate_presign_url:
result = self.s3_client.generate_presigned_url(
"get_object", Params={"Bucket": bucket, "Key": key}, ExpiresIn=presign_expiry
)
else: # read operation
# Get object from S3
response = self.s3_client.get_object(Bucket=bucket, Key=key)
result = response["Body"].read().decode("utf-8")
# Generate presigned URL if requested
if generate_presign_url:
result = self.s3_client.generate_presigned_url(
"get_object", Params={"Bucket": bucket, "Key": key}, ExpiresIn=presign_expiry
)
return self.create_text_message(text=result)
except self.s3_client.exceptions.NoSuchBucket:
return self.create_text_message(f"Bucket '{bucket}' does not exist")
except self.s3_client.exceptions.NoSuchKey:
return self.create_text_message(f"Object '{key}' does not exist in bucket '{bucket}'")
except Exception as e:
return self.create_text_message(f"Exception: {str(e)}")

View File

@ -0,0 +1,98 @@
identity:
name: s3_operator
author: AWS
label:
en_US: AWS S3 Operator
zh_Hans: AWS S3 读写器
pt_BR: AWS S3 Operator
icon: icon.svg
description:
human:
en_US: AWS S3 Writer and Reader
zh_Hans: 读写S3 bucket中的文件
pt_BR: AWS S3 Writer and Reader
llm: AWS S3 Writer and Reader
parameters:
- name: text_content
type: string
required: false
label:
en_US: The text to write
zh_Hans: 待写入的文本
pt_BR: The text to write
human_description:
en_US: The text to write
zh_Hans: 待写入的文本
pt_BR: The text to write
llm_description: The text to write
form: llm
- name: s3_uri
type: string
required: true
label:
en_US: s3 uri
zh_Hans: s3 uri
pt_BR: s3 uri
human_description:
en_US: s3 uri
zh_Hans: s3 uri
pt_BR: s3 uri
llm_description: s3 uri
form: llm
- name: aws_region
type: string
required: true
label:
en_US: region of bucket
zh_Hans: bucket 所在的region
pt_BR: region of bucket
human_description:
en_US: region of bucket
zh_Hans: bucket 所在的region
pt_BR: region of bucket
llm_description: region of bucket
form: form
- name: operation_type
type: select
required: true
label:
en_US: operation type
zh_Hans: 操作类型
pt_BR: operation type
human_description:
en_US: operation type
zh_Hans: 操作类型
pt_BR: operation type
default: read
options:
- value: read
label:
en_US: read
zh_Hans:
- value: write
label:
en_US: write
zh_Hans:
form: form
- name: generate_presign_url
type: boolean
required: false
label:
en_US: Generate presigned URL
zh_Hans: 生成预签名URL
human_description:
en_US: Whether to generate a presigned URL for the S3 object
zh_Hans: 是否生成S3对象的预签名URL
default: false
form: form
- name: presign_expiry
type: number
required: false
label:
en_US: Presigned URL expiration time
zh_Hans: 预签名URL有效期
human_description:
en_US: Expiration time in seconds for the presigned URL
zh_Hans: 预签名URL的有效期
default: 3600
form: form