feat(vannaai): add base_url configuration (#10294)

This commit is contained in:
Benjamin 2024-11-05 20:58:49 +08:00 committed by GitHub
parent 1279e27825
commit d7b4d0756e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 31 additions and 2 deletions

View File

@ -35,7 +35,8 @@ class VannaTool(BuiltinTool):
password = tool_parameters.get("password", "") password = tool_parameters.get("password", "")
port = tool_parameters.get("port", 0) port = tool_parameters.get("port", 0)
vn = VannaDefault(model=model, api_key=api_key) base_url = self.runtime.credentials.get("base_url", None)
vn = VannaDefault(model=model, api_key=api_key, config={"endpoint": base_url})
db_type = tool_parameters.get("db_type", "") db_type = tool_parameters.get("db_type", "")
if db_type in {"Postgres", "MySQL", "Hive", "ClickHouse"}: if db_type in {"Postgres", "MySQL", "Hive", "ClickHouse"}:

View File

@ -1,4 +1,6 @@
import re
from typing import Any from typing import Any
from urllib.parse import urlparse
from core.tools.errors import ToolProviderCredentialValidationError from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.vanna.tools.vanna import VannaTool from core.tools.provider.builtin.vanna.tools.vanna import VannaTool
@ -6,7 +8,26 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl
class VannaProvider(BuiltinToolProviderController): class VannaProvider(BuiltinToolProviderController):
def _get_protocol_and_main_domain(self, url):
parsed_url = urlparse(url)
protocol = parsed_url.scheme
hostname = parsed_url.hostname
port = f":{parsed_url.port}" if parsed_url.port else ""
# Check if the hostname is an IP address
is_ip = re.match(r"^\d{1,3}(\.\d{1,3}){3}$", hostname) is not None
# Return the full hostname (with port if present) for IP addresses, otherwise return the main domain
main_domain = f"{hostname}{port}" if is_ip else ".".join(hostname.split(".")[-2:]) + port
return f"{protocol}://{main_domain}"
def _validate_credentials(self, credentials: dict[str, Any]) -> None: def _validate_credentials(self, credentials: dict[str, Any]) -> None:
base_url = credentials.get("base_url")
if not base_url:
base_url = "https://ask.vanna.ai/rpc"
else:
base_url = base_url.removesuffix("/")
credentials["base_url"] = base_url
try: try:
VannaTool().fork_tool_runtime( VannaTool().fork_tool_runtime(
runtime={ runtime={
@ -17,7 +38,7 @@ class VannaProvider(BuiltinToolProviderController):
tool_parameters={ tool_parameters={
"model": "chinook", "model": "chinook",
"db_type": "SQLite", "db_type": "SQLite",
"url": "https://vanna.ai/Chinook.sqlite", "url": f'{self._get_protocol_and_main_domain(credentials["base_url"])}/Chinook.sqlite',
"query": "What are the top 10 customers by sales?", "query": "What are the top 10 customers by sales?",
}, },
) )

View File

@ -26,3 +26,10 @@ credentials_for_provider:
en_US: Get your API key from Vanna.AI en_US: Get your API key from Vanna.AI
zh_Hans: 从 Vanna.AI 获取你的 API key zh_Hans: 从 Vanna.AI 获取你的 API key
url: https://vanna.ai/account/profile url: https://vanna.ai/account/profile
base_url:
type: text-input
required: false
label:
en_US: Vanna.AI Endpoint Base URL
placeholder:
en_US: https://ask.vanna.ai/rpc