mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-10 07:38:59 +08:00
[Fix] Fix sagemaker_chinese_toxicity_detector and bedrock_retrieve (#12227)
This commit is contained in:
parent
adacd01f82
commit
562450751f
@ -21,7 +21,7 @@ class BedrockRetrieveTool(BuiltinTool):
|
||||
|
||||
retrieval_configuration = {"vectorSearchConfiguration": {"numberOfResults": num_results}}
|
||||
|
||||
# 如果有元数据过滤条件,则添加到检索配置中
|
||||
# Add metadata filter to retrieval configuration if present
|
||||
if metadata_filter:
|
||||
retrieval_configuration["vectorSearchConfiguration"]["filter"] = metadata_filter
|
||||
|
||||
@ -77,7 +77,7 @@ class BedrockRetrieveTool(BuiltinTool):
|
||||
if not query:
|
||||
return self.create_text_message("Please input query")
|
||||
|
||||
# 获取元数据过滤条件(如果存在)
|
||||
# Get metadata filter conditions (if they exist)
|
||||
metadata_filter_str = tool_parameters.get("metadata_filter")
|
||||
metadata_filter = json.loads(metadata_filter_str) if metadata_filter_str else None
|
||||
|
||||
@ -86,7 +86,7 @@ class BedrockRetrieveTool(BuiltinTool):
|
||||
query_input=query,
|
||||
knowledge_base_id=self.knowledge_base_id,
|
||||
num_results=self.topk,
|
||||
metadata_filter=metadata_filter, # 将元数据过滤条件传递给检索方法
|
||||
metadata_filter=metadata_filter,
|
||||
)
|
||||
|
||||
line = 5
|
||||
@ -109,7 +109,7 @@ class BedrockRetrieveTool(BuiltinTool):
|
||||
if not parameters.get("query"):
|
||||
raise ValueError("query is required")
|
||||
|
||||
# 可选:可以验证元数据过滤条件是否为有效的 JSON 字符串(如果提供)
|
||||
# Optional: Validate if metadata filter is a valid JSON string (if provided)
|
||||
metadata_filter_str = parameters.get("metadata_filter")
|
||||
if metadata_filter_str and not isinstance(json.loads(metadata_filter_str), dict):
|
||||
raise ValueError("metadata_filter must be a valid JSON object")
|
||||
|
@ -73,9 +73,9 @@ parameters:
|
||||
llm_description: AWS region where the Bedrock Knowledge Base is located
|
||||
form: form
|
||||
|
||||
- name: metadata_filter
|
||||
type: string
|
||||
required: false
|
||||
- name: metadata_filter # 新增的元数据过滤参数
|
||||
type: string # 可以是字符串类型,包含 JSON 格式的过滤条件
|
||||
required: false # 可选参数
|
||||
label:
|
||||
en_US: Metadata Filter
|
||||
zh_Hans: 元数据过滤器
|
||||
|
@ -6,8 +6,8 @@ import boto3
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
# 定义标签映射
|
||||
LABEL_MAPPING = {"LABEL_0": "SAFE", "LABEL_1": "NO_SAFE"}
|
||||
# Define label mappings
|
||||
LABEL_MAPPING = {0: "SAFE", 1: "NO_SAFE"}
|
||||
|
||||
|
||||
class ContentModerationTool(BuiltinTool):
|
||||
@ -28,12 +28,12 @@ class ContentModerationTool(BuiltinTool):
|
||||
# Handle nested JSON if present
|
||||
if isinstance(json_obj, dict) and "body" in json_obj:
|
||||
body_content = json.loads(json_obj["body"])
|
||||
raw_label = body_content.get("label")
|
||||
prediction_result = body_content.get("prediction")
|
||||
else:
|
||||
raw_label = json_obj.get("label")
|
||||
prediction_result = json_obj.get("prediction")
|
||||
|
||||
# 映射标签并返回
|
||||
result = LABEL_MAPPING.get(raw_label, "NO_SAFE") # 如果映射中没有找到,默认返回NO_SAFE
|
||||
# Map labels and return
|
||||
result = LABEL_MAPPING.get(prediction_result, "NO_SAFE") # If not found in mapping, default to NO_SAFE
|
||||
return result
|
||||
|
||||
def _invoke(
|
||||
|
Loading…
x
Reference in New Issue
Block a user