From 4a026fa352dff40643ddf13960cc05a5fa27d27c Mon Sep 17 00:00:00 2001 From: ybalbert001 <120714773+ybalbert001@users.noreply.github.com> Date: Thu, 18 Jul 2024 19:32:31 +0800 Subject: [PATCH] Enhancement: add model provider - Amazon Sagemaker (#6255) Co-authored-by: Yuanbo Li Co-authored-by: crazywoola <427733928@qq.com> --- .../model_providers/sagemaker/__init__.py | 0 .../sagemaker/_assets/icon_l_en.png | Bin 0 -> 9395 bytes .../sagemaker/_assets/icon_s_en.png | Bin 0 -> 9720 bytes .../model_providers/sagemaker/llm/__init__.py | 0 .../model_providers/sagemaker/llm/llm.py | 238 ++++++++++++++++++ .../sagemaker/rerank/__init__.py | 0 .../sagemaker/rerank/rerank.py | 190 ++++++++++++++ .../model_providers/sagemaker/sagemaker.py | 17 ++ .../model_providers/sagemaker/sagemaker.yaml | 125 +++++++++ .../sagemaker/text_embedding/__init__.py | 0 .../text_embedding/text_embedding.py | 214 ++++++++++++++++ .../model_runtime/sagemaker/__init__.py | 0 .../model_runtime/sagemaker/test_provider.py | 19 ++ .../model_runtime/sagemaker/test_rerank.py | 55 ++++ .../sagemaker/test_text_embedding.py | 55 ++++ 15 files changed, 913 insertions(+) create mode 100644 api/core/model_runtime/model_providers/sagemaker/__init__.py create mode 100644 api/core/model_runtime/model_providers/sagemaker/_assets/icon_l_en.png create mode 100644 api/core/model_runtime/model_providers/sagemaker/_assets/icon_s_en.png create mode 100644 api/core/model_runtime/model_providers/sagemaker/llm/__init__.py create mode 100644 api/core/model_runtime/model_providers/sagemaker/llm/llm.py create mode 100644 api/core/model_runtime/model_providers/sagemaker/rerank/__init__.py create mode 100644 api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py create mode 100644 api/core/model_runtime/model_providers/sagemaker/sagemaker.py create mode 100644 api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml create mode 100644 api/core/model_runtime/model_providers/sagemaker/text_embedding/__init__.py create mode 100644 api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py create mode 100644 api/tests/integration_tests/model_runtime/sagemaker/__init__.py create mode 100644 api/tests/integration_tests/model_runtime/sagemaker/test_provider.py create mode 100644 api/tests/integration_tests/model_runtime/sagemaker/test_rerank.py create mode 100644 api/tests/integration_tests/model_runtime/sagemaker/test_text_embedding.py diff --git a/api/core/model_runtime/model_providers/sagemaker/__init__.py b/api/core/model_runtime/model_providers/sagemaker/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/sagemaker/_assets/icon_l_en.png b/api/core/model_runtime/model_providers/sagemaker/_assets/icon_l_en.png new file mode 100644 index 0000000000000000000000000000000000000000..0abe07a78ff776018db04d84b7cc0ceb1c9ae81f GIT binary patch literal 9395 zcmai)WmsEFwD$uo5Ue;9in|tCpm=e2cLK$o;0{Gv2<{YjFD^wxDG&+;iWj#6MS{ED zlylDeKF|GfKP1_EXYH9aGi%nKJ->gXnu;to1{nqb0Kk@)lhObHkV_HAAAzWdZ?@^f z7laapHGjSiJ(WauH znX3hb#6i<4EOjqRsbqeQBcRqg-{O0hg~XB)v8o&}06oFbhAUveW?OQmb{1VDzrAYr z6ERHf>-n?AX=4wiA_}iM8%xE$W>Zhiv(>B@pt!9iw~;$OFRL!9=lJ&f3!+;&ej^w6 zBL}1Tp`lLz-CkXhqyS3*3c#HDTR&wmYw&S+){=G;Ork3^hX2daA;qixA1^w?CFrS~{(SY!O*+De_ zG5^o5kLA&v=0Cf&;I+jes~NWTt(ovxh3EyNZEq?*Y|kw;u`i+!J(bDKYHq+dWxM9 zJ)fyv9BDh!8&Y8Q6e-+C0$A z(I6z@hXUVrS)s+kBjtr>Jx{IG{~6yk9&x45#e8h{NU|nAYN2^aut|DTwqYHue`>?( z&^g`w$)0!q#aWkJf3pk<_J&%=OZ+572 zuUXe|*`u#L)e-!6Xygx#IB^2~F)j-)5(!G2_FICdyn-LDYvBG|RVsTS_YJddn4< z1}hCtTXrvYHNc-9q^9M2N#25-0(IPUf6z*%{Tm(VLvFT0G%y1)M(Q%dwfie+ymBfu zL^l<`0mf*Hu(AGiaDeMjVot^8u6;E$z9Vs&N|ZF4vpnfn(D{#~^`;BU#6U^204PKC z%?95r&0|c16_9}wHCEatC5PPTA54Gi%;(4My)6jhnyGnczWisl zap#y5q!(grKKexSHE)9lAj_C2q+=v*^!JX)+~|}_q=74k`k#%vyv>Atw#;)kOX9lf zCs&-mmcruSXOzK^^o+aeGd%;+IHdm8AU1?*MPM*j%CBru|HIy)KQbxabaVzQIOO&> z;XJa~lVI#T)V%4&-Aso*t)&UlKC5=+b|SbX67Mm&kX5A&>%f{)XPaXOFT8-z!{b5?~vFEu_?HyG%&}8dv zOf)$$++#;r$R3OE-BM1x8GeZmpr(OAuCz9J=ve$l8?T!u6Z}Q)hpXUgRw#IxZfz0WlIDsyfoX(?b z+0P4^J7lnaT)~b(C*GODE0(G;==_yrsFC z7;k*Mv^ns!_LSJgt=MZ8163~TTDia+DNuHxwt$X0im~!p{X3NyWg*h>^y{gYS@3E_R4n?iUje@z2oxDS>iBsp#`#M$P;0V@QvB~t*$jE^lu$DdA8NSU%m(_wF0 zEU5JNSGm*5Yz|z|Yx7)I!P+Gi1D;FRK6?c-B4CjxLnlI~4ge@x`rl@-Ia zyCzBX+D~fMiRhGu0Lno&<#rPdhbI6SGB7S%N@1MiCT$B9%dQki8{76aPX7f7fk)>PbN^MQn4yPvh< z2o=yTf=^Ql_He{&Jg~aaNZE9;lN4amCQIgqOHU>kM%jJ0z6mfUoxtR!-wK%}HN_!a zGMjHGpaxhPHN$t2a^5RTX-UhU#|$COQ?tQ>`R{g0H{6N~$kzo|N0$>AW1Re8DHf|$AJ zc*gxTD!YHB*OjI-hy*^pgfnw5#lm`Fq&Pf2IyUVc?n%U(<_ zEGogHZEKQ{bnpE%-~yCacl(D)jHmBCOD5U8(GR^=0|Ssu+u{4TTsuCqEbJpJ3(0d` zS#s+_Z%Vqd54t$4!j_p^$ooP^Y1RBNP9IPu*}B@dT@=}+A+nyEgIH;#^~ z4H@rIHn~|eZfH)VUm_ZtMPs;{<4RnAL27}jfS?e?P{Q-+1AKPQ! ze^FM5w%3hUERvXn@4ovaBB$juL|pJADvmQeO3zsob$aQDB;Xm{$MB#Xst19?(jauB zFBrGq>ksRcdn$I4UEsx34x)6+)OVDp6>H1cSat^GTuu(e^*X}1JU59L{nA5Wpw77d8515k3gYL-E4xbd@qb~Mt z??UdRJF&VL;tWG0#uq)%cOfOUzmn#ssmF62$o3_@R=`!yn{W7yQD;eKRRX6OIKlNS zXiAzd?HkowPNvZ5XOlt$r>2@wt_)KQrAf!wp9m&w@aj@8MDHaHils!S48d?6jXwC` z4y9)Wa}HJ^6Fs+3KlJ&qa96)M@r=P6jkfj1YS4{ z#}&N3qxw~@eP!PyRYwKhS5|s2D8nQHle2P4RVS?mv+yTP*#dmzz@+B2_ z6hon2VGqVOvHooF5@Q;)g~nf}Ve7|Go}jq=a=R2t`n%gDrx)4IxvtcwiUYaB zdaa*;Cvv_r8*kuk9B9#KkN|HP`?mACNS7hupg+VTWID;pwy#cwoThhMN~ZCLyE7AE z0h-ufj2p-Xkh3=J*X)ezx$<8IYJf)%XhHk`EtS);nF%j57#9c{t7``gU1ccRn z$g`b^Iq@2VSE4UiQ{w?$A z=b_IBkUSq6F3S}X&g@un{G{<&BH3JJ!9=U0tfHS#6px2k95**NZ_l0LuszjREA<|C zU6*5jOSaid+MK+t;3&jU?bama^$+I<^iU>~GPx z{hRzy_5dAe$1;VdPA7xQ5(yNb=|4g|C9bSOR!Y@XvlX|~?XFTTeMYR0%?~EJspat1 zzB17+N+- zm{jlf7&;9GC{PnVdd2)fK{m%i$$W~G1t1}y{36}G*L)ZD&+b>XyNb^fH3TZFm9R-nt-oA%6RYs&|f4eFBY;Ba} zN$au58_>8z7TRW4KSTDjm@_xFwhpb8bHERjtPZmbt6bn=1U;(?4zO6bn~&&Kp}%A3 zU)SZmrojbssip z)P&!^zmUvG!4M=0yS*l#H->>`37d* zasjF-(_hNLT7(7T+~OWT_KxfPWQz6&Z?qGLJ-3QSid>b^G1ywo9IuQ9rHtl&#r#>S ztSSvGXr5W0;px%U#ZX)_Wm)pMQeN|+6b8>N=^C<8+=fJ3jh1PuE_S`VHPL=s!n7G& z+#-cLDH6AFl}QBdAb@V;mx!J%EkzX)&{|FgI`Ucwh4& za-fYw{`u2F%G^+uRvGOOyK$lXHkQ7x^+Q`~vdoCbI{G`$+yb5#+E{(I>^c)X(I6z| ze9=(nClg6p8Y->o6lNJBOiv3*tT`cA02SIG&gmNTv2p?z51lc@mJ<+f+g~Ys`-o*# zfy?hYa_D{1kD%@O=*zhDo?wd7nI4|p#P`Mr_ys4vY-3XYdb$n^@g`H+^^9DoeCd2347Yq5 z-?YZZP~+h5*2LS)Tj>0~JA1vjzW1>Jz|y4H*tb4nsr)XySpWSMWMQ=iy- zuXWcb_?f#>5|srsDu@u5totL!^bOS;IdwkIQ%&Z*PFufm7`ox6dvK!B%S*)$Qe07r z%RGh0Lzkmhvj#KPYVn$;CS7jlk}vp>rAG?K^3Qeq=X$`36!doZqT1Nk8Kj9O5numQXHP2o^;NIxXjgD z0VC8>qKeL_)e-}+n+a#ObXm#>rlfS%*Jp{cN#d$G=~(kOwF8vTD?ch@WLDQpzeL6aIWcejEpo`!KFlbbS@in0;N|S<|vOhQ4a=x6{#< zL9~5+I4lPD2l3th=@`VSKPWFwS6d@2A?{(C#bJAck-Ls-ub{lb-kdOn!c?kLV16HPYcD>FqTGnVhu@u zzBy}9>0@j*YE=pV@$j&fUaE&b)Gt2oJqFKi8Nql_@87%K#m|UP%D4{&l=Uv?4Ayg& zVYqiB*11}84t+L4^Z&)@qKXrpuT)ENF{t@5(f$lk}VQU z$#mx^@9gLN1f^1;oEB;+6V`shHi^U{{#lm@KxG45-#~IOSgw#YA*K}o=-{%`wiO8Y z?55iJ+btVcDOi>*qB$Wvfb2=OSASP#%J(<#j^b64e9B{-d_n20VFtT$haeA$<#}XA zYwOgwA{r#bQSF+gCle7(iTZt`6i<~)S=%IZx#8?RVHpw2K9qTA8{c@n+AksPrqaB& z>NYU7w=Oz`jTe6Hns!uKHrB6>Ice8F*!g+P`sEPt9s>JSZh(Fr*gweoyV+BAIX!{X zbG|*1WobBW zfjJA16}Li?UBxl{k5?6BBrKJWkaF>7odjCIn!Pb&I3#O8`{q_6xbG()ntH7Yoj*ef z8>Q0|-k5b~#Hz5Ku%5dfB%@7Q0TiipCa1O1H5h&S(6}3=R2RCA zeqOI1sDjv^?!wY~&yv*2QdxJz@1}qME;G1}5mIQ;)#7RX{{8*_6JXS&JC%tw9v&05 zS7ap_)uUt&{RU5C!**IbxCo0c^`)9x%PT^YP|E=&%TUmA>En?$0Q6lKytV=biW|q>PG`Nb#Jc z8ltN8ReThhKJsFsUYfVE-@qL%O8-ZXs|TW)&37nHL~F+QmCsbCiwhf>7jhCH6rzOu zZIB4tYzG*Zg1dI1AhiY?7mgDFLOV&NR@%4lX>x0an72OWab`PH=MWK1buLcwcv`d% zJ63m%>GlX~fPJA1?12@3t~9s8e@+fBe-dG;s&)bm5x5onuzqee>dkZ;S1FOqBDgiS zL;_z@eTgkXSZZkIrlhc1KvqEjweUSa;;vEnoi3G@q)X zPAV9Cb$MvzU%OKeL;DBKmkR z{c&(vDo5%zNf!#a&^8^GD(2AEFr=j{LxfHt!M8X=@HNf38S5b9R_Wb6CjsZ5MS6W$J~MR+2d7j%oIYc?uzPtS+zYJfE(E4I=KTDVsOQJdNBY zE9Cs$+4T1;|s5JeRC zkqV2H}CU97-UmLq2$0cAYCG&3uREQ)wc? zR(tO~R;hCgMGPhF!?YyjeAMn#D9*gq z$^GbvZuK!|wNN z;H5~rvAMDoESF;kezEm{di#gd_vE*-{&;Ocqr?2Y;ueSt_$wn?Gs&fyK9 zJ9)`f3*qQq6Rpkb=(L~) zmO}O+iuL}HGohN&*hv%ealt!s;>PR!I@8(c&ABLGcUk%VgLifi6QAx{&2#C$)UpmU z1PwAAi(RNk`%QVzN9DjDL)ZrWK@F;)hS&`EGC|Fu8k={nq+zdaO9yIFzqKz?7^nYC2=}Kns1G&B06pOhyNXy+gmB^f>_fEH58mJ}69SG}~T{j|QO&Q`Hh7 zLg}hTSYY!TMZyS6gtnXO$a1~^O`BH=2ZA<y2uP2<**)v_3SuX)YRCebg!HdW~0P3u>}f@YgFJHReP`L`pw!!?DO18@j46v{tTKACAFo8Tgcv_FLNr750v3X3W`fTn|~p3D6)VT z()?nUp&_76zz@&+RCMeMrLc{9I-+0Vk0O6m$^-js4St@n4d)Q_Xw4wcJ3`yS0ZN}9 z22HQ~5ZpL21r?PG=`BqFhf2`#MRy$Z02;$85XAIwA>gv>+{W3IDRZ=oKNEj3TLl9%bz zzY4;N-1VU^+P>dX>@OgbRqoX1TpVOX8XKp7M4;vmGI`U~igupKA5K{p+tBcaL7{Tu zeFMl6*Xmh0f{2yc`3KzMw)-5gzSes6k`w}yr>GHwsR4ke`hT?m5=|I+gQ?UOz}|A! zKigxm5#4dD!16jqscMpwsHKprO^nd|M}UZGkkx$Aj;Xcy`&QlUrcLgI$4}25t+4+vZN3ci~0^e!r5H$(0(0nSJkQ5yxbC#DHp$387r%c~PR&!E^J3 z^rcY=0Vp%=4Tb$~XnTP_>+U$gU;L<+W#`rmygg~!*G|Z?bQ5LV>U}HdmPaYX;xFZ_ zGwwxzVBr^|4cw2|zDOmD@)H2Tz}kAtp2;$WF`p@7dbraX&s4aL`~`mqERTNZg2umk zau+GEVt;_1JKmoG1=)Jav#(!ZGH(4?Vg090#%}m-_nz&UCxU~IHX2F=o@ZT}q*x8E zF4tUieYGA1ULY^Jbp>Q+KjMAD(mFRNN>wLYYEvA`&d)e<;)H8x6R;|X#2c)2X+t?? zR1znMW{3NjMRh@y-$L8hvh-~s|9XTmIpuZwITEIV%n2=rKAd*}^LC2Qz!q9YnXxIq z1h!0mJ3dl#QeW^60*K>fQxzYS4eCyPWs}>v)Vg%z@(}tdV*S_gc=3&&gOzA#wu1TV znmyysc|wf&20>f=^DlZJwr!33DI=^q>%GJOALdJPs8D1r5>rE}UIK94;6$Yqm_(=|Xy~)Yi{Jq(=Wzv* z$+=liHlfG8kZTvF^c6VTKUxU6@eZyCk~6jIr2em+G4~2+(tjt|xT-LW@lmz%|CA*E zDO3I#kH4J%u{ps1Lviw-kmWzS|AZ`I|0+|klK2T>7!&u zPPNyTv@tQEbT;W9nGC%?*msrn4bpEGwMfPEP0N3UH9A?y-QktBNO}MrW{r;q$<=^b hi=clrH(3v84ipNJtpSyN2o(}QURp(}TEaBse*pS#84mye literal 0 HcmV?d00001 diff --git a/api/core/model_runtime/model_providers/sagemaker/_assets/icon_s_en.png b/api/core/model_runtime/model_providers/sagemaker/_assets/icon_s_en.png new file mode 100644 index 0000000000000000000000000000000000000000..6b88942a5ce27cb57e27ce876cecedf3c5dfce3b GIT binary patch literal 9720 zcmZ{KWmFy8vh~J&Gi0z?1+fFdU=ss3iQ-j+5z?Ax`=Z1wk>fwmS?5(5D0 z6Oo?GVBYS@EM(P{002K~03aj+0C;!{g&YC^o~!`CktqPcp9KKmx_sdk|AI{^>`fIrpVJb*j|;lF%!2-<%zkN`lGEdcr-jQ*SbbI87} zKb*fU6d2;)6~U1ILTiDc{^e``p|x}Jn!XtX7g;@b000T|&w>Ete8d9)AV+OAbv<;I z6!|Tj9huGEIh$KD`#8G%VF3hv_}^SdOAj+LA4dl#cYYrs&|e7tH~&wV1w{52#KT?) zq^qP#CgJR6Nyg30%FGH9MkFI66LfoL#jh?Y{SW=^P6%Y<;o-v1!s6}i&Fsy=?CfUE z!p6tP$HL0a!p_e0hG25{b@DLtVRCY(_`Ase){(Syw{Wv{@vwDvBKuR<%-q@2LkI-= zBlNH1?>;?jt^Td#V`i`LFPQvvmKR|F_J42LGF&;%004_8|Z2 zi|wDb{xkL;_&?L}tJ?ZlI_OH;I$Aop|Cxo2omG(K|IGY1Q^MK7*-gX6%)(OmP3SMm ze}w)U{g;m3f9m`%>Ay4oFa=ruyzT#Z`+sk(zoKt@OBhj*L6 zG<_gW0)3OTmV6IZP6O>(s!%3?;uz6VjN-C3P~?MI)Ypp z=@sSjKN7PwClsk0=ZooG%>_24sT6FaG8~cs5a#HQAkpaN(9%MaZtqt2$)_&&EyqqL z{jyvqRrbgLs=nkPK|IK z@j>&b7kN33i8?>3vdnN_-|K!YnFwYx8~-KZz#OO6VjOM!TV+MSPx-9zY$Z@TGXJ*Q zd0b(?k#qX$oApXyhDc>**W8zHXYo@HnV-DVyus{xd>z0vf;Ah)f(=58mBCYrDDIQj z5pR=#GnJ*QGoddpB24WT*S8Yjnp9nPX4NxMBZ9K_wJCMeb~;@5%@L$1<)<&*?fy)! zwZZeQEm?$IV#mwpSTk#~DA*8HzSM0u5$c&W`OlbYs!UWCJ%mnB{F<~PBU8P4aXa$z z+!eL~dZs2I#AnL>lP`o406Yu^Sw-_6QD-PW3tFHd;&O+P?RINx!6qwY)54ZAMO=@g zstlfLMU9$JJHkS+SNcGRO+nN4bD> zyW7Kfju;T&312~>cw8uyd4z_Sz7SruOa1QAP?UYtBJC3>Zr0;oX-Y`^h_F%6jf-J zIX)|ymUZXoR7ewHH1*gtH!(hINc>$c>p!kO$r6WF%#yD2drE^{N|oHs<{G@t2uk#& zksyc-^M#9cZ&inWLY1o8V26!C9JP&*ZLzBoTiU7~SX$69Z7=LkJ$PN_HH2KW3;Xa* zTCr%iI99-WcUJq~bk?{JpWcJ@hk&B4_@3GVD7ikL%a{Vc&2iui6Xicfq^5mT%B%d^ zV#~h#+*0!^fN3OZ1H4*lpwP@itl1h9(osj5oyjxvVe7W($*Lc1CCu`HG`X~AsWR@DP(bsYxI3| z#muRdGFz-$%>-Z1q)z;+G`n0-IBHO;;YK?%ohCComoIs35s<^N#tm!J0M7bBJ}Tx3 z68Akdj~=I(-FLx zJC;tH&3^jYuO8V&jtU`>T7)MX;7_+Ea!(`Pv4<}#6)xlN#6r*yzb^M7jddUTVxFO0 z-0xBw-evVsf3sCrN0cGvx7rD1V)LoLkXADL@nYSXAHyLD3{-7WgE(AI;mYSZA<^5o z6el9~#ZB7+6_0NS8jro+dWSJ1*9~IC#EFiUag?v_CG3N~p5Z6P=Y9{t(8Gj&`1y%{ z`SOaAuX!rXjVVk?RIfGYywHEQPFxM}6+DQ* zq5VLWQ~4=3?$=S1*|9E3MlZ(sXsEajy+ZKveSPun`BHFZnY)ZbOb@ybGlDqz-i70; zfd4Ed)2`37ptNjS=(~{H59KXC`*PV;2Fz@}U%{EGrWkJD%TiVrW5mP+a}O6hx3kQF+&CwQq3~uQH$%er31v5z&_uXlUM8m;nmOn(-H?tu@Zz%Pf|pIZ3{R^ z7VT;0WtQ@g*S*h$!hSy_8qN839Q}oz^lvkhg?)PrdoeX*N2Be8eoO?arY{AoeL`Ud zEhb+(w?!ipY92+y4ssgDwmNqadPS_xCCDiRS{cGbrw8ZlF;Lh69>@p# ziyqNU#r-6NPX1gmUfs!9c6o;N-+oHqXJV3ejQ*b8N?q_LQD?%7BiniX9QBDd27gbt z=pE1O@9@z=k|IvBBV+XUdKSB`bNCG0a&x}l$|?_b_$C8qi^%RZ5rfxW6`nA^c$^=S z`Ush-P+F?iN-|t{@;KgcL)zYf7X`%TSK4Cp`JBExz9!u%%@xmbRb53+RJ3j-%vRCH z!ybC_k7Twv5RCQ3eumfVYcOkSn}9`?WYRRqA8W+RHIEm?n0*+q^6E;Cpt;2>c}mNf z;Z=$!=1@C7RmeEmK!d}y{y;R+JTDq{8vl-fP1Wc#kx4y-~bAaq{SY$7cA z@8EITtPk{w{u7Qroq`>;P!^U9-g~$@wwMXxv#CbSGMPs@32`%@6apc!O|f;OCBP0k z$;aYQD_n{jTwr5UUMBrH@CNM(Ksk8_ytT@uQjqIxF0; zb9o?w%;{wqbV0EOtz3qTTv``zC%S6?frPPxQXSot#Ou62*W+T@i&)%#w)w*~AN-Pi z*y492OCW99Gcr)?L6PyLnq`cF9=i$h2?ts;5FZ*SO^4S zZofWK4CLH19*7AI&u>Q>2mClm5-kZ$$tq@$W-(;g!MG5_dw<1AJCui3>bX-~KJ&ah z0&4a*Z-!|5o)f$yfaJ(JxSOA1Dj3`1a+5u3^$uevXBLSSUu-4z5FM1v2>Jap?%4USmKURzz z7hksBmi2={!87zzxvzEajn7JZJh?k9T!#=afB6agklE^m)sUR9NnEDhF1} z4hEu5l1%FpG6$-2UX*5%`8$^ldt&fLjUzPQ56QJ(HB}4WeKB4ajLba3WVD*hRI1+E zUGNSf(NQ^JgLxll%8HrM%xg188cML_aJlVX(c*%A;|{STBVQ6Z#Of7H+W!AnRxo* za6~XNIBpN7A`v&YUBtF!+B#viZFWTP(m9abA8K*M{(%syOi~FQl0;hR$FG^%J?QL- zqY(baO@m<4`p#od2yLoq8@q&0z=BQtCrmhx+0}2mi>`;PC(L+4N@&OnJDc(RPZOvX zzQ2St22wN!`Wq57a^;47uOyvHCqoOc{L#|~r76Ju{*9;UrPHH!dDfzyBpQ^<$!*Zs zY1GcDF-KP)K58*@G;FcYSjw~#$eEihwyKnW#hn*p32q4V>q0|o z(BOfC4ZN?4F}{n9_V#WFaEyWdtjBx1BcrQB_JtEt3fNU~NL?j4+lM&)DH zdJWQjY=kDPXj=^piD5gREF7DzF)T9c*&C|E4}IX3q`KAZyqzl!coz%I3{=w+E$qrv z5i0em+;%He%z5Yv>fP?2A>%#3L?Kiv$6 z?v0)XL#PEu~@H5pypH_4Ik%!PQ4=MM>%JmK9yg!!t~su^*)!_Ti+DLoP-Wi*~5 zU-C#4P}8zQ#CI-3Frd6gJTdtb$b9R7IKnq5EcyI8I7K;JLzzF1G7d0(_mS6Wz;?A{ zgF}xF_j{AHeuu{u9fz)odlo5zeuYqdE_OR~)K$SnL0`OmcBq?LH6p$}WymM}uCeH%3zF zyafu%;Ns&#P5jEbTb|R4{a+0_%H$Q*;Z-y-UmjU(1RN8pdRfI==F^}cN9SbX%DrYO zi+j|@(xQh_)@a(YJ|ON@sOes=$jiCMj~~>DO37@LiuC|gV^QEkQDQdl8#1twk!Q$= z#**F}E%um5a*AS#LPap!Sd>-paY3x{=cm}M+_PaO>GX5!?$Z&<_biahGrac1er8H) zo)Ae4TrKZ}6a-`E`acfU2i=Vs54O98Ov%7;^-HJU=9a_Q(J9S~T$*r+>1REIf6Hdv z-9cfR@HxA;UrxW9u9`eHawgIYi_XBLBVtoXIFaikuC@2|Yi0f+gQOgZ5Mo}vm?9Okhwy?vAHK=*auU-3FrBD^1Jf!x{+Su|(XNp%KE-Z9*Z*He9 z{jbCYomYC0L>!sCu=g-d!D}O}P4;jq^i$h$98*scXk;hzLbhlcp22ceTrF^qp)1tq z!0W2AvifCWs%U$US%b`#Zmi(im3N5{4m+fX;>Qz$k^k$pB+;4!;(B;+xlzwP@>-AY zCu=6$1Zz1k%{!?`a(m) zFS#ja0ywU$Ht05~M~A|1O1N5gwWF-$VMFSqJ};NPUn9OH5x(E}FkRj%(W6Slc=_!R z4hlRFP8>3^>w}P^>0;-PB~CKAFRvN>!dhom`qm%{lvD`H$d(UMqPpn$#B&EO(*L*0LCVUbat5Pj~rIig9ff+uL8{@XbWs&T>2U=Z&Gh zeBRP1z~UFS)B0Q?A`1c*Yt%nIU3Yov# zQ4(;9geO`eeSjgxnO`}!*MGMZ@oc6RQPkA`rmI?^WJcXkxFm1@$pdPB({+&siX+tz zY5f*!Zc}gtMTj~G#iToKw`nP@pr^)!%5b;jLc`Z_sWv5lgM~l@m{*abVWaCg$a><# zPiM+t3zm@AyZZ{!=ZY92^hE(WHea>EHg20W0QxbgWNg_~lB6c-o8(i10}>@HzM^?~O#W}+h;4pa z0e3}QEq6xuW?&a~o1P+FdutW8z+!gkGhz6cO2C=`dtknPG$L_K$Lfq-$(&>qH(@xf z`1I+u-RYx~D%eU&G}P`Gi2x}}hcJE~OZ%m#k%Vwz(mpph#V4Iv`&Q_1Hn z)!6L-gI9g0RosY3fK9-ESn~7DO2F%63P;_d*~z@+*N#Iu6_qZ96um+6#4-$HJF^tQ zI(n7icS3~g>6#ivo;*rrDdyJH{Z-5S+KRIFWwpNV#D@dwF0AW2UHb6LYm*skmcE0< zO7x2JIA_MU&ck*4uoX|S=C+kDPK`_V4!P9=AK_d}jz4OT6{KE#d^Bq&*xj2q*2>6e zQM!HRH!Sd|pB_jV>i*QnhgxRX^^sY@P5>s%Nq8i8quskK)tJxD{VT7(|4F(aB0uNI zB};k!ph_9Yye3D3Zcxi9;bdk{`C?m{Zd216(KDW35dxz@Yj~sf@N#f2k|TNo;x|4;+OTh!x}6SW>`5ZTplyg)5Z(C&R7`_XPk@9d zv|j{C&drz+cvCj zN=NF6ELR{@DnW$9F~q%_RU8(z8KvvFLal-8&Mufra^_{&ilrJo8u~gobbCm`p0C)h z(5x=S=`&DBzdAfjqVTkIs~3hDJW{+eOOXsD$Xj`_GTeCn0Acc4%%Yi}5?}q^9DQzY z#=GcF;j@zC^VR@PFV?;}A<29v+V2|ge|q@&X^I|BUN#$_%*v?wJdlC35Xo*}wE-Eua-G#n za@Tl@egRgswoVoHk-uGdZFRxw{pBwx74Jh(Fcf5BnZvghjqgd5-{Wa}eU>o=jNmFc z@~!uJGp7rGO`im=*3Ii(f0+CF8?M&Adv(I?g*KlzihR1_!2R=$`g4)@@`}I`_oZtU zuJe;(Vd7a&%?E7)u@k22k%TB8{7Ll8Yr1=3I=1C6X5sQEGa--OPxNuUc{GYO>&&jw zqYdR(mQ(#xz3OjYt}BGeVdZ234iSJBc=UjZU0$NdYC|k?FDk8pyr>Lu`sgMid{j=> zA!b17ZudYSdM-?$h<4ZT$_w$A7RWRxZzu2%`$puGk?0V4C4P2Wo7_8+UujN?=*l&W z`V?1yCKX9UMp?FBbr*2Qm_b$|Me1m#2`FfWc!Q<_WchmS>b2PKe($9JEMq(hqxp>X7XpV>*&%KY#e5G5`XXsX1-o~c1XkG)a< zkPNBIDC{m7Rfl|mc}B>O`W_5lghRN%Uwugj4xrQC_S<6b-`swRF?4+GN&{x6&=fHa`FwTO%Iof7egAA zh9KO}01bpMjH})|^F6(;Ga#{%fwT6v>8lnIfXHnf8*zMJZe$^a%)#91Se`E58JA^V zr;)s>6rjb2pU3KlyYGX)EjBzLr)>B4)fTWfCIX4r<6 z&UsfgeMg^=Z!v&=;-%nXBf+yzLNAmug;V}WlU$};SYwifOf33*!fEyJTofrWfOUjG z2&fnuCEuNrE!_$KkbwS5Sx6ApMu-?Ac7M3yOR=ML|G~E4Co`6kI)XAXqlUgHqj;s#T64N9S4qF7FuTSQs zy;01kPBl3c+NE@VXT*^1|(PiK$T=9b~O`bIIXUQ*sy|8TJdr0Gj)bobN_{ zmGG}WxDkQg88JRoSf3hGQ^79DRQ0iP8glML8N|V@RbEk}(Opw_l%&TpIP%8#-25nH z1f)vNtZGG&WA79^G$7&%(`(r_Y2*?vnDECCldON-Gi3x7ktTCGsFvdnqyef{FH?(s zUZSg+{Pe%+PM(eng(irFnU{naP%1*&7k=wlUT}a0jChy~Y`26KbYUF0=B(=Q^;V#b z!SmTc20J;cKoZG^&kZt!@ubu(BoQs^j;dPL+c~gX#K30+_F7ddS zBmZz%Il!Sc!MigfT7k5lJt(7J>fI7^*1Jgy_K&=&sr(45pYEtnS#;^Xd$UkCA5rCg z-a-JQL}hdv9I;kGZ|7p3BZh;150HCskQD% zOFka)1C8bFKzCW)yD-%W&I>z5*4wfAN1=%&IG*gr0<~1t+!HBA$GiC`ueb?aUZJF1 zJ$#T)E)&ZFl8zTVo$nUK)c~zzf;u z$ys zrty(&H^c^J@vyyYDCv#sB;+%r>gcFcS87j6{)b2#xNF?{x{K4L6vO7+AQ>vp&xx#A zo!yFT-$ibdfkMxerZJM@Hk!I~7wcM_E~wUkBL)8bw6vAJyV+F9405}iJI7rs^DnN; z$eR&9eIB8BJ@c26J?bFr*2?}x8;JpR@A}5yA2-8qrAbf4_r?`sj(B0cLNzwaJ!9SC ze2!o|C4KE)ol4U}FuhpOK{=yq&yhC?KW#atCbC7QhJRt%Xl3@)7OX!+J{m?eSGbvO4B2? zKD+GR2V02>1R+5PwS_?;-~BAUX*ZOxt;0#f%Pv7-AB?x&Cu-2H81DUJ%Ou4>0^k~k2xINDO)wL*EG z{5cR{_OTfD!M*#^qV4tqZgDO;6;#c9@cmT;SM(A2!X1k17)A{%dsnuLik`9uR>BhK zvgh61uFt|VP5JI*ef0PMNRN^cQk<8~EFrJ?LwXdYJh{-CTJF+|zzo_coWCzb-LqA! zlMf8cPTh!bdYUW85Zvs0kd1P6f)e;;3loVZzD?D4#y+3IAjJDE*625BtSI}hos9=X z(LBjqgdnr;Zr(Cd#3HSa&Kn|^Fe#nSi2*2q!Zz{!>cCppGm*<{47Fb;!3jVz!$b%S zSv@zh#qM{vdxyWSre+qlwqBwz{ezR<|1SJRsyJSQU;JECaEVIu5-cwf3H#B0(e2Ow OOmb4nk~QL{VgCa)i~r65 literal 0 HcmV?d00001 diff --git a/api/core/model_runtime/model_providers/sagemaker/llm/__init__.py b/api/core/model_runtime/model_providers/sagemaker/llm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/sagemaker/llm/llm.py b/api/core/model_runtime/model_providers/sagemaker/llm/llm.py new file mode 100644 index 0000000000..f8e7757a96 --- /dev/null +++ b/api/core/model_runtime/model_providers/sagemaker/llm/llm.py @@ -0,0 +1,238 @@ +import json +import logging +from collections.abc import Generator +from typing import Any, Optional, Union + +import boto3 + +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, +) +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel + +logger = logging.getLogger(__name__) + + +class SageMakerLargeLanguageModel(LargeLanguageModel): + """ + Model class for Cohere large language model. + """ + sagemaker_client: Any = None + + def _invoke(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, + stream: bool = True, user: Optional[str] = None) \ + -> Union[LLMResult, Generator]: + """ + Invoke large language model + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param model_parameters: model parameters + :param tools: tools for tool calling + :param stop: stop words + :param stream: is stream response + :param user: unique user id + :return: full response or stream response chunk generator result + """ + # get model mode + model_mode = self.get_model_mode(model, credentials) + + if not self.sagemaker_client: + access_key = credentials.get('access_key') + secret_key = credentials.get('secret_key') + aws_region = credentials.get('aws_region') + if aws_region: + if access_key and secret_key: + self.sagemaker_client = boto3.client("sagemaker-runtime", + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + region_name=aws_region) + else: + self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) + else: + self.sagemaker_client = boto3.client("sagemaker-runtime") + + + sagemaker_endpoint = credentials.get('sagemaker_endpoint') + response_model = self.sagemaker_client.invoke_endpoint( + EndpointName=sagemaker_endpoint, + Body=json.dumps( + { + "inputs": prompt_messages[0].content, + "parameters": { "stop" : stop}, + "history" : [] + } + ), + ContentType="application/json", + ) + + assistant_text = response_model['Body'].read().decode('utf8') + + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage( + content=assistant_text + ) + + usage = self._calc_response_usage(model, credentials, 0, 0) + + response = LLMResult( + model=model, + prompt_messages=prompt_messages, + message=assistant_prompt_message, + usage=usage + ) + + return response + + def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param tools: tools for tool calling + :return: + """ + # get model mode + model_mode = self.get_model_mode(model) + + try: + return 0 + except Exception as e: + raise self._transform_invoke_error(e) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + # get model mode + model_mode = self.get_model_mode(model) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [ + InvokeConnectionError + ], + InvokeServerUnavailableError: [ + InvokeServerUnavailableError + ], + InvokeRateLimitError: [ + InvokeRateLimitError + ], + InvokeAuthorizationError: [ + InvokeAuthorizationError + ], + InvokeBadRequestError: [ + InvokeBadRequestError, + KeyError, + ValueError + ] + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + """ + used to define customizable model schema + """ + rules = [ + ParameterRule( + name='temperature', + type=ParameterType.FLOAT, + use_template='temperature', + label=I18nObject( + zh_Hans='温度', + en_US='Temperature' + ), + ), + ParameterRule( + name='top_p', + type=ParameterType.FLOAT, + use_template='top_p', + label=I18nObject( + zh_Hans='Top P', + en_US='Top P' + ) + ), + ParameterRule( + name='max_tokens', + type=ParameterType.INT, + use_template='max_tokens', + min=1, + max=credentials.get('context_length', 2048), + default=512, + label=I18nObject( + zh_Hans='最大生成长度', + en_US='Max Tokens' + ) + ) + ] + + completion_type = LLMMode.value_of(credentials["mode"]) + + if completion_type == LLMMode.CHAT: + print(f"completion_type : {LLMMode.CHAT.value}") + + if completion_type == LLMMode.COMPLETION: + print(f"completion_type : {LLMMode.COMPLETION.value}") + + features = [] + + support_function_call = credentials.get('support_function_call', False) + if support_function_call: + features.append(ModelFeature.TOOL_CALL) + + support_vision = credentials.get('support_vision', False) + if support_vision: + features.append(ModelFeature.VISION) + + context_length = credentials.get('context_length', 2048) + + entity = AIModelEntity( + model=model, + label=I18nObject( + en_US=model + ), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.LLM, + features=features, + model_properties={ + ModelPropertyKey.MODE: completion_type, + ModelPropertyKey.CONTEXT_SIZE: context_length + }, + parameter_rules=rules + ) + + return entity diff --git a/api/core/model_runtime/model_providers/sagemaker/rerank/__init__.py b/api/core/model_runtime/model_providers/sagemaker/rerank/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py b/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py new file mode 100644 index 0000000000..0b06f54ef1 --- /dev/null +++ b/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py @@ -0,0 +1,190 @@ +import json +import logging +from typing import Any, Optional + +import boto3 + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.rerank_model import RerankModel + +logger = logging.getLogger(__name__) + +class SageMakerRerankModel(RerankModel): + """ + Model class for Cohere rerank model. + """ + sagemaker_client: Any = None + + def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint:str): + inputs = [query_input]*len(docs) + response_model = self.sagemaker_client.invoke_endpoint( + EndpointName=rerank_endpoint, + Body=json.dumps( + { + "inputs": inputs, + "docs": docs + } + ), + ContentType="application/json", + ) + json_str = response_model['Body'].read().decode('utf8') + json_obj = json.loads(json_str) + scores = json_obj['scores'] + return scores if isinstance(scores, list) else [scores] + + + def _invoke(self, model: str, credentials: dict, + query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, + user: Optional[str] = None) \ + -> RerankResult: + """ + Invoke rerank model + + :param model: model name + :param credentials: model credentials + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n + :param user: unique user id + :return: rerank result + """ + line = 0 + try: + if len(docs) == 0: + return RerankResult( + model=model, + docs=docs + ) + + line = 1 + if not self.sagemaker_client: + access_key = credentials.get('aws_access_key_id') + secret_key = credentials.get('aws_secret_access_key') + aws_region = credentials.get('aws_region') + if aws_region: + if access_key and secret_key: + self.sagemaker_client = boto3.client("sagemaker-runtime", + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + region_name=aws_region) + else: + self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) + else: + self.sagemaker_client = boto3.client("sagemaker-runtime") + + line = 2 + + sagemaker_endpoint = credentials.get('sagemaker_endpoint') + candidate_docs = [] + + scores = self._sagemaker_rerank(query, docs, sagemaker_endpoint) + for idx in range(len(scores)): + candidate_docs.append({"content" : docs[idx], "score": scores[idx]}) + + sorted(candidate_docs, key=lambda x: x['score'], reverse=True) + + line = 3 + rerank_documents = [] + for idx, result in enumerate(candidate_docs): + rerank_document = RerankDocument( + index=idx, + text=result.get('content'), + score=result.get('score', -100.0) + ) + + if score_threshold is not None: + if rerank_document.score >= score_threshold: + rerank_documents.append(rerank_document) + else: + rerank_documents.append(rerank_document) + + return RerankResult( + model=model, + docs=rerank_documents + ) + + except Exception as e: + logger.exception(f'Exception {e}, line : {line}') + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + self._invoke( + model=model, + credentials=credentials, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8 + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [ + InvokeConnectionError + ], + InvokeServerUnavailableError: [ + InvokeServerUnavailableError + ], + InvokeRateLimitError: [ + InvokeRateLimitError + ], + InvokeAuthorizationError: [ + InvokeAuthorizationError + ], + InvokeBadRequestError: [ + InvokeBadRequestError, + KeyError, + ValueError + ] + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + """ + used to define customizable model schema + """ + entity = AIModelEntity( + model=model, + label=I18nObject( + en_US=model + ), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.RERANK, + model_properties={ }, + parameter_rules=[] + ) + + return entity \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/sagemaker/sagemaker.py b/api/core/model_runtime/model_providers/sagemaker/sagemaker.py new file mode 100644 index 0000000000..02d05f406c --- /dev/null +++ b/api/core/model_runtime/model_providers/sagemaker/sagemaker.py @@ -0,0 +1,17 @@ +import logging + +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class SageMakerProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + """ + Validate provider credentials + + if validate failed, raise exception + + :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. + """ + pass diff --git a/api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml b/api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml new file mode 100644 index 0000000000..290cb0edab --- /dev/null +++ b/api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml @@ -0,0 +1,125 @@ +provider: sagemaker +label: + zh_Hans: Sagemaker + en_US: Sagemaker +icon_small: + en_US: icon_s_en.png +icon_large: + en_US: icon_l_en.png +description: + en_US: Customized model on Sagemaker + zh_Hans: Sagemaker上的私有化部署的模型 +background: "#ECE9E3" +help: + title: + en_US: How to deploy customized model on Sagemaker + zh_Hans: 如何在Sagemaker上的私有化部署的模型 + url: + en_US: https://github.com/aws-samples/dify-aws-tool/blob/main/README.md#how-to-deploy-sagemaker-endpoint + zh_Hans: https://github.com/aws-samples/dify-aws-tool/blob/main/README_ZH.md#%E5%A6%82%E4%BD%95%E9%83%A8%E7%BD%B2sagemaker%E6%8E%A8%E7%90%86%E7%AB%AF%E7%82%B9 +supported_model_types: + - llm + - text-embedding + - rerank +configurate_methods: + - customizable-model +model_credential_schema: + model: + label: + en_US: Model Name + zh_Hans: 模型名称 + placeholder: + en_US: Enter your model name + zh_Hans: 输入模型名称 + credential_form_schemas: + - variable: mode + show_on: + - variable: __model_type + value: llm + label: + en_US: Completion mode + type: select + required: false + default: chat + placeholder: + zh_Hans: 选择对话类型 + en_US: Select completion mode + options: + - value: completion + label: + en_US: Completion + zh_Hans: 补全 + - value: chat + label: + en_US: Chat + zh_Hans: 对话 + - variable: sagemaker_endpoint + label: + en_US: sagemaker endpoint + type: text-input + required: true + placeholder: + zh_Hans: 请输出你的Sagemaker推理端点 + en_US: Enter your Sagemaker Inference endpoint + - variable: aws_access_key_id + required: false + label: + en_US: Access Key (If not provided, credentials are obtained from the running environment.) + zh_Hans: Access Key (如果未提供,凭证将从运行环境中获取。) + type: secret-input + placeholder: + en_US: Enter your Access Key + zh_Hans: 在此输入您的 Access Key + - variable: aws_secret_access_key + required: false + label: + en_US: Secret Access Key + zh_Hans: Secret Access Key + type: secret-input + placeholder: + en_US: Enter your Secret Access Key + zh_Hans: 在此输入您的 Secret Access Key + - variable: aws_region + required: false + label: + en_US: AWS Region + zh_Hans: AWS 地区 + type: select + default: us-east-1 + options: + - value: us-east-1 + label: + en_US: US East (N. Virginia) + zh_Hans: 美国东部 (弗吉尼亚北部) + - value: us-west-2 + label: + en_US: US West (Oregon) + zh_Hans: 美国西部 (俄勒冈州) + - value: ap-southeast-1 + label: + en_US: Asia Pacific (Singapore) + zh_Hans: 亚太地区 (新加坡) + - value: ap-northeast-1 + label: + en_US: Asia Pacific (Tokyo) + zh_Hans: 亚太地区 (东京) + - value: eu-central-1 + label: + en_US: Europe (Frankfurt) + zh_Hans: 欧洲 (法兰克福) + - value: us-gov-west-1 + label: + en_US: AWS GovCloud (US-West) + zh_Hans: AWS GovCloud (US-West) + - value: ap-southeast-2 + label: + en_US: Asia Pacific (Sydney) + zh_Hans: 亚太地区 (悉尼) + - value: cn-north-1 + label: + en_US: AWS Beijing (cn-north-1) + zh_Hans: 中国北京 (cn-north-1) + - value: cn-northwest-1 + label: + en_US: AWS Ningxia (cn-northwest-1) + zh_Hans: 中国宁夏 (cn-northwest-1) diff --git a/api/core/model_runtime/model_providers/sagemaker/text_embedding/__init__.py b/api/core/model_runtime/model_providers/sagemaker/text_embedding/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py new file mode 100644 index 0000000000..4b2858b1a2 --- /dev/null +++ b/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py @@ -0,0 +1,214 @@ +import itertools +import json +import logging +import time +from typing import Any, Optional + +import boto3 + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType +from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel + +BATCH_SIZE = 20 +CONTEXT_SIZE=8192 + +logger = logging.getLogger(__name__) + +def batch_generator(generator, batch_size): + while True: + batch = list(itertools.islice(generator, batch_size)) + if not batch: + break + yield batch + +class SageMakerEmbeddingModel(TextEmbeddingModel): + """ + Model class for Cohere text embedding model. + """ + sagemaker_client: Any = None + + def _sagemaker_embedding(self, sm_client, endpoint_name, content_list:list[str]): + response_model = sm_client.invoke_endpoint( + EndpointName=endpoint_name, + Body=json.dumps( + { + "inputs": content_list, + "parameters": {}, + "is_query" : False, + "instruction" : '' + } + ), + ContentType="application/json", + ) + json_str = response_model['Body'].read().decode('utf8') + json_obj = json.loads(json_str) + embeddings = json_obj['embeddings'] + return embeddings + + def _invoke(self, model: str, credentials: dict, + texts: list[str], user: Optional[str] = None) \ + -> TextEmbeddingResult: + """ + Invoke text embedding model + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :param user: unique user id + :return: embeddings result + """ + # get model properties + try: + line = 1 + if not self.sagemaker_client: + access_key = credentials.get('aws_access_key_id') + secret_key = credentials.get('aws_secret_access_key') + aws_region = credentials.get('aws_region') + if aws_region: + if access_key and secret_key: + self.sagemaker_client = boto3.client("sagemaker-runtime", + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + region_name=aws_region) + else: + self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) + else: + self.sagemaker_client = boto3.client("sagemaker-runtime") + + line = 2 + sagemaker_endpoint = credentials.get('sagemaker_endpoint') + + line = 3 + truncated_texts = [ item[:CONTEXT_SIZE] for item in texts ] + + batches = batch_generator((text for text in truncated_texts), batch_size=BATCH_SIZE) + all_embeddings = [] + + line = 4 + for batch in batches: + embeddings = self._sagemaker_embedding(self.sagemaker_client, sagemaker_endpoint, batch) + all_embeddings.extend(embeddings) + + line = 5 + # calc usage + usage = self._calc_response_usage( + model=model, + credentials=credentials, + tokens=0 # It's not SAAS API, usage is meaningless + ) + line = 6 + + return TextEmbeddingResult( + embeddings=all_embeddings, + usage=usage, + model=model + ) + + except Exception as e: + logger.exception(f'Exception {e}, line : {line}') + + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :return: + """ + return 0 + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + print("validate_credentials ok....") + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: + """ + Calculate response usage + + :param model: model name + :param credentials: model credentials + :param tokens: input tokens + :return: usage + """ + # get input price info + input_price_info = self.get_price( + model=model, + credentials=credentials, + price_type=PriceType.INPUT, + tokens=tokens + ) + + # transform usage + usage = EmbeddingUsage( + tokens=tokens, + total_tokens=tokens, + unit_price=input_price_info.unit_price, + price_unit=input_price_info.unit, + total_price=input_price_info.total_amount, + currency=input_price_info.currency, + latency=time.perf_counter() - self.started_at + ) + + return usage + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + return { + InvokeConnectionError: [ + InvokeConnectionError + ], + InvokeServerUnavailableError: [ + InvokeServerUnavailableError + ], + InvokeRateLimitError: [ + InvokeRateLimitError + ], + InvokeAuthorizationError: [ + InvokeAuthorizationError + ], + InvokeBadRequestError: [ + KeyError + ] + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + """ + used to define customizable model schema + """ + + entity = AIModelEntity( + model=model, + label=I18nObject( + en_US=model + ), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.TEXT_EMBEDDING, + model_properties={ + ModelPropertyKey.CONTEXT_SIZE: CONTEXT_SIZE, + ModelPropertyKey.MAX_CHUNKS: BATCH_SIZE, + }, + parameter_rules=[] + ) + + return entity diff --git a/api/tests/integration_tests/model_runtime/sagemaker/__init__.py b/api/tests/integration_tests/model_runtime/sagemaker/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/model_runtime/sagemaker/test_provider.py b/api/tests/integration_tests/model_runtime/sagemaker/test_provider.py new file mode 100644 index 0000000000..639227e745 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/sagemaker/test_provider.py @@ -0,0 +1,19 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.sagemaker.sagemaker import SageMakerProvider + + +def test_validate_provider_credentials(): + provider = SageMakerProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials( + credentials={} + ) + + provider.validate_provider_credentials( + credentials={} + ) diff --git a/api/tests/integration_tests/model_runtime/sagemaker/test_rerank.py b/api/tests/integration_tests/model_runtime/sagemaker/test_rerank.py new file mode 100644 index 0000000000..c67849dd79 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/sagemaker/test_rerank.py @@ -0,0 +1,55 @@ +import os + +import pytest + +from core.model_runtime.entities.rerank_entities import RerankResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.sagemaker.rerank.rerank import SageMakerRerankModel + + +def test_validate_credentials(): + model = SageMakerRerankModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model='bge-m3-rerank-v2', + credentials={ + "aws_region": os.getenv("AWS_REGION"), + "aws_access_key": os.getenv("AWS_ACCESS_KEY"), + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY") + }, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8 + ) + + +def test_invoke_model(): + model = SageMakerRerankModel() + + result = model.invoke( + model='bge-m3-rerank-v2', + credentials={ + "aws_region": os.getenv("AWS_REGION"), + "aws_access_key": os.getenv("AWS_ACCESS_KEY"), + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY") + }, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8 + ) + + assert isinstance(result, RerankResult) + assert len(result.docs) == 1 + assert result.docs[0].index == 1 + assert result.docs[0].score >= 0.8 diff --git a/api/tests/integration_tests/model_runtime/sagemaker/test_text_embedding.py b/api/tests/integration_tests/model_runtime/sagemaker/test_text_embedding.py new file mode 100644 index 0000000000..e817e8f04a --- /dev/null +++ b/api/tests/integration_tests/model_runtime/sagemaker/test_text_embedding.py @@ -0,0 +1,55 @@ +import os + +import pytest + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.sagemaker.text_embedding.text_embedding import SageMakerEmbeddingModel + + +def test_validate_credentials(): + model = SageMakerEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model='bge-m3', + credentials={ + } + ) + + model.validate_credentials( + model='bge-m3-embedding', + credentials={ + } + ) + + +def test_invoke_model(): + model = SageMakerEmbeddingModel() + + result = model.invoke( + model='bge-m3-embedding', + credentials={ + }, + texts=[ + "hello", + "world" + ], + user="abc-123" + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 2 + +def test_get_num_tokens(): + model = SageMakerEmbeddingModel() + + num_tokens = model.get_num_tokens( + model='bge-m3-embedding', + credentials={ + }, + texts=[ + ] + ) + + assert num_tokens == 0