mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 06:58:59 +08:00
add three aws tools (#11905)
Co-authored-by: Yuanbo Li <ybalbert@amazon.com>
This commit is contained in:
parent
7a00798027
commit
de8800f41a
115
api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.py
Normal file
115
api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.py
Normal file
@ -0,0 +1,115 @@
|
||||
import json
|
||||
import operator
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import boto3
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class BedrockRetrieveTool(BuiltinTool):
|
||||
bedrock_client: Any = None
|
||||
knowledge_base_id: str = None
|
||||
topk: int = None
|
||||
|
||||
def _bedrock_retrieve(
|
||||
self, query_input: str, knowledge_base_id: str, num_results: int, metadata_filter: Optional[dict] = None
|
||||
):
|
||||
try:
|
||||
retrieval_query = {"text": query_input}
|
||||
|
||||
retrieval_configuration = {"vectorSearchConfiguration": {"numberOfResults": num_results}}
|
||||
|
||||
# 如果有元数据过滤条件,则添加到检索配置中
|
||||
if metadata_filter:
|
||||
retrieval_configuration["vectorSearchConfiguration"]["filter"] = metadata_filter
|
||||
|
||||
response = self.bedrock_client.retrieve(
|
||||
knowledgeBaseId=knowledge_base_id,
|
||||
retrievalQuery=retrieval_query,
|
||||
retrievalConfiguration=retrieval_configuration,
|
||||
)
|
||||
|
||||
results = []
|
||||
for result in response.get("retrievalResults", []):
|
||||
results.append(
|
||||
{
|
||||
"content": result.get("content", {}).get("text", ""),
|
||||
"score": result.get("score", 0.0),
|
||||
"metadata": result.get("metadata", {}),
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
raise Exception(f"Error retrieving from knowledge base: {str(e)}")
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
line = 0
|
||||
try:
|
||||
if not self.bedrock_client:
|
||||
aws_region = tool_parameters.get("aws_region")
|
||||
if aws_region:
|
||||
self.bedrock_client = boto3.client("bedrock-agent-runtime", region_name=aws_region)
|
||||
else:
|
||||
self.bedrock_client = boto3.client("bedrock-agent-runtime")
|
||||
|
||||
line = 1
|
||||
if not self.knowledge_base_id:
|
||||
self.knowledge_base_id = tool_parameters.get("knowledge_base_id")
|
||||
if not self.knowledge_base_id:
|
||||
return self.create_text_message("Please provide knowledge_base_id")
|
||||
|
||||
line = 2
|
||||
if not self.topk:
|
||||
self.topk = tool_parameters.get("topk", 5)
|
||||
|
||||
line = 3
|
||||
query = tool_parameters.get("query", "")
|
||||
if not query:
|
||||
return self.create_text_message("Please input query")
|
||||
|
||||
# 获取元数据过滤条件(如果存在)
|
||||
metadata_filter_str = tool_parameters.get("metadata_filter")
|
||||
metadata_filter = json.loads(metadata_filter_str) if metadata_filter_str else None
|
||||
|
||||
line = 4
|
||||
retrieved_docs = self._bedrock_retrieve(
|
||||
query_input=query,
|
||||
knowledge_base_id=self.knowledge_base_id,
|
||||
num_results=self.topk,
|
||||
metadata_filter=metadata_filter, # 将元数据过滤条件传递给检索方法
|
||||
)
|
||||
|
||||
line = 5
|
||||
# Sort results by score in descending order
|
||||
sorted_docs = sorted(retrieved_docs, key=operator.itemgetter("score"), reverse=True)
|
||||
|
||||
line = 6
|
||||
return [self.create_json_message(res) for res in sorted_docs]
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Exception {str(e)}, line : {line}")
|
||||
|
||||
def validate_parameters(self, parameters: dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the parameters
|
||||
"""
|
||||
if not parameters.get("knowledge_base_id"):
|
||||
raise ValueError("knowledge_base_id is required")
|
||||
|
||||
if not parameters.get("query"):
|
||||
raise ValueError("query is required")
|
||||
|
||||
# 可选:可以验证元数据过滤条件是否为有效的 JSON 字符串(如果提供)
|
||||
metadata_filter_str = parameters.get("metadata_filter")
|
||||
if metadata_filter_str and not isinstance(json.loads(metadata_filter_str), dict):
|
||||
raise ValueError("metadata_filter must be a valid JSON object")
|
@ -0,0 +1,87 @@
|
||||
identity:
|
||||
name: bedrock_retrieve
|
||||
author: AWS
|
||||
label:
|
||||
en_US: Bedrock Retrieve
|
||||
zh_Hans: Bedrock检索
|
||||
pt_BR: Bedrock Retrieve
|
||||
icon: icon.svg
|
||||
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base. You can find deploy instructions on Github Repo - https://github.com/aws-samples/dify-aws-tool
|
||||
zh_Hans: Amazon Bedrock知识库检索工具, 请参考 Github Repo - https://github.com/aws-samples/dify-aws-tool上的部署说明
|
||||
pt_BR: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base.
|
||||
llm: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base. You can find deploy instructions on Github Repo - https://github.com/aws-samples/dify-aws-tool
|
||||
|
||||
parameters:
|
||||
- name: knowledge_base_id
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Bedrock Knowledge Base ID
|
||||
zh_Hans: Bedrock知识库ID
|
||||
pt_BR: Bedrock Knowledge Base ID
|
||||
human_description:
|
||||
en_US: ID of the Bedrock Knowledge Base to retrieve from
|
||||
zh_Hans: 用于检索的Bedrock知识库ID
|
||||
pt_BR: ID of the Bedrock Knowledge Base to retrieve from
|
||||
llm_description: ID of the Bedrock Knowledge Base to retrieve from
|
||||
form: form
|
||||
|
||||
- name: query
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Query string
|
||||
zh_Hans: 查询语句
|
||||
pt_BR: Query string
|
||||
human_description:
|
||||
en_US: The search query to retrieve relevant information
|
||||
zh_Hans: 用于检索相关信息的查询语句
|
||||
pt_BR: The search query to retrieve relevant information
|
||||
llm_description: The search query to retrieve relevant information
|
||||
form: llm
|
||||
|
||||
- name: topk
|
||||
type: number
|
||||
required: false
|
||||
form: form
|
||||
label:
|
||||
en_US: Limit for results count
|
||||
zh_Hans: 返回结果数量限制
|
||||
pt_BR: Limit for results count
|
||||
human_description:
|
||||
en_US: Maximum number of results to return
|
||||
zh_Hans: 最大返回结果数量
|
||||
pt_BR: Maximum number of results to return
|
||||
min: 1
|
||||
max: 10
|
||||
default: 5
|
||||
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: AWS Region
|
||||
zh_Hans: AWS 区域
|
||||
pt_BR: AWS Region
|
||||
human_description:
|
||||
en_US: AWS region where the Bedrock Knowledge Base is located
|
||||
zh_Hans: Bedrock知识库所在的AWS区域
|
||||
pt_BR: AWS region where the Bedrock Knowledge Base is located
|
||||
llm_description: AWS region where the Bedrock Knowledge Base is located
|
||||
form: form
|
||||
|
||||
- name: metadata_filter
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Metadata Filter
|
||||
zh_Hans: 元数据过滤器
|
||||
pt_BR: Metadata Filter
|
||||
human_description:
|
||||
en_US: 'JSON formatted filter conditions for metadata (e.g., {"greaterThan": {"key: "aaa", "value": 10}})'
|
||||
zh_Hans: '元数据的JSON格式过滤条件(例如,{{"greaterThan": {"key: "aaa", "value": 10}})'
|
||||
pt_BR: 'JSON formatted filter conditions for metadata (e.g., {"greaterThan": {"key: "aaa", "value": 10}})'
|
||||
form: form
|
357
api/core/tools/provider/builtin/aws/tools/nova_canvas.py
Normal file
357
api/core/tools/provider/builtin/aws/tools/nova_canvas.py
Normal file
@ -0,0 +1,357 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import boto3
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NovaCanvasTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
Invoke AWS Bedrock Nova Canvas model for image generation
|
||||
"""
|
||||
# Get common parameters
|
||||
prompt = tool_parameters.get("prompt", "")
|
||||
image_output_s3uri = tool_parameters.get("image_output_s3uri", "").strip()
|
||||
if not prompt:
|
||||
return self.create_text_message("Please provide a text prompt for image generation.")
|
||||
if not image_output_s3uri or urlparse(image_output_s3uri).scheme != "s3":
|
||||
return self.create_text_message("Please provide an valid S3 URI for image output.")
|
||||
|
||||
task_type = tool_parameters.get("task_type", "TEXT_IMAGE")
|
||||
aws_region = tool_parameters.get("aws_region", "us-east-1")
|
||||
|
||||
# Get common image generation config parameters
|
||||
width = tool_parameters.get("width", 1024)
|
||||
height = tool_parameters.get("height", 1024)
|
||||
cfg_scale = tool_parameters.get("cfg_scale", 8.0)
|
||||
negative_prompt = tool_parameters.get("negative_prompt", "")
|
||||
seed = tool_parameters.get("seed", 0)
|
||||
quality = tool_parameters.get("quality", "standard")
|
||||
|
||||
# Handle S3 image if provided
|
||||
image_input_s3uri = tool_parameters.get("image_input_s3uri", "")
|
||||
if task_type != "TEXT_IMAGE":
|
||||
if not image_input_s3uri or urlparse(image_input_s3uri).scheme != "s3":
|
||||
return self.create_text_message("Please provide a valid S3 URI for image to image generation.")
|
||||
|
||||
# Parse S3 URI
|
||||
parsed_uri = urlparse(image_input_s3uri)
|
||||
bucket = parsed_uri.netloc
|
||||
key = parsed_uri.path.lstrip("/")
|
||||
|
||||
# Initialize S3 client and download image
|
||||
s3_client = boto3.client("s3")
|
||||
response = s3_client.get_object(Bucket=bucket, Key=key)
|
||||
image_data = response["Body"].read()
|
||||
|
||||
# Base64 encode the image
|
||||
input_image = base64.b64encode(image_data).decode("utf-8")
|
||||
|
||||
try:
|
||||
# Initialize Bedrock client
|
||||
bedrock = boto3.client(service_name="bedrock-runtime", region_name=aws_region)
|
||||
|
||||
# Base image generation config
|
||||
image_generation_config = {
|
||||
"width": width,
|
||||
"height": height,
|
||||
"cfgScale": cfg_scale,
|
||||
"seed": seed,
|
||||
"numberOfImages": 1,
|
||||
"quality": quality,
|
||||
}
|
||||
|
||||
# Prepare request body based on task type
|
||||
body = {"imageGenerationConfig": image_generation_config}
|
||||
|
||||
if task_type == "TEXT_IMAGE":
|
||||
body["taskType"] = "TEXT_IMAGE"
|
||||
body["textToImageParams"] = {"text": prompt}
|
||||
if negative_prompt:
|
||||
body["textToImageParams"]["negativeText"] = negative_prompt
|
||||
|
||||
elif task_type == "COLOR_GUIDED_GENERATION":
|
||||
colors = tool_parameters.get("colors", "#ff8080-#ffb280-#ffe680-#ffe680")
|
||||
if not self._validate_color_string(colors):
|
||||
return self.create_text_message("Please provide valid colors in hexadecimal format.")
|
||||
|
||||
body["taskType"] = "COLOR_GUIDED_GENERATION"
|
||||
body["colorGuidedGenerationParams"] = {
|
||||
"colors": colors.split("-"),
|
||||
"referenceImage": input_image,
|
||||
"text": prompt,
|
||||
}
|
||||
if negative_prompt:
|
||||
body["colorGuidedGenerationParams"]["negativeText"] = negative_prompt
|
||||
|
||||
elif task_type == "IMAGE_VARIATION":
|
||||
similarity_strength = tool_parameters.get("similarity_strength", 0.5)
|
||||
|
||||
body["taskType"] = "IMAGE_VARIATION"
|
||||
body["imageVariationParams"] = {
|
||||
"images": [input_image],
|
||||
"similarityStrength": similarity_strength,
|
||||
"text": prompt,
|
||||
}
|
||||
if negative_prompt:
|
||||
body["imageVariationParams"]["negativeText"] = negative_prompt
|
||||
|
||||
elif task_type == "INPAINTING":
|
||||
mask_prompt = tool_parameters.get("mask_prompt")
|
||||
if not mask_prompt:
|
||||
return self.create_text_message("Please provide a mask prompt for image inpainting.")
|
||||
|
||||
body["taskType"] = "INPAINTING"
|
||||
body["inPaintingParams"] = {"image": input_image, "maskPrompt": mask_prompt, "text": prompt}
|
||||
if negative_prompt:
|
||||
body["inPaintingParams"]["negativeText"] = negative_prompt
|
||||
|
||||
elif task_type == "OUTPAINTING":
|
||||
mask_prompt = tool_parameters.get("mask_prompt")
|
||||
if not mask_prompt:
|
||||
return self.create_text_message("Please provide a mask prompt for image outpainting.")
|
||||
outpainting_mode = tool_parameters.get("outpainting_mode", "DEFAULT")
|
||||
|
||||
body["taskType"] = "OUTPAINTING"
|
||||
body["outPaintingParams"] = {
|
||||
"image": input_image,
|
||||
"maskPrompt": mask_prompt,
|
||||
"outPaintingMode": outpainting_mode,
|
||||
"text": prompt,
|
||||
}
|
||||
if negative_prompt:
|
||||
body["outPaintingParams"]["negativeText"] = negative_prompt
|
||||
|
||||
elif task_type == "BACKGROUND_REMOVAL":
|
||||
body["taskType"] = "BACKGROUND_REMOVAL"
|
||||
body["backgroundRemovalParams"] = {"image": input_image}
|
||||
|
||||
else:
|
||||
return self.create_text_message(f"Unsupported task type: {task_type}")
|
||||
|
||||
# Call Nova Canvas model
|
||||
response = bedrock.invoke_model(
|
||||
body=json.dumps(body),
|
||||
modelId="amazon.nova-canvas-v1:0",
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
|
||||
# Process response
|
||||
response_body = json.loads(response.get("body").read())
|
||||
if response_body.get("error"):
|
||||
raise Exception(f"Error in model response: {response_body.get('error')}")
|
||||
base64_image = response_body.get("images")[0]
|
||||
|
||||
# Upload to S3 if image_output_s3uri is provided
|
||||
try:
|
||||
# Parse S3 URI for output
|
||||
parsed_uri = urlparse(image_output_s3uri)
|
||||
output_bucket = parsed_uri.netloc
|
||||
output_base_path = parsed_uri.path.lstrip("/")
|
||||
# Generate filename with timestamp
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_key = f"{output_base_path}/canvas-output-{timestamp}.png"
|
||||
|
||||
# Initialize S3 client if not already done
|
||||
s3_client = boto3.client("s3", region_name=aws_region)
|
||||
|
||||
# Decode base64 image and upload to S3
|
||||
image_data = base64.b64decode(base64_image)
|
||||
s3_client.put_object(Bucket=output_bucket, Key=output_key, Body=image_data, ContentType="image/png")
|
||||
logger.info(f"Image uploaded to s3://{output_bucket}/{output_key}")
|
||||
except Exception as e:
|
||||
logger.exception("Failed to upload image to S3")
|
||||
# Return image
|
||||
return [
|
||||
self.create_text_message(f"Image is available at: s3://{output_bucket}/{output_key}"),
|
||||
self.create_blob_message(
|
||||
blob=base64.b64decode(base64_image),
|
||||
meta={"mime_type": "image/png"},
|
||||
save_as=self.VariableKey.IMAGE.value,
|
||||
),
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Failed to generate image: {str(e)}")
|
||||
|
||||
def _validate_color_string(self, color_string) -> bool:
|
||||
color_pattern = r"^#[0-9a-fA-F]{6}(?:-#[0-9a-fA-F]{6})*$"
|
||||
|
||||
if re.match(color_pattern, color_string):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_runtime_parameters(self) -> list[ToolParameter]:
|
||||
parameters = [
|
||||
ToolParameter(
|
||||
name="prompt",
|
||||
label=I18nObject(en_US="Prompt", zh_Hans="提示词"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=True,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
human_description=I18nObject(
|
||||
en_US="Text description of the image you want to generate or modify",
|
||||
zh_Hans="您想要生成或修改的图像的文本描述",
|
||||
),
|
||||
llm_description="Describe the image you want to generate or how you want to modify the input image",
|
||||
),
|
||||
ToolParameter(
|
||||
name="image_input_s3uri",
|
||||
label=I18nObject(en_US="Input image s3 uri", zh_Hans="输入图片的s3 uri"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
human_description=I18nObject(en_US="Image to be modified", zh_Hans="想要修改的图片"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="image_output_s3uri",
|
||||
label=I18nObject(en_US="Output Image S3 URI", zh_Hans="输出图片的S3 URI目录"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=True,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(
|
||||
en_US="S3 URI where the generated image should be uploaded", zh_Hans="生成的图像应该上传到的S3 URI"
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="width",
|
||||
label=I18nObject(en_US="Width", zh_Hans="宽度"),
|
||||
type=ToolParameter.ToolParameterType.NUMBER,
|
||||
required=False,
|
||||
default=1024,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(en_US="Width of the generated image", zh_Hans="生成图像的宽度"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="height",
|
||||
label=I18nObject(en_US="Height", zh_Hans="高度"),
|
||||
type=ToolParameter.ToolParameterType.NUMBER,
|
||||
required=False,
|
||||
default=1024,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(en_US="Height of the generated image", zh_Hans="生成图像的高度"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="cfg_scale",
|
||||
label=I18nObject(en_US="CFG Scale", zh_Hans="CFG比例"),
|
||||
type=ToolParameter.ToolParameterType.NUMBER,
|
||||
required=False,
|
||||
default=8.0,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(
|
||||
en_US="How strongly the image should conform to the prompt", zh_Hans="图像应该多大程度上符合提示词"
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="negative_prompt",
|
||||
label=I18nObject(en_US="Negative Prompt", zh_Hans="负面提示词"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
default="",
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
human_description=I18nObject(
|
||||
en_US="Things you don't want in the generated image", zh_Hans="您不想在生成的图像中出现的内容"
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="seed",
|
||||
label=I18nObject(en_US="Seed", zh_Hans="种子值"),
|
||||
type=ToolParameter.ToolParameterType.NUMBER,
|
||||
required=False,
|
||||
default=0,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(en_US="Random seed for image generation", zh_Hans="图像生成的随机种子"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="aws_region",
|
||||
label=I18nObject(en_US="AWS Region", zh_Hans="AWS 区域"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
default="us-east-1",
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(en_US="AWS region for Bedrock service", zh_Hans="Bedrock 服务的 AWS 区域"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="task_type",
|
||||
label=I18nObject(en_US="Task Type", zh_Hans="任务类型"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
default="TEXT_IMAGE",
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
human_description=I18nObject(en_US="Type of image generation task", zh_Hans="图像生成任务的类型"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="quality",
|
||||
label=I18nObject(en_US="Quality", zh_Hans="质量"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
default="standard",
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(
|
||||
en_US="Quality of the generated image (standard or premium)", zh_Hans="生成图像的质量(标准或高级)"
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="colors",
|
||||
label=I18nObject(en_US="Colors", zh_Hans="颜色"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(
|
||||
en_US="List of colors for color-guided generation, example: #ff8080-#ffb280-#ffe680-#ffe680",
|
||||
zh_Hans="颜色引导生成的颜色列表, 例子: #ff8080-#ffb280-#ffe680-#ffe680",
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="similarity_strength",
|
||||
label=I18nObject(en_US="Similarity Strength", zh_Hans="相似度强度"),
|
||||
type=ToolParameter.ToolParameterType.NUMBER,
|
||||
required=False,
|
||||
default=0.5,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(
|
||||
en_US="How similar the generated image should be to the input image (0.0 to 1.0)",
|
||||
zh_Hans="生成的图像应该与输入图像的相似程度(0.0到1.0)",
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="mask_prompt",
|
||||
label=I18nObject(en_US="Mask Prompt", zh_Hans="蒙版提示词"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
human_description=I18nObject(
|
||||
en_US="Text description to generate mask for inpainting/outpainting",
|
||||
zh_Hans="用于生成内补绘制/外补绘制蒙版的文本描述",
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="outpainting_mode",
|
||||
label=I18nObject(en_US="Outpainting Mode", zh_Hans="外补绘制模式"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
default="DEFAULT",
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(
|
||||
en_US="Mode for outpainting (DEFAULT or other supported modes)",
|
||||
zh_Hans="外补绘制的模式(DEFAULT或其他支持的模式)",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
return parameters
|
175
api/core/tools/provider/builtin/aws/tools/nova_canvas.yaml
Normal file
175
api/core/tools/provider/builtin/aws/tools/nova_canvas.yaml
Normal file
@ -0,0 +1,175 @@
|
||||
identity:
|
||||
name: nova_canvas
|
||||
author: AWS
|
||||
label:
|
||||
en_US: AWS Bedrock Nova Canvas
|
||||
zh_Hans: AWS Bedrock Nova Canvas
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for generating and modifying images using AWS Bedrock's Nova Canvas model. Supports text-to-image, color-guided generation, image variation, inpainting, outpainting, and background removal. Input parameters reference https://docs.aws.amazon.com/nova/latest/userguide/image-gen-req-resp-structure.html
|
||||
zh_Hans: 使用 AWS Bedrock 的 Nova Canvas 模型生成和修改图像的工具。支持文生图、颜色引导生成、图像变体、内补绘制、外补绘制和背景移除功能, 输入参数参考 https://docs.aws.amazon.com/nova/latest/userguide/image-gen-req-resp-structure.html。
|
||||
llm: Generate or modify images using AWS Bedrock's Nova Canvas model with multiple task types including text-to-image, color-guided generation, image variation, inpainting, outpainting, and background removal.
|
||||
parameters:
|
||||
- name: task_type
|
||||
type: string
|
||||
required: false
|
||||
default: TEXT_IMAGE
|
||||
label:
|
||||
en_US: Task Type
|
||||
zh_Hans: 任务类型
|
||||
human_description:
|
||||
en_US: Type of image generation task (TEXT_IMAGE, COLOR_GUIDED_GENERATION, IMAGE_VARIATION, INPAINTING, OUTPAINTING, BACKGROUND_REMOVAL)
|
||||
zh_Hans: 图像生成任务的类型(文生图、颜色引导生成、图像变体、内补绘制、外补绘制、背景移除)
|
||||
form: llm
|
||||
- name: prompt
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Prompt
|
||||
zh_Hans: 提示词
|
||||
human_description:
|
||||
en_US: Text description of the image you want to generate or modify
|
||||
zh_Hans: 您想要生成或修改的图像的文本描述
|
||||
llm_description: Describe the image you want to generate or how you want to modify the input image
|
||||
form: llm
|
||||
- name: image_input_s3uri
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Input image s3 uri
|
||||
zh_Hans: 输入图片的s3 uri
|
||||
human_description:
|
||||
en_US: The input image to modify (required for all modes except TEXT_IMAGE)
|
||||
zh_Hans: 要修改的输入图像(除文生图外的所有模式都需要)
|
||||
llm_description: The input image you want to modify. Required for all modes except TEXT_IMAGE.
|
||||
form: llm
|
||||
- name: image_output_s3uri
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Output S3 URI
|
||||
zh_Hans: 输出S3 URI
|
||||
human_description:
|
||||
en_US: The S3 URI where the generated image will be saved. If provided, the image will be uploaded with name format canvas-output-{timestamp}.png
|
||||
zh_Hans: 生成的图像将保存到的S3 URI。如果提供,图像将以canvas-output-{timestamp}.png的格式上传
|
||||
llm_description: Optional S3 URI where the generated image will be uploaded. The image will be saved with a timestamp-based filename.
|
||||
form: form
|
||||
- name: negative_prompt
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Negative Prompt
|
||||
zh_Hans: 负面提示词
|
||||
human_description:
|
||||
en_US: Things you don't want in the generated image
|
||||
zh_Hans: 您不想在生成的图像中出现的内容
|
||||
form: llm
|
||||
- name: width
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Width
|
||||
zh_Hans: 宽度
|
||||
human_description:
|
||||
en_US: Width of the generated image
|
||||
zh_Hans: 生成图像的宽度
|
||||
form: form
|
||||
default: 1024
|
||||
- name: height
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Height
|
||||
zh_Hans: 高度
|
||||
human_description:
|
||||
en_US: Height of the generated image
|
||||
zh_Hans: 生成图像的高度
|
||||
form: form
|
||||
default: 1024
|
||||
- name: cfg_scale
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: CFG Scale
|
||||
zh_Hans: CFG比例
|
||||
human_description:
|
||||
en_US: How strongly the image should conform to the prompt
|
||||
zh_Hans: 图像应该多大程度上符合提示词
|
||||
form: form
|
||||
default: 8.0
|
||||
- name: seed
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Seed
|
||||
zh_Hans: 种子值
|
||||
human_description:
|
||||
en_US: Random seed for image generation
|
||||
zh_Hans: 图像生成的随机种子
|
||||
form: form
|
||||
default: 0
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: false
|
||||
default: us-east-1
|
||||
label:
|
||||
en_US: AWS Region
|
||||
zh_Hans: AWS 区域
|
||||
human_description:
|
||||
en_US: AWS region for Bedrock service
|
||||
zh_Hans: Bedrock 服务的 AWS 区域
|
||||
form: form
|
||||
- name: quality
|
||||
type: string
|
||||
required: false
|
||||
default: standard
|
||||
label:
|
||||
en_US: Quality
|
||||
zh_Hans: 质量
|
||||
human_description:
|
||||
en_US: Quality of the generated image (standard or premium)
|
||||
zh_Hans: 生成图像的质量(标准或高级)
|
||||
form: form
|
||||
- name: colors
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Colors
|
||||
zh_Hans: 颜色
|
||||
human_description:
|
||||
en_US: List of colors for color-guided generation
|
||||
zh_Hans: 颜色引导生成的颜色列表
|
||||
form: form
|
||||
- name: similarity_strength
|
||||
type: number
|
||||
required: false
|
||||
default: 0.5
|
||||
label:
|
||||
en_US: Similarity Strength
|
||||
zh_Hans: 相似度强度
|
||||
human_description:
|
||||
en_US: How similar the generated image should be to the input image (0.0 to 1.0)
|
||||
zh_Hans: 生成的图像应该与输入图像的相似程度(0.0到1.0)
|
||||
form: form
|
||||
- name: mask_prompt
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Mask Prompt
|
||||
zh_Hans: 蒙版提示词
|
||||
human_description:
|
||||
en_US: Text description to generate mask for inpainting/outpainting
|
||||
zh_Hans: 用于生成内补绘制/外补绘制蒙版的文本描述
|
||||
form: llm
|
||||
- name: outpainting_mode
|
||||
type: string
|
||||
required: false
|
||||
default: DEFAULT
|
||||
label:
|
||||
en_US: Outpainting Mode
|
||||
zh_Hans: 外补绘制模式
|
||||
human_description:
|
||||
en_US: Mode for outpainting (DEFAULT or other supported modes)
|
||||
zh_Hans: 外补绘制的模式(DEFAULT或其他支持的模式)
|
||||
form: form
|
371
api/core/tools/provider/builtin/aws/tools/nova_reel.py
Normal file
371
api/core/tools/provider/builtin/aws/tools/nova_reel.py
Normal file
@ -0,0 +1,371 @@
|
||||
import base64
|
||||
import logging
|
||||
import time
|
||||
from io import BytesIO
|
||||
from typing import Any, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
from PIL import Image
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
NOVA_REEL_DEFAULT_REGION = "us-east-1"
|
||||
NOVA_REEL_DEFAULT_DIMENSION = "1280x720"
|
||||
NOVA_REEL_DEFAULT_FPS = 24
|
||||
NOVA_REEL_DEFAULT_DURATION = 6
|
||||
NOVA_REEL_MODEL_ID = "amazon.nova-reel-v1:0"
|
||||
NOVA_REEL_STATUS_CHECK_INTERVAL = 5
|
||||
|
||||
# Image requirements
|
||||
NOVA_REEL_REQUIRED_IMAGE_WIDTH = 1280
|
||||
NOVA_REEL_REQUIRED_IMAGE_HEIGHT = 720
|
||||
NOVA_REEL_REQUIRED_IMAGE_MODE = "RGB"
|
||||
|
||||
|
||||
class NovaReelTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
Invoke AWS Bedrock Nova Reel model for video generation.
|
||||
|
||||
Args:
|
||||
user_id: The ID of the user making the request
|
||||
tool_parameters: Dictionary containing the tool parameters
|
||||
|
||||
Returns:
|
||||
ToolInvokeMessage containing either the video content or status information
|
||||
"""
|
||||
try:
|
||||
# Validate and extract parameters
|
||||
params = self._validate_and_extract_parameters(tool_parameters)
|
||||
if isinstance(params, ToolInvokeMessage):
|
||||
return params
|
||||
|
||||
# Initialize AWS clients
|
||||
bedrock, s3_client = self._initialize_aws_clients(params["aws_region"])
|
||||
|
||||
# Prepare model input
|
||||
model_input = self._prepare_model_input(params, s3_client)
|
||||
if isinstance(model_input, ToolInvokeMessage):
|
||||
return model_input
|
||||
|
||||
# Start video generation
|
||||
invocation = self._start_video_generation(bedrock, model_input, params["video_output_s3uri"])
|
||||
invocation_arn = invocation["invocationArn"]
|
||||
|
||||
# Handle async/sync mode
|
||||
return self._handle_generation_mode(bedrock, s3_client, invocation_arn, params["async_mode"])
|
||||
|
||||
except ClientError as e:
|
||||
error_code = e.response.get("Error", {}).get("Code", "Unknown")
|
||||
error_message = e.response.get("Error", {}).get("Message", str(e))
|
||||
logger.exception(f"AWS API error: {error_code} - {error_message}")
|
||||
return self.create_text_message(f"AWS service error: {error_code} - {error_message}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in video generation: {str(e)}", exc_info=True)
|
||||
return self.create_text_message(f"Failed to generate video: {str(e)}")
|
||||
|
||||
def _validate_and_extract_parameters(
|
||||
self, tool_parameters: dict[str, Any]
|
||||
) -> Union[dict[str, Any], ToolInvokeMessage]:
|
||||
"""Validate and extract parameters from the input dictionary."""
|
||||
prompt = tool_parameters.get("prompt", "")
|
||||
video_output_s3uri = tool_parameters.get("video_output_s3uri", "").strip()
|
||||
|
||||
# Validate required parameters
|
||||
if not prompt:
|
||||
return self.create_text_message("Please provide a text prompt for video generation.")
|
||||
if not video_output_s3uri:
|
||||
return self.create_text_message("Please provide an S3 URI for video output.")
|
||||
|
||||
# Validate S3 URI format
|
||||
if not video_output_s3uri.startswith("s3://"):
|
||||
return self.create_text_message("Invalid S3 URI format. Must start with 's3://'")
|
||||
|
||||
# Ensure S3 URI ends with '/'
|
||||
video_output_s3uri = video_output_s3uri if video_output_s3uri.endswith("/") else video_output_s3uri + "/"
|
||||
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"video_output_s3uri": video_output_s3uri,
|
||||
"image_input_s3uri": tool_parameters.get("image_input_s3uri", "").strip(),
|
||||
"aws_region": tool_parameters.get("aws_region", NOVA_REEL_DEFAULT_REGION),
|
||||
"dimension": tool_parameters.get("dimension", NOVA_REEL_DEFAULT_DIMENSION),
|
||||
"seed": int(tool_parameters.get("seed", 0)),
|
||||
"fps": int(tool_parameters.get("fps", NOVA_REEL_DEFAULT_FPS)),
|
||||
"duration": int(tool_parameters.get("duration", NOVA_REEL_DEFAULT_DURATION)),
|
||||
"async_mode": bool(tool_parameters.get("async", True)),
|
||||
}
|
||||
|
||||
def _initialize_aws_clients(self, region: str) -> tuple[Any, Any]:
|
||||
"""Initialize AWS Bedrock and S3 clients."""
|
||||
bedrock = boto3.client(service_name="bedrock-runtime", region_name=region)
|
||||
s3_client = boto3.client("s3", region_name=region)
|
||||
return bedrock, s3_client
|
||||
|
||||
def _prepare_model_input(self, params: dict[str, Any], s3_client: Any) -> Union[dict[str, Any], ToolInvokeMessage]:
|
||||
"""Prepare the input for the Nova Reel model."""
|
||||
model_input = {
|
||||
"taskType": "TEXT_VIDEO",
|
||||
"textToVideoParams": {"text": params["prompt"]},
|
||||
"videoGenerationConfig": {
|
||||
"durationSeconds": params["duration"],
|
||||
"fps": params["fps"],
|
||||
"dimension": params["dimension"],
|
||||
"seed": params["seed"],
|
||||
},
|
||||
}
|
||||
|
||||
# Add image if provided
|
||||
if params["image_input_s3uri"]:
|
||||
try:
|
||||
image_data = self._get_image_from_s3(s3_client, params["image_input_s3uri"])
|
||||
if not image_data:
|
||||
return self.create_text_message("Failed to retrieve image from S3")
|
||||
|
||||
# Process and validate image
|
||||
processed_image = self._process_and_validate_image(image_data)
|
||||
if isinstance(processed_image, ToolInvokeMessage):
|
||||
return processed_image
|
||||
|
||||
# Convert processed image to base64
|
||||
img_buffer = BytesIO()
|
||||
processed_image.save(img_buffer, format="PNG")
|
||||
img_buffer.seek(0)
|
||||
input_image_base64 = base64.b64encode(img_buffer.getvalue()).decode("utf-8")
|
||||
|
||||
model_input["textToVideoParams"]["images"] = [
|
||||
{"format": "png", "source": {"bytes": input_image_base64}}
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing input image: {str(e)}", exc_info=True)
|
||||
return self.create_text_message(f"Failed to process input image: {str(e)}")
|
||||
|
||||
return model_input
|
||||
|
||||
def _process_and_validate_image(self, image_data: bytes) -> Union[Image.Image, ToolInvokeMessage]:
|
||||
"""
|
||||
Process and validate the input image according to Nova Reel requirements.
|
||||
|
||||
Requirements:
|
||||
- Must be 1280x720 pixels
|
||||
- Must be RGB format (8 bits per channel)
|
||||
- If PNG, alpha channel must not have transparent/translucent pixels
|
||||
"""
|
||||
try:
|
||||
# Open image
|
||||
img = Image.open(BytesIO(image_data))
|
||||
|
||||
# Convert RGBA to RGB if needed, ensuring no transparency
|
||||
if img.mode == "RGBA":
|
||||
# Check for transparency
|
||||
if img.getchannel("A").getextrema()[0] < 255:
|
||||
return self.create_text_message(
|
||||
"PNG image contains transparent or translucent pixels, which is not supported. "
|
||||
"Please provide an image without transparency."
|
||||
)
|
||||
# Convert to RGB
|
||||
img = img.convert("RGB")
|
||||
elif img.mode != "RGB":
|
||||
# Convert any other mode to RGB
|
||||
img = img.convert("RGB")
|
||||
|
||||
# Validate/adjust dimensions
|
||||
if img.size != (NOVA_REEL_REQUIRED_IMAGE_WIDTH, NOVA_REEL_REQUIRED_IMAGE_HEIGHT):
|
||||
logger.warning(
|
||||
f"Image dimensions {img.size} do not match required dimensions "
|
||||
f"({NOVA_REEL_REQUIRED_IMAGE_WIDTH}x{NOVA_REEL_REQUIRED_IMAGE_HEIGHT}). Resizing..."
|
||||
)
|
||||
img = img.resize(
|
||||
(NOVA_REEL_REQUIRED_IMAGE_WIDTH, NOVA_REEL_REQUIRED_IMAGE_HEIGHT), Image.Resampling.LANCZOS
|
||||
)
|
||||
|
||||
# Validate bit depth
|
||||
if img.mode != NOVA_REEL_REQUIRED_IMAGE_MODE:
|
||||
return self.create_text_message(
|
||||
f"Image must be in {NOVA_REEL_REQUIRED_IMAGE_MODE} mode with 8 bits per channel"
|
||||
)
|
||||
|
||||
return img
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing image: {str(e)}", exc_info=True)
|
||||
return self.create_text_message(
|
||||
"Failed to process image. Please ensure the image is a valid JPEG or PNG file."
|
||||
)
|
||||
|
||||
def _get_image_from_s3(self, s3_client: Any, s3_uri: str) -> Optional[bytes]:
|
||||
"""Download and return image data from S3."""
|
||||
parsed_uri = urlparse(s3_uri)
|
||||
bucket = parsed_uri.netloc
|
||||
key = parsed_uri.path.lstrip("/")
|
||||
|
||||
response = s3_client.get_object(Bucket=bucket, Key=key)
|
||||
return response["Body"].read()
|
||||
|
||||
def _start_video_generation(self, bedrock: Any, model_input: dict[str, Any], output_s3uri: str) -> dict[str, Any]:
|
||||
"""Start the async video generation process."""
|
||||
return bedrock.start_async_invoke(
|
||||
modelId=NOVA_REEL_MODEL_ID,
|
||||
modelInput=model_input,
|
||||
outputDataConfig={"s3OutputDataConfig": {"s3Uri": output_s3uri}},
|
||||
)
|
||||
|
||||
def _handle_generation_mode(
|
||||
self, bedrock: Any, s3_client: Any, invocation_arn: str, async_mode: bool
|
||||
) -> ToolInvokeMessage:
|
||||
"""Handle async or sync video generation mode."""
|
||||
invocation_response = bedrock.get_async_invoke(invocationArn=invocation_arn)
|
||||
video_path = invocation_response["outputDataConfig"]["s3OutputDataConfig"]["s3Uri"]
|
||||
video_uri = f"{video_path}/output.mp4"
|
||||
|
||||
if async_mode:
|
||||
return self.create_text_message(
|
||||
f"Video generation started.\nInvocation ARN: {invocation_arn}\n"
|
||||
f"Video will be available at: {video_uri}"
|
||||
)
|
||||
|
||||
return self._wait_for_completion(bedrock, s3_client, invocation_arn)
|
||||
|
||||
def _wait_for_completion(self, bedrock: Any, s3_client: Any, invocation_arn: str) -> ToolInvokeMessage:
|
||||
"""Wait for video generation completion and handle the result."""
|
||||
while True:
|
||||
status_response = bedrock.get_async_invoke(invocationArn=invocation_arn)
|
||||
status = status_response["status"]
|
||||
video_path = status_response["outputDataConfig"]["s3OutputDataConfig"]["s3Uri"]
|
||||
|
||||
if status == "Completed":
|
||||
return self._handle_completed_video(s3_client, video_path)
|
||||
elif status == "Failed":
|
||||
failure_message = status_response.get("failureMessage", "Unknown error")
|
||||
return self.create_text_message(f"Video generation failed.\nError: {failure_message}")
|
||||
elif status == "InProgress":
|
||||
time.sleep(NOVA_REEL_STATUS_CHECK_INTERVAL)
|
||||
else:
|
||||
return self.create_text_message(f"Unexpected status: {status}")
|
||||
|
||||
def _handle_completed_video(self, s3_client: Any, video_path: str) -> ToolInvokeMessage:
|
||||
"""Handle completed video generation and return the result."""
|
||||
parsed_uri = urlparse(video_path)
|
||||
bucket = parsed_uri.netloc
|
||||
key = parsed_uri.path.lstrip("/") + "/output.mp4"
|
||||
|
||||
try:
|
||||
response = s3_client.get_object(Bucket=bucket, Key=key)
|
||||
video_content = response["Body"].read()
|
||||
return [
|
||||
self.create_text_message(f"Video is available at: {video_path}/output.mp4"),
|
||||
self.create_blob_message(blob=video_content, meta={"mime_type": "video/mp4"}, save_as="output.mp4"),
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"Error downloading video: {str(e)}", exc_info=True)
|
||||
return self.create_text_message(
|
||||
f"Video generation completed but failed to download video: {str(e)}\n"
|
||||
f"Video is available at: s3://{bucket}/{key}"
|
||||
)
|
||||
|
||||
def get_runtime_parameters(self) -> list[ToolParameter]:
|
||||
"""Define the tool's runtime parameters."""
|
||||
parameters = [
|
||||
ToolParameter(
|
||||
name="prompt",
|
||||
label=I18nObject(en_US="Prompt", zh_Hans="提示词"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=True,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
human_description=I18nObject(
|
||||
en_US="Text description of the video you want to generate", zh_Hans="您想要生成的视频的文本描述"
|
||||
),
|
||||
llm_description="Describe the video you want to generate",
|
||||
),
|
||||
ToolParameter(
|
||||
name="video_output_s3uri",
|
||||
label=I18nObject(en_US="Output S3 URI", zh_Hans="输出S3 URI"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=True,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(
|
||||
en_US="S3 URI where the generated video will be stored", zh_Hans="生成的视频将存储的S3 URI"
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="dimension",
|
||||
label=I18nObject(en_US="Dimension", zh_Hans="尺寸"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
default=NOVA_REEL_DEFAULT_DIMENSION,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(en_US="Video dimensions (width x height)", zh_Hans="视频尺寸(宽 x 高)"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="duration",
|
||||
label=I18nObject(en_US="Duration", zh_Hans="时长"),
|
||||
type=ToolParameter.ToolParameterType.NUMBER,
|
||||
required=False,
|
||||
default=NOVA_REEL_DEFAULT_DURATION,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(en_US="Video duration in seconds", zh_Hans="视频时长(秒)"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="seed",
|
||||
label=I18nObject(en_US="Seed", zh_Hans="种子值"),
|
||||
type=ToolParameter.ToolParameterType.NUMBER,
|
||||
required=False,
|
||||
default=0,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(en_US="Random seed for video generation", zh_Hans="视频生成的随机种子"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="fps",
|
||||
label=I18nObject(en_US="FPS", zh_Hans="帧率"),
|
||||
type=ToolParameter.ToolParameterType.NUMBER,
|
||||
required=False,
|
||||
default=NOVA_REEL_DEFAULT_FPS,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(
|
||||
en_US="Frames per second for the generated video", zh_Hans="生成视频的每秒帧数"
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="async",
|
||||
label=I18nObject(en_US="Async Mode", zh_Hans="异步模式"),
|
||||
type=ToolParameter.ToolParameterType.BOOLEAN,
|
||||
required=False,
|
||||
default=True,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
human_description=I18nObject(
|
||||
en_US="Whether to run in async mode (return immediately) or sync mode (wait for completion)",
|
||||
zh_Hans="是否以异步模式运行(立即返回)或同步模式(等待完成)",
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="aws_region",
|
||||
label=I18nObject(en_US="AWS Region", zh_Hans="AWS 区域"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
default=NOVA_REEL_DEFAULT_REGION,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(en_US="AWS region for Bedrock service", zh_Hans="Bedrock 服务的 AWS 区域"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="image_input_s3uri",
|
||||
label=I18nObject(en_US="Input Image S3 URI", zh_Hans="输入图像S3 URI"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
human_description=I18nObject(
|
||||
en_US="S3 URI of the input image (1280x720 JPEG/PNG) to use as first frame",
|
||||
zh_Hans="用作第一帧的输入图像(1280x720 JPEG/PNG)的S3 URI",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
return parameters
|
124
api/core/tools/provider/builtin/aws/tools/nova_reel.yaml
Normal file
124
api/core/tools/provider/builtin/aws/tools/nova_reel.yaml
Normal file
@ -0,0 +1,124 @@
|
||||
identity:
|
||||
name: nova_reel
|
||||
author: AWS
|
||||
label:
|
||||
en_US: AWS Bedrock Nova Reel
|
||||
zh_Hans: AWS Bedrock Nova Reel
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for generating videos using AWS Bedrock's Nova Reel model. Supports text-to-video generation and image-to-video generation with customizable parameters like duration, FPS, and dimensions. Input parameters reference https://docs.aws.amazon.com/nova/latest/userguide/video-generation.html
|
||||
zh_Hans: 使用 AWS Bedrock 的 Nova Reel 模型生成视频的工具。支持文本生成视频和图像生成视频功能,可自定义持续时间、帧率和尺寸等参数。输入参数参考 https://docs.aws.amazon.com/nova/latest/userguide/video-generation.html
|
||||
llm: Generate videos using AWS Bedrock's Nova Reel model with support for both text-to-video and image-to-video generation, allowing customization of video properties like duration, frame rate, and resolution.
|
||||
|
||||
parameters:
|
||||
- name: prompt
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Prompt
|
||||
zh_Hans: 提示词
|
||||
human_description:
|
||||
en_US: Text description of the video you want to generate
|
||||
zh_Hans: 您想要生成的视频的文本描述
|
||||
llm_description: Describe the video you want to generate
|
||||
form: llm
|
||||
|
||||
- name: video_output_s3uri
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Output S3 URI
|
||||
zh_Hans: 输出S3 URI
|
||||
human_description:
|
||||
en_US: S3 URI where the generated video will be stored
|
||||
zh_Hans: 生成的视频将存储的S3 URI
|
||||
form: form
|
||||
|
||||
- name: dimension
|
||||
type: string
|
||||
required: false
|
||||
default: 1280x720
|
||||
label:
|
||||
en_US: Dimension
|
||||
zh_Hans: 尺寸
|
||||
human_description:
|
||||
en_US: Video dimensions (width x height)
|
||||
zh_Hans: 视频尺寸(宽 x 高)
|
||||
form: form
|
||||
|
||||
- name: duration
|
||||
type: number
|
||||
required: false
|
||||
default: 6
|
||||
label:
|
||||
en_US: Duration
|
||||
zh_Hans: 时长
|
||||
human_description:
|
||||
en_US: Video duration in seconds
|
||||
zh_Hans: 视频时长(秒)
|
||||
form: form
|
||||
|
||||
- name: seed
|
||||
type: number
|
||||
required: false
|
||||
default: 0
|
||||
label:
|
||||
en_US: Seed
|
||||
zh_Hans: 种子值
|
||||
human_description:
|
||||
en_US: Random seed for video generation
|
||||
zh_Hans: 视频生成的随机种子
|
||||
form: form
|
||||
|
||||
- name: fps
|
||||
type: number
|
||||
required: false
|
||||
default: 24
|
||||
label:
|
||||
en_US: FPS
|
||||
zh_Hans: 帧率
|
||||
human_description:
|
||||
en_US: Frames per second for the generated video
|
||||
zh_Hans: 生成视频的每秒帧数
|
||||
form: form
|
||||
|
||||
- name: async
|
||||
type: boolean
|
||||
required: false
|
||||
default: true
|
||||
label:
|
||||
en_US: Async Mode
|
||||
zh_Hans: 异步模式
|
||||
human_description:
|
||||
en_US: Whether to run in async mode (return immediately) or sync mode (wait for completion)
|
||||
zh_Hans: 是否以异步模式运行(立即返回)或同步模式(等待完成)
|
||||
form: llm
|
||||
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: false
|
||||
default: us-east-1
|
||||
label:
|
||||
en_US: AWS Region
|
||||
zh_Hans: AWS 区域
|
||||
human_description:
|
||||
en_US: AWS region for Bedrock service
|
||||
zh_Hans: Bedrock 服务的 AWS 区域
|
||||
form: form
|
||||
|
||||
- name: image_input_s3uri
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Input Image S3 URI
|
||||
zh_Hans: 输入图像S3 URI
|
||||
human_description:
|
||||
en_US: S3 URI of the input image (1280x720 JPEG/PNG) to use as first frame
|
||||
zh_Hans: 用作第一帧的输入图像(1280x720 JPEG/PNG)的S3 URI
|
||||
form: llm
|
||||
|
||||
development:
|
||||
dependencies:
|
||||
- boto3
|
||||
- pillow
|
80
api/core/tools/provider/builtin/aws/tools/s3_operator.py
Normal file
80
api/core/tools/provider/builtin/aws/tools/s3_operator.py
Normal file
@ -0,0 +1,80 @@
|
||||
from typing import Any, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import boto3
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class S3Operator(BuiltinTool):
|
||||
s3_client: Any = None
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
try:
|
||||
# Initialize S3 client if not already done
|
||||
if not self.s3_client:
|
||||
aws_region = tool_parameters.get("aws_region")
|
||||
if aws_region:
|
||||
self.s3_client = boto3.client("s3", region_name=aws_region)
|
||||
else:
|
||||
self.s3_client = boto3.client("s3")
|
||||
|
||||
# Parse S3 URI
|
||||
s3_uri = tool_parameters.get("s3_uri")
|
||||
if not s3_uri:
|
||||
return self.create_text_message("s3_uri parameter is required")
|
||||
|
||||
parsed_uri = urlparse(s3_uri)
|
||||
if parsed_uri.scheme != "s3":
|
||||
return self.create_text_message("Invalid S3 URI format. Must start with 's3://'")
|
||||
|
||||
bucket = parsed_uri.netloc
|
||||
# Remove leading slash from key
|
||||
key = parsed_uri.path.lstrip("/")
|
||||
|
||||
operation_type = tool_parameters.get("operation_type", "read")
|
||||
generate_presign_url = tool_parameters.get("generate_presign_url", False)
|
||||
presign_expiry = int(tool_parameters.get("presign_expiry", 3600)) # default 1 hour
|
||||
|
||||
if operation_type == "write":
|
||||
text_content = tool_parameters.get("text_content")
|
||||
if not text_content:
|
||||
return self.create_text_message("text_content parameter is required for write operation")
|
||||
|
||||
# Write content to S3
|
||||
self.s3_client.put_object(Bucket=bucket, Key=key, Body=text_content.encode("utf-8"))
|
||||
result = f"s3://{bucket}/{key}"
|
||||
|
||||
# Generate presigned URL for the written object if requested
|
||||
if generate_presign_url:
|
||||
result = self.s3_client.generate_presigned_url(
|
||||
"get_object", Params={"Bucket": bucket, "Key": key}, ExpiresIn=presign_expiry
|
||||
)
|
||||
|
||||
else: # read operation
|
||||
# Get object from S3
|
||||
response = self.s3_client.get_object(Bucket=bucket, Key=key)
|
||||
result = response["Body"].read().decode("utf-8")
|
||||
|
||||
# Generate presigned URL if requested
|
||||
if generate_presign_url:
|
||||
result = self.s3_client.generate_presigned_url(
|
||||
"get_object", Params={"Bucket": bucket, "Key": key}, ExpiresIn=presign_expiry
|
||||
)
|
||||
|
||||
return self.create_text_message(text=result)
|
||||
|
||||
except self.s3_client.exceptions.NoSuchBucket:
|
||||
return self.create_text_message(f"Bucket '{bucket}' does not exist")
|
||||
except self.s3_client.exceptions.NoSuchKey:
|
||||
return self.create_text_message(f"Object '{key}' does not exist in bucket '{bucket}'")
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Exception: {str(e)}")
|
98
api/core/tools/provider/builtin/aws/tools/s3_operator.yaml
Normal file
98
api/core/tools/provider/builtin/aws/tools/s3_operator.yaml
Normal file
@ -0,0 +1,98 @@
|
||||
identity:
|
||||
name: s3_operator
|
||||
author: AWS
|
||||
label:
|
||||
en_US: AWS S3 Operator
|
||||
zh_Hans: AWS S3 读写器
|
||||
pt_BR: AWS S3 Operator
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: AWS S3 Writer and Reader
|
||||
zh_Hans: 读写S3 bucket中的文件
|
||||
pt_BR: AWS S3 Writer and Reader
|
||||
llm: AWS S3 Writer and Reader
|
||||
parameters:
|
||||
- name: text_content
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: The text to write
|
||||
zh_Hans: 待写入的文本
|
||||
pt_BR: The text to write
|
||||
human_description:
|
||||
en_US: The text to write
|
||||
zh_Hans: 待写入的文本
|
||||
pt_BR: The text to write
|
||||
llm_description: The text to write
|
||||
form: llm
|
||||
- name: s3_uri
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: s3 uri
|
||||
zh_Hans: s3 uri
|
||||
pt_BR: s3 uri
|
||||
human_description:
|
||||
en_US: s3 uri
|
||||
zh_Hans: s3 uri
|
||||
pt_BR: s3 uri
|
||||
llm_description: s3 uri
|
||||
form: llm
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: region of bucket
|
||||
zh_Hans: bucket 所在的region
|
||||
pt_BR: region of bucket
|
||||
human_description:
|
||||
en_US: region of bucket
|
||||
zh_Hans: bucket 所在的region
|
||||
pt_BR: region of bucket
|
||||
llm_description: region of bucket
|
||||
form: form
|
||||
- name: operation_type
|
||||
type: select
|
||||
required: true
|
||||
label:
|
||||
en_US: operation type
|
||||
zh_Hans: 操作类型
|
||||
pt_BR: operation type
|
||||
human_description:
|
||||
en_US: operation type
|
||||
zh_Hans: 操作类型
|
||||
pt_BR: operation type
|
||||
default: read
|
||||
options:
|
||||
- value: read
|
||||
label:
|
||||
en_US: read
|
||||
zh_Hans: 读
|
||||
- value: write
|
||||
label:
|
||||
en_US: write
|
||||
zh_Hans: 写
|
||||
form: form
|
||||
- name: generate_presign_url
|
||||
type: boolean
|
||||
required: false
|
||||
label:
|
||||
en_US: Generate presigned URL
|
||||
zh_Hans: 生成预签名URL
|
||||
human_description:
|
||||
en_US: Whether to generate a presigned URL for the S3 object
|
||||
zh_Hans: 是否生成S3对象的预签名URL
|
||||
default: false
|
||||
form: form
|
||||
- name: presign_expiry
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Presigned URL expiration time
|
||||
zh_Hans: 预签名URL有效期
|
||||
human_description:
|
||||
en_US: Expiration time in seconds for the presigned URL
|
||||
zh_Hans: 预签名URL的有效期(秒)
|
||||
default: 3600
|
||||
form: form
|
Loading…
x
Reference in New Issue
Block a user