From 2d9f55b632bd65ca8d182a360ceea561f2edef81 Mon Sep 17 00:00:00 2001 From: Henry Lu Date: Tue, 4 Jun 2024 14:05:29 +0800 Subject: [PATCH] feat: Add Vanna.AI as a builtin tool (#4878) Co-authored-by: Yeuoly --- .../provider/builtin/vanna/_assets/icon.png | Bin 0 -> 4612 bytes .../provider/builtin/vanna/tools/vanna.py | 119 ++++++++++ .../provider/builtin/vanna/tools/vanna.yaml | 213 ++++++++++++++++++ .../tools/provider/builtin/vanna/vanna.py | 25 ++ .../tools/provider/builtin/vanna/vanna.yaml | 25 ++ api/requirements.txt | 1 + 6 files changed, 383 insertions(+) create mode 100644 api/core/tools/provider/builtin/vanna/_assets/icon.png create mode 100644 api/core/tools/provider/builtin/vanna/tools/vanna.py create mode 100644 api/core/tools/provider/builtin/vanna/tools/vanna.yaml create mode 100644 api/core/tools/provider/builtin/vanna/vanna.py create mode 100644 api/core/tools/provider/builtin/vanna/vanna.yaml diff --git a/api/core/tools/provider/builtin/vanna/_assets/icon.png b/api/core/tools/provider/builtin/vanna/_assets/icon.png new file mode 100644 index 0000000000000000000000000000000000000000..3a9011b54d8a07f01e6b2fb934f3937bca0fd85a GIT binary patch literal 4612 zcmb7Ic{r49`+mm04=H<05wZ=Vk}<}<8)VBikzID9#%>VVDJ`-WLPcbcWX+PSk)7-m z#xVBfqu$^9e!t`RzCXVEKJMo_pX<8L^SbW)e*Su*^mNqdX*g&A0HD`USJpqVR;P=K z;^aACqpptv06%^Jz}*Lcy%QI15diLr1HiHs0Kk(0 zfYl@I!wtkq;J%%ShCK=e2%TUmfE;uNAU#2#lL>%0fK#;-1ZaXd|Hk?t!M}1y0N{ZW zK=xOT*@-L=@z^T(2Af&NjXC%7K;_b zN{YFAIY1?(rKO?bFenTrdXgcE@pbdQ?IY@j;r%1>e>%!`7#lAq4{s-TH^`~(ZEN?t z-m(`jo)-Fh{OQx%$^M^8ZkWHeb+SR|X#^@ECJy~iG&>)s|3W*B{Gt8c*Pn8TQ!%)% zmy_K|!>4T_B@n+0{6F76?IBK8;CfCzb}lB$POf%tnBN?6*cAlyKZgDz)%3qp|1Iz5f7I(Q>0 z9z#a%GgjtDak%Kka6#m8->AW@@hhfsTURS5tFfQiWw!jVvNaP~)lsrF8&%t5YiS$l z7#^F?wbK`-{^yw}NfKIGsMx)`#k`_gQXad{57eQNF`7fmfMu~wVEH~GTa6Rn+PpB# zTY$+Gsk!g2YUfq5-cpldXUW0_M#F-A`q>SMuxGUj7?mz~o(g_pXx^LGys`hS$_~09 zm_pczQS_+d($^))IFvKng6^IQ3^nat3cp!ff1doHNl~!jjGjsgB{7fKdWN;691)(( zn5=L>+=E!)CktcKAYr_FgSc-gxxVR)#B0hZf7V8T|RvC>M0i1l*+OL|_wIABIGkq2g z)m_wWh8lQ(qlSCgYPk^`3;hz?HAYOyzO4EF5-No6P)ZhwWd7t})~1Rh?=S%^C7y{7 z!W1b@#H!Ota?{DK2K!BU4Vq-yuHX}4qVz7=IqA^5a&^%UKy?z@hd7yRDDD$m3mTlC zsBflO-=I&z86|1VH1aEWZ}Eu>6A4Dkf7nj#+DhlX&zC#gMI*{n*_1%0bp1k+a|4=O zox*DZXXI@Yhnti2gzWv3c-~qWe6v+N5y~tx#s6JVr>m{!NodR`3 z+hK!T@P}LR$)lQY$y8+OuGR4c`h72T_N{tPKFrWY(InEon?wxTH0UyA^aY;X688kh zw~#eN?2*%|tx_Km@${FBNa8$X3@b2u4RYwh2XDAxaTloQN1o$01Xd&c9mdyj;Vg?3 z7V@^91{zoJ*6~c&{W_;f`N1k{bgn-?Xb8P{(#7hT2H;sGdl(1@{a_@voTbQ*af-Px zAD7BpG$R?8#+|asfATfnv}9Ls@XFG-OFI(5l10W6e|T;4;5aCn2@_FXzVx54q9U^034* z-$TnY(fCwEqcj`^fw9Ue^yxg%H9PVj9b8x60d?Hd(^l8^yu2NHAuhEAP0oig#Ep|K z??_y-En7|*ZGv*|RBzrNG{!DQzI>FEu2N9FR_jAg`+zC{?jL*4C9HTZ+ z2TxUdQDbEFD!Cq&4Q+bk9LzLvbX54ZM50NM|BB6CZjW^@gHna$_R1%8jNwL*Bm-Ve zRF0?rnPQ|i;-Mu^UgAf57s}9+%b?d#z1d|rv9^7eR=%DY&&Em`$S!(jT<@??hBHS% z#AS}Wg^arHWqP%{`zPZ!4C4ZA1nPsg1+NUm792V_lI3jn=0fS5F}H*dN@ZxWrZlt? zouvuVrBMZ!*p?2eqbL@iNIXL5HTGUfXm*QR9Xr-qqBhh(X2`FTHq_emzWkOZg|&E4 zoxZQltK+aiU8U(w7bTJRWKbxKZ)QBb*XUyzm&P00<%p$2e)GHU7}dXyeY+;WV|U~D z90K{sD#f&+I_k&AJy36t9#0_0c6!ka*J>-1boEn-@)VBD$nB#K!R-g;IV35N*R_{p z+Q=|!4-!BJ29smH>D&b$Sf4+Y-%q|(A5Z^4}4dl`X%Z=bu{O8R-gJ!tAvG<#H_h>2;4mIl?E==TIqY(fqH+QZs+sdzH7`BWR% z{UD`iJI-r*oyivV0z2t4-<@CJKicj(W+vwD3Ua;OdIitl*Eh?BCbF|Fvs^3*<$g%J zDHftjoD*UgPc^sC_5sV3vQ;kYOTeXN@#nd~%eyyH#ZWJ29#gnZ9OzqVH2JGM)8 z_}A$~Yv-&aDgN_?I{P!TgNBX+WJu9wOP1l*&Q2k2K7nF3B+p$fT?onGmFD}8S>nb` z_B%MSOmgjepMUjV7#UtPnEK=;`!0S=G`4CMJf^$ka`Z7|V8@y%gI)8g7G7qYGV^{1 z<)Z_5Z*10BXa61#rPR9PfDqD=&&u;$GUHcnQWtP>ytXf`XVoUzHILc44EG?{LpEK< zRub^_C z42#ka4L^>Y1l910gTm!6{*l|}(iGt;LfPZO`iI(|Nq^DPAO?ajzsjsEepq)UpDI0# zkG{O?&Ci+3N>_1?szzfQ|)09deVTeh>Q))4cu#)g73x}6T#cFPgLdt^N=xW3SGv? zJBeaxX2vx-Y-*Mc4OIE+kn~i|V`lWT00q(uC_dHVW@(pxr5}L`pPPh9L8b1+qt7%o zCLCNj#T`rEOF39t99dNL>{2PN*N?F0dD5@fv)%qM8BIwBlYde*Z*qA&hKVF2=dL}6 z6)LBIOIQhYKMh3Tgc+jj?HX5=}ex`WMp+ z(vv<2ma~HoiHqe-U%51|@m#lS>Ds$Ex8EvN)TQ0(#s5Zf-&dI1O5uo`r4ZdcJ?<4C zeoe4HfRWq@=jT+6n#-6ven1Xj+nX0OPnYxj|_Sx8f{mb)d~#rbO8RA&W~8GBLNy!UEl$umNr zadSX0%t`AV@TfVSaxx~6O6!2>U4EK4T|S${^F{KEn{bpAQ8-X8i0sDl*uiY9;$O!@9|` z#E5#g9pj8Xu3-ONLELl*@t&-40%y2{$74#8W;JqiAx8d*Rs^_buxt*E>He8p5b=CI zFpr}{4a9fuBRb(+QA4if4%t$J=Bs{MgvoHnox&ux_p;s0G1h}uhP@YK7V1P6e3!@3 zhS*TkHV0PqibVT7p9teVJk0W|mb?2@$xvbA^_PQS@oRb*rLJm7%6VtD_CxqAiDsDI@uZess7goI! zUF0_%C^FVvkI=oRFfc$c1lm>Cff-VprBH4UZQk6R9m-rtk-aBbzF1+)ORhI|s0YiR zvgdt^Ju+$8Bt3PX7Iogg&Rpiy5@s^WDc66;Ip`9FUsCchIQBbCT}@5>9)-5b3( z>tts(5X_tNRR(iEK^6Hz+>UCj!Mpcd(7_6C;wh}IQu_|;-~_)mt3SIiDMcCUhTb-> zO|TGhrm&NOVOj|{aj}%HnmUtHnu0&XX$^b$-)%bTMaba%0=&Fih;wM7Wu*_#D88fi z@h1{0c|Ui$QLukEJZ9^;rHFQvF#VZCOe4>Ib?2*7ft?!xgaP;3F`U0slksz#V3KIH zVY-gDvghRrMz=GPiEB3vO1D>0Y41N7SUFEm+9iDBd^*Q7sw2PC-nZ@`F<=QrIAQ&N z{_3i1{;y>y{px(55b0~ytg&-^Mvn>!3NaV~FFeDWvA?5m7B`*-sX+tyjKCQnva<%ORi&yP@c#*?;yUbly6Q)+=%V literal 0 HcmV?d00001 diff --git a/api/core/tools/provider/builtin/vanna/tools/vanna.py b/api/core/tools/provider/builtin/vanna/tools/vanna.py new file mode 100644 index 0000000000..bbc21cc107 --- /dev/null +++ b/api/core/tools/provider/builtin/vanna/tools/vanna.py @@ -0,0 +1,119 @@ +from typing import Any, Union + +from vanna.remote import VannaDefault + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.tool.builtin_tool import BuiltinTool + + +class VannaTool(BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + api_key = self.runtime.credentials.get("api_key", None) + if not api_key: + raise ToolProviderCredentialValidationError("Please input api key") + + model = tool_parameters.get("model", "") + if not model: + return self.create_text_message("Please input RAG model") + + prompt = tool_parameters.get("prompt", "") + if not prompt: + return self.create_text_message("Please input prompt") + + url = tool_parameters.get("url", "") + if not url: + return self.create_text_message("Please input URL/Host/DSN") + + db_name = tool_parameters.get("db_name", "") + username = tool_parameters.get("username", "") + password = tool_parameters.get("password", "") + port = tool_parameters.get("port", 0) + + vn = VannaDefault(model=model, api_key=api_key) + + db_type = tool_parameters.get("db_type", "") + if db_type in ["Postgres", "MySQL", "Hive", "ClickHouse"]: + if not db_name: + return self.create_text_message("Please input database name") + if not username: + return self.create_text_message("Please input username") + if port < 1: + return self.create_text_message("Please input port") + + schema_sql = "SELECT * FROM INFORMATION_SCHEMA.COLUMNS" + match db_type: + case "SQLite": + schema_sql = "SELECT type, sql FROM sqlite_master WHERE sql is not null" + vn.connect_to_sqlite(url) + case "Postgres": + vn.connect_to_postgres(host=url, dbname=db_name, user=username, password=password, port=port) + case "DuckDB": + vn.connect_to_duckdb(url=url) + case "SQLServer": + vn.connect_to_mssql(url) + case "MySQL": + vn.connect_to_mysql(host=url, dbname=db_name, user=username, password=password, port=port) + case "Oracle": + vn.connect_to_oracle(user=username, password=password, dsn=url) + case "Hive": + vn.connect_to_hive(host=url, dbname=db_name, user=username, password=password, port=port) + case "ClickHouse": + vn.connect_to_clickhouse(host=url, dbname=db_name, user=username, password=password, port=port) + + enable_training = tool_parameters.get("enable_training", False) + reset_training_data = tool_parameters.get("reset_training_data", False) + if enable_training: + if reset_training_data: + existing_training_data = vn.get_training_data() + if len(existing_training_data) > 0: + for _, training_data in existing_training_data.iterrows(): + vn.remove_training_data(training_data["id"]) + + ddl = tool_parameters.get("ddl", "") + question = tool_parameters.get("question", "") + sql = tool_parameters.get("sql", "") + memos = tool_parameters.get("memos", "") + training_metadata = tool_parameters.get("training_metadata", False) + + if training_metadata: + if db_type == "SQLite": + df_ddl = vn.run_sql(schema_sql) + for ddl in df_ddl["sql"].to_list(): + vn.train(ddl=ddl) + else: + df_information_schema = vn.run_sql(schema_sql) + plan = vn.get_training_plan_generic(df_information_schema) + vn.train(plan=plan) + + if ddl: + vn.train(ddl=ddl) + + if sql: + if question: + vn.train(question=question, sql=sql) + else: + vn.train(sql=sql) + if memos: + vn.train(documentation=memos) + + generate_chart = tool_parameters.get("generate_chart", True) + res = vn.ask(prompt, False, True, generate_chart) + + result = [] + + if res is not None: + result.append(self.create_text_message(res[0])) + if len(res) > 1 and res[1] is not None: + result.append(self.create_text_message(res[1].to_markdown())) + if len(res) > 2 and res[2] is not None: + result.append( + self.create_blob_message(blob=res[2].to_image(format="svg"), meta={"mime_type": "image/svg+xml"}) + ) + + return result diff --git a/api/core/tools/provider/builtin/vanna/tools/vanna.yaml b/api/core/tools/provider/builtin/vanna/tools/vanna.yaml new file mode 100644 index 0000000000..ae2eae94c4 --- /dev/null +++ b/api/core/tools/provider/builtin/vanna/tools/vanna.yaml @@ -0,0 +1,213 @@ +identity: + name: vanna + author: QCTC + label: + en_US: Vanna.AI + zh_Hans: Vanna.AI +description: + human: + en_US: The fastest way to get actionable insights from your database just by asking questions. + zh_Hans: 一个基于大模型和RAG的Text2SQL工具。 + llm: A tool for converting text to SQL. +parameters: + - name: prompt + type: string + required: true + label: + en_US: Prompt + zh_Hans: 提示词 + pt_BR: Prompt + human_description: + en_US: used for generating SQL + zh_Hans: 用于生成SQL + llm_description: key words for generating SQL + form: llm + - name: model + type: string + required: true + label: + en_US: RAG Model + zh_Hans: RAG模型 + human_description: + en_US: RAG Model for your database DDL + zh_Hans: 存储数据库训练数据的RAG模型 + llm_description: RAG Model for generating SQL + form: form + - name: db_type + type: select + required: true + options: + - value: SQLite + label: + en_US: SQLite + zh_Hans: SQLite + - value: Postgres + label: + en_US: Postgres + zh_Hans: Postgres + - value: DuckDB + label: + en_US: DuckDB + zh_Hans: DuckDB + - value: SQLServer + label: + en_US: Microsoft SQL Server + zh_Hans: 微软 SQL Server + - value: MySQL + label: + en_US: MySQL + zh_Hans: MySQL + - value: Oracle + label: + en_US: Oracle + zh_Hans: Oracle + - value: Hive + label: + en_US: Hive + zh_Hans: Hive + - value: ClickHouse + label: + en_US: ClickHouse + zh_Hans: ClickHouse + default: SQLite + label: + en_US: DB Type + zh_Hans: 数据库类型 + human_description: + en_US: Database type. + zh_Hans: 选择要链接的数据库类型。 + form: form + - name: url + type: string + required: true + label: + en_US: URL/Host/DSN + zh_Hans: URL/Host/DSN + human_description: + en_US: Please input depending on DB type, visit https://vanna.ai/docs/ for more specification + zh_Hans: 请根据数据库类型,填入对应值,详情参考https://vanna.ai/docs/ + form: form + - name: db_name + type: string + required: false + label: + en_US: DB name + zh_Hans: 数据库名 + human_description: + en_US: Database name + zh_Hans: 数据库名 + form: form + - name: username + type: string + required: false + label: + en_US: Username + zh_Hans: 用户名 + human_description: + en_US: Username + zh_Hans: 用户名 + form: form + - name: password + type: secret-input + required: false + label: + en_US: Password + zh_Hans: 密码 + human_description: + en_US: Password + zh_Hans: 密码 + form: form + - name: port + type: number + required: false + label: + en_US: Port + zh_Hans: 端口 + human_description: + en_US: Port + zh_Hans: 端口 + form: form + - name: ddl + type: string + required: false + label: + en_US: Training DDL + zh_Hans: 训练DDL + human_description: + en_US: DDL statements for training data + zh_Hans: 用于训练RAG Model的建表语句 + form: form + - name: question + type: string + required: false + label: + en_US: Training Question + zh_Hans: 训练问题 + human_description: + en_US: Question-SQL Pairs + zh_Hans: Question-SQL中的问题 + form: form + - name: sql + type: string + required: false + label: + en_US: Training SQL + zh_Hans: 训练SQL + human_description: + en_US: SQL queries to your training data + zh_Hans: 用于训练RAG Model的SQL语句 + form: form + - name: memos + type: string + required: false + label: + en_US: Training Memos + zh_Hans: 训练说明 + human_description: + en_US: Sometimes you may want to add documentation about your business terminology or definitions + zh_Hans: 添加更多关于数据库的业务说明 + form: form + - name: enable_training + type: boolean + required: false + default: false + label: + en_US: Training Data + zh_Hans: 训练数据 + human_description: + en_US: You only need to train once. Do not train again unless you want to add more training data + zh_Hans: 训练数据无更新时,训练一次即可 + form: form + - name: reset_training_data + type: boolean + required: false + default: false + label: + en_US: Reset Training Data + zh_Hans: 重置训练数据 + human_description: + en_US: Remove all training data in the current RAG Model + zh_Hans: 删除当前RAG Model中的所有训练数据 + form: form + - name: training_metadata + type: boolean + required: false + default: false + label: + en_US: Training Metadata + zh_Hans: 训练元数据 + human_description: + en_US: If enabled, it will attempt to train on the metadata of that database + zh_Hans: 是否自动从数据库获取元数据来训练 + form: form + - name: generate_chart + type: boolean + required: false + default: True + label: + en_US: Generate Charts + zh_Hans: 生成图表 + human_description: + en_US: Generate Charts + zh_Hans: 是否生成图表 + form: form diff --git a/api/core/tools/provider/builtin/vanna/vanna.py b/api/core/tools/provider/builtin/vanna/vanna.py new file mode 100644 index 0000000000..ab1fd71df5 --- /dev/null +++ b/api/core/tools/provider/builtin/vanna/vanna.py @@ -0,0 +1,25 @@ +from typing import Any + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin.vanna.tools.vanna import VannaTool +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class VannaProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + try: + VannaTool().fork_tool_runtime( + runtime={ + "credentials": credentials, + } + ).invoke( + user_id='', + tool_parameters={ + "model": "chinook", + "db_type": "SQLite", + "url": "https://vanna.ai/Chinook.sqlite", + "query": "What are the top 10 customers by sales?" + }, + ) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/vanna/vanna.yaml b/api/core/tools/provider/builtin/vanna/vanna.yaml new file mode 100644 index 0000000000..b29fa103e1 --- /dev/null +++ b/api/core/tools/provider/builtin/vanna/vanna.yaml @@ -0,0 +1,25 @@ +identity: + author: QCTC + name: vanna + label: + en_US: Vanna.AI + zh_Hans: Vanna.AI + description: + en_US: The fastest way to get actionable insights from your database just by asking questions. + zh_Hans: 一个基于大模型和RAG的Text2SQL工具。 + icon: icon.png +credentials_for_provider: + api_key: + type: secret-input + required: true + label: + en_US: API key + zh_Hans: API key + placeholder: + en_US: Please input your API key + zh_Hans: 请输入你的 API key + pt_BR: Please input your API key + help: + en_US: Get your API key from Vanna.AI + zh_Hans: 从 Vanna.AI 获取你的 API key + url: https://vanna.ai/account/profile diff --git a/api/requirements.txt b/api/requirements.txt index 6d6edf1071..cde827b987 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -82,3 +82,4 @@ firecrawl-py==0.0.5 oss2==2.18.5 pgvector==0.2.5 google-cloud-aiplatform==1.49.0 +vanna[postgres,mysql,clickhouse,duckdb]==0.5.5