From 3b5b548af39186e8a27347e5725fd699b1f857bb Mon Sep 17 00:00:00 2001 From: forrestlinfeng Date: Thu, 18 Jul 2024 07:47:18 +0800 Subject: [PATCH] Add Stepfun LLM Support (#6346) --- .../model_providers/stepfun/__init__.py | 0 .../stepfun/_assets/icon_l_en.png | Bin 0 -> 9176 bytes .../stepfun/_assets/icon_s_en.png | Bin 0 -> 1991 bytes .../model_providers/stepfun/llm/__init__.py | 0 .../stepfun/llm/_position.yaml | 6 + .../model_providers/stepfun/llm/llm.py | 328 ++++++++++++++++++ .../stepfun/llm/step-1-128k.yaml | 25 ++ .../stepfun/llm/step-1-256k.yaml | 25 ++ .../stepfun/llm/step-1-32k.yaml | 28 ++ .../stepfun/llm/step-1-8k.yaml | 28 ++ .../stepfun/llm/step-1v-32k.yaml | 25 ++ .../stepfun/llm/step-1v-8k.yaml | 25 ++ .../model_providers/stepfun/stepfun.py | 30 ++ .../model_providers/stepfun/stepfun.yaml | 81 +++++ .../model_runtime/stepfun/__init__.py | 0 .../model_runtime/stepfun/test_llm.py | 176 ++++++++++ 16 files changed, 777 insertions(+) create mode 100644 api/core/model_runtime/model_providers/stepfun/__init__.py create mode 100644 api/core/model_runtime/model_providers/stepfun/_assets/icon_l_en.png create mode 100644 api/core/model_runtime/model_providers/stepfun/_assets/icon_s_en.png create mode 100644 api/core/model_runtime/model_providers/stepfun/llm/__init__.py create mode 100644 api/core/model_runtime/model_providers/stepfun/llm/_position.yaml create mode 100644 api/core/model_runtime/model_providers/stepfun/llm/llm.py create mode 100644 api/core/model_runtime/model_providers/stepfun/llm/step-1-128k.yaml create mode 100644 api/core/model_runtime/model_providers/stepfun/llm/step-1-256k.yaml create mode 100644 api/core/model_runtime/model_providers/stepfun/llm/step-1-32k.yaml create mode 100644 api/core/model_runtime/model_providers/stepfun/llm/step-1-8k.yaml create mode 100644 api/core/model_runtime/model_providers/stepfun/llm/step-1v-32k.yaml create mode 100644 api/core/model_runtime/model_providers/stepfun/llm/step-1v-8k.yaml create mode 100644 api/core/model_runtime/model_providers/stepfun/stepfun.py create mode 100644 api/core/model_runtime/model_providers/stepfun/stepfun.yaml create mode 100644 api/tests/integration_tests/model_runtime/stepfun/__init__.py create mode 100644 api/tests/integration_tests/model_runtime/stepfun/test_llm.py diff --git a/api/core/model_runtime/model_providers/stepfun/__init__.py b/api/core/model_runtime/model_providers/stepfun/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/stepfun/_assets/icon_l_en.png b/api/core/model_runtime/model_providers/stepfun/_assets/icon_l_en.png new file mode 100644 index 0000000000000000000000000000000000000000..c118ea09bd8a84d7d43f03f6f5bc15c9a0c67e95 GIT binary patch literal 9176 zcmd^l`9IX}`|k+Zq9T-uWG%8YBC;=oLAGd$EW=E^?7Ol?A<7y}H1_P&7=)CvFWHwF z`@T)G@8*p6`}6*M&i8T7=llcb^~20-xvuAZEzkR!`+i*$X{fKoNY6nJ0)ZHz+UiCi z5G6o9Z=t0mf9kRB-zI-u@zgN)GUY!KW~z}f*}gs`^rb8AH?khN~0 zOw2vaVS4hmE`X@@sf?&E;7Zm8ffST{U9D{$5uV&O2nUq2BF|b)Jr6g^PLb!n3``v6 zs)j(KwEf)?#{T*yw*HQ`a&|mQcexdO<;eg5!qb}D7jSa+koQ&O`3qN`e13XcjEDO# z5l=@&p1+MUhZ%CKxws>^);^|=x%2(Z=|mIui?m7iabb9Pgi*{F&`fvQ6DK$7k39S2{}1AF>y&T zNl6j1gouZqv!}JMh_eUpKM3jw4_kMXt0&6Enfnyc+Q!ApQ;~;k?B8AhuK$L0_V`z% z$T1W1wRRPg5EVc5>K{QE?Efta0RNWu@H9gF$KL-du!o7CD?-c&;o;)tZcEOcJ@4sI zuJUT`2y0IlcM}&Er+)%zh;;FE@j$w`a!X2!igVw9S=*wVPw(9PO9cj#hdO(BT07e! zpz4Y|WDrpl%1&NFT|!MlT24wzO&lU2p`ocFqb?&SBP)4ZUFP;}Nfp(9bk$vKy#R!> z=Rdl3|IwBHuezs%09?r?>>#9K=jah6si&m3MST`8M*ggC=!7kT0nz||I=g!O=;%nmo5bZx zO8PXPUEUgyx!Bn^hHIK_>z}P@nXpdyy|l8r_Gd%ZX9-)NJsYNluB7he@je`&(Moh30lkcz(apS2>*F8;BJB|IM?F`-C0 z0Rw8G8$Y_H3%|{)gdYieZe@P%JvcZVn_9AqpQ&r?9~@t8>Y0T`Y~#o0TUK{hR)164 zkfvvsjG|Tp(EUC56%psTn@)?NS)+qvi@SUKJG*-d&(_b{k!Y<+VL8*wtLuJfLTq+B zr|lf*0ZH0#Yh!a;KXSt&cC~kCg2Q2p--86QBbi3d#pI8=CH4hk#(WYxKnA4IiMi={ z;*+GQ#`dA{+21xvekA$CGYU*8-3p|*Z znrE~j4Nnq-KMW*&UJ1?r9g#mbIkOa-*H8aw8Jec43vy>WK0&DL8l8DP>;MZy>K~@RHs3^uijYC8_k=Sl95PWzrrkiuAFfYh-ll z%tI2N`$2W{RC?v?2kgkm#=&{SA;^*hf|EdxNT7!#kQGV)S)aJeOk{e~4dmFb;RQzZ6x@ul=Ag()$4A-h!UWUBW%W(J=v)y{F7bfh*Lsmk;|_D52U?xyk_s3 za9K?0rot3@b-{e*JmUqDz5`p@I?FrwyLj$}+d@V=)B&1|p8rmCg$vr4Lz3cXRSD(Nh@uSrJ?;P%%|wmbha&UPTJfn$UqL zXbvPga~^+u^o!~|kS^@d_&Ddm)vU?DJiurK_pDU$EwN!>1;-WVWH5Snh~&ti(%a$G zh|2t=Z55>`rHs2bu6AJTLgMTm%q=ljFuxK<*EdFU<(#rscTT{2@8$@f%BwInYY}eXP78>6R`WFCH{#y0D@Q{s1v+Oo)v8J)Qn4p3Yc3 zwOMBj3!2__s;csxJ<-s8jAqn-3tl``ql_alzvFwg7a-_JhXuW8h1&kf)K5 zZ~Q1+oFAr`SY351>FWzxQp7%;@uE?=aPG(CW!L54kKpls&c$4vQ`APm^>O2LYj5a`J*?}vgot@1f-_K)n zs?6znww><981fHCWqxC9og1gU4l2C z{()uVWtLZu_te;EM1e$rj!cUEMBKwv_&~-)-Pgcl;gw%mRTm#oCPM_}=*e^7DHMGC zZZ?gHQT=-XbfGsYe%GdOEP0LcE$;i@{trjT8!r*>@w5!UKD@qu@8JT?EkK{}@YK)i z05qQeP{^Nm{-krB;cd4->FPjmPVp`3Q(P^u=?$+B1PgP4N}Qj5W>u;)6&G-+@l!NU z?yp`|aO<+D(H+3DC8p?`fzdmNW|*9*rPAI_a_I-`9b}#99@wMjREY zRng1M^iwWtFhibDdbb5@^Dkr$kaMkmJlekICP5-{S0p@YH^z@@UT?q4 z3!I3}Km{-(rNn_W+H7LNAJb2UnVWGUI?eWmkS2dsG}fa&4?Cci{PL za?FW|TqihC;EvL7>}Jv9w$`FgUN6&*)-i0Fl?05w#a8oG$}>khuYXRGvul2pa>Qc* zqbFi!RftE6|;@zK2^jhfp z-#UX2eS?$!*_XoVF{u>Ah;Qp{ytb7~y}3Td%7HaSZ@>LPho!y_yr#b(*jm@HrB5fR!WqhvrdEYowPLs>f)XB_)-?PBnGdAQ z4r`TVYJppjP!X$xdpoKfVUFRfYfh4IgnVq8pD4K3X3`D2X4*I(tMmt>GuhNDD~`M4PuX;k7etigjLN5 zw5xhKs4AB_ix63}yKF))TW?cf_ah+Z;kdjIve~)_Vl4gXrU0?<1Bzv z@`qAUao%lXSAongysK+dws$M5s5fHKR?Ak%FRLz;_;VTrr$82UZ|CEq`=Mc3!=IPA z{9_Cw(Sj0{M=%12oCXhF_$_AbXyPzE8Ut?Czkk^8;MhO46X2P8+&DEYZ^cLs0Rhx|ABL-VCT^@oUPXr{y*);1dO*qB z+B*#sQe>oqmL%6x_+KjBOQTZynKaGcD_$IZfa&zXZp)8O(8wQ%DWlKeHF-H>LS+~~YU z7f-%XY%q>$AU5(fQC$OAbmOBv&Bl3OFk$x}Dx;-=ma-W6UU4V1i4bx`Zwh~{&L9|k zoHvwo7{?`CPr9N4brjH3-Eh-J*r!;ewu&b1J?*y1Qo40#P=J&1Xvh8<>iBnAueM{ms&bvo%@?be|M(f* zvZFFv5oaB3KNPjF9Gu|6jyx%DrJwD}=4T>j)>w^jhZg9kw8XR&jA0Yf#+MNT=?)AB zYjf8IqmjJNNBVm&?2P`Dk1YH|r-coAs7@Av1RffIvyeRel(n@C=&eN*tbynP-L z+lJ#KBlhWozra0>JVz`-;x6?zff>Gb>--tcDKUPq(G`U4yS}Bi)a;+(o={Igu7uA} z49#67k(Ld*EaORiNf6PUiRp^jXCLq2Yp3S)Pfk~kTR9^8x~skujXE5ci?OIn-Sc#S zY3tW+A2UqnF}*VNAcNE#CGcGXKB0K=ZuhrqXMo=*LVO6sdCL_8y9o%)>(z;hM}@D2 z8x}Ay{>neBfZ!j?=DQ(h>{H8|fIHo}uW++>_!)!vmF^b+Ez#Y_Yd2u#y<6D|7?Ez& zOjmCdhb*SkFVZ*tB(M#Cd%Ml>pE0Pe5 zdy^fbcK#qI7f&uBTxC-S6@Xx>vdMRjDe2YdF}xHFnPSCV5ZU8+v{}wnS*r;Z;6gd-*$YA ze^Q@%vrv}w3P2zJsdxA?AUn~HKwUBl@4R=qP)UJ!K8U}1#T;85b&u3yy31zUs z{2ch$sMmO)!f+_6#9^?ONO;4iqxmu(+xTq5uCB>@_8LTeEqyop>z9PwM4Zya824~- z;Gw8d){D6-$&ibiL96p6?9|V0dcSfQT-83y1ay9BkUg_T_276e0wC*6W!FCkBZ{B41UlW{V6lBe-Z z!(r9!CUKH4ku}#PTWrqz+H}ZS=>JXoq10rtm~S~ic?)+m2dsfheZiiJwRngh@H#Sc^rfOq7+ znt#Q-&|senCp4=k**PncJw}D_y4LJQ5j#@t7p~Tj`4)EjjXMkW{osS{AR(1VEx-wYKJXNs$yi_k{hioV@d`_bUdX zp7M>0i~-$3;q-@5R=kV9qgZgu(Yu@HR0&J0z?y0-C0^kjKlA~@1kc>m%-CS}HubJv zl7uZbm&Kw{A0{c;T@JXZHu3mbNqC5u=0t zGM3Hi&cacIYa$A^D89%E=K>fElCF<27L0P;z>b_OY|5HG@u)q)J}RDa5NZ(80uyAF zkt^<0`K(C-in;s_ygeExP2pDo+2rZX+Sz4ga<&u}D64Y-Nu^4??8$={wep&^EZkUA z3=OA?d?c4dLrMSRRy12@k49)1PTxZH)=Rn&RX96uH`|MIuG-(@`QNq&O_Au&;itQE zV-&hrdHHoxner=V^`AfXe)_V<41NnQM|vGZg-`5@h{)j)n;zQSbMPpur#*-7&h1Ra z27_Y)-K~rLb-z9zZIHZ)Me2&NBBmRz1w-brrWvst7KL_N%0FKPV*sO@%Y_vgC!d|L z+MSA~aCgceSjf_kaAp9-%NF^?m&Io$Look6F1F=7#Q27`(xQA$Z4>x8?jn-n!W3hD zCvrwpt3bwXLn6>Qk)-jf?K~AevqL)o zw&r{!UBfJLr(!*LTvgz~s3teP0RF(#?pQ5MIB0PYd7u>Gh~Z3l{f#|??0TCzt4hzF^l}AHv`f;UrdEo&FxeIN35nF zar~a!qw5!TJINFgQp1W$;SzYP&f>FC1a!=O=JH&kxn=R+v8!RWvy!?v0m4aUngoKBGv=n=ygqI1FbVMHf8tTTCs{*6l0%(SFs{5?WiO4m`eIH?p=!0$4P?7%neKj+Ui=+)xkvT0Jc0o0WnP6Ls{9V7Isot#;$ z_7`HH`tK8}9SpnKlCCc+s={-%^H5ntFPD_-7f=NDRPF%@gJ{F6-o0Vj2ifg@t)U+Ba=!OMAJTtnK^+ z&gKDU=C_ocFFefwn?bf_+9Z9uN42x?pv%^u%29-9*k_|;zXq{68+1{fTEA621STJyZ?oUh)UNSMq_S#L({>UPVwaZAFE z1VCUw5SHIXpVGWY_vTZIx!8X7tIQMw?tS&xikbPQhbxBZ&xDB)+<2u?mZ z`{R5#YL!X$&u~lp`ggX;O<>jT{3_V^-LhE^a$GXY*tFG*YmFD*0I7O(WMhAy0uVx~ zsh;%a#>PW?#NE7>PI5Y-JoR79sRr?imsLH>x3(abcYoRtcgEa#x{9j;wTouY7fkm< zA-B=7*lm0k=Dyz1Wu)?GU>A!4Ot#Mc~hIgE(mF!&r4;&aVALh$C_Dk$X>Lk2a0Z zdd6|8usiG4b<400xLP{S3K8v6A{R=Uzt`ge0@_fHLf0ZH7$Xy4kJ0F_Z`@Zb+7BR-Mx86-@R{SOyut0f=Fl^R(0o~Nt zT3W*m11rXFvI1Rgn^Y<_oGP;$Y$+^%uFY!8GOQ2%!IkxzFt=eFuZgg%Xh_C-ZHK44 zcIo07q9B8`yYbukc*zpeRJ9ML{8J%$@qNWNbyp6&W3Q~OB+l0aOuE&|6k#>OMS;3Z zjlx@g_4$n-LcuHUVi=eNfU0q>`QWJb4x(dWU(uU<6nelv{Uy1CxJD&x$+X*-H#kzI z%b2w1{Ib`qqW6}(qN<+D&8;c(`;&vm;-`yeo#$tF?S0(FGl^S+VKuUr-H(svIqCLu zd`+(>nk&sHXwNdtBK~|}hGkDJyPW305%$(-5Wv91@S1`cDdx>u zEqE+>$NkRBoTtX%V|~!O3lM^RPVN%ljXWJ=Ut~|RlxZz{wbtVo1H;|L^5qQOx#xj( z-MDe;n%GS)pdwZ1Anp=0z5MBMPen@Zg{7U5P#HN7u{oTau&a+li4Ws-d;ycZI$)ru zblq34$LQ%IgcI5K0(VhcCi~JgiQ;gGV}vSR8`pJ7qI&HeKLDi`H5YnTmd>1=_i#)# z0;e{TIYy7?`+{e&b$57zfw2JhyDyJFpQitaqQ->X@wj&hR(8mHc8x=kts0R;>qIkP zq2)^>=r+p1f%TG!q}ltGyco6HI3hg;!Ykwft(sfvxp)06bBaqxzr|Q&^={$txo>Yu);F(oA?x* zF7mvhN_BII0A7c$LqkG(>hkTa#>8D>rfvIv=sk8hbFa^dY&S;>^z2_A)n-3~kHQO? zT+|!f6%~s1^&3eGSAZ_7`q84zOh#~>f)?J)!_C2YgBR0jj4yNjF?+A82HP058fcHi zJ$^Rz^0fm#$B32g(csc6-|d}?)vuAi_Mb=gh(4{X@N<;H8lTcEXLoKg4!V0Uptt>YX&)p4xSU;X*ZZSu#y`e1 zaXErJSuAVpg;J{0n5xCkI(~zcg0OdxwjRZix_q0g^%T(`ArK;#Som zsJB_x;CEGbZ-^%oxGeIwm5q(s0D+GLNK|dMXGcw}GW;hy`x$_%BYt$|_+??taju_i ze)7)CfxfgV7h^bdXTGZ2k&vP{$8~#ftTu;or-e0zf8B}tvDL~dczcd+rZ<@#x;*en zX3-|hfqJuyf!!Yfqb0P$GxRGP?RmV-WO#g`H=f0l3lB*aPfh$-MbUXfltYHeBe|* z)W8=>Rj>{L)~G-L)vxy0J*%8nCeG^8?<`;2zdlohpV%udgjF75oQBVI$zn%nd9{BU zNcLFkKRM8qw8A|Nv8s2kK8aT|K~!_Z49Te8N)}s73`k)2^n0<6T5&5fK*!!(yIS(@ z2X*j9gwyUoLa*Sfjd}svAxv?_2ZJYl3lmdAmHk$Mjv-a695MI!@kCS(u}00sS7m|_ zV*sX|F!@hEi`*&utGg8Pci)V>`1yV^iZZ?vD3M!(g z9aA4|@c~t-Xr+LIFs*jVIHfk!_j77|j*lXWU{tVH@N|PfTL=5McV_qAd*=JT^WA&S z?3Se@O%0-l(h&p+QY9!f)a>#;0lw6Gb7^J`HH{^e>0~ONP3CI}43Qad9fqn*+8j)S zX$_7At=MD)@nsp)(#dpnqC}6II9e};V>ekSHiAr!v0JqI`51}nupFaV%IrGW%tVa_ zDRX9|ny0qNv0P(90fD6!B&F#K=Ig}%(P9oiny(i`XKQsIq5_^kz~za!{3y0SAc=^S z2>Iy1U{VPQL$*Yth#yEt?WD|HlC(&;-2D7}PJRRjCvv!au~^LI3Ah3Qn?kUy4l}8> zv(479K?Mb7)e}YwX~fN_S5d3OZKRY*nSFi=lVwQOY#q2M>dv@!t%b|y@VuuQ1gh2l z4KaUo`iCvX$me`<+w7%$P0 zaS`!6xrove4)5hT*YHv=BIQfHa49ofZ*Bik>%kD!K;Gq0^HQ6k_Q%Xr&l6O0?|oDr zgdjc(QWDdY$ZUAp3-9ECa1r1;K)n{)J@DNrKye712e&%GF9SZX!sT{2aSfzP;l@4C zZw8+^u>UyJAA|4(;GYFC%VFATV3DwTKa|$PdjyD}-QD4CPdyVmrr7hS zo|WL&xasrLGj8v(`_Ng}*(-_N*OvNETt!=5%)FC@w_hg@A4+sz5IWAZJ!u**XvGEP zmks*fo`hEzXlya!LsNtrgqqGKyQ^^LC8fu?flkl^7{}Nw*L08b2>rnf)Q~$lrG0Lk zEHnA?#80!QhRbTl0-hLBb!J3}FRiMl^UK(f^=NV;8ZkY#!q#DXyj1V(Y<+#xt|FE! z%OUi9Pj~&;w{}|R35QKw`AsDwcyHL2mb%7b#K>~h&P4ZSF5TKguNvu!ac;}|ssdZH zP?kBS``)sQ6RzzNxf?3}s0gk8S$?27xaEfD`e+ZM!6q4SJ|f`Xi&1{FNCpw+-oe(l z%3KdCkGTuFf0Mbo3vjO6FOcR#dyy~@sgaeQ!@Ae3%u^oe3kjZ375BAftjD?Bm)4TO z$U0oz(D3$aGWuug=HsWM$1S@_0)yYh2qzU0;Q=F7NuA=l*z=c0&T?$P8|QZUKaM-E zJlgRlx@7!wZ%qFOCoek_I>vf9$z6DUea Union[LLMResult, Generator]: + self._add_custom_parameters(credentials) + self._add_function_call(model, credentials) + user = user[:32] if user else None + return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + + def validate_credentials(self, model: str, credentials: dict) -> None: + self._add_custom_parameters(credentials) + super().validate_credentials(model, credentials) + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + return AIModelEntity( + model=model, + label=I18nObject(en_US=model, zh_Hans=model), + model_type=ModelType.LLM, + features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL] + if credentials.get('function_calling_type') == 'tool_call' + else [], + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 8000)), + ModelPropertyKey.MODE: LLMMode.CHAT.value, + }, + parameter_rules=[ + ParameterRule( + name='temperature', + use_template='temperature', + label=I18nObject(en_US='Temperature', zh_Hans='温度'), + type=ParameterType.FLOAT, + ), + ParameterRule( + name='max_tokens', + use_template='max_tokens', + default=512, + min=1, + max=int(credentials.get('max_tokens', 1024)), + label=I18nObject(en_US='Max Tokens', zh_Hans='最大标记'), + type=ParameterType.INT, + ), + ParameterRule( + name='top_p', + use_template='top_p', + label=I18nObject(en_US='Top P', zh_Hans='Top P'), + type=ParameterType.FLOAT, + ), + ] + ) + + def _add_custom_parameters(self, credentials: dict) -> None: + credentials['mode'] = 'chat' + credentials['endpoint_url'] = 'https://api.stepfun.com/v1' + + def _add_function_call(self, model: str, credentials: dict) -> None: + model_schema = self.get_model_schema(model, credentials) + if model_schema and { + ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL + }.intersection(model_schema.features or []): + credentials['function_calling_type'] = 'tool_call' + + def _convert_prompt_message_to_dict(self, message: PromptMessage,credentials: Optional[dict] = None) -> dict: + """ + Convert PromptMessage to dict for OpenAI API format + """ + if isinstance(message, UserPromptMessage): + message = cast(UserPromptMessage, message) + if isinstance(message.content, str): + message_dict = {"role": "user", "content": message.content} + else: + sub_messages = [] + for message_content in message.content: + if message_content.type == PromptMessageContentType.TEXT: + message_content = cast(PromptMessageContent, message_content) + sub_message_dict = { + "type": "text", + "text": message_content.data + } + sub_messages.append(sub_message_dict) + elif message_content.type == PromptMessageContentType.IMAGE: + message_content = cast(ImagePromptMessageContent, message_content) + sub_message_dict = { + "type": "image_url", + "image_url": { + "url": message_content.data, + } + } + sub_messages.append(sub_message_dict) + message_dict = {"role": "user", "content": sub_messages} + elif isinstance(message, AssistantPromptMessage): + message = cast(AssistantPromptMessage, message) + message_dict = {"role": "assistant", "content": message.content} + if message.tool_calls: + message_dict["tool_calls"] = [] + for function_call in message.tool_calls: + message_dict["tool_calls"].append({ + "id": function_call.id, + "type": function_call.type, + "function": { + "name": function_call.function.name, + "arguments": function_call.function.arguments + } + }) + elif isinstance(message, ToolPromptMessage): + message = cast(ToolPromptMessage, message) + message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id} + elif isinstance(message, SystemPromptMessage): + message = cast(SystemPromptMessage, message) + message_dict = {"role": "system", "content": message.content} + else: + raise ValueError(f"Got unknown type {message}") + + if message.name: + message_dict["name"] = message.name + + return message_dict + + def _extract_response_tool_calls(self, response_tool_calls: list[dict]) -> list[AssistantPromptMessage.ToolCall]: + """ + Extract tool calls from response + + :param response_tool_calls: response tool calls + :return: list of tool calls + """ + tool_calls = [] + if response_tool_calls: + for response_tool_call in response_tool_calls: + function = AssistantPromptMessage.ToolCall.ToolCallFunction( + name=response_tool_call["function"]["name"] if response_tool_call.get("function", {}).get("name") else "", + arguments=response_tool_call["function"]["arguments"] if response_tool_call.get("function", {}).get("arguments") else "" + ) + + tool_call = AssistantPromptMessage.ToolCall( + id=response_tool_call["id"] if response_tool_call.get("id") else "", + type=response_tool_call["type"] if response_tool_call.get("type") else "", + function=function + ) + tool_calls.append(tool_call) + + return tool_calls + + def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response, + prompt_messages: list[PromptMessage]) -> Generator: + """ + Handle llm stream response + + :param model: model name + :param credentials: model credentials + :param response: streamed response + :param prompt_messages: prompt messages + :return: llm response chunk generator + """ + full_assistant_content = '' + chunk_index = 0 + + def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \ + -> LLMResultChunk: + # calculate num tokens + prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content) + completion_tokens = self._num_tokens_from_string(model, full_assistant_content) + + # transform usage + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + + return LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=index, + message=message, + finish_reason=finish_reason, + usage=usage + ) + ) + + tools_calls: list[AssistantPromptMessage.ToolCall] = [] + finish_reason = "Unknown" + + def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]): + def get_tool_call(tool_name: str): + if not tool_name: + return tools_calls[-1] + + tool_call = next((tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None) + if tool_call is None: + tool_call = AssistantPromptMessage.ToolCall( + id='', + type='', + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments="") + ) + tools_calls.append(tool_call) + + return tool_call + + for new_tool_call in new_tool_calls: + # get tool call + tool_call = get_tool_call(new_tool_call.function.name) + # update tool call + if new_tool_call.id: + tool_call.id = new_tool_call.id + if new_tool_call.type: + tool_call.type = new_tool_call.type + if new_tool_call.function.name: + tool_call.function.name = new_tool_call.function.name + if new_tool_call.function.arguments: + tool_call.function.arguments += new_tool_call.function.arguments + + for chunk in response.iter_lines(decode_unicode=True, delimiter="\n\n"): + if chunk: + # ignore sse comments + if chunk.startswith(':'): + continue + decoded_chunk = chunk.strip().lstrip('data: ').lstrip() + chunk_json = None + try: + chunk_json = json.loads(decoded_chunk) + # stream ended + except json.JSONDecodeError as e: + yield create_final_llm_result_chunk( + index=chunk_index + 1, + message=AssistantPromptMessage(content=""), + finish_reason="Non-JSON encountered." + ) + break + if not chunk_json or len(chunk_json['choices']) == 0: + continue + + choice = chunk_json['choices'][0] + finish_reason = chunk_json['choices'][0].get('finish_reason') + chunk_index += 1 + + if 'delta' in choice: + delta = choice['delta'] + delta_content = delta.get('content') + + assistant_message_tool_calls = delta.get('tool_calls', None) + # assistant_message_function_call = delta.delta.function_call + + # extract tool calls from response + if assistant_message_tool_calls: + tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) + increase_tool_call(tool_calls) + + if delta_content is None or delta_content == '': + continue + + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage( + content=delta_content, + tool_calls=tool_calls if assistant_message_tool_calls else [] + ) + + full_assistant_content += delta_content + elif 'text' in choice: + choice_text = choice.get('text', '') + if choice_text == '': + continue + + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage(content=choice_text) + full_assistant_content += choice_text + else: + continue + + # check payload indicator for completion + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=chunk_index, + message=assistant_prompt_message, + ) + ) + + chunk_index += 1 + + if tools_calls: + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=chunk_index, + message=AssistantPromptMessage( + tool_calls=tools_calls, + content="" + ), + ) + ) + + yield create_final_llm_result_chunk( + index=chunk_index, + message=AssistantPromptMessage(content=""), + finish_reason=finish_reason + ) \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/stepfun/llm/step-1-128k.yaml b/api/core/model_runtime/model_providers/stepfun/llm/step-1-128k.yaml new file mode 100644 index 0000000000..13f7b7fd26 --- /dev/null +++ b/api/core/model_runtime/model_providers/stepfun/llm/step-1-128k.yaml @@ -0,0 +1,25 @@ +model: step-1-128k +label: + zh_Hans: step-1-128k + en_US: step-1-128k +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 128000 +pricing: + input: '0.04' + output: '0.20' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/stepfun/llm/step-1-256k.yaml b/api/core/model_runtime/model_providers/stepfun/llm/step-1-256k.yaml new file mode 100644 index 0000000000..f80ec9851c --- /dev/null +++ b/api/core/model_runtime/model_providers/stepfun/llm/step-1-256k.yaml @@ -0,0 +1,25 @@ +model: step-1-256k +label: + zh_Hans: step-1-256k + en_US: step-1-256k +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 256000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 256000 +pricing: + input: '0.095' + output: '0.300' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/stepfun/llm/step-1-32k.yaml b/api/core/model_runtime/model_providers/stepfun/llm/step-1-32k.yaml new file mode 100644 index 0000000000..96132d14a8 --- /dev/null +++ b/api/core/model_runtime/model_providers/stepfun/llm/step-1-32k.yaml @@ -0,0 +1,28 @@ +model: step-1-32k +label: + zh_Hans: step-1-32k + en_US: step-1-32k +model_type: llm +features: + - agent-thought + - tool-call + - multi-tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 32000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 32000 +pricing: + input: '0.015' + output: '0.070' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/stepfun/llm/step-1-8k.yaml b/api/core/model_runtime/model_providers/stepfun/llm/step-1-8k.yaml new file mode 100644 index 0000000000..4a4ba8d178 --- /dev/null +++ b/api/core/model_runtime/model_providers/stepfun/llm/step-1-8k.yaml @@ -0,0 +1,28 @@ +model: step-1-8k +label: + zh_Hans: step-1-8k + en_US: step-1-8k +model_type: llm +features: + - agent-thought + - tool-call + - multi-tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 8000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 8000 +pricing: + input: '0.005' + output: '0.020' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/stepfun/llm/step-1v-32k.yaml b/api/core/model_runtime/model_providers/stepfun/llm/step-1v-32k.yaml new file mode 100644 index 0000000000..f878ee3e56 --- /dev/null +++ b/api/core/model_runtime/model_providers/stepfun/llm/step-1v-32k.yaml @@ -0,0 +1,25 @@ +model: step-1v-32k +label: + zh_Hans: step-1v-32k + en_US: step-1v-32k +model_type: llm +features: + - vision +model_properties: + mode: chat + context_size: 32000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 32000 +pricing: + input: '0.015' + output: '0.070' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/stepfun/llm/step-1v-8k.yaml b/api/core/model_runtime/model_providers/stepfun/llm/step-1v-8k.yaml new file mode 100644 index 0000000000..6c3cb61d2c --- /dev/null +++ b/api/core/model_runtime/model_providers/stepfun/llm/step-1v-8k.yaml @@ -0,0 +1,25 @@ +model: step-1v-8k +label: + zh_Hans: step-1v-8k + en_US: step-1v-8k +model_type: llm +features: + - vision +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 8192 +pricing: + input: '0.005' + output: '0.020' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/stepfun/stepfun.py b/api/core/model_runtime/model_providers/stepfun/stepfun.py new file mode 100644 index 0000000000..50b17392b5 --- /dev/null +++ b/api/core/model_runtime/model_providers/stepfun/stepfun.py @@ -0,0 +1,30 @@ +import logging + +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class StepfunProvider(ModelProvider): + + def validate_provider_credentials(self, credentials: dict) -> None: + """ + Validate provider credentials + if validate failed, raise exception + + :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. + """ + try: + model_instance = self.get_model_instance(ModelType.LLM) + + model_instance.validate_credentials( + model='step-1-8k', + credentials=credentials + ) + except CredentialsValidateFailedError as ex: + raise ex + except Exception as ex: + logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + raise ex diff --git a/api/core/model_runtime/model_providers/stepfun/stepfun.yaml b/api/core/model_runtime/model_providers/stepfun/stepfun.yaml new file mode 100644 index 0000000000..ccc8455adc --- /dev/null +++ b/api/core/model_runtime/model_providers/stepfun/stepfun.yaml @@ -0,0 +1,81 @@ +provider: stepfun +label: + zh_Hans: 阶跃星辰 + en_US: Stepfun +description: + en_US: Models provided by stepfun, such as step-1-8k, step-1-32k、step-1v-8k、step-1v-32k, step-1-128k and step-1-256k + zh_Hans: 阶跃星辰提供的模型,例如 step-1-8k、step-1-32k、step-1v-8k、step-1v-32k、step-1-128k 和 step-1-256k。 +icon_small: + en_US: icon_s_en.png +icon_large: + en_US: icon_l_en.png +background: "#FFFFFF" +help: + title: + en_US: Get your API Key from stepfun + zh_Hans: 从 stepfun 获取 API Key + url: + en_US: https://platform.stepfun.com/interface-key +supported_model_types: + - llm +configurate_methods: + - predefined-model + - customizable-model +provider_credential_schema: + credential_form_schemas: + - variable: api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key +model_credential_schema: + model: + label: + en_US: Model Name + zh_Hans: 模型名称 + placeholder: + en_US: Enter your model name + zh_Hans: 输入模型名称 + credential_form_schemas: + - variable: api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key + - variable: context_size + label: + zh_Hans: 模型上下文长度 + en_US: Model context size + required: true + type: text-input + default: '8192' + placeholder: + zh_Hans: 在此输入您的模型上下文长度 + en_US: Enter your Model context size + - variable: max_tokens + label: + zh_Hans: 最大 token 上限 + en_US: Upper bound for max tokens + default: '8192' + type: text-input + - variable: function_calling_type + label: + en_US: Function calling + type: select + required: false + default: no_call + options: + - value: no_call + label: + en_US: Not supported + zh_Hans: 不支持 + - value: tool_call + label: + en_US: Tool Call + zh_Hans: Tool Call diff --git a/api/tests/integration_tests/model_runtime/stepfun/__init__.py b/api/tests/integration_tests/model_runtime/stepfun/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/model_runtime/stepfun/test_llm.py b/api/tests/integration_tests/model_runtime/stepfun/test_llm.py new file mode 100644 index 0000000000..d703147d63 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/stepfun/test_llm.py @@ -0,0 +1,176 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import AIModelEntity, ModelType +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.stepfun.llm.llm import StepfunLargeLanguageModel + + +def test_validate_credentials(): + model = StepfunLargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model='step-1-8k', + credentials={ + 'api_key': 'invalid_key' + } + ) + + model.validate_credentials( + model='step-1-8k', + credentials={ + 'api_key': os.environ.get('STEPFUN_API_KEY') + } + ) + +def test_invoke_model(): + model = StepfunLargeLanguageModel() + + response = model.invoke( + model='step-1-8k', + credentials={ + 'api_key': os.environ.get('STEPFUN_API_KEY') + }, + prompt_messages=[ + UserPromptMessage( + content='Hello World!' + ) + ], + model_parameters={ + 'temperature': 0.9, + 'top_p': 0.7 + }, + stop=['Hi'], + stream=False, + user="abc-123" + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + + +def test_invoke_stream_model(): + model = StepfunLargeLanguageModel() + + response = model.invoke( + model='step-1-8k', + credentials={ + 'api_key': os.environ.get('STEPFUN_API_KEY') + }, + prompt_messages=[ + SystemPromptMessage( + content='You are a helpful AI assistant.', + ), + UserPromptMessage( + content='Hello World!' + ) + ], + model_parameters={ + 'temperature': 0.9, + 'top_p': 0.7 + }, + stream=True, + user="abc-123" + ) + + assert isinstance(response, Generator) + + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + +def test_get_customizable_model_schema(): + model = StepfunLargeLanguageModel() + + schema = model.get_customizable_model_schema( + model='step-1-8k', + credentials={ + 'api_key': os.environ.get('STEPFUN_API_KEY') + } + ) + assert isinstance(schema, AIModelEntity) + + +def test_invoke_chat_model_with_tools(): + model = StepfunLargeLanguageModel() + + result = model.invoke( + model='step-1-8k', + credentials={ + 'api_key': os.environ.get('STEPFUN_API_KEY') + }, + prompt_messages=[ + SystemPromptMessage( + content='You are a helpful AI assistant.', + ), + UserPromptMessage( + content="what's the weather today in Shanghai?", + ) + ], + model_parameters={ + 'temperature': 0.9, + 'max_tokens': 100 + }, + tools=[ + PromptMessageTool( + name='get_weather', + description='Determine weather in my location', + parameters={ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": [ + "c", + "f" + ] + } + }, + "required": [ + "location" + ] + } + ), + PromptMessageTool( + name='get_stock_price', + description='Get the current stock price', + parameters={ + "type": "object", + "properties": { + "symbol": { + "type": "string", + "description": "The stock symbol" + } + }, + "required": [ + "symbol" + ] + } + ) + ], + stream=False, + user="abc-123" + ) + + assert isinstance(result, LLMResult) + assert isinstance(result.message, AssistantPromptMessage) + assert len(result.message.tool_calls) > 0 \ No newline at end of file