mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-11 21:09:05 +08:00
feat(vannaai): add base_url configuration (#10294)
This commit is contained in:
parent
1279e27825
commit
d7b4d0756e
@ -35,7 +35,8 @@ class VannaTool(BuiltinTool):
|
|||||||
password = tool_parameters.get("password", "")
|
password = tool_parameters.get("password", "")
|
||||||
port = tool_parameters.get("port", 0)
|
port = tool_parameters.get("port", 0)
|
||||||
|
|
||||||
vn = VannaDefault(model=model, api_key=api_key)
|
base_url = self.runtime.credentials.get("base_url", None)
|
||||||
|
vn = VannaDefault(model=model, api_key=api_key, config={"endpoint": base_url})
|
||||||
|
|
||||||
db_type = tool_parameters.get("db_type", "")
|
db_type = tool_parameters.get("db_type", "")
|
||||||
if db_type in {"Postgres", "MySQL", "Hive", "ClickHouse"}:
|
if db_type in {"Postgres", "MySQL", "Hive", "ClickHouse"}:
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from core.tools.errors import ToolProviderCredentialValidationError
|
from core.tools.errors import ToolProviderCredentialValidationError
|
||||||
from core.tools.provider.builtin.vanna.tools.vanna import VannaTool
|
from core.tools.provider.builtin.vanna.tools.vanna import VannaTool
|
||||||
@ -6,7 +8,26 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl
|
|||||||
|
|
||||||
|
|
||||||
class VannaProvider(BuiltinToolProviderController):
|
class VannaProvider(BuiltinToolProviderController):
|
||||||
|
def _get_protocol_and_main_domain(self, url):
|
||||||
|
parsed_url = urlparse(url)
|
||||||
|
protocol = parsed_url.scheme
|
||||||
|
hostname = parsed_url.hostname
|
||||||
|
port = f":{parsed_url.port}" if parsed_url.port else ""
|
||||||
|
|
||||||
|
# Check if the hostname is an IP address
|
||||||
|
is_ip = re.match(r"^\d{1,3}(\.\d{1,3}){3}$", hostname) is not None
|
||||||
|
|
||||||
|
# Return the full hostname (with port if present) for IP addresses, otherwise return the main domain
|
||||||
|
main_domain = f"{hostname}{port}" if is_ip else ".".join(hostname.split(".")[-2:]) + port
|
||||||
|
return f"{protocol}://{main_domain}"
|
||||||
|
|
||||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||||
|
base_url = credentials.get("base_url")
|
||||||
|
if not base_url:
|
||||||
|
base_url = "https://ask.vanna.ai/rpc"
|
||||||
|
else:
|
||||||
|
base_url = base_url.removesuffix("/")
|
||||||
|
credentials["base_url"] = base_url
|
||||||
try:
|
try:
|
||||||
VannaTool().fork_tool_runtime(
|
VannaTool().fork_tool_runtime(
|
||||||
runtime={
|
runtime={
|
||||||
@ -17,7 +38,7 @@ class VannaProvider(BuiltinToolProviderController):
|
|||||||
tool_parameters={
|
tool_parameters={
|
||||||
"model": "chinook",
|
"model": "chinook",
|
||||||
"db_type": "SQLite",
|
"db_type": "SQLite",
|
||||||
"url": "https://vanna.ai/Chinook.sqlite",
|
"url": f'{self._get_protocol_and_main_domain(credentials["base_url"])}/Chinook.sqlite',
|
||||||
"query": "What are the top 10 customers by sales?",
|
"query": "What are the top 10 customers by sales?",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
@ -26,3 +26,10 @@ credentials_for_provider:
|
|||||||
en_US: Get your API key from Vanna.AI
|
en_US: Get your API key from Vanna.AI
|
||||||
zh_Hans: 从 Vanna.AI 获取你的 API key
|
zh_Hans: 从 Vanna.AI 获取你的 API key
|
||||||
url: https://vanna.ai/account/profile
|
url: https://vanna.ai/account/profile
|
||||||
|
base_url:
|
||||||
|
type: text-input
|
||||||
|
required: false
|
||||||
|
label:
|
||||||
|
en_US: Vanna.AI Endpoint Base URL
|
||||||
|
placeholder:
|
||||||
|
en_US: https://ask.vanna.ai/rpc
|
||||||
|
Loading…
x
Reference in New Issue
Block a user