From 95b74c211df5e8191924c98f5cc1627a87343d9c Mon Sep 17 00:00:00 2001 From: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Date: Mon, 18 Mar 2024 16:55:26 +0800 Subject: [PATCH] Feat/support tool credentials bool schema (#2875) --- api/core/tools/entities/tool_entities.py | 3 +- api/core/tools/provider/builtin/bing/bing.py | 5 +- .../tools/provider/builtin/bing/bing.yaml | 60 +++++++ .../builtin/bing/tools/bing_web_search.py | 148 +++++++++++++----- .../tools/provider/builtin_tool_provider.py | 23 ++- api/services/tools_manage_service.py | 6 +- .../setting/build-in/config-credentials.tsx | 9 +- 7 files changed, 204 insertions(+), 50 deletions(-) diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index f7a61b0b0c..437f871864 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -171,6 +171,7 @@ class ToolProviderCredentials(BaseModel): SECRET_INPUT = "secret-input" TEXT_INPUT = "text-input" SELECT = "select" + BOOLEAN = "boolean" @classmethod def value_of(cls, value: str) -> "ToolProviderCredentials.CredentialsType": @@ -192,7 +193,7 @@ class ToolProviderCredentials(BaseModel): name: str = Field(..., description="The name of the credentials") type: CredentialsType = Field(..., description="The type of the credentials") required: bool = False - default: Optional[str] = None + default: Optional[Union[int, str]] = None options: Optional[list[ToolCredentialsOption]] = None label: Optional[I18nObject] = None help: Optional[I18nObject] = None diff --git a/api/core/tools/provider/builtin/bing/bing.py b/api/core/tools/provider/builtin/bing/bing.py index ff131b26cd..6e62abfc10 100644 --- a/api/core/tools/provider/builtin/bing/bing.py +++ b/api/core/tools/provider/builtin/bing/bing.py @@ -12,12 +12,11 @@ class BingProvider(BuiltinToolProviderController): meta={ "credentials": credentials, } - ).invoke( - user_id='', + ).validate_credentials( + credentials=credentials, tool_parameters={ "query": "test", "result_type": "link", - "enable_webpages": True, }, ) except Exception as e: diff --git a/api/core/tools/provider/builtin/bing/bing.yaml b/api/core/tools/provider/builtin/bing/bing.yaml index 9df836929c..35cd729208 100644 --- a/api/core/tools/provider/builtin/bing/bing.yaml +++ b/api/core/tools/provider/builtin/bing/bing.yaml @@ -43,3 +43,63 @@ credentials_for_provider: zh_Hans: 例如 "https://api.bing.microsoft.com/v7.0/search" pt_BR: An endpoint is like "https://api.bing.microsoft.com/v7.0/search" default: https://api.bing.microsoft.com/v7.0/search + allow_entities: + type: boolean + required: false + label: + en_US: Allow Entities Search + zh_Hans: 支持实体搜索 + pt_BR: Allow Entities Search + help: + en_US: Does your subscription plan allow entity search + zh_Hans: 您的订阅计划是否支持实体搜索 + pt_BR: Does your subscription plan allow entity search + default: true + allow_web_pages: + type: boolean + required: false + label: + en_US: Allow Web Pages Search + zh_Hans: 支持网页搜索 + pt_BR: Allow Web Pages Search + help: + en_US: Does your subscription plan allow web pages search + zh_Hans: 您的订阅计划是否支持网页搜索 + pt_BR: Does your subscription plan allow web pages search + default: true + allow_computation: + type: boolean + required: false + label: + en_US: Allow Computation Search + zh_Hans: 支持计算搜索 + pt_BR: Allow Computation Search + help: + en_US: Does your subscription plan allow computation search + zh_Hans: 您的订阅计划是否支持计算搜索 + pt_BR: Does your subscription plan allow computation search + default: false + allow_news: + type: boolean + required: false + label: + en_US: Allow News Search + zh_Hans: 支持新闻搜索 + pt_BR: Allow News Search + help: + en_US: Does your subscription plan allow news search + zh_Hans: 您的订阅计划是否支持新闻搜索 + pt_BR: Does your subscription plan allow news search + default: false + allow_related_searches: + type: boolean + required: false + label: + en_US: Allow Related Searches + zh_Hans: 支持相关搜索 + pt_BR: Allow Related Searches + help: + en_US: Does your subscription plan allow related searches + zh_Hans: 您的订阅计划是否支持相关搜索 + pt_BR: Does your subscription plan allow related searches + default: false diff --git a/api/core/tools/provider/builtin/bing/tools/bing_web_search.py b/api/core/tools/provider/builtin/bing/tools/bing_web_search.py index 7b740293dd..8f11d2173c 100644 --- a/api/core/tools/provider/builtin/bing/tools/bing_web_search.py +++ b/api/core/tools/provider/builtin/bing/tools/bing_web_search.py @@ -10,53 +10,23 @@ from core.tools.tool.builtin_tool import BuiltinTool class BingSearchTool(BuiltinTool): url = 'https://api.bing.microsoft.com/v7.0/search' - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke_bing(self, + user_id: str, + subscription_key: str, query: str, limit: int, + result_type: str, market: str, lang: str, + filters: list[str]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke bing search """ - - key = self.runtime.credentials.get('subscription_key', None) - if not key: - raise Exception('subscription_key is required') - - server_url = self.runtime.credentials.get('server_url', None) - if not server_url: - server_url = self.url - - query = tool_parameters.get('query', None) - if not query: - raise Exception('query is required') - - limit = min(tool_parameters.get('limit', 5), 10) - result_type = tool_parameters.get('result_type', 'text') or 'text' - - market = tool_parameters.get('market', 'US') - lang = tool_parameters.get('language', 'en') - filter = [] - - if tool_parameters.get('enable_computation', False): - filter.append('Computation') - if tool_parameters.get('enable_entities', False): - filter.append('Entities') - if tool_parameters.get('enable_news', False): - filter.append('News') - if tool_parameters.get('enable_related_search', False): - filter.append('RelatedSearches') - if tool_parameters.get('enable_webpages', False): - filter.append('WebPages') - market_code = f'{lang}-{market}' accept_language = f'{lang},{market_code};q=0.9' headers = { - 'Ocp-Apim-Subscription-Key': key, + 'Ocp-Apim-Subscription-Key': subscription_key, 'Accept-Language': accept_language } query = quote(query) - server_url = f'{server_url}?q={query}&mkt={market_code}&count={limit}&responseFilter={",".join(filter)}' + server_url = f'{self.url}?q={query}&mkt={market_code}&count={limit}&responseFilter={",".join(filters)}' response = get(server_url, headers=headers) if response.status_code != 200: @@ -124,3 +94,105 @@ class BingSearchTool(BuiltinTool): text += f'{related["displayText"]} - {related["webSearchUrl"]}\n' return self.create_text_message(text=self.summary(user_id=user_id, content=text)) + + + def validate_credentials(self, credentials: dict[str, Any], tool_parameters: dict[str, Any]) -> None: + key = credentials.get('subscription_key', None) + if not key: + raise Exception('subscription_key is required') + + server_url = credentials.get('server_url', None) + if not server_url: + server_url = self.url + + query = tool_parameters.get('query', None) + if not query: + raise Exception('query is required') + + limit = min(tool_parameters.get('limit', 5), 10) + result_type = tool_parameters.get('result_type', 'text') or 'text' + + market = tool_parameters.get('market', 'US') + lang = tool_parameters.get('language', 'en') + filter = [] + + if credentials.get('allow_entities', False): + filter.append('Entities') + + if credentials.get('allow_computation', False): + filter.append('Computation') + + if credentials.get('allow_news', False): + filter.append('News') + + if credentials.get('allow_related_searches', False): + filter.append('RelatedSearches') + + if credentials.get('allow_web_pages', False): + filter.append('WebPages') + + if not filter: + raise Exception('At least one filter is required') + + self._invoke_bing( + user_id='test', + subscription_key=key, + query=query, + limit=limit, + result_type=result_type, + market=market, + lang=lang, + filters=filter + ) + + def _invoke(self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + + key = self.runtime.credentials.get('subscription_key', None) + if not key: + raise Exception('subscription_key is required') + + server_url = self.runtime.credentials.get('server_url', None) + if not server_url: + server_url = self.url + + query = tool_parameters.get('query', None) + if not query: + raise Exception('query is required') + + limit = min(tool_parameters.get('limit', 5), 10) + result_type = tool_parameters.get('result_type', 'text') or 'text' + + market = tool_parameters.get('market', 'US') + lang = tool_parameters.get('language', 'en') + filter = [] + + if tool_parameters.get('enable_computation', False): + filter.append('Computation') + if tool_parameters.get('enable_entities', False): + filter.append('Entities') + if tool_parameters.get('enable_news', False): + filter.append('News') + if tool_parameters.get('enable_related_search', False): + filter.append('RelatedSearches') + if tool_parameters.get('enable_webpages', False): + filter.append('WebPages') + + if not filter: + raise Exception('At least one filter is required') + + return self._invoke_bing( + user_id=user_id, + subscription_key=key, + query=query, + limit=limit, + result_type=result_type, + market=market, + lang=lang, + filters=filter + ) \ No newline at end of file diff --git a/api/core/tools/provider/builtin_tool_provider.py b/api/core/tools/provider/builtin_tool_provider.py index 93e7d5a39e..824f91c822 100644 --- a/api/core/tools/provider/builtin_tool_provider.py +++ b/api/core/tools/provider/builtin_tool_provider.py @@ -246,8 +246,27 @@ class BuiltinToolProviderController(ToolProviderController): if credentials[credential_name] not in [x.value for x in options]: raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be one of {options}') - - if credentials[credential_name]: + elif credential_schema.type == ToolProviderCredentials.CredentialsType.BOOLEAN: + if isinstance(credentials[credential_name], bool): + pass + elif isinstance(credentials[credential_name], str): + if credentials[credential_name].lower() == 'true': + credentials[credential_name] = True + elif credentials[credential_name].lower() == 'false': + credentials[credential_name] = False + else: + raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be boolean') + elif isinstance(credentials[credential_name], int): + if credentials[credential_name] == 1: + credentials[credential_name] = True + elif credentials[credential_name] == 0: + credentials[credential_name] = False + else: + raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be boolean') + else: + raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be boolean') + + if credentials[credential_name] or credentials[credential_name] == False: credentials_need_to_validate.pop(credential_name) for credential_name in credentials_need_to_validate: diff --git a/api/services/tools_manage_service.py b/api/services/tools_manage_service.py index ff618e5d2b..70c6a44459 100644 --- a/api/services/tools_manage_service.py +++ b/api/services/tools_manage_service.py @@ -138,9 +138,9 @@ class ToolManageService: :return: the list of tool providers """ provider = ToolManager.get_builtin_provider(provider_name) - return [ - v.to_dict() for _, v in (provider.credentials_schema or {}).items() - ] + return json.loads(serialize_base_model_array([ + v for _, v in (provider.credentials_schema or {}).items() + ])) @staticmethod def parser_api_schema(schema: str) -> list[ApiBasedToolBundle]: diff --git a/web/app/components/tools/setting/build-in/config-credentials.tsx b/web/app/components/tools/setting/build-in/config-credentials.tsx index d5365001c8..1a3c8f015a 100644 --- a/web/app/components/tools/setting/build-in/config-credentials.tsx +++ b/web/app/components/tools/setting/build-in/config-credentials.tsx @@ -3,7 +3,7 @@ import type { FC } from 'react' import React, { useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import cn from 'classnames' -import { toolCredentialToFormSchemas } from '../../utils/to-form-schema' +import { addDefaultValue, toolCredentialToFormSchemas } from '../../utils/to-form-schema' import type { Collection } from '../../types' import Drawer from '@/app/components/base/drawer-plus' import Button from '@/app/components/base/button' @@ -28,12 +28,15 @@ const ConfigCredential: FC = ({ const { t } = useTranslation() const [credentialSchema, setCredentialSchema] = useState(null) const { team_credentials: credentialValue, name: collectionName } = collection + const [tempCredential, setTempCredential] = React.useState(credentialValue) useEffect(() => { fetchBuiltInToolCredentialSchema(collectionName).then((res) => { - setCredentialSchema(toolCredentialToFormSchemas(res)) + const toolCredentialSchemas = toolCredentialToFormSchemas(res) + const defaultCredentials = addDefaultValue(credentialValue, toolCredentialSchemas) + setCredentialSchema(toolCredentialSchemas) + setTempCredential(defaultCredentials) }) }, []) - const [tempCredential, setTempCredential] = React.useState(credentialValue) return (