mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-07-28 04:22:00 +08:00
Co-authored-by: Yuanbo Li <ybalbert@amazon.com>
This commit is contained in:
parent
c236f05f4b
commit
2a14c67edc
@ -14,14 +14,38 @@ class BedrockRetrieveTool(BuiltinTool):
|
|||||||
topk: int = None
|
topk: int = None
|
||||||
|
|
||||||
def _bedrock_retrieve(
|
def _bedrock_retrieve(
|
||||||
self, query_input: str, knowledge_base_id: str, num_results: int, metadata_filter: Optional[dict] = None
|
self,
|
||||||
|
query_input: str,
|
||||||
|
knowledge_base_id: str,
|
||||||
|
num_results: int,
|
||||||
|
search_type: str,
|
||||||
|
rerank_model_id: str,
|
||||||
|
metadata_filter: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
retrieval_query = {"text": query_input}
|
retrieval_query = {"text": query_input}
|
||||||
|
|
||||||
retrieval_configuration = {"vectorSearchConfiguration": {"numberOfResults": num_results}}
|
if search_type not in ["HYBRID", "SEMANTIC"]:
|
||||||
|
raise RuntimeException("search_type should be HYBRID or SEMANTIC")
|
||||||
|
|
||||||
# Add metadata filter to retrieval configuration if present
|
retrieval_configuration = {
|
||||||
|
"vectorSearchConfiguration": {"numberOfResults": num_results, "overrideSearchType": search_type}
|
||||||
|
}
|
||||||
|
|
||||||
|
if rerank_model_id != "default":
|
||||||
|
model_for_rerank_arn = f"arn:aws:bedrock:us-west-2::foundation-model/{rerank_model_id}"
|
||||||
|
rerankingConfiguration = {
|
||||||
|
"bedrockRerankingConfiguration": {
|
||||||
|
"numberOfRerankedResults": num_results,
|
||||||
|
"modelConfiguration": {"modelArn": model_for_rerank_arn},
|
||||||
|
},
|
||||||
|
"type": "BEDROCK_RERANKING_MODEL",
|
||||||
|
}
|
||||||
|
|
||||||
|
retrieval_configuration["vectorSearchConfiguration"]["rerankingConfiguration"] = rerankingConfiguration
|
||||||
|
retrieval_configuration["vectorSearchConfiguration"]["numberOfResults"] = num_results * 5
|
||||||
|
|
||||||
|
# 如果有元数据过滤条件,则添加到检索配置中
|
||||||
if metadata_filter:
|
if metadata_filter:
|
||||||
retrieval_configuration["vectorSearchConfiguration"]["filter"] = metadata_filter
|
retrieval_configuration["vectorSearchConfiguration"]["filter"] = metadata_filter
|
||||||
|
|
||||||
@ -77,15 +101,20 @@ class BedrockRetrieveTool(BuiltinTool):
|
|||||||
if not query:
|
if not query:
|
||||||
return self.create_text_message("Please input query")
|
return self.create_text_message("Please input query")
|
||||||
|
|
||||||
# Get metadata filter conditions (if they exist)
|
# 获取元数据过滤条件(如果存在)
|
||||||
metadata_filter_str = tool_parameters.get("metadata_filter")
|
metadata_filter_str = tool_parameters.get("metadata_filter")
|
||||||
metadata_filter = json.loads(metadata_filter_str) if metadata_filter_str else None
|
metadata_filter = json.loads(metadata_filter_str) if metadata_filter_str else None
|
||||||
|
|
||||||
|
search_type = tool_parameters.get("search_type")
|
||||||
|
rerank_model_id = tool_parameters.get("rerank_model_id")
|
||||||
|
|
||||||
line = 4
|
line = 4
|
||||||
retrieved_docs = self._bedrock_retrieve(
|
retrieved_docs = self._bedrock_retrieve(
|
||||||
query_input=query,
|
query_input=query,
|
||||||
knowledge_base_id=self.knowledge_base_id,
|
knowledge_base_id=self.knowledge_base_id,
|
||||||
num_results=self.topk,
|
num_results=self.topk,
|
||||||
|
search_type=search_type,
|
||||||
|
rerank_model_id=rerank_model_id,
|
||||||
metadata_filter=metadata_filter,
|
metadata_filter=metadata_filter,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -109,7 +138,7 @@ class BedrockRetrieveTool(BuiltinTool):
|
|||||||
if not parameters.get("query"):
|
if not parameters.get("query"):
|
||||||
raise ValueError("query is required")
|
raise ValueError("query is required")
|
||||||
|
|
||||||
# Optional: Validate if metadata filter is a valid JSON string (if provided)
|
# 可选:可以验证元数据过滤条件是否为有效的 JSON 字符串(如果提供)
|
||||||
metadata_filter_str = parameters.get("metadata_filter")
|
metadata_filter_str = parameters.get("metadata_filter")
|
||||||
if metadata_filter_str and not isinstance(json.loads(metadata_filter_str), dict):
|
if metadata_filter_str and not isinstance(json.loads(metadata_filter_str), dict):
|
||||||
raise ValueError("metadata_filter must be a valid JSON object")
|
raise ValueError("metadata_filter must be a valid JSON object")
|
||||||
|
@ -59,6 +59,57 @@ parameters:
|
|||||||
max: 10
|
max: 10
|
||||||
default: 5
|
default: 5
|
||||||
|
|
||||||
|
- name: search_type
|
||||||
|
type: select
|
||||||
|
required: false
|
||||||
|
label:
|
||||||
|
en_US: search type
|
||||||
|
zh_Hans: 搜索类型
|
||||||
|
pt_BR: search type
|
||||||
|
human_description:
|
||||||
|
en_US: search type
|
||||||
|
zh_Hans: 搜索类型
|
||||||
|
pt_BR: search type
|
||||||
|
llm_description: search type
|
||||||
|
default: SEMANTIC
|
||||||
|
options:
|
||||||
|
- value: SEMANTIC
|
||||||
|
label:
|
||||||
|
en_US: SEMANTIC
|
||||||
|
zh_Hans: 语义搜索
|
||||||
|
- value: HYBRID
|
||||||
|
label:
|
||||||
|
en_US: HYBRID
|
||||||
|
zh_Hans: 混合搜索
|
||||||
|
form: form
|
||||||
|
|
||||||
|
- name: rerank_model_id
|
||||||
|
type: select
|
||||||
|
required: false
|
||||||
|
label:
|
||||||
|
en_US: rerank model id
|
||||||
|
zh_Hans: 重拍模型ID
|
||||||
|
pt_BR: rerank model id
|
||||||
|
human_description:
|
||||||
|
en_US: rerank model id
|
||||||
|
zh_Hans: 重拍模型ID
|
||||||
|
pt_BR: rerank model id
|
||||||
|
llm_description: rerank model id
|
||||||
|
options:
|
||||||
|
- value: default
|
||||||
|
label:
|
||||||
|
en_US: default
|
||||||
|
zh_Hans: 默认
|
||||||
|
- value: cohere.rerank-v3-5:0
|
||||||
|
label:
|
||||||
|
en_US: cohere.rerank-v3-5:0
|
||||||
|
zh_Hans: cohere.rerank-v3-5:0
|
||||||
|
- value: amazon.rerank-v1:0
|
||||||
|
label:
|
||||||
|
en_US: amazon.rerank-v1:0
|
||||||
|
zh_Hans: amazon.rerank-v1:0
|
||||||
|
form: form
|
||||||
|
|
||||||
- name: aws_region
|
- name: aws_region
|
||||||
type: string
|
type: string
|
||||||
required: false
|
required: false
|
||||||
|
Loading…
x
Reference in New Issue
Block a user