Update aws tools (#11174)

Co-authored-by: Yuanbo Li <ybalbert@amazon.com>
This commit is contained in:
ybalbert001 2024-11-29 09:28:28 +08:00 committed by GitHub
parent e576d32fb6
commit cc0b92bc75
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 668 additions and 4 deletions

View File

@ -12,7 +12,7 @@ class LambdaTranslateUtilsTool(BuiltinTool):
def _invoke_lambda(self, text_content, src_lang, dest_lang, model_id, dictionary_name, request_type, lambda_name):
msg = {
"src_content": text_content,
"src_contents": [text_content],
"src_lang": src_lang,
"dest_lang": dest_lang,
"dictionary_id": dictionary_name,

View File

@ -8,9 +8,9 @@ identity:
icon: icon.svg
description:
human:
en_US: A util tools for LLM translation, extra deployment is needed on AWS. Please refer Github Repo - https://github.com/ybalbert001/dynamodb-rag
zh_Hans: 大语言模型翻译工具(专词映射获取)需要在AWS上进行额外部署可参考Github Repo - https://github.com/ybalbert001/dynamodb-rag
pt_BR: A util tools for LLM translation, specific Lambda Function deployment is needed on AWS. Please refer Github Repo - https://github.com/ybalbert001/dynamodb-rag
en_US: A util tools for LLM translation, extra deployment is needed on AWS. Please refer Github Repo - https://github.com/aws-samples/rag-based-translation-with-dynamodb-and-bedrock
zh_Hans: 大语言模型翻译工具(专词映射获取)需要在AWS上进行额外部署可参考Github Repo - https://github.com/aws-samples/rag-based-translation-with-dynamodb-and-bedrock
pt_BR: A util tools for LLM translation, specific Lambda Function deployment is needed on AWS. Please refer Github Repo - https://github.com/aws-samples/rag-based-translation-with-dynamodb-and-bedrock
llm: A util tools for translation.
parameters:
- name: text_content

View File

@ -0,0 +1,67 @@
import json
from typing import Any, Union
import boto3
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
# 定义标签映射
LABEL_MAPPING = {"LABEL_0": "SAFE", "LABEL_1": "NO_SAFE"}
class ContentModerationTool(BuiltinTool):
sagemaker_client: Any = None
sagemaker_endpoint: str = None
def _invoke_sagemaker(self, payload: dict, endpoint: str):
response = self.sagemaker_client.invoke_endpoint(
EndpointName=endpoint,
Body=json.dumps(payload),
ContentType="application/json",
)
# Parse response
response_body = response["Body"].read().decode("utf8")
json_obj = json.loads(response_body)
# Handle nested JSON if present
if isinstance(json_obj, dict) and "body" in json_obj:
body_content = json.loads(json_obj["body"])
raw_label = body_content.get("label")
else:
raw_label = json_obj.get("label")
# 映射标签并返回
result = LABEL_MAPPING.get(raw_label, "NO_SAFE") # 如果映射中没有找到默认返回NO_SAFE
return result
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
"""
try:
if not self.sagemaker_client:
aws_region = tool_parameters.get("aws_region")
if aws_region:
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
else:
self.sagemaker_client = boto3.client("sagemaker-runtime")
if not self.sagemaker_endpoint:
self.sagemaker_endpoint = tool_parameters.get("sagemaker_endpoint")
content_text = tool_parameters.get("content_text")
payload = {"text": content_text}
result = self._invoke_sagemaker(payload, self.sagemaker_endpoint)
return self.create_text_message(text=result)
except Exception as e:
return self.create_text_message(f"Exception {str(e)}")

View File

@ -0,0 +1,46 @@
identity:
name: chinese_toxicity_detector
author: AWS
label:
en_US: Chinese Toxicity Detector
zh_Hans: 中文有害内容检测
icon: icon.svg
description:
human:
en_US: A tool to detect Chinese toxicity
zh_Hans: 检测中文有害内容的工具
llm: A tool that checks if Chinese content is safe for work
parameters:
- name: sagemaker_endpoint
type: string
required: true
label:
en_US: sagemaker endpoint for moderation
zh_Hans: 内容审核的SageMaker端点
human_description:
en_US: sagemaker endpoint for content moderation
zh_Hans: 内容审核的SageMaker端点
llm_description: sagemaker endpoint for content moderation
form: form
- name: content_text
type: string
required: true
label:
en_US: content text
zh_Hans: 待审核文本
human_description:
en_US: text content to be moderated
zh_Hans: 需要审核的文本内容
llm_description: text content to be moderated
form: llm
- name: aws_region
type: string
required: false
label:
en_US: region of sagemaker endpoint
zh_Hans: SageMaker 端点所在的region
human_description:
en_US: region of sagemaker endpoint
zh_Hans: SageMaker 端点所在的region
llm_description: region of sagemaker endpoint
form: form

View File

@ -0,0 +1,418 @@
import json
import logging
import os
import re
import time
import uuid
from typing import Any, Union
from urllib.parse import urlparse
import boto3
import requests
from botocore.exceptions import ClientError
from requests.exceptions import RequestException
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
LanguageCodeOptions = [
"af-ZA",
"ar-AE",
"ar-SA",
"da-DK",
"de-CH",
"de-DE",
"en-AB",
"en-AU",
"en-GB",
"en-IE",
"en-IN",
"en-US",
"en-WL",
"es-ES",
"es-US",
"fa-IR",
"fr-CA",
"fr-FR",
"he-IL",
"hi-IN",
"id-ID",
"it-IT",
"ja-JP",
"ko-KR",
"ms-MY",
"nl-NL",
"pt-BR",
"pt-PT",
"ru-RU",
"ta-IN",
"te-IN",
"tr-TR",
"zh-CN",
"zh-TW",
"th-TH",
"en-ZA",
"en-NZ",
"vi-VN",
"sv-SE",
"ab-GE",
"ast-ES",
"az-AZ",
"ba-RU",
"be-BY",
"bg-BG",
"bn-IN",
"bs-BA",
"ca-ES",
"ckb-IQ",
"ckb-IR",
"cs-CZ",
"cy-WL",
"el-GR",
"et-ET",
"eu-ES",
"fi-FI",
"gl-ES",
"gu-IN",
"ha-NG",
"hr-HR",
"hu-HU",
"hy-AM",
"is-IS",
"ka-GE",
"kab-DZ",
"kk-KZ",
"kn-IN",
"ky-KG",
"lg-IN",
"lt-LT",
"lv-LV",
"mhr-RU",
"mi-NZ",
"mk-MK",
"ml-IN",
"mn-MN",
"mr-IN",
"mt-MT",
"no-NO",
"or-IN",
"pa-IN",
"pl-PL",
"ps-AF",
"ro-RO",
"rw-RW",
"si-LK",
"sk-SK",
"sl-SI",
"so-SO",
"sr-RS",
"su-ID",
"sw-BI",
"sw-KE",
"sw-RW",
"sw-TZ",
"sw-UG",
"tl-PH",
"tt-RU",
"ug-CN",
"uk-UA",
"uz-UZ",
"wo-SN",
"zu-ZA",
]
MediaFormat = ["mp3", "mp4", "wav", "flac", "ogg", "amr", "webm", "m4a"]
def is_url(text):
if not text:
return False
text = text.strip()
# Regular expression pattern for URL validation
pattern = re.compile(
r"^" # Start of the string
r"(?:http|https)://" # Protocol (http or https)
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|" # Domain
r"localhost|" # localhost
r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})" # IP address
r"(?::\d+)?" # Optional port
r"(?:/?|[/?]\S+)" # Path
r"$", # End of the string
re.IGNORECASE,
)
return bool(pattern.match(text))
def upload_file_from_url_to_s3(s3_client, url, bucket_name, s3_key=None, max_retries=3):
"""
Upload a file from a URL to an S3 bucket with retries and better error handling.
Parameters:
- s3_client
- url (str): The URL of the file to upload
- bucket_name (str): The name of the S3 bucket
- s3_key (str): The desired key (path) in S3. If None, will use the filename from URL
- max_retries (int): Maximum number of retry attempts
Returns:
- tuple: (bool, str) - (Success status, Message)
"""
# Validate inputs
if not url or not bucket_name:
return False, "URL and bucket name are required"
retry_count = 0
while retry_count < max_retries:
try:
# Download the file from URL
response = requests.get(url, stream=True, timeout=30)
response.raise_for_status()
# If s3_key is not provided, try to get filename from URL
if not s3_key:
parsed_url = urlparse(url)
filename = os.path.basename(parsed_url.path.split("/file-preview")[0])
s3_key = "transcribe-files/" + filename
# Upload the file to S3
s3_client.upload_fileobj(
response.raw,
bucket_name,
s3_key,
ExtraArgs={
"ContentType": response.headers.get("content-type"),
"ACL": "private", # Ensure the uploaded file is private
},
)
return f"s3://{bucket_name}/{s3_key}", f"Successfully uploaded file to s3://{bucket_name}/{s3_key}"
except RequestException as e:
retry_count += 1
if retry_count == max_retries:
return None, f"Failed to download file from URL after {max_retries} attempts: {str(e)}"
continue
except ClientError as e:
return None, f"AWS S3 error: {str(e)}"
except Exception as e:
return None, f"Unexpected error: {str(e)}"
return None, "Maximum retries exceeded"
class TranscribeTool(BuiltinTool):
s3_client: Any = None
transcribe_client: Any = None
"""
Note that you must include one of LanguageCode, IdentifyLanguage,
or IdentifyMultipleLanguages in your request.
If you include more than one of these parameters, your transcription job fails.
"""
def _transcribe_audio(self, audio_file_uri, file_type, **extra_args):
uuid_str = str(uuid.uuid4())
job_name = f"{int(time.time())}-{uuid_str}"
try:
# Start transcription job
response = self.transcribe_client.start_transcription_job(
TranscriptionJobName=job_name, Media={"MediaFileUri": audio_file_uri}, **extra_args
)
# Wait for the job to complete
while True:
status = self.transcribe_client.get_transcription_job(TranscriptionJobName=job_name)
if status["TranscriptionJob"]["TranscriptionJobStatus"] in ["COMPLETED", "FAILED"]:
break
time.sleep(5)
if status["TranscriptionJob"]["TranscriptionJobStatus"] == "COMPLETED":
return status["TranscriptionJob"]["Transcript"]["TranscriptFileUri"], None
else:
return None, f"Error: TranscriptionJobStatus:{status['TranscriptionJob']['TranscriptionJobStatus']} "
except Exception as e:
return None, f"Error: {str(e)}"
def _download_and_read_transcript(self, transcript_file_uri: str, max_retries: int = 3) -> tuple[str, str]:
"""
Download and read the transcript file from the given URI.
Parameters:
- transcript_file_uri (str): The URI of the transcript file
- max_retries (int): Maximum number of retry attempts
Returns:
- tuple: (text, error) - (Transcribed text if successful, error message if failed)
"""
retry_count = 0
while retry_count < max_retries:
try:
# Download the transcript file
response = requests.get(transcript_file_uri, timeout=30)
response.raise_for_status()
# Parse the JSON content
transcript_data = response.json()
# Check if speaker labels are present and enabled
has_speaker_labels = (
"results" in transcript_data
and "speaker_labels" in transcript_data["results"]
and "segments" in transcript_data["results"]["speaker_labels"]
)
if has_speaker_labels:
# Get speaker segments
segments = transcript_data["results"]["speaker_labels"]["segments"]
items = transcript_data["results"]["items"]
# Create a mapping of start_time -> speaker_label
time_to_speaker = {}
for segment in segments:
speaker_label = segment["speaker_label"]
for item in segment["items"]:
time_to_speaker[item["start_time"]] = speaker_label
# Build transcript with speaker labels
current_speaker = None
transcript_parts = []
for item in items:
# Skip non-pronunciation items (like punctuation)
if item["type"] == "punctuation":
transcript_parts.append(item["alternatives"][0]["content"])
continue
start_time = item["start_time"]
speaker = time_to_speaker.get(start_time)
if speaker != current_speaker:
current_speaker = speaker
transcript_parts.append(f"\n[{speaker}]: ")
transcript_parts.append(item["alternatives"][0]["content"])
return " ".join(transcript_parts).strip(), None
else:
# Extract the transcription text
# The transcript text is typically in the 'results' -> 'transcripts' array
if "results" in transcript_data and "transcripts" in transcript_data["results"]:
transcripts = transcript_data["results"]["transcripts"]
if transcripts:
# Combine all transcript segments
full_text = " ".join(t.get("transcript", "") for t in transcripts)
return full_text, None
return None, "No transcripts found in the response"
except requests.exceptions.RequestException as e:
retry_count += 1
if retry_count == max_retries:
return None, f"Failed to download transcript file after {max_retries} attempts: {str(e)}"
continue
except json.JSONDecodeError as e:
return None, f"Failed to parse transcript JSON: {str(e)}"
except Exception as e:
return None, f"Unexpected error while processing transcript: {str(e)}"
return None, "Maximum retries exceeded"
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
"""
try:
if not self.transcribe_client:
aws_region = tool_parameters.get("aws_region")
if aws_region:
self.transcribe_client = boto3.client("transcribe", region_name=aws_region)
self.s3_client = boto3.client("s3", region_name=aws_region)
else:
self.transcribe_client = boto3.client("transcribe")
self.s3_client = boto3.client("s3")
file_url = tool_parameters.get("file_url")
file_type = tool_parameters.get("file_type")
language_code = tool_parameters.get("language_code")
identify_language = tool_parameters.get("identify_language", True)
identify_multiple_languages = tool_parameters.get("identify_multiple_languages", False)
language_options_str = tool_parameters.get("language_options")
s3_bucket_name = tool_parameters.get("s3_bucket_name")
ShowSpeakerLabels = tool_parameters.get("ShowSpeakerLabels", True)
MaxSpeakerLabels = tool_parameters.get("MaxSpeakerLabels", 2)
# Check the input params
if not s3_bucket_name:
return self.create_text_message(text="s3_bucket_name is required")
language_options = None
if language_options_str:
language_options = language_options_str.split("|")
for lang in language_options:
if lang not in LanguageCodeOptions:
return self.create_text_message(
text=f"{lang} is not supported, should be one of {LanguageCodeOptions}"
)
if language_code and language_code not in LanguageCodeOptions:
err_msg = f"language_code:{language_code} is not supported, should be one of {LanguageCodeOptions}"
return self.create_text_message(text=err_msg)
err_msg = f"identify_language:{identify_language}, \
identify_multiple_languages:{identify_multiple_languages}, \
Note that you must include one of LanguageCode, IdentifyLanguage, \
or IdentifyMultipleLanguages in your request. \
If you include more than one of these parameters, \
your transcription job fails."
if not language_code:
if identify_language and identify_multiple_languages:
return self.create_text_message(text=err_msg)
else:
if identify_language or identify_multiple_languages:
return self.create_text_message(text=err_msg)
extra_args = {
"IdentifyLanguage": identify_language,
"IdentifyMultipleLanguages": identify_multiple_languages,
}
if language_code:
extra_args["LanguageCode"] = language_code
if language_options:
extra_args["LanguageOptions"] = language_options
if ShowSpeakerLabels:
extra_args["Settings"] = {"ShowSpeakerLabels": ShowSpeakerLabels, "MaxSpeakerLabels": MaxSpeakerLabels}
# upload to s3 bucket
s3_path_result, error = upload_file_from_url_to_s3(self.s3_client, url=file_url, bucket_name=s3_bucket_name)
if not s3_path_result:
return self.create_text_message(text=error)
transcript_file_uri, error = self._transcribe_audio(
audio_file_uri=s3_path_result,
file_type=file_type,
**extra_args,
)
if not transcript_file_uri:
return self.create_text_message(text=error)
# Download and read the transcript
transcript_text, error = self._download_and_read_transcript(transcript_file_uri)
if not transcript_text:
return self.create_text_message(text=error)
return self.create_text_message(text=transcript_text)
except Exception as e:
return self.create_text_message(f"Exception {str(e)}")

View File

@ -0,0 +1,133 @@
identity:
name: transcribe_asr
author: AWS
label:
en_US: TranscribeASR
zh_Hans: Transcribe语音识别转录
pt_BR: TranscribeASR
icon: icon.svg
description:
human:
en_US: A tool for ASR (Automatic Speech Recognition) - https://github.com/aws-samples/dify-aws-tool
zh_Hans: AWS 语音识别转录服务, 请参考 https://aws.amazon.com/cn/pm/transcribe/#Learn_More_About_Amazon_Transcribe
pt_BR: A tool for ASR (Automatic Speech Recognition).
llm: A tool for ASR (Automatic Speech Recognition).
parameters:
- name: file_url
type: string
required: true
label:
en_US: video or audio file url for transcribe
zh_Hans: 语音或者视频文件url
pt_BR: video or audio file url for transcribe
human_description:
en_US: video or audio file url for transcribe
zh_Hans: 语音或者视频文件url
pt_BR: video or audio file url for transcribe
llm_description: video or audio file url for transcribe
form: llm
- name: language_code
type: string
required: false
label:
en_US: Language Code
zh_Hans: 语言编码
pt_BR: Language Code
human_description:
en_US: The language code used to create your transcription job. refer to :https://docs.aws.amazon.com/transcribe/latest/dg/supported-languages.html
zh_Hans: 语言编码,例如zh-CN, en-US 可参考 https://docs.aws.amazon.com/transcribe/latest/dg/supported-languages.html
pt_BR: The language code used to create your transcription job. refer to :https://docs.aws.amazon.com/transcribe/latest/dg/supported-languages.html
llm_description: The language code used to create your transcription job.
form: llm
- name: identify_language
type: boolean
default: true
required: false
label:
en_US: Automactically Identify Language
zh_Hans: 自动识别语言
pt_BR: Automactically Identify Language
human_description:
en_US: Automactically Identify Language
zh_Hans: 自动识别语言
pt_BR: Automactically Identify Language
llm_description: Enable Automactically Identify Language
form: form
- name: identify_multiple_languages
type: boolean
required: false
label:
en_US: Automactically Identify Multiple Languages
zh_Hans: 自动识别多种语言
pt_BR: Automactically Identify Multiple Languages
human_description:
en_US: Automactically Identify Multiple Languages
zh_Hans: 自动识别多种语言
pt_BR: Automactically Identify Multiple Languages
llm_description: Enable Automactically Identify Multiple Languages
form: form
- name: language_options
type: string
required: false
label:
en_US: Language Options
zh_Hans: 语言种类选项
pt_BR: Language Options
human_description:
en_US: Seperated by |, e.g:zh-CN|en-US, You can specify two or more language codes that represent the languages you think may be present in your media
zh_Hans: 您可以指定两个或更多的语言代码来表示您认为可能出现在媒体中的语言。用|分隔,如 zh-CN|en-US
pt_BR: Seperated by |, e.g:zh-CN|en-US, You can specify two or more language codes that represent the languages you think may be present in your media
llm_description: Seperated by |, e.g:zh-CN|en-US, You can specify two or more language codes that represent the languages you think may be present in your media
form: llm
- name: s3_bucket_name
type: string
required: true
label:
en_US: s3 bucket name
zh_Hans: s3 存储桶名称
pt_BR: s3 bucket name
human_description:
en_US: s3 bucket name to store transcribe files (don't add prefix s3://)
zh_Hans: s3 存储桶名称,用于存储转录文件 (不需要前缀 s3://)
pt_BR: s3 bucket name to store transcribe files (don't add prefix s3://)
llm_description: s3 bucket name to store transcribe files
form: form
- name: ShowSpeakerLabels
type: boolean
required: true
default: true
label:
en_US: ShowSpeakerLabels
zh_Hans: 显示说话人标签
pt_BR: ShowSpeakerLabels
human_description:
en_US: Enables speaker partitioning (diarization) in your transcription output
zh_Hans: 在转录输出中启用说话人分区(说话人分离)
pt_BR: Enables speaker partitioning (diarization) in your transcription output
llm_description: Enables speaker partitioning (diarization) in your transcription output
form: form
- name: MaxSpeakerLabels
type: number
required: true
default: 2
label:
en_US: MaxSpeakerLabels
zh_Hans: 说话人标签数量
pt_BR: MaxSpeakerLabels
human_description:
en_US: Specify the maximum number of speakers you want to partition in your media
zh_Hans: 指定您希望在媒体中划分的最多演讲者数量。
pt_BR: Specify the maximum number of speakers you want to partition in your media
llm_description: Specify the maximum number of speakers you want to partition in your media
form: form
- name: aws_region
type: string
required: false
label:
en_US: AWS Region
zh_Hans: AWS 区域
human_description:
en_US: Please enter the AWS region for the transcribe service, for example 'us-east-1'.
zh_Hans: 请输入Transcribe的 AWS 区域,例如 'us-east-1'。
llm_description: Please enter the AWS region for the transcribe service, for example 'us-east-1'.
form: form