Add api for sessions and add max_tokens for tenant_llm (#3472)

### What problem does this PR solve?

Add api for sessions and add max_tokens for tenant_llm

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: liuhua <10215101452@stu.ecun.edu.cn>
This commit is contained in:
liuhua 2024-11-19 14:51:33 +08:00 committed by GitHub
parent 883fafde72
commit d42362deb6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 401 additions and 67 deletions

View File

@ -74,9 +74,9 @@ def set_api_key():
mdl = ChatModel[factory](
req["api_key"], llm.llm_name, base_url=req.get("base_url"))
try:
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}],
{"temperature": 0.9,'max_tokens':50})
if m.find("**ERROR**") >=0:
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}],
{"temperature": 0.9, 'max_tokens': 50})
if m.find("**ERROR**") >= 0:
raise Exception(m)
chat_passed = True
except Exception as e:
@ -110,6 +110,7 @@ def set_api_key():
llm_config[n] = req[n]
for llm in LLMService.query(fid=factory):
llm_config["max_tokens"]=llm.max_tokens
if not TenantLLMService.filter_update(
[TenantLLM.tenant_id == current_user.id,
TenantLLM.llm_factory == factory,
@ -121,7 +122,8 @@ def set_api_key():
llm_name=llm.llm_name,
model_type=llm.model_type,
api_key=llm_config["api_key"],
api_base=llm_config["api_base"]
api_base=llm_config["api_base"],
max_tokens=llm_config["max_tokens"]
)
return get_json_result(data=True)
@ -158,23 +160,23 @@ def add_llm():
api_key = apikey_json(["bedrock_ak", "bedrock_sk", "bedrock_region"])
elif factory == "LocalAI":
llm_name = req["llm_name"]+"___LocalAI"
llm_name = req["llm_name"] + "___LocalAI"
api_key = "xxxxxxxxxxxxxxx"
elif factory == "HuggingFace":
llm_name = req["llm_name"]+"___HuggingFace"
llm_name = req["llm_name"] + "___HuggingFace"
api_key = "xxxxxxxxxxxxxxx"
elif factory == "OpenAI-API-Compatible":
llm_name = req["llm_name"]+"___OpenAI-API"
api_key = req.get("api_key","xxxxxxxxxxxxxxx")
llm_name = req["llm_name"] + "___OpenAI-API"
api_key = req.get("api_key", "xxxxxxxxxxxxxxx")
elif factory =="XunFei Spark":
elif factory == "XunFei Spark":
llm_name = req["llm_name"]
if req["model_type"] == "chat":
api_key = req.get("spark_api_password", "xxxxxxxxxxxxxxx")
elif req["model_type"] == "tts":
api_key = apikey_json(["spark_app_id", "spark_api_secret","spark_api_key"])
api_key = apikey_json(["spark_app_id", "spark_api_secret", "spark_api_key"])
elif factory == "BaiduYiyan":
llm_name = req["llm_name"]
@ -202,14 +204,15 @@ def add_llm():
"model_type": req["model_type"],
"llm_name": llm_name,
"api_base": req.get("api_base", ""),
"api_key": api_key
"api_key": api_key,
"max_tokens": req.get("max_tokens")
}
msg = ""
if llm["model_type"] == LLMType.EMBEDDING.value:
mdl = EmbeddingModel[factory](
key=llm['api_key'],
model_name=llm["llm_name"],
model_name=llm["llm_name"],
base_url=llm["api_base"])
try:
arr, tc = mdl.encode(["Test if the api key is available"])
@ -225,7 +228,7 @@ def add_llm():
)
try:
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {
"temperature": 0.9})
"temperature": 0.9})
if not tc:
raise Exception(m)
except Exception as e:
@ -233,8 +236,8 @@ def add_llm():
e)
elif llm["model_type"] == LLMType.RERANK:
mdl = RerankModel[factory](
key=llm["api_key"],
model_name=llm["llm_name"],
key=llm["api_key"],
model_name=llm["llm_name"],
base_url=llm["api_base"]
)
try:
@ -246,8 +249,8 @@ def add_llm():
e)
elif llm["model_type"] == LLMType.IMAGE2TEXT.value:
mdl = CvModel[factory](
key=llm["api_key"],
model_name=llm["llm_name"],
key=llm["api_key"],
model_name=llm["llm_name"],
base_url=llm["api_base"]
)
try:
@ -282,7 +285,8 @@ def add_llm():
return get_data_error_result(message=msg)
if not TenantLLMService.filter_update(
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory, TenantLLM.llm_name == llm["llm_name"]], llm):
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory,
TenantLLM.llm_name == llm["llm_name"]], llm):
TenantLLMService.save(**llm)
return get_json_result(data=True)
@ -294,7 +298,8 @@ def add_llm():
def delete_llm():
req = request.json
TenantLLMService.filter_delete(
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]])
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"],
TenantLLM.llm_name == req["llm_name"]])
return get_json_result(data=True)
@ -304,7 +309,7 @@ def delete_llm():
def delete_factory():
req = request.json
TenantLLMService.filter_delete(
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"]])
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"]])
return get_json_result(data=True)
@ -332,8 +337,8 @@ def my_llms():
@manager.route('/list', methods=['GET'])
@login_required
def list_app():
self_deploied = ["Youdao","FastEmbed", "BAAI", "Ollama", "Xinference", "LocalAI", "LM-Studio"]
weighted = ["Youdao","FastEmbed", "BAAI"] if settings.LIGHTEN != 0 else []
self_deploied = ["Youdao", "FastEmbed", "BAAI", "Ollama", "Xinference", "LocalAI", "LM-Studio"]
weighted = ["Youdao", "FastEmbed", "BAAI"] if settings.LIGHTEN != 0 else []
model_type = request.args.get("model_type")
try:
objs = TenantLLMService.query(tenant_id=current_user.id)
@ -344,15 +349,15 @@ def list_app():
for m in llms:
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in self_deploied
llm_set = set([m["llm_name"]+"@"+m["fid"] for m in llms])
llm_set = set([m["llm_name"] + "@" + m["fid"] for m in llms])
for o in objs:
if not o.api_key:continue
if o.llm_name+"@"+o.llm_factory in llm_set:continue
if not o.api_key: continue
if o.llm_name + "@" + o.llm_factory in llm_set: continue
llms.append({"llm_name": o.llm_name, "model_type": o.model_type, "fid": o.llm_factory, "available": True})
res = {}
for m in llms:
if model_type and m["model_type"].find(model_type)<0:
if model_type and m["model_type"].find(model_type) < 0:
continue
if m["fid"] not in res:
res[m["fid"]] = []
@ -360,4 +365,4 @@ def list_app():
return get_json_result(data=res)
except Exception as e:
return server_error_response(e)
return server_error_response(e)

View File

@ -13,21 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
import json
from functools import partial
from uuid import uuid4
from api.db import LLMType
from flask import request, Response
from api.db.services.dialog_service import ask
from agent.canvas import Canvas
from api.db import StatusEnum
from api.db.db_models import API4Conversation
from api.db.services.api_service import API4ConversationService
from api.db.services.canvas_service import UserCanvasService
from api.db.services.dialog_service import DialogService, ConversationService, chat
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.utils import get_uuid
from api.utils.api_utils import get_error_data_result
from api.utils.api_utils import get_result, token_required
from api.db.services.llm_service import LLMBundle
@manager.route('/chats/<chat_id>/sessions', methods=['POST'])
@ -342,7 +345,7 @@ def agent_completion(tenant_id, agent_id):
yield "data:" + json.dumps({"code": 500, "message": str(e),
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "data": True}, ensure_ascii=False) + "\n\n"
resp = Response(sse(), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
@ -366,7 +369,7 @@ def agent_completion(tenant_id, agent_id):
@manager.route('/chats/<chat_id>/sessions', methods=['GET'])
@token_required
def list(chat_id,tenant_id):
def list_session(chat_id,tenant_id):
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
return get_error_data_result(message=f"You don't own the assistant {chat_id}.")
id = request.args.get("id")
@ -441,4 +444,80 @@ def delete(tenant_id,chat_id):
if not conv:
return get_error_data_result(message="The chat doesn't own the session")
ConversationService.delete_by_id(id)
return get_result()
return get_result()
@manager.route('/sessions/ask', methods=['POST'])
@token_required
def ask_about(tenant_id):
req = request.json
if not req.get("question"):
return get_error_data_result("`question` is required.")
if not req.get("dataset_ids"):
return get_error_data_result("`dataset_ids` is required.")
if not isinstance(req.get("dataset_ids"),list):
return get_error_data_result("`dataset_ids` should be a list.")
req["kb_ids"]=req.pop("dataset_ids")
for kb_id in req["kb_ids"]:
if not KnowledgebaseService.accessible(kb_id,tenant_id):
return get_error_data_result(f"You don't own the dataset {kb_id}.")
kbs = KnowledgebaseService.query(id=kb_id)
kb = kbs[0]
if kb.chunk_num == 0:
return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
uid = tenant_id
def stream():
nonlocal req, uid
try:
for ans in ask(req["question"], req["kb_ids"], uid):
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
except Exception as e:
yield "data:" + json.dumps({"code": 500, "message": str(e),
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
resp = Response(stream(), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp
@manager.route('/sessions/related_questions', methods=['POST'])
@token_required
def related_questions(tenant_id):
req = request.json
if not req.get("question"):
return get_error_data_result("`question` is required.")
question = req["question"]
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
prompt = """
Objective: To generate search terms related to the user's search keywords, helping users find more valuable information.
Instructions:
- Based on the keywords provided by the user, generate 5-10 related search terms.
- Each search term should be directly or indirectly related to the keyword, guiding the user to find more valuable information.
- Use common, general terms as much as possible, avoiding obscure words or technical jargon.
- Keep the term length between 2-4 words, concise and clear.
- DO NOT translate, use the language of the original keywords.
### Example:
Keywords: Chinese football
Related search terms:
1. Current status of Chinese football
2. Reform of Chinese football
3. Youth training of Chinese football
4. Chinese football in the Asian Cup
5. Chinese football in the World Cup
Reason:
- When searching, users often only use one or two keywords, making it difficult to fully express their information needs.
- Generating related search terms can help users dig deeper into relevant information and improve search efficiency.
- At the same time, related terms can also help search engines better understand user needs and return more accurate search results.
"""
ans = chat_mdl.chat(prompt, [{"role": "user", "content": f"""
Keywords: {question}
Related search terms:
"""}], {"temperature": 0.9})
return get_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)])

View File

@ -17,6 +17,7 @@ import logging
import inspect
import os
import sys
import typing
import operator
from enum import Enum
from functools import wraps
@ -29,10 +30,13 @@ from peewee import (
Field, Model, Metadata
)
from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
from api.db import SerializedType, ParserType
from api import settings
from api import utils
def singleton(cls, *args, **kw):
instances = {}
@ -120,13 +124,13 @@ class SerializedField(LongTextField):
f"the serialized type {self._serialized_type} is not supported")
def is_continuous_field(cls: type) -> bool:
def is_continuous_field(cls: typing.Type) -> bool:
if cls in CONTINUOUS_FIELD_TYPE:
return True
for p in cls.__bases__:
if p in CONTINUOUS_FIELD_TYPE:
return True
elif p is not Field and p is not object:
elif p != Field and p != object:
if is_continuous_field(p):
return True
else:
@ -158,7 +162,7 @@ class BaseModel(Model):
def to_dict(self):
return self.__dict__['__data__']
def to_human_model_dict(self, only_primary_with: list | None = None):
def to_human_model_dict(self, only_primary_with: list = None):
model_dict = self.__dict__['__data__']
if not only_primary_with:
@ -268,6 +272,7 @@ class JsonSerializedField(SerializedField):
super(JsonSerializedField, self).__init__(serialized_type=SerializedType.JSON, object_hook=object_hook,
object_pairs_hook=object_pairs_hook, **kwargs)
class PooledDatabase(Enum):
MYSQL = PooledMySQLDatabase
POSTGRES = PooledPostgresqlDatabase
@ -286,6 +291,7 @@ class BaseDataBase:
self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config)
logging.info('init database on cluster mode successfully')
class PostgresDatabaseLock:
def __init__(self, lock_name, timeout=10, db=None):
self.lock_name = lock_name
@ -330,6 +336,7 @@ class PostgresDatabaseLock:
return magic
class MysqlDatabaseLock:
def __init__(self, lock_name, timeout=10, db=None):
self.lock_name = lock_name
@ -644,7 +651,7 @@ class TenantLLM(DataBaseModel):
index=True)
api_key = CharField(max_length=1024, null=True, help_text="API KEY", index=True)
api_base = CharField(max_length=255, null=True, help_text="API Base")
max_tokens = IntegerField(default=8192, index=True)
used_tokens = IntegerField(default=0, index=True)
def __str__(self):
@ -875,8 +882,10 @@ class Dialog(DataBaseModel):
default="simple",
help_text="simple|advanced",
index=True)
prompt_config = JSONField(null=False, default={"system": "", "prologue": "Hi! I'm your assistant, what can I do for you?",
"parameters": [], "empty_response": "Sorry! No relevant content was found in the knowledge base!"})
prompt_config = JSONField(null=False,
default={"system": "", "prologue": "Hi! I'm your assistant, what can I do for you?",
"parameters": [],
"empty_response": "Sorry! No relevant content was found in the knowledge base!"})
similarity_threshold = FloatField(default=0.2)
vector_similarity_weight = FloatField(default=0.3)
@ -890,7 +899,7 @@ class Dialog(DataBaseModel):
null=False,
default="1",
help_text="it needs to insert reference index into answer or not")
rerank_id = CharField(
max_length=128,
null=False,
@ -1025,8 +1034,8 @@ def migrate_db():
pass
try:
migrate(
migrator.add_column("tenant","tts_id",
CharField(max_length=256,null=True,help_text="default tts model ID",index=True))
migrator.add_column("tenant", "tts_id",
CharField(max_length=256, null=True, help_text="default tts model ID", index=True))
)
except Exception:
pass
@ -1055,4 +1064,9 @@ def migrate_db():
)
except Exception:
pass
try:
migrate(
migrator.add_column("tenant_llm","max_tokens",IntegerField(default=8192,index=True))
)
except Exception:
pass

View File

@ -567,7 +567,7 @@ class TogetherAIEmbed(OllamaEmbed):
def __init__(self, key, model_name, base_url="https://api.together.xyz/v1"):
if not base_url:
base_url = "https://api.together.xyz/v1"
super().__init__(key, model_name, base_url)
super().__init__(key, model_name, base_url=base_url)
class PerfXCloudEmbed(OpenAIEmbed):

View File

@ -4,6 +4,7 @@ export interface IAddLlmRequestBody {
model_type: string;
api_base?: string; // chat|embedding|speech2text|image2text
api_key: string;
max_tokens: number;
}
export interface IDeleteLlmRequestBody {

View File

@ -393,6 +393,8 @@ The above is the content you need to summarize.`,
maxTokensMessage: 'Max Tokens is required',
maxTokensTip:
'This sets the maximum length of the models output, measured in the number of tokens (words or pieces of words).',
maxTokensInvalidMessage: 'Please enter a valid number for Max Tokens.',
maxTokensMinMessage: 'Max Tokens cannot be less than 0.',
quote: 'Show Quote',
quoteTip: 'Should the source of the original text be displayed?',
selfRag: 'Self-RAG',
@ -441,6 +443,12 @@ The above is the content you need to summarize.`,
setting: {
profile: 'Profile',
profileDescription: 'Update your photo and personal details here.',
maxTokens: 'Max Tokens',
maxTokensMessage: 'Max Tokens is required',
maxTokensTip:
'This sets the maximum length of the models output, measured in the number of tokens (words or pieces of words).',
maxTokensInvalidMessage: 'Please enter a valid number for Max Tokens.',
maxTokensMinMessage: 'Max Tokens cannot be less than 0.',
password: 'Password',
passwordDescription:
'Please enter your current password to change your password.',

View File

@ -231,6 +231,8 @@ export default {
maxTokensMessage: 'El máximo de tokens es obligatorio',
maxTokensTip:
'Esto establece la longitud máxima de la salida del modelo, medida en el número de tokens (palabras o piezas de palabras).',
maxTokensInvalidMessage: 'Por favor, ingresa un número válido para Max Tokens.',
maxTokensMinMessage: 'Max Tokens no puede ser menor que 0.',
quote: 'Mostrar cita',
quoteTip: '¿Debe mostrarse la fuente del texto original?',
selfRag: 'Self-RAG',
@ -278,6 +280,12 @@ export default {
setting: {
profile: 'Perfil',
profileDescription: 'Actualiza tu foto y tus datos personales aquí.',
maxTokens: 'Máximo de tokens',
maxTokensMessage: 'El máximo de tokens es obligatorio',
maxTokensTip:
'Esto establece la longitud máxima de la salida del modelo, medida en el número de tokens (palabras o piezas de palabras).',
maxTokensInvalidMessage: 'Por favor, ingresa un número válido para Max Tokens.',
maxTokensMinMessage: 'Max Tokens no puede ser menor que 0.',
password: 'Contraseña',
passwordDescription:
'Por favor ingresa tu contraseña actual para cambiarla.',

View File

@ -401,6 +401,8 @@ export default {
maxTokensMessage: 'Token Maksimum diperlukan',
maxTokensTip:
'Ini menetapkan panjang maksimum keluaran model, diukur dalam jumlah token (kata atau potongan kata).',
maxTokensInvalidMessage: 'Silakan masukkan angka yang valid untuk Max Tokens.',
maxTokensMinMessage: 'Max Tokens tidak boleh kurang dari 0.',
quote: 'Tampilkan Kutipan',
quoteTip: 'Haruskah sumber teks asli ditampilkan?',
selfRag: 'Self-RAG',
@ -450,6 +452,12 @@ export default {
setting: {
profile: 'Profil',
profileDescription: 'Perbarui foto dan detail pribadi Anda di sini.',
maxTokens: 'Token Maksimum',
maxTokensMessage: 'Token Maksimum diperlukan',
maxTokensTip:
'Ini menetapkan panjang maksimum keluaran model, diukur dalam jumlah token (kata atau potongan kata).',
maxTokensInvalidMessage: 'Silakan masukkan angka yang valid untuk Max Tokens.',
maxTokensMinMessage: 'Max Tokens tidak boleh kurang dari 0.',
password: 'Kata Sandi',
passwordDescription:
'Silakan masukkan kata sandi Anda saat ini untuk mengubah kata sandi Anda.',

View File

@ -376,6 +376,8 @@ export default {
maxTokensMessage: '最大token數是必填項',
maxTokensTip:
'這設置了模型輸出的最大長度,以標記(單詞或單詞片段)的數量來衡量。',
maxTokensInvalidMessage: '請輸入有效的最大標記數。',
maxTokensMinMessage: '最大標記數不能小於 0。',
quote: '顯示引文',
quoteTip: '是否應該顯示原文出處?',
selfRag: 'Self-RAG',
@ -422,6 +424,12 @@ export default {
setting: {
profile: '概述',
profileDescription: '在此更新您的照片和個人詳細信息。',
maxTokens: '最大token數',
maxTokensMessage: '最大token數是必填項',
maxTokensTip:
'這設置了模型輸出的最大長度,以標記(單詞或單詞片段)的數量來衡量。',
maxTokensInvalidMessage: '請輸入有效的最大標記數。',
maxTokensMinMessage: '最大標記數不能小於 0。',
password: '密碼',
passwordDescription: '請輸入您當前的密碼以更改您的密碼。',
model: '模型提供商',

View File

@ -393,6 +393,8 @@ export default {
maxTokensMessage: '最大token数是必填项',
maxTokensTip:
'这设置了模型输出的最大长度,以标记(单词或单词片段)的数量来衡量。',
maxTokensInvalidMessage: '请输入有效的最大令牌数。',
maxTokensMinMessage: '最大令牌数不能小于 0。',
quote: '显示引文',
quoteTip: '是否应该显示原文出处?',
selfRag: 'Self-RAG',
@ -439,6 +441,12 @@ export default {
setting: {
profile: '概要',
profileDescription: '在此更新您的照片和个人详细信息。',
maxTokens: '最大token数',
maxTokensMessage: '最大token数是必填项',
maxTokensTip:
'这设置了模型输出的最大长度,以标记(单词或单词片段)的数量来衡量。',
maxTokensInvalidMessage: '请输入有效的最大令牌数。',
maxTokensMinMessage: '最大令牌数不能小于 0。',
password: '密码',
passwordDescription: '请输入您当前的密码以更改您的密码。',
model: '模型提供商',

View File

@ -1,7 +1,7 @@
import { useTranslate } from '@/hooks/common-hooks';
import { IModalProps } from '@/interfaces/common';
import { IAddLlmRequestBody } from '@/interfaces/request/llm';
import { Flex, Form, Input, Modal, Select, Space } from 'antd';
import { Flex, Form, Input, Modal, Select, Space, InputNumber } from 'antd';
import omit from 'lodash/omit';
type FieldType = IAddLlmRequestBody & {
@ -30,6 +30,7 @@ const TencentCloudModal = ({
...omit(values),
model_type: modelType,
llm_factory: llmFactory,
max_tokens:16000,
};
console.info(data);

View File

@ -1,7 +1,7 @@
import { useTranslate } from '@/hooks/common-hooks';
import { IModalProps } from '@/interfaces/common';
import { IAddLlmRequestBody } from '@/interfaces/request/llm';
import { Form, Input, Modal, Select, Switch } from 'antd';
import { Form, Input, Modal, Select, Switch, InputNumber } from 'antd';
import omit from 'lodash/omit';
type FieldType = IAddLlmRequestBody & {
@ -33,6 +33,7 @@ const AzureOpenAIModal = ({
...omit(values, ['vision']),
model_type: modelType,
llm_factory: llmFactory,
max_tokens:values.max_tokens,
};
console.info(data);
@ -107,6 +108,31 @@ const AzureOpenAIModal = ({
>
<Input placeholder={t('apiVersionMessage')} />
</Form.Item>
<Form.Item<FieldType>
label={t('maxTokens')}
name="max_tokens"
rules={[
{ required: true, message: t('maxTokensMessage') },
{
type: 'number',
message: t('maxTokensInvalidMessage'),
},
({ getFieldValue }) => ({
validator(_, value) {
if (value < 0) {
return Promise.reject(new Error(t('maxTokensMinMessage')));
}
return Promise.resolve();
},
}),
]}
>
<InputNumber
placeholder={t('maxTokensTip')}
style={{ width: '100%' }}
/>
</Form.Item>
<Form.Item noStyle dependencies={['model_type']}>
{({ getFieldValue }) =>
getFieldValue('model_type') === 'chat' && (

View File

@ -1,7 +1,7 @@
import { useTranslate } from '@/hooks/common-hooks';
import { IModalProps } from '@/interfaces/common';
import { IAddLlmRequestBody } from '@/interfaces/request/llm';
import { Flex, Form, Input, Modal, Select, Space } from 'antd';
import { Flex, Form, Input, Modal, Select, Space, InputNumber } from 'antd';
import { useMemo } from 'react';
import { BedrockRegionList } from '../constant';
@ -34,6 +34,7 @@ const BedrockModal = ({
const data = {
...values,
llm_factory: llmFactory,
max_tokens:values.max_tokens,
};
onOk?.(data);
@ -111,6 +112,31 @@ const BedrockModal = ({
allowClear
></Select>
</Form.Item>
<Form.Item<FieldType>
label={t('maxTokens')}
name="max_tokens"
rules={[
{ required: true, message: t('maxTokensMessage') },
{
type: 'number',
message: t('maxTokensInvalidMessage'),
},
({ getFieldValue }) => ({
validator(_, value) {
if (value < 0) {
return Promise.reject(new Error(t('maxTokensMinMessage')));
}
return Promise.resolve();
},
}),
]}
>
<InputNumber
placeholder={t('maxTokensTip')}
style={{ width: '100%' }}
/>
</Form.Item>
</Form>
</Modal>
);

View File

@ -1,7 +1,7 @@
import { useTranslate } from '@/hooks/common-hooks';
import { IModalProps } from '@/interfaces/common';
import { IAddLlmRequestBody } from '@/interfaces/request/llm';
import { Flex, Form, Input, Modal, Select, Space } from 'antd';
import { Flex, Form, Input, Modal, Select, Space, InputNumber } from 'antd';
import omit from 'lodash/omit';
type FieldType = IAddLlmRequestBody & {
@ -30,6 +30,7 @@ const FishAudioModal = ({
...omit(values),
model_type: modelType,
llm_factory: llmFactory,
max_tokens:values.max_tokens,
};
console.info(data);
@ -93,6 +94,31 @@ const FishAudioModal = ({
>
<Input placeholder={t('FishAudioRefIDMessage')} />
</Form.Item>
<Form.Item<FieldType>
label={t('maxTokens')}
name="max_tokens"
rules={[
{ required: true, message: t('maxTokensMessage') },
{
type: 'number',
message: t('maxTokensInvalidMessage'),
},
({ getFieldValue }) => ({
validator(_, value) {
if (value < 0) {
return Promise.reject(new Error(t('maxTokensMinMessage')));
}
return Promise.resolve();
},
}),
]}
>
<InputNumber
placeholder={t('maxTokensTip')}
style={{ width: '100%' }}
/>
</Form.Item>
</Form>
</Modal>
);

View File

@ -1,7 +1,7 @@
import { useTranslate } from '@/hooks/common-hooks';
import { IModalProps } from '@/interfaces/common';
import { IAddLlmRequestBody } from '@/interfaces/request/llm';
import { Form, Input, Modal, Select } from 'antd';
import { Form, Input, Modal, Select, InputNumber } from 'antd';
type FieldType = IAddLlmRequestBody & {
google_project_id: string;
@ -27,6 +27,7 @@ const GoogleModal = ({
const data = {
...values,
llm_factory: llmFactory,
max_tokens:values.max_tokens,
};
onOk?.(data);
@ -87,6 +88,31 @@ const GoogleModal = ({
>
<Input placeholder={t('GoogleServiceAccountKeyMessage')} />
</Form.Item>
<Form.Item<FieldType>
label={t('maxTokens')}
name="max_tokens"
rules={[
{ required: true, message: t('maxTokensMessage') },
{
type: 'number',
message: t('maxTokensInvalidMessage'),
},
({ getFieldValue }) => ({
validator(_, value) {
if (value < 0) {
return Promise.reject(new Error(t('maxTokensMinMessage')));
}
return Promise.resolve();
},
}),
]}
>
<InputNumber
placeholder={t('maxTokensTip')}
style={{ width: '100%' }}
/>
</Form.Item>
</Form>
</Modal>
);

View File

@ -1,7 +1,7 @@
import { useTranslate } from '@/hooks/common-hooks';
import { IModalProps } from '@/interfaces/common';
import { IAddLlmRequestBody } from '@/interfaces/request/llm';
import { Form, Input, Modal, Select } from 'antd';
import { Form, Input, Modal, Select} from 'antd';
import omit from 'lodash/omit';
type FieldType = IAddLlmRequestBody & {

View File

@ -402,7 +402,7 @@ const UserSettingModel = () => {
hideModal={hideTencentCloudAddingModal}
onOk={onTencentCloudAddingOk}
loading={TencentCloudAddingLoading}
llmFactory={'Tencent TencentCloud'}
llmFactory={'Tencent Cloud'}
></TencentCloudModal>
<SparkModal
visible={SparkAddingVisible}

View File

@ -1,7 +1,7 @@
import { useTranslate } from '@/hooks/common-hooks';
import { IModalProps } from '@/interfaces/common';
import { IAddLlmRequestBody } from '@/interfaces/request/llm';
import { Flex, Form, Input, Modal, Select, Space, Switch } from 'antd';
import { Flex, Form, Input, Modal, Select, Space, Switch, InputNumber } from 'antd';
import omit from 'lodash/omit';
type FieldType = IAddLlmRequestBody & { vision: boolean };
@ -45,6 +45,7 @@ const OllamaModal = ({
...omit(values, ['vision']),
model_type: modelType,
llm_factory: llmFactory,
max_tokens:values.max_tokens,
};
console.info(data);
@ -136,6 +137,31 @@ const OllamaModal = ({
>
<Input placeholder={t('apiKeyMessage')} />
</Form.Item>
<Form.Item<FieldType>
label={t('maxTokens')}
name="max_tokens"
rules={[
{ required: true, message: t('maxTokensMessage') },
{
type: 'number',
message: t('maxTokensInvalidMessage'),
},
({ getFieldValue }) => ({
validator(_, value) {
if (value < 0) {
return Promise.reject(new Error(t('maxTokensMinMessage')));
}
return Promise.resolve();
},
}),
]}
>
<InputNumber
placeholder={t('maxTokensTip')}
style={{ width: '100%' }}
/>
</Form.Item>
<Form.Item noStyle dependencies={['model_type']}>
{({ getFieldValue }) =>
getFieldValue('model_type') === 'chat' && (

View File

@ -1,7 +1,7 @@
import { useTranslate } from '@/hooks/common-hooks';
import { IModalProps } from '@/interfaces/common';
import { IAddLlmRequestBody } from '@/interfaces/request/llm';
import { Form, Input, Modal, Select } from 'antd';
import { Form, Input, Modal, Select, InputNumber } from 'antd';
import omit from 'lodash/omit';
type FieldType = IAddLlmRequestBody & {
@ -36,6 +36,7 @@ const SparkModal = ({
...omit(values, ['vision']),
model_type: modelType,
llm_factory: llmFactory,
max_tokens:values.max_tokens,
};
console.info(data);
@ -128,6 +129,31 @@ const SparkModal = ({
)
}
</Form.Item>
<Form.Item<FieldType>
label={t('maxTokens')}
name="max_tokens"
rules={[
{ required: true, message: t('maxTokensMessage') },
{
type: 'number',
message: t('maxTokensInvalidMessage'),
},
({ getFieldValue }) => ({
validator(_, value) {
if (value < 0) {
return Promise.reject(new Error(t('maxTokensMinMessage')));
}
return Promise.resolve();
},
}),
]}
>
<InputNumber
placeholder={t('maxTokensTip')}
style={{ width: '100%' }}
/>
</Form.Item>
</Form>
</Modal>
);

View File

@ -1,7 +1,7 @@
import { useTranslate } from '@/hooks/common-hooks';
import { IModalProps } from '@/interfaces/common';
import { IAddLlmRequestBody } from '@/interfaces/request/llm';
import { Flex, Form, Input, Modal, Select, Space, Switch } from 'antd';
import { Flex, Form, Input, Modal, Select, Space, Switch, InputNumber } from 'antd';
import omit from 'lodash/omit';
type FieldType = IAddLlmRequestBody & {
@ -36,6 +36,7 @@ const VolcEngineModal = ({
...omit(values, ['vision']),
model_type: modelType,
llm_factory: llmFactory,
max_tokens:values.max_tokens,
};
console.info(data);
@ -103,19 +104,31 @@ const VolcEngineModal = ({
>
<Input placeholder={t('ArkApiKeyMessage')} />
</Form.Item>
<Form.Item noStyle dependencies={['model_type']}>
{({ getFieldValue }) =>
getFieldValue('model_type') === 'chat' && (
<Form.Item
label={t('vision')}
valuePropName="checked"
name={'vision'}
>
<Switch />
</Form.Item>
)
}
<Form.Item<FieldType>
label={t('maxTokens')}
name="max_tokens"
rules={[
{ required: true, message: t('maxTokensMessage') },
{
type: 'number',
message: t('maxTokensInvalidMessage'),
},
({ getFieldValue }) => ({
validator(_, value) {
if (value < 0) {
return Promise.reject(new Error(t('maxTokensMinMessage')));
}
return Promise.resolve();
},
}),
]}
>
<InputNumber
placeholder={t('maxTokensTip')}
style={{ width: '100%' }}
/>
</Form.Item>
</Form>
</Modal>
);

View File

@ -1,7 +1,7 @@
import { useTranslate } from '@/hooks/common-hooks';
import { IModalProps } from '@/interfaces/common';
import { IAddLlmRequestBody } from '@/interfaces/request/llm';
import { Form, Input, Modal, Select } from 'antd';
import { Form, Input, Modal, Select, InputNumber } from 'antd';
import omit from 'lodash/omit';
type FieldType = IAddLlmRequestBody & {
@ -34,6 +34,7 @@ const YiyanModal = ({
...omit(values, ['vision']),
model_type: modelType,
llm_factory: llmFactory,
max_tokens:values.max_tokens,
};
console.info(data);
@ -89,6 +90,30 @@ const YiyanModal = ({
>
<Input placeholder={t('yiyanSKMessage')} />
</Form.Item>
<Form.Item<FieldType>
label={t('maxTokens')}
name="max_tokens"
rules={[
{ required: true, message: t('maxTokensMessage') },
{
type: 'number',
message: t('maxTokensInvalidMessage'),
},
({ getFieldValue }) => ({
validator(_, value) {
if (value < 0) {
return Promise.reject(new Error(t('maxTokensMinMessage')));
}
return Promise.resolve();
},
}),
]}
>
<InputNumber
placeholder={t('maxTokensTip')}
style={{ width: '100%' }}
/>
</Form.Item>
</Form>
</Modal>
);