From 2a14c67edcf76034346d55251ed92b4e0eaa366e Mon Sep 17 00:00:00 2001 From: ybalbert001 <120714773+ybalbert001@users.noreply.github.com> Date: Tue, 7 Jan 2025 19:51:23 +0800 Subject: [PATCH] =?UTF-8?q?Fix=20#12448=20-=20update=20bedrock=20retrieve?= =?UTF-8?q?=20tool,=20support=20hybrid=20search=20type=20and=20re=E2=80=A6?= =?UTF-8?q?=20(#12446)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Yuanbo Li --- .../builtin/aws/tools/bedrock_retrieve.py | 39 ++++++++++++-- .../builtin/aws/tools/bedrock_retrieve.yaml | 51 +++++++++++++++++++ 2 files changed, 85 insertions(+), 5 deletions(-) diff --git a/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.py b/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.py index aca369b438..2e6a9740c2 100644 --- a/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.py +++ b/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.py @@ -14,14 +14,38 @@ class BedrockRetrieveTool(BuiltinTool): topk: int = None def _bedrock_retrieve( - self, query_input: str, knowledge_base_id: str, num_results: int, metadata_filter: Optional[dict] = None + self, + query_input: str, + knowledge_base_id: str, + num_results: int, + search_type: str, + rerank_model_id: str, + metadata_filter: Optional[dict] = None, ): try: retrieval_query = {"text": query_input} - retrieval_configuration = {"vectorSearchConfiguration": {"numberOfResults": num_results}} + if search_type not in ["HYBRID", "SEMANTIC"]: + raise RuntimeException("search_type should be HYBRID or SEMANTIC") - # Add metadata filter to retrieval configuration if present + retrieval_configuration = { + "vectorSearchConfiguration": {"numberOfResults": num_results, "overrideSearchType": search_type} + } + + if rerank_model_id != "default": + model_for_rerank_arn = f"arn:aws:bedrock:us-west-2::foundation-model/{rerank_model_id}" + rerankingConfiguration = { + "bedrockRerankingConfiguration": { + "numberOfRerankedResults": num_results, + "modelConfiguration": {"modelArn": model_for_rerank_arn}, + }, + "type": "BEDROCK_RERANKING_MODEL", + } + + retrieval_configuration["vectorSearchConfiguration"]["rerankingConfiguration"] = rerankingConfiguration + retrieval_configuration["vectorSearchConfiguration"]["numberOfResults"] = num_results * 5 + + # 如果有元数据过滤条件,则添加到检索配置中 if metadata_filter: retrieval_configuration["vectorSearchConfiguration"]["filter"] = metadata_filter @@ -77,15 +101,20 @@ class BedrockRetrieveTool(BuiltinTool): if not query: return self.create_text_message("Please input query") - # Get metadata filter conditions (if they exist) + # 获取元数据过滤条件(如果存在) metadata_filter_str = tool_parameters.get("metadata_filter") metadata_filter = json.loads(metadata_filter_str) if metadata_filter_str else None + search_type = tool_parameters.get("search_type") + rerank_model_id = tool_parameters.get("rerank_model_id") + line = 4 retrieved_docs = self._bedrock_retrieve( query_input=query, knowledge_base_id=self.knowledge_base_id, num_results=self.topk, + search_type=search_type, + rerank_model_id=rerank_model_id, metadata_filter=metadata_filter, ) @@ -109,7 +138,7 @@ class BedrockRetrieveTool(BuiltinTool): if not parameters.get("query"): raise ValueError("query is required") - # Optional: Validate if metadata filter is a valid JSON string (if provided) + # 可选:可以验证元数据过滤条件是否为有效的 JSON 字符串(如果提供) metadata_filter_str = parameters.get("metadata_filter") if metadata_filter_str and not isinstance(json.loads(metadata_filter_str), dict): raise ValueError("metadata_filter must be a valid JSON object") diff --git a/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.yaml b/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.yaml index 31961a0cf0..f8d1d1d49d 100644 --- a/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.yaml +++ b/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.yaml @@ -59,6 +59,57 @@ parameters: max: 10 default: 5 + - name: search_type + type: select + required: false + label: + en_US: search type + zh_Hans: 搜索类型 + pt_BR: search type + human_description: + en_US: search type + zh_Hans: 搜索类型 + pt_BR: search type + llm_description: search type + default: SEMANTIC + options: + - value: SEMANTIC + label: + en_US: SEMANTIC + zh_Hans: 语义搜索 + - value: HYBRID + label: + en_US: HYBRID + zh_Hans: 混合搜索 + form: form + + - name: rerank_model_id + type: select + required: false + label: + en_US: rerank model id + zh_Hans: 重拍模型ID + pt_BR: rerank model id + human_description: + en_US: rerank model id + zh_Hans: 重拍模型ID + pt_BR: rerank model id + llm_description: rerank model id + options: + - value: default + label: + en_US: default + zh_Hans: 默认 + - value: cohere.rerank-v3-5:0 + label: + en_US: cohere.rerank-v3-5:0 + zh_Hans: cohere.rerank-v3-5:0 + - value: amazon.rerank-v1:0 + label: + en_US: amazon.rerank-v1:0 + zh_Hans: amazon.rerank-v1:0 + form: form + - name: aws_region type: string required: false