[Fix] Fix sagemaker_chinese_toxicity_detector and bedrock_retrieve (#12227)

This commit is contained in:
Warren Chen 2024-12-30 22:26:04 +08:00 committed by GitHub
parent adacd01f82
commit 562450751f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 13 additions and 13 deletions

View File

@ -21,7 +21,7 @@ class BedrockRetrieveTool(BuiltinTool):
retrieval_configuration = {"vectorSearchConfiguration": {"numberOfResults": num_results}}
# 如果有元数据过滤条件,则添加到检索配置中
# Add metadata filter to retrieval configuration if present
if metadata_filter:
retrieval_configuration["vectorSearchConfiguration"]["filter"] = metadata_filter
@ -77,7 +77,7 @@ class BedrockRetrieveTool(BuiltinTool):
if not query:
return self.create_text_message("Please input query")
# 获取元数据过滤条件(如果存在)
# Get metadata filter conditions (if they exist)
metadata_filter_str = tool_parameters.get("metadata_filter")
metadata_filter = json.loads(metadata_filter_str) if metadata_filter_str else None
@ -86,7 +86,7 @@ class BedrockRetrieveTool(BuiltinTool):
query_input=query,
knowledge_base_id=self.knowledge_base_id,
num_results=self.topk,
metadata_filter=metadata_filter, # 将元数据过滤条件传递给检索方法
metadata_filter=metadata_filter,
)
line = 5
@ -109,7 +109,7 @@ class BedrockRetrieveTool(BuiltinTool):
if not parameters.get("query"):
raise ValueError("query is required")
# 可选:可以验证元数据过滤条件是否为有效的 JSON 字符串(如果提供)
# Optional: Validate if metadata filter is a valid JSON string (if provided)
metadata_filter_str = parameters.get("metadata_filter")
if metadata_filter_str and not isinstance(json.loads(metadata_filter_str), dict):
raise ValueError("metadata_filter must be a valid JSON object")

View File

@ -73,9 +73,9 @@ parameters:
llm_description: AWS region where the Bedrock Knowledge Base is located
form: form
- name: metadata_filter
type: string
required: false
- name: metadata_filter # 新增的元数据过滤参数
type: string # 可以是字符串类型,包含 JSON 格式的过滤条件
required: false # 可选参数
label:
en_US: Metadata Filter
zh_Hans: 元数据过滤器

View File

@ -6,8 +6,8 @@ import boto3
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
# 定义标签映射
LABEL_MAPPING = {"LABEL_0": "SAFE", "LABEL_1": "NO_SAFE"}
# Define label mappings
LABEL_MAPPING = {0: "SAFE", 1: "NO_SAFE"}
class ContentModerationTool(BuiltinTool):
@ -28,12 +28,12 @@ class ContentModerationTool(BuiltinTool):
# Handle nested JSON if present
if isinstance(json_obj, dict) and "body" in json_obj:
body_content = json.loads(json_obj["body"])
raw_label = body_content.get("label")
prediction_result = body_content.get("prediction")
else:
raw_label = json_obj.get("label")
prediction_result = json_obj.get("prediction")
# 映射标签并返回
result = LABEL_MAPPING.get(raw_label, "NO_SAFE") # 如果映射中没有找到,默认返回NO_SAFE
# Map labels and return
result = LABEL_MAPPING.get(prediction_result, "NO_SAFE") # If not found in mapping, default to NO_SAFE
return result
def _invoke(