diff --git a/api/core/tools/provider/builtin/vanna/_assets/icon.png b/api/core/tools/provider/builtin/vanna/_assets/icon.png new file mode 100644 index 0000000000..3a9011b54d Binary files /dev/null and b/api/core/tools/provider/builtin/vanna/_assets/icon.png differ diff --git a/api/core/tools/provider/builtin/vanna/tools/vanna.py b/api/core/tools/provider/builtin/vanna/tools/vanna.py new file mode 100644 index 0000000000..bbc21cc107 --- /dev/null +++ b/api/core/tools/provider/builtin/vanna/tools/vanna.py @@ -0,0 +1,119 @@ +from typing import Any, Union + +from vanna.remote import VannaDefault + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.tool.builtin_tool import BuiltinTool + + +class VannaTool(BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + api_key = self.runtime.credentials.get("api_key", None) + if not api_key: + raise ToolProviderCredentialValidationError("Please input api key") + + model = tool_parameters.get("model", "") + if not model: + return self.create_text_message("Please input RAG model") + + prompt = tool_parameters.get("prompt", "") + if not prompt: + return self.create_text_message("Please input prompt") + + url = tool_parameters.get("url", "") + if not url: + return self.create_text_message("Please input URL/Host/DSN") + + db_name = tool_parameters.get("db_name", "") + username = tool_parameters.get("username", "") + password = tool_parameters.get("password", "") + port = tool_parameters.get("port", 0) + + vn = VannaDefault(model=model, api_key=api_key) + + db_type = tool_parameters.get("db_type", "") + if db_type in ["Postgres", "MySQL", "Hive", "ClickHouse"]: + if not db_name: + return self.create_text_message("Please input database name") + if not username: + return self.create_text_message("Please input username") + if port < 1: + return self.create_text_message("Please input port") + + schema_sql = "SELECT * FROM INFORMATION_SCHEMA.COLUMNS" + match db_type: + case "SQLite": + schema_sql = "SELECT type, sql FROM sqlite_master WHERE sql is not null" + vn.connect_to_sqlite(url) + case "Postgres": + vn.connect_to_postgres(host=url, dbname=db_name, user=username, password=password, port=port) + case "DuckDB": + vn.connect_to_duckdb(url=url) + case "SQLServer": + vn.connect_to_mssql(url) + case "MySQL": + vn.connect_to_mysql(host=url, dbname=db_name, user=username, password=password, port=port) + case "Oracle": + vn.connect_to_oracle(user=username, password=password, dsn=url) + case "Hive": + vn.connect_to_hive(host=url, dbname=db_name, user=username, password=password, port=port) + case "ClickHouse": + vn.connect_to_clickhouse(host=url, dbname=db_name, user=username, password=password, port=port) + + enable_training = tool_parameters.get("enable_training", False) + reset_training_data = tool_parameters.get("reset_training_data", False) + if enable_training: + if reset_training_data: + existing_training_data = vn.get_training_data() + if len(existing_training_data) > 0: + for _, training_data in existing_training_data.iterrows(): + vn.remove_training_data(training_data["id"]) + + ddl = tool_parameters.get("ddl", "") + question = tool_parameters.get("question", "") + sql = tool_parameters.get("sql", "") + memos = tool_parameters.get("memos", "") + training_metadata = tool_parameters.get("training_metadata", False) + + if training_metadata: + if db_type == "SQLite": + df_ddl = vn.run_sql(schema_sql) + for ddl in df_ddl["sql"].to_list(): + vn.train(ddl=ddl) + else: + df_information_schema = vn.run_sql(schema_sql) + plan = vn.get_training_plan_generic(df_information_schema) + vn.train(plan=plan) + + if ddl: + vn.train(ddl=ddl) + + if sql: + if question: + vn.train(question=question, sql=sql) + else: + vn.train(sql=sql) + if memos: + vn.train(documentation=memos) + + generate_chart = tool_parameters.get("generate_chart", True) + res = vn.ask(prompt, False, True, generate_chart) + + result = [] + + if res is not None: + result.append(self.create_text_message(res[0])) + if len(res) > 1 and res[1] is not None: + result.append(self.create_text_message(res[1].to_markdown())) + if len(res) > 2 and res[2] is not None: + result.append( + self.create_blob_message(blob=res[2].to_image(format="svg"), meta={"mime_type": "image/svg+xml"}) + ) + + return result diff --git a/api/core/tools/provider/builtin/vanna/tools/vanna.yaml b/api/core/tools/provider/builtin/vanna/tools/vanna.yaml new file mode 100644 index 0000000000..ae2eae94c4 --- /dev/null +++ b/api/core/tools/provider/builtin/vanna/tools/vanna.yaml @@ -0,0 +1,213 @@ +identity: + name: vanna + author: QCTC + label: + en_US: Vanna.AI + zh_Hans: Vanna.AI +description: + human: + en_US: The fastest way to get actionable insights from your database just by asking questions. + zh_Hans: 一个基于大模型和RAG的Text2SQL工具。 + llm: A tool for converting text to SQL. +parameters: + - name: prompt + type: string + required: true + label: + en_US: Prompt + zh_Hans: 提示词 + pt_BR: Prompt + human_description: + en_US: used for generating SQL + zh_Hans: 用于生成SQL + llm_description: key words for generating SQL + form: llm + - name: model + type: string + required: true + label: + en_US: RAG Model + zh_Hans: RAG模型 + human_description: + en_US: RAG Model for your database DDL + zh_Hans: 存储数据库训练数据的RAG模型 + llm_description: RAG Model for generating SQL + form: form + - name: db_type + type: select + required: true + options: + - value: SQLite + label: + en_US: SQLite + zh_Hans: SQLite + - value: Postgres + label: + en_US: Postgres + zh_Hans: Postgres + - value: DuckDB + label: + en_US: DuckDB + zh_Hans: DuckDB + - value: SQLServer + label: + en_US: Microsoft SQL Server + zh_Hans: 微软 SQL Server + - value: MySQL + label: + en_US: MySQL + zh_Hans: MySQL + - value: Oracle + label: + en_US: Oracle + zh_Hans: Oracle + - value: Hive + label: + en_US: Hive + zh_Hans: Hive + - value: ClickHouse + label: + en_US: ClickHouse + zh_Hans: ClickHouse + default: SQLite + label: + en_US: DB Type + zh_Hans: 数据库类型 + human_description: + en_US: Database type. + zh_Hans: 选择要链接的数据库类型。 + form: form + - name: url + type: string + required: true + label: + en_US: URL/Host/DSN + zh_Hans: URL/Host/DSN + human_description: + en_US: Please input depending on DB type, visit https://vanna.ai/docs/ for more specification + zh_Hans: 请根据数据库类型,填入对应值,详情参考https://vanna.ai/docs/ + form: form + - name: db_name + type: string + required: false + label: + en_US: DB name + zh_Hans: 数据库名 + human_description: + en_US: Database name + zh_Hans: 数据库名 + form: form + - name: username + type: string + required: false + label: + en_US: Username + zh_Hans: 用户名 + human_description: + en_US: Username + zh_Hans: 用户名 + form: form + - name: password + type: secret-input + required: false + label: + en_US: Password + zh_Hans: 密码 + human_description: + en_US: Password + zh_Hans: 密码 + form: form + - name: port + type: number + required: false + label: + en_US: Port + zh_Hans: 端口 + human_description: + en_US: Port + zh_Hans: 端口 + form: form + - name: ddl + type: string + required: false + label: + en_US: Training DDL + zh_Hans: 训练DDL + human_description: + en_US: DDL statements for training data + zh_Hans: 用于训练RAG Model的建表语句 + form: form + - name: question + type: string + required: false + label: + en_US: Training Question + zh_Hans: 训练问题 + human_description: + en_US: Question-SQL Pairs + zh_Hans: Question-SQL中的问题 + form: form + - name: sql + type: string + required: false + label: + en_US: Training SQL + zh_Hans: 训练SQL + human_description: + en_US: SQL queries to your training data + zh_Hans: 用于训练RAG Model的SQL语句 + form: form + - name: memos + type: string + required: false + label: + en_US: Training Memos + zh_Hans: 训练说明 + human_description: + en_US: Sometimes you may want to add documentation about your business terminology or definitions + zh_Hans: 添加更多关于数据库的业务说明 + form: form + - name: enable_training + type: boolean + required: false + default: false + label: + en_US: Training Data + zh_Hans: 训练数据 + human_description: + en_US: You only need to train once. Do not train again unless you want to add more training data + zh_Hans: 训练数据无更新时,训练一次即可 + form: form + - name: reset_training_data + type: boolean + required: false + default: false + label: + en_US: Reset Training Data + zh_Hans: 重置训练数据 + human_description: + en_US: Remove all training data in the current RAG Model + zh_Hans: 删除当前RAG Model中的所有训练数据 + form: form + - name: training_metadata + type: boolean + required: false + default: false + label: + en_US: Training Metadata + zh_Hans: 训练元数据 + human_description: + en_US: If enabled, it will attempt to train on the metadata of that database + zh_Hans: 是否自动从数据库获取元数据来训练 + form: form + - name: generate_chart + type: boolean + required: false + default: True + label: + en_US: Generate Charts + zh_Hans: 生成图表 + human_description: + en_US: Generate Charts + zh_Hans: 是否生成图表 + form: form diff --git a/api/core/tools/provider/builtin/vanna/vanna.py b/api/core/tools/provider/builtin/vanna/vanna.py new file mode 100644 index 0000000000..ab1fd71df5 --- /dev/null +++ b/api/core/tools/provider/builtin/vanna/vanna.py @@ -0,0 +1,25 @@ +from typing import Any + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin.vanna.tools.vanna import VannaTool +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class VannaProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + try: + VannaTool().fork_tool_runtime( + runtime={ + "credentials": credentials, + } + ).invoke( + user_id='', + tool_parameters={ + "model": "chinook", + "db_type": "SQLite", + "url": "https://vanna.ai/Chinook.sqlite", + "query": "What are the top 10 customers by sales?" + }, + ) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/vanna/vanna.yaml b/api/core/tools/provider/builtin/vanna/vanna.yaml new file mode 100644 index 0000000000..b29fa103e1 --- /dev/null +++ b/api/core/tools/provider/builtin/vanna/vanna.yaml @@ -0,0 +1,25 @@ +identity: + author: QCTC + name: vanna + label: + en_US: Vanna.AI + zh_Hans: Vanna.AI + description: + en_US: The fastest way to get actionable insights from your database just by asking questions. + zh_Hans: 一个基于大模型和RAG的Text2SQL工具。 + icon: icon.png +credentials_for_provider: + api_key: + type: secret-input + required: true + label: + en_US: API key + zh_Hans: API key + placeholder: + en_US: Please input your API key + zh_Hans: 请输入你的 API key + pt_BR: Please input your API key + help: + en_US: Get your API key from Vanna.AI + zh_Hans: 从 Vanna.AI 获取你的 API key + url: https://vanna.ai/account/profile diff --git a/api/requirements.txt b/api/requirements.txt index 6d6edf1071..cde827b987 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -82,3 +82,4 @@ firecrawl-py==0.0.5 oss2==2.18.5 pgvector==0.2.5 google-cloud-aiplatform==1.49.0 +vanna[postgres,mysql,clickhouse,duckdb]==0.5.5