feat: Add base URL settings and secure_ascii options to the Brave search tool (#8463)

Co-authored-by: crazywoola <427733928@qq.com>
This commit is contained in:
Xiao Ley 2024-09-15 17:38:43 +08:00 committed by GitHub
parent 3d083b758f
commit 6dba68f62d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 40 additions and 5 deletions

View File

@ -29,3 +29,11 @@ credentials_for_provider:
zh_Hans: 从 Brave 获取您的 Brave Search API key zh_Hans: 从 Brave 获取您的 Brave Search API key
pt_BR: Get your Brave Search API key from Brave pt_BR: Get your Brave Search API key from Brave
url: https://brave.com/search/api/ url: https://brave.com/search/api/
base_url:
type: text-input
required: false
label:
en_US: Brave server's Base URL
zh_Hans: Brave服务器的API URL
placeholder:
en_US: https://api.search.brave.com/res/v1/web/search

View File

@ -7,6 +7,8 @@ from pydantic import BaseModel, Field
from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.builtin_tool import BuiltinTool
BRAVE_BASE_URL = "https://api.search.brave.com/res/v1/web/search"
class BraveSearchWrapper(BaseModel): class BraveSearchWrapper(BaseModel):
"""Wrapper around the Brave search engine.""" """Wrapper around the Brave search engine."""
@ -15,8 +17,10 @@ class BraveSearchWrapper(BaseModel):
"""The API key to use for the Brave search engine.""" """The API key to use for the Brave search engine."""
search_kwargs: dict = Field(default_factory=dict) search_kwargs: dict = Field(default_factory=dict)
"""Additional keyword arguments to pass to the search request.""" """Additional keyword arguments to pass to the search request."""
base_url: str = "https://api.search.brave.com/res/v1/web/search" base_url: str = BRAVE_BASE_URL
"""The base URL for the Brave search engine.""" """The base URL for the Brave search engine."""
ensure_ascii: bool = True
"""Ensure the JSON output is ASCII encoded."""
def run(self, query: str) -> str: def run(self, query: str) -> str:
"""Query the Brave search engine and return the results as a JSON string. """Query the Brave search engine and return the results as a JSON string.
@ -36,7 +40,7 @@ class BraveSearchWrapper(BaseModel):
} }
for item in web_search_results for item in web_search_results
] ]
return json.dumps(final_results) return json.dumps(final_results, ensure_ascii=self.ensure_ascii)
def _search_request(self, query: str) -> list[dict]: def _search_request(self, query: str) -> list[dict]:
headers = { headers = {
@ -68,7 +72,9 @@ class BraveSearch(BaseModel):
search_wrapper: BraveSearchWrapper search_wrapper: BraveSearchWrapper
@classmethod @classmethod
def from_api_key(cls, api_key: str, search_kwargs: Optional[dict] = None, **kwargs: Any) -> "BraveSearch": def from_api_key(
cls, api_key: str, base_url: str, search_kwargs: Optional[dict] = None, ensure_ascii: bool = True, **kwargs: Any
) -> "BraveSearch":
"""Create a tool from an api key. """Create a tool from an api key.
Args: Args:
@ -79,7 +85,9 @@ class BraveSearch(BaseModel):
Returns: Returns:
A tool. A tool.
""" """
wrapper = BraveSearchWrapper(api_key=api_key, search_kwargs=search_kwargs or {}) wrapper = BraveSearchWrapper(
api_key=api_key, base_url=base_url, search_kwargs=search_kwargs or {}, ensure_ascii=ensure_ascii
)
return cls(search_wrapper=wrapper, **kwargs) return cls(search_wrapper=wrapper, **kwargs)
def _run( def _run(
@ -109,11 +117,18 @@ class BraveSearchTool(BuiltinTool):
query = tool_parameters.get("query", "") query = tool_parameters.get("query", "")
count = tool_parameters.get("count", 3) count = tool_parameters.get("count", 3)
api_key = self.runtime.credentials["brave_search_api_key"] api_key = self.runtime.credentials["brave_search_api_key"]
base_url = self.runtime.credentials.get("base_url", BRAVE_BASE_URL)
ensure_ascii = tool_parameters.get("ensure_ascii", True)
if len(base_url) == 0:
base_url = BRAVE_BASE_URL
if not query: if not query:
return self.create_text_message("Please input query") return self.create_text_message("Please input query")
tool = BraveSearch.from_api_key(api_key=api_key, search_kwargs={"count": count}) tool = BraveSearch.from_api_key(
api_key=api_key, base_url=base_url, search_kwargs={"count": count}, ensure_ascii=ensure_ascii
)
results = tool._run(query) results = tool._run(query)

View File

@ -39,3 +39,15 @@ parameters:
pt_BR: O número de resultados de pesquisa a serem retornados, permitindo que os usuários controlem a amplitude de sua saída de pesquisa. pt_BR: O número de resultados de pesquisa a serem retornados, permitindo que os usuários controlem a amplitude de sua saída de pesquisa.
llm_description: Specifies the amount of search results to be displayed, offering users the ability to adjust the scope of their search findings. llm_description: Specifies the amount of search results to be displayed, offering users the ability to adjust the scope of their search findings.
form: llm form: llm
- name: ensure_ascii
type: boolean
default: true
label:
en_US: Ensure ASCII
zh_Hans: 确保 ASCII
pt_BR: Ensure ASCII
human_description:
en_US: Ensure the JSON output is ASCII encoded
zh_Hans: 确保输出的 JSON 是 ASCII 编码
pt_BR: Ensure the JSON output is ASCII encoded
form: form