From 8d5456b6d07639824944fbac9f609b55216d99dd Mon Sep 17 00:00:00 2001 From: larcane97 <70624819+larcane97@users.noreply.github.com> Date: Fri, 1 Nov 2024 14:38:52 +0900 Subject: [PATCH] Add VESSL AI OpenAI API-compatible model provider and LLM model (#9474) Co-authored-by: moon --- .../model_providers/vessl_ai/__init__.py | 0 .../vessl_ai/_assets/icon_l_en.png | Bin 0 -> 11261 bytes .../vessl_ai/_assets/icon_s_en.svg | 3 + .../model_providers/vessl_ai/llm/__init__.py | 0 .../model_providers/vessl_ai/llm/llm.py | 83 +++++++++++ .../model_providers/vessl_ai/vessl_ai.py | 10 ++ .../model_providers/vessl_ai/vessl_ai.yaml | 56 ++++++++ api/tests/integration_tests/.env.example | 7 +- .../model_runtime/vessl_ai/__init__.py | 0 .../model_runtime/vessl_ai/test_llm.py | 131 ++++++++++++++++++ 10 files changed, 289 insertions(+), 1 deletion(-) create mode 100644 api/core/model_runtime/model_providers/vessl_ai/__init__.py create mode 100644 api/core/model_runtime/model_providers/vessl_ai/_assets/icon_l_en.png create mode 100644 api/core/model_runtime/model_providers/vessl_ai/_assets/icon_s_en.svg create mode 100644 api/core/model_runtime/model_providers/vessl_ai/llm/__init__.py create mode 100644 api/core/model_runtime/model_providers/vessl_ai/llm/llm.py create mode 100644 api/core/model_runtime/model_providers/vessl_ai/vessl_ai.py create mode 100644 api/core/model_runtime/model_providers/vessl_ai/vessl_ai.yaml create mode 100644 api/tests/integration_tests/model_runtime/vessl_ai/__init__.py create mode 100644 api/tests/integration_tests/model_runtime/vessl_ai/test_llm.py diff --git a/api/core/model_runtime/model_providers/vessl_ai/__init__.py b/api/core/model_runtime/model_providers/vessl_ai/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/vessl_ai/_assets/icon_l_en.png b/api/core/model_runtime/model_providers/vessl_ai/_assets/icon_l_en.png new file mode 100644 index 0000000000000000000000000000000000000000..18ba350fa0c98f288a0511a9793873fe68532d20 GIT binary patch literal 11261 zcmd_QhgXx$*Dj0-f+8p#MXE~208*rj1ZfhGA{_*zR{;qKA&Me|CLN_pXwn3!(vnAM zN{{qr6e*!bBmuc0$+>yn^Zw4a*1Ntx;H*PP=FZH%=9<~l_TCwkDNK*$?4`4GbaX5R z`cN}EItCIQ-3eKyQ(#W-JG$3kBIK%XW=u!-M3j#1=?gkK0L*$yprd;%OGmf|74XNEOvZiVI5-T3>OLv3UK%GbrT8+4|Ed>^9!V-6HkpXw+rnSVxo9(Ih;H} z-*B0aT)RqY7=U=Xr+vL!zOz*N1@ZLvE}2E}(HXwq?r##@YF^qZz9HP^rL}q8NLG9g zZF}cw^{v}&7+q~e zU*3D|AshBQYvIZ{od8p->}^LC2(L>#lEGo&}7!g~aCFz&`Yqtxg`ou?Mg$iF-PrEG=3 zQ%BuVWC*attzLjWX*g$)97DB;tT%|P{L1K0^Xj4%s4bg(pVQ!xJ9TI`*Es@b-TZX8 z;~_=%N{yW?!B?~X@>kQCc{R5XCLdXTt=j5W*mnUj>}qJl#oBerf7qYmBJ3+==gnLL zWe?}&b}B>#-_b$?W(qrTRDyo+?jYqG&sgoF!NW_GIBMOck!tOaeh{%ahMQkkRQ>n! zh^_c!tDY}>Xj+ei#2WU#8o z@cX#~3*5%c<>`pe_*cGt=W9KM0F#xB(YyBR37ED9DiLq6yz0&-F30m3|1|_|%R?D@ zM(HY75vU|k0g-vsQCTD+nxNG>m=ryc5qJSX4m`vsX{hS_%%h#sDtf>=hi{!h#M;Bl zy>NoCn0VTc!B-CQg2V6R4g=cjZ!w~25AQS9n6H3sfT!y+&Jyyb88VY~i=_1?Tcm0;HACwWN5efkx!P zen1ti?0j)sF+lBi>(_+RPOLh}#Shy(W4H5uLcmEi{*s`eK?50FnzOtC$XMd`5nIzF6#2TQMQX9E^a)=V2R3Cf2eYcwgzn^TMYuGY2cuEMSua zWqmMrPa2Bh;9H|+v`tyJ++ivV(>|6(;EO#d%wWbp;9Jr9c;*yS9F^Zc()|zm35t40 zgZ?M+#kkfSNrCUo=C`XH$nt*{ybMWG{cbs&eYZW3ZUQS>)2L_!h}kQf7RlVW2E6~b z$L`7XPQlFoiW`t>GaD@gpdOmZF3|g&P#Wu;K%@S08w#`Gd~UXIi00`!J)GIo`~j?5 zpG3~R%gz~Bh5;0Y5AaK{ zYROP4^HC!=lxEy0!k35I5c{u2Y=!alp1(}tZ@vs;fYiiA4gEckv$o0oQe-y!lW%XG z9+z#ya(btvm{27cC&VjA)_4L|t%ofCQsxUKRry2}tWQlvQeQ+92A2B%Q4QrpOtH$4 zP4U}HhHOz(R(x6BY)X?b08B8RlefUux9mQx04^Z3_coHFv_F;;;E#EA!1|QI`WC!& zNL4F=PN(ly0<4FrYU&fQ1m-y%vUzI(x)v!EEi`U*{$=!_`5rU!KJ`bv2-p?;mfG3Q zt|v`7*awzow@9Tubn6GVt#JDt*R3cltY*z$M=}UeczJj^18xc;?|#MbpPQ`L`cO{^ zyYOEZO$wh^5%|WjJ1nwfzdI@Q#(Rn1$Y%%U^TB$1#A(n`^skRfA3|p@AS#UB3BZ7O zj7Mx1Rq4u|nX+`O`LqkK6Uia%V<6kct5wbwYv3vM|YtF$9&`G$d7=7ux?l%_|UIw5k|ihaj=@XqaO-Bghd zLPU-l{54xlRco-v{r5Z&AY*Bpka}Gnhj6=#e1P8C2jC|&? zVWoM^C0gZ~Z3ls)wqIt4RlJHu?xh>(imrt~VK>|i(`S=3rI^bVQgaj(N!~ATtUeBaSJCG{$ z8eymp{*qG;q(jw2VdzZRbFBILHk1cJyG0Z9aAag6d0Z)sNzi=kHmP~=H^jV^imuhV zcTj}EnDZ6+Hcb0G`rAE?CK3twH^|*9^gwXx$Y)Sn5lUEzgGxJao-h+f6w1?F_#xpSqMk zmZFUWy@R0O-Pso8W^T!C89P#i#?ET4zG0Dfv+Pqsd#GoH2d7Ae-{cDgeLlLK9|r2gbpKX@)7+~gtB0J(k#5!?9NNSS$VOWkEcaeOAkt&Bpg@n#(c#R{ zyLF-`i2=%u(1u3wxm6JiJqt^2d!pKdIW*yGe=>d-Wl00$SAaeLmzOH$dC>u*MLpkH z);u^*X83Uz_(ODxPu;8|Fn5PYIT20j*RpwUYft^wl@@wlHybSwcP&AaY>QbsvBqQw z$>f9nt2u~3&=*B2`nP}-o1_PYo*RLA?Y{h*G1a4BTKWSm-){24d)-|2C*sGvCux6N zMj_FTlFD%ezSq*5Z?nz*{d_n6@*3|^_fPTS>x3Sg|nSA;rjoJZn$Jd@XLuztzJBKwXyhp94 z_DV!*87YO_5r8HtFhuA$Hr^zQ1LIJjfpsi5q}0M+2&Zi774pMzRBLU;kDutdjHE3|712+zQ7^(JmD30txzb0Qw#e1@I|&br zj`**^%OB&cE|_ocq7!u{1*OK>gO!&uel}U)%A@O06*}irVaKT5h_f?wTrR~9>PuRF zGuI!X*u`7bz>!dtQz%e^@0rjP1BOO<)mL(F%fO7lafnse(I-5%6wBl@Jno|9Hb)>F zE_tA4xD9)N^FT->F(tNp<$S;W=d`Ko@7Ir=$NB0^aZYDmddb3;U$P^1P(OCV9(#H@ zhNK&?AoT_oyc|fD7z-TT@6)Ege1ZhIc2=3WDEUj--wr)8E&Qc%`|79lo><;t-;W>Lp<{bB( zl!#9&{1M352R$iPv-+<0utv+3W#0?b zdhzqb9RBN;m}M|%rXn&O_T2|uY{kM(y{o_J5ROl;OXHRREN&SL<#u!1YTKbvm{qx( z-WM5Yza|ChDEWi7S2T{ttC03=QkHSFJ>p>z+N3A_p)9mKh^zZr9w7px30cXFaGK=6 zVAHdIj0MAUco6%8h*V{ZblrGOE?63o-{C&GuHu{+P%yPD>33ss(ApKLH?nX>8W@xd zIZqCE+FjXJzkZKYK@k7Y)qHokmjgj@4g7oWb-&VcWHMHz8MHDM-Z+K+ir2is^vKG5 z<@ee`qqj;yK1!3E%Mt$A_Vs{r<_`4u;LLn3bduzz@qrAV-ez?*@Ny&>y%)B4rWWO$ zy={~1;ecy^g5GE`C0@kHJ(89{@H*5j63`T-Z&|`B{ zQX{%gt#F@tugnSKF&+XgXbgcc)mfMqFS?(X}J^uI+evAfJ$2FPp${n7VbpjV5&&80jmI+e)Xt`sn8`@@0gog0O=VS}8J; z7g@%#$xy2tC!7NeKE;rh>3^ z%ZV2s4r5OxMBVYHIqYxLQZFv|9vV@|ULazIyTMMT!NTA4UvZ&Vm)}7)P?^}`ju<9s z;Ef&j*Rq8&Ha$G2Y5%q5ukfdoTND^ieqB1-sLE3s)D@4}c-$JU_nUrW;0Zd6?vL_p zEbxLGL9PhPws<~+cGwMV&^M4aB7K#loBU-W{p;FV*GDvQwkUHqN3RvKp&W0OA6XRD zq2r00GLc1`YiTQ4hu7n0Hp;GMT>~sM6I)(SM@zT^``3h@b1wkM<_IIpa?2UA)6H#IX+KsRS1!P$cDn?za?(|}{#fl&j*sMt+_ldz*e)LJ zZFxRBl8g=|8}dkE`}QhXrN2Bfb);yS%^N`iC}z*&(c{V~Dj{Q;^C+Z76* zW-$Mxi+`JySVFjUnQ{}g^cwzPE)7~3KgQ5ZN2y#kf_dWL8j0GI;uq?{$=8f6CT_pr zp=@5?lKyl4pX%mo;gKVWVMqhCNvq${;-hp6f8T%Ake4B(5*>@ov)goD>Uj-p!y<7I z)kkX2T~}I0b%gnr8sl(@<;kGFR^X5u%Xi1?*sHE~o0*{l;Aw2$%RvrX;@36k0Za1Pj{Nayn z1tEYuK|-($r%;M2*~p)agOq;e)XDKTA1J?z?H=NO0--d{HP~k4)R@$#lM6 zNedW=m``9s1Q{|*HLw%y%4#k<9*^mADQD#Nr>r6Z(fXsBUmPZVoz||eAs{$`{SX=o z%6r3QFy5%pH#f{%cXXYw_hEINizS@5L5_PQ`1H%%ZnC!ysndccE;F}(kAk%8U2o*X zi5>pvaSwFKEy(Uqf*f}}PvTM=WaA7T|0W{fAfc}N5$KxYPsd&=fct-Ine%$wHdG5U zVnfI~W|Xh2$x78~y*q)Ms=afedu5kh?CwC=CxLF2aA&7}vDw~c5rCBpoG;n)o1XBJ zK!f_bsG1+PTyNIO5_h~{cWnlK91rUyS!D_1`TUnqnxEye_<&a2^Mv$QJvYh$W{}5r z@C%)P24f||@m>cHjc=CMI9P?0wG=9C!;H{0r6q?V`*>@zGM9-{g!3ITj|gCql9?G% z+9zEVRL6i*p1v0>^UX#k9uIE9Uyq%_`lC=IG?8m_AgTMgQC(rdEs(`~si9JugM3w` z^to{y#wm7k^BhZLA{CMWT8YIeNq6l)qC|lpNh^e)ht>fhTNo(0>(Tn@Ozjo1%(<6@ z4Lr-IT+4>Sxm!+Ik>|{rMeqx$e$Os(P+D4&V)NV^qiWUF3oNwfA&R>4!=YVM+T=*#^;)ta8oKiTe(f^rivnfCMewO zu80FW(u0=^g4=Q;&d-b(3D;D56ma}%Wx|QM7SD4EErg~zD=CNoM|`wt+1nJG7^$|A zMx{&C}0}Ose%kpnFRIBxp+c^eYTpN!ni(YWWlJKpQEXBc6j4#a{vSSk~vV?$#k(r3QQ@37a{tMY%emRtF;-o~Dc?*K=3gIqFMrs35)pHAVer^y#Bc9V z-zNt+hHdw)nD)Iwqh}f7G$kyry-JP>?>USOg*st%q%^D>PENs`xDUlNiWKGs7uIf( zh}AyyM`wBDYrW}@x-VNlC79P5;G~Vjn-t7PQ$Qo&I$S9@Z^r$nL4}&@>UM|-Ac$QG zY)<<7*vz2ITfy;doZOZZ>1(S!r>H2sL0hfAK8fXe87#woIO{>ncs992)~qEMuF6hP z%N^@6y+?lk&`~h-r7(>&H-ROjd(>psOroA~%-)efVKB0zIKbC57mj##Bc{?1BohS$ zt7|fx4mMw9X8tj;Qh-B%=f3LGdE+muuHlqFUe6}!c{n}4lZ>Mo%(Q;7Ex%#%iHFkF zPuFn0edw`zm00mP%od>F-{erSX8xj`2T|WyPWpfw(c?X>a18+XZ7uGq+DVqA81KF% zoaStbuonH15S7q8yY67dRf_li!&23yYYvEUBrYUqazVZuUP*^eWJ+hwWf1tldC_?^ z+}l_~cGHk(+c4eg?KuN*Q%3Obu2Gsj+Sd(eC55rz|K)Ra^Re=M{z$HpH~Sh@ zB|u-q2=2j#rhSSn$%4(E{l4g~xNV4CU0NJ(CmhZjOXslpp)lI9liilXyE*t#IRGpE z8$R8jBIex0i$j`vlC6B$BpQM8W?X&(M%%tp^Fa~{^EHK9uYWQ7hI?V|66_haO905j z*}NtoW$lBxnzjv)2Mq7b6UG zINF2!v4Q2fvBCA_9NRoq9z@{)!7cT^4;U=t?PBcfitT0T2oAEF%=bn1jW5`|Cm@gh zCWPE?j~x0bso~&RdHHPsPFY-bN!VanNepmzP;ZfGX+(c;1%Fc6BED`@8{nqE-`6wY zd@huf>zf>LAQCXtGr#TWE9O_Act2W(V#m&pSP?B`SxtLG7;g0--sz^mUQ61hp{Qdf zMi(!ZBz3ud4~lPf#{N8W@RYxjLW27)@W)ZrtS~MghqblOEOd9R+ma64hSr(>sBfV# z=E9piHG6Iv&YWTzf7gi9ZE6=NYJ7Xqn;Ww4r_-I$AKHCzwNa-i%tt)6+Bo^;jKjNPSphB=udQXc@k z95{~=wA0SwL7KMDJ<>sa=_`0#>8)g}>((|sC*pexu6=Zgqs%{&LL1Bx+`8uqL!+ij zU2*XDa?ZIIjN>uC;=MzEKhc#x;`f1z6%cN~*Vnb_ZoEAJXm5mD0I$K05r-^-TK4NY z_>5mXX6yvdcS{b#I%PwrqEmR&0gX)YJVG!wB~*sz#tSNlNl1Sp8NYw_x*>^i7}D?} zx8Kj4)T!F&Zw#KDKM&PL5G)a?LyzH$C{1`u-cf!nE&tH7vE`k5j9msnt!=nacDdpn z4jwd;lb3*WK2$R?f9epG*U52K445A$e>h5y`pmZhf+C^|0D;V=o1neMLVWG<`Ev6e zU3@vS3qB0oZY!`|_!jp<6@C*^Yp%f-v1``{Lqo1iSr8>2CXJ*)@yq?yg|fyS_jU25 zENC)D`rwDg-4o62>BP%@90>oRrcg;MNnpOmz)nvX!2X8)`u?-?a!ov@?@{QwTDbBH zDrJ;(?}go;)GBzj+nY^#rS}QSX3DqHNYmiu3s?Xb6D; zXO#A5aZk`ti0^&!*FCWh5twL(}&ZN>l#Et_oo6#Ua(^n7b^_gxYO!-7_l!JPps&MPd z8SCSCjBE3I9aB`!Vatl&B^#IilVQUo+r%C!0hA=#PK1jI92_x!YbFGEKlUUad!yLB zig)MKO9%sLcl!RQhHd+yu-z7+9lu&5A1}fRDj3R)aAJzn{lfa{e-1P}=Fm=6I#tEs zbXbo?X1i72j>r72FR=5vFoX#>nqoXtxQq#5nmN%mEcOd!LXU|!I%gGH{UMHYZMp&JWeN&HQzq^EWk%Tx7XHc9utDmO}dw@89#NBw3 zmW2MY5(GGMKk%K+dFC5J$^ZOh=`Fg@P0)XNMh&%O;>OqjaKaahI1My^#)panOI_wz zqurS9-oW8Eo{pp~wZv*wIr zkNlBnlU$#y)_=ys`ESHeU00jj6(w1P`mT^TUi;Zyg~n&x0znqz`32 zZow{lU4p6Y=-&;X0-zr+qn=HhJHviTrJ>8yiBzN32pp5I5i|TB7RN{7q6p3WCaBtW zRfqcFb9srk4ZlZI!aKf6jn}De*2{a0B^bcbCVdReKV4khhnLt@FCczx1RegTiYkEu zu_v=V9IJi<{2)gvBwO{i1@xO|I4-e#C;5%?>0>%=~s9zg~*OG z>p9iHl+o`0Wr}d08c?q9cY6#y@r*DR>iKvYoNPF}#CO&nE{ZsQlUc;riXz&o5_hvj z4JTH7G8c@n8u)I1tC$yn;v03U`oKB%gbI7~<~FBSK2z$|SnJ8B@-CH*5V`YQt9znA zBkpa4##gySVJR6>qW8B0mC+!?fC&5*(864K?KA>?J{ zt6YCb!=y0k8ltTBVbv3!GZTWPAA5@~?#^wwP(7q~wl#m|6UCZ8<*i5w@g~F!s$=5K@*FRAJ#Shz?Lpf3gJ* z#Kf?D!CxjBAa(Q4_<;RZ$xb*yCddPAveM(5K=RO3-!g)W0C#e5K43dkC5wJI_Ht7U zV_W5?73Jb9l*=Kw^0ANHJZPFISL3zC2^bXbE@@Mc&3TF+srRWfQN#)=#DC2|HQP1K zq1;99w$&j!!c1$!Q_oSdg$Jh$u9+5aPe{fVn@(P`Cv9B*sWGzo$9C%^P66fJgMk)W zIBbBv$C-og9+U~Jq>~-<^O^6Dz?NLX#N6?-4Z*!))4AbZLS5*)Sh6nI#*R?+tQ^Cq z$3Ncf@2HIQlR4*WwXqV<1~W!D)AwL$3mDD%IB>J|bous!2*zfu42KOFi_Jf~v8+54cpqSWE|b zood$8J>EhvAStpG*o~C?P~VQ@ack(<7}#f6CDdE%ipB48}4E8Re^$ zg?yFX*naMjq=}*FB*K`)c-!|t?XN?V@6ouEyHj9g zL28^7bX&|XbkN7br3fx>_~#}ufS6F)6y^8$l3MMg<2cL=J&hua{d`Q>wKGdvG?(*z zg?2KvEWco*sE@%G-!C6U6ZrIS5uG9_>7|>@*<~&*eKjMbezQ49!HDSu%P=xdW&m?9XGX3=n^NM+xuGYZf%=MLy zeEibs=qDTZa7M9;j{`{<;uk?3vtwZXT|N0v!Uv*tZI{1+!YI2rN%?c8?TdvfDrpFQ85(H z-;r(CGiWeIz~7?P>xS*j%YrYSGyu1{H;*T|Nz|LwQbOph2HFDg_EzLdTKcE#V#+Q8k|m`)OWW}-WFf}QU8j^_#R zMR!7u?j-ok2EOHEz3Axw&8MS-f+?N$!vGgIAp<{W4>vP6=ZE27)~UBA{tx!;$^VvT zeERm(e?QaTJjT}8GvmDs}d_D000Q>?Tdbx-H{{wjNsV*Qu{67pKUVd)DAt%q7`HGlcU7GL-+&f)Xa~?!j&$bV@fzM-1J;(u@Bo zHT4Q}^RGcm?Wz3fcUBgk*2ZHg^;~1ce0tQ}#&as9Af-C|8eE&yHk~EXI Qpe&t%E)4qq_JipE1w@ALBme*a literal 0 HcmV?d00001 diff --git a/api/core/model_runtime/model_providers/vessl_ai/_assets/icon_s_en.svg b/api/core/model_runtime/model_providers/vessl_ai/_assets/icon_s_en.svg new file mode 100644 index 0000000000..242f4e82b2 --- /dev/null +++ b/api/core/model_runtime/model_providers/vessl_ai/_assets/icon_s_en.svg @@ -0,0 +1,3 @@ + + + diff --git a/api/core/model_runtime/model_providers/vessl_ai/llm/__init__.py b/api/core/model_runtime/model_providers/vessl_ai/llm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/vessl_ai/llm/llm.py b/api/core/model_runtime/model_providers/vessl_ai/llm/llm.py new file mode 100644 index 0000000000..034c066ab5 --- /dev/null +++ b/api/core/model_runtime/model_providers/vessl_ai/llm/llm.py @@ -0,0 +1,83 @@ +from decimal import Decimal + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.llm_entities import LLMMode +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + DefaultParameterName, + FetchFrom, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, + PriceConfig, +) +from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel + + +class VesslAILargeLanguageModel(OAIAPICompatLargeLanguageModel): + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + features = [] + + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + model_type=ModelType.LLM, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + features=features, + model_properties={ + ModelPropertyKey.MODE: credentials.get("mode"), + }, + parameter_rules=[ + ParameterRule( + name=DefaultParameterName.TEMPERATURE.value, + label=I18nObject(en_US="Temperature"), + type=ParameterType.FLOAT, + default=float(credentials.get("temperature", 0.7)), + min=0, + max=2, + precision=2, + ), + ParameterRule( + name=DefaultParameterName.TOP_P.value, + label=I18nObject(en_US="Top P"), + type=ParameterType.FLOAT, + default=float(credentials.get("top_p", 1)), + min=0, + max=1, + precision=2, + ), + ParameterRule( + name=DefaultParameterName.TOP_K.value, + label=I18nObject(en_US="Top K"), + type=ParameterType.INT, + default=int(credentials.get("top_k", 50)), + min=-2147483647, + max=2147483647, + precision=0, + ), + ParameterRule( + name=DefaultParameterName.MAX_TOKENS.value, + label=I18nObject(en_US="Max Tokens"), + type=ParameterType.INT, + default=512, + min=1, + max=int(credentials.get("max_tokens_to_sample", 4096)), + ), + ], + pricing=PriceConfig( + input=Decimal(credentials.get("input_price", 0)), + output=Decimal(credentials.get("output_price", 0)), + unit=Decimal(credentials.get("unit", 0)), + currency=credentials.get("currency", "USD"), + ), + ) + + if credentials["mode"] == "chat": + entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value + elif credentials["mode"] == "completion": + entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value + else: + raise ValueError(f"Unknown completion type {credentials['completion_type']}") + + return entity diff --git a/api/core/model_runtime/model_providers/vessl_ai/vessl_ai.py b/api/core/model_runtime/model_providers/vessl_ai/vessl_ai.py new file mode 100644 index 0000000000..7a987c6710 --- /dev/null +++ b/api/core/model_runtime/model_providers/vessl_ai/vessl_ai.py @@ -0,0 +1,10 @@ +import logging + +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class VesslAIProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + pass diff --git a/api/core/model_runtime/model_providers/vessl_ai/vessl_ai.yaml b/api/core/model_runtime/model_providers/vessl_ai/vessl_ai.yaml new file mode 100644 index 0000000000..6052756cae --- /dev/null +++ b/api/core/model_runtime/model_providers/vessl_ai/vessl_ai.yaml @@ -0,0 +1,56 @@ +provider: vessl_ai +label: + en_US: vessl_ai +icon_small: + en_US: icon_s_en.svg +icon_large: + en_US: icon_l_en.png +background: "#F1EFED" +help: + title: + en_US: How to deploy VESSL AI LLM Model Endpoint + url: + en_US: https://docs.vessl.ai/guides/get-started/llama3-deployment +supported_model_types: + - llm +configurate_methods: + - customizable-model +model_credential_schema: + model: + label: + en_US: Model Name + placeholder: + en_US: Enter your model name + credential_form_schemas: + - variable: endpoint_url + label: + en_US: endpoint url + type: text-input + required: true + placeholder: + en_US: Enter the url of your endpoint url + - variable: api_key + required: true + label: + en_US: API Key + type: secret-input + placeholder: + en_US: Enter your VESSL AI secret key + - variable: mode + show_on: + - variable: __model_type + value: llm + label: + en_US: Completion mode + type: select + required: false + default: chat + placeholder: + en_US: Select completion mode + options: + - value: completion + label: + en_US: Completion + - value: chat + label: + en_US: Chat diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example index 6791cd891b..f95d5c2ca1 100644 --- a/api/tests/integration_tests/.env.example +++ b/api/tests/integration_tests/.env.example @@ -84,5 +84,10 @@ VOLC_EMBEDDING_ENDPOINT_ID= # 360 AI Credentials ZHINAO_API_KEY= +# VESSL AI Credentials +VESSL_AI_MODEL_NAME= +VESSL_AI_API_KEY= +VESSL_AI_ENDPOINT_URL= + # Gitee AI Credentials -GITEE_AI_API_KEY= +GITEE_AI_API_KEY= \ No newline at end of file diff --git a/api/tests/integration_tests/model_runtime/vessl_ai/__init__.py b/api/tests/integration_tests/model_runtime/vessl_ai/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/model_runtime/vessl_ai/test_llm.py b/api/tests/integration_tests/model_runtime/vessl_ai/test_llm.py new file mode 100644 index 0000000000..7797d0f8e4 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/vessl_ai/test_llm.py @@ -0,0 +1,131 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.vessl_ai.llm.llm import VesslAILargeLanguageModel + + +def test_validate_credentials(): + model = VesslAILargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model=os.environ.get("VESSL_AI_MODEL_NAME"), + credentials={ + "api_key": "invalid_key", + "endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"), + "mode": "chat", + }, + ) + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model=os.environ.get("VESSL_AI_MODEL_NAME"), + credentials={ + "api_key": os.environ.get("VESSL_AI_API_KEY"), + "endpoint_url": "http://invalid_url", + "mode": "chat", + }, + ) + + model.validate_credentials( + model=os.environ.get("VESSL_AI_MODEL_NAME"), + credentials={ + "api_key": os.environ.get("VESSL_AI_API_KEY"), + "endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"), + "mode": "chat", + }, + ) + + +def test_invoke_model(): + model = VesslAILargeLanguageModel() + + response = model.invoke( + model=os.environ.get("VESSL_AI_MODEL_NAME"), + credentials={ + "api_key": os.environ.get("VESSL_AI_API_KEY"), + "endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"), + "mode": "chat", + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Who are you?"), + ], + model_parameters={ + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, + }, + stop=["How"], + stream=False, + user="abc-123", + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + + +def test_invoke_stream_model(): + model = VesslAILargeLanguageModel() + + response = model.invoke( + model=os.environ.get("VESSL_AI_MODEL_NAME"), + credentials={ + "api_key": os.environ.get("VESSL_AI_API_KEY"), + "endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"), + "mode": "chat", + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Who are you?"), + ], + model_parameters={ + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, + }, + stop=["How"], + stream=True, + user="abc-123", + ) + + assert isinstance(response, Generator) + + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + + +def test_get_num_tokens(): + model = VesslAILargeLanguageModel() + + num_tokens = model.get_num_tokens( + model=os.environ.get("VESSL_AI_MODEL_NAME"), + credentials={ + "api_key": os.environ.get("VESSL_AI_API_KEY"), + "endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"), + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + ) + + assert isinstance(num_tokens, int) + assert num_tokens == 21