mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-11 17:09:01 +08:00
[Fix] Fix sagemaker_chinese_toxicity_detector and bedrock_retrieve (#12227)
This commit is contained in:
parent
adacd01f82
commit
562450751f
@ -21,7 +21,7 @@ class BedrockRetrieveTool(BuiltinTool):
|
|||||||
|
|
||||||
retrieval_configuration = {"vectorSearchConfiguration": {"numberOfResults": num_results}}
|
retrieval_configuration = {"vectorSearchConfiguration": {"numberOfResults": num_results}}
|
||||||
|
|
||||||
# 如果有元数据过滤条件,则添加到检索配置中
|
# Add metadata filter to retrieval configuration if present
|
||||||
if metadata_filter:
|
if metadata_filter:
|
||||||
retrieval_configuration["vectorSearchConfiguration"]["filter"] = metadata_filter
|
retrieval_configuration["vectorSearchConfiguration"]["filter"] = metadata_filter
|
||||||
|
|
||||||
@ -77,7 +77,7 @@ class BedrockRetrieveTool(BuiltinTool):
|
|||||||
if not query:
|
if not query:
|
||||||
return self.create_text_message("Please input query")
|
return self.create_text_message("Please input query")
|
||||||
|
|
||||||
# 获取元数据过滤条件(如果存在)
|
# Get metadata filter conditions (if they exist)
|
||||||
metadata_filter_str = tool_parameters.get("metadata_filter")
|
metadata_filter_str = tool_parameters.get("metadata_filter")
|
||||||
metadata_filter = json.loads(metadata_filter_str) if metadata_filter_str else None
|
metadata_filter = json.loads(metadata_filter_str) if metadata_filter_str else None
|
||||||
|
|
||||||
@ -86,7 +86,7 @@ class BedrockRetrieveTool(BuiltinTool):
|
|||||||
query_input=query,
|
query_input=query,
|
||||||
knowledge_base_id=self.knowledge_base_id,
|
knowledge_base_id=self.knowledge_base_id,
|
||||||
num_results=self.topk,
|
num_results=self.topk,
|
||||||
metadata_filter=metadata_filter, # 将元数据过滤条件传递给检索方法
|
metadata_filter=metadata_filter,
|
||||||
)
|
)
|
||||||
|
|
||||||
line = 5
|
line = 5
|
||||||
@ -109,7 +109,7 @@ class BedrockRetrieveTool(BuiltinTool):
|
|||||||
if not parameters.get("query"):
|
if not parameters.get("query"):
|
||||||
raise ValueError("query is required")
|
raise ValueError("query is required")
|
||||||
|
|
||||||
# 可选:可以验证元数据过滤条件是否为有效的 JSON 字符串(如果提供)
|
# Optional: Validate if metadata filter is a valid JSON string (if provided)
|
||||||
metadata_filter_str = parameters.get("metadata_filter")
|
metadata_filter_str = parameters.get("metadata_filter")
|
||||||
if metadata_filter_str and not isinstance(json.loads(metadata_filter_str), dict):
|
if metadata_filter_str and not isinstance(json.loads(metadata_filter_str), dict):
|
||||||
raise ValueError("metadata_filter must be a valid JSON object")
|
raise ValueError("metadata_filter must be a valid JSON object")
|
||||||
|
@ -73,9 +73,9 @@ parameters:
|
|||||||
llm_description: AWS region where the Bedrock Knowledge Base is located
|
llm_description: AWS region where the Bedrock Knowledge Base is located
|
||||||
form: form
|
form: form
|
||||||
|
|
||||||
- name: metadata_filter
|
- name: metadata_filter # 新增的元数据过滤参数
|
||||||
type: string
|
type: string # 可以是字符串类型,包含 JSON 格式的过滤条件
|
||||||
required: false
|
required: false # 可选参数
|
||||||
label:
|
label:
|
||||||
en_US: Metadata Filter
|
en_US: Metadata Filter
|
||||||
zh_Hans: 元数据过滤器
|
zh_Hans: 元数据过滤器
|
||||||
|
@ -6,8 +6,8 @@ import boto3
|
|||||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||||
from core.tools.tool.builtin_tool import BuiltinTool
|
from core.tools.tool.builtin_tool import BuiltinTool
|
||||||
|
|
||||||
# 定义标签映射
|
# Define label mappings
|
||||||
LABEL_MAPPING = {"LABEL_0": "SAFE", "LABEL_1": "NO_SAFE"}
|
LABEL_MAPPING = {0: "SAFE", 1: "NO_SAFE"}
|
||||||
|
|
||||||
|
|
||||||
class ContentModerationTool(BuiltinTool):
|
class ContentModerationTool(BuiltinTool):
|
||||||
@ -28,12 +28,12 @@ class ContentModerationTool(BuiltinTool):
|
|||||||
# Handle nested JSON if present
|
# Handle nested JSON if present
|
||||||
if isinstance(json_obj, dict) and "body" in json_obj:
|
if isinstance(json_obj, dict) and "body" in json_obj:
|
||||||
body_content = json.loads(json_obj["body"])
|
body_content = json.loads(json_obj["body"])
|
||||||
raw_label = body_content.get("label")
|
prediction_result = body_content.get("prediction")
|
||||||
else:
|
else:
|
||||||
raw_label = json_obj.get("label")
|
prediction_result = json_obj.get("prediction")
|
||||||
|
|
||||||
# 映射标签并返回
|
# Map labels and return
|
||||||
result = LABEL_MAPPING.get(raw_label, "NO_SAFE") # 如果映射中没有找到,默认返回NO_SAFE
|
result = LABEL_MAPPING.get(prediction_result, "NO_SAFE") # If not found in mapping, default to NO_SAFE
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _invoke(
|
def _invoke(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user