diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index 75ad45a3c..dccd283db 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -74,9 +74,9 @@ def set_api_key(): mdl = ChatModel[factory]( req["api_key"], llm.llm_name, base_url=req.get("base_url")) try: - m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], - {"temperature": 0.9,'max_tokens':50}) - if m.find("**ERROR**") >=0: + m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], + {"temperature": 0.9, 'max_tokens': 50}) + if m.find("**ERROR**") >= 0: raise Exception(m) chat_passed = True except Exception as e: @@ -110,6 +110,7 @@ def set_api_key(): llm_config[n] = req[n] for llm in LLMService.query(fid=factory): + llm_config["max_tokens"]=llm.max_tokens if not TenantLLMService.filter_update( [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory, @@ -121,7 +122,8 @@ def set_api_key(): llm_name=llm.llm_name, model_type=llm.model_type, api_key=llm_config["api_key"], - api_base=llm_config["api_base"] + api_base=llm_config["api_base"], + max_tokens=llm_config["max_tokens"] ) return get_json_result(data=True) @@ -158,23 +160,23 @@ def add_llm(): api_key = apikey_json(["bedrock_ak", "bedrock_sk", "bedrock_region"]) elif factory == "LocalAI": - llm_name = req["llm_name"]+"___LocalAI" + llm_name = req["llm_name"] + "___LocalAI" api_key = "xxxxxxxxxxxxxxx" - + elif factory == "HuggingFace": - llm_name = req["llm_name"]+"___HuggingFace" + llm_name = req["llm_name"] + "___HuggingFace" api_key = "xxxxxxxxxxxxxxx" elif factory == "OpenAI-API-Compatible": - llm_name = req["llm_name"]+"___OpenAI-API" - api_key = req.get("api_key","xxxxxxxxxxxxxxx") + llm_name = req["llm_name"] + "___OpenAI-API" + api_key = req.get("api_key", "xxxxxxxxxxxxxxx") - elif factory =="XunFei Spark": + elif factory == "XunFei Spark": llm_name = req["llm_name"] if req["model_type"] == "chat": api_key = req.get("spark_api_password", "xxxxxxxxxxxxxxx") elif req["model_type"] == "tts": - api_key = apikey_json(["spark_app_id", "spark_api_secret","spark_api_key"]) + api_key = apikey_json(["spark_app_id", "spark_api_secret", "spark_api_key"]) elif factory == "BaiduYiyan": llm_name = req["llm_name"] @@ -202,14 +204,15 @@ def add_llm(): "model_type": req["model_type"], "llm_name": llm_name, "api_base": req.get("api_base", ""), - "api_key": api_key + "api_key": api_key, + "max_tokens": req.get("max_tokens") } msg = "" if llm["model_type"] == LLMType.EMBEDDING.value: mdl = EmbeddingModel[factory]( key=llm['api_key'], - model_name=llm["llm_name"], + model_name=llm["llm_name"], base_url=llm["api_base"]) try: arr, tc = mdl.encode(["Test if the api key is available"]) @@ -225,7 +228,7 @@ def add_llm(): ) try: m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], { - "temperature": 0.9}) + "temperature": 0.9}) if not tc: raise Exception(m) except Exception as e: @@ -233,8 +236,8 @@ def add_llm(): e) elif llm["model_type"] == LLMType.RERANK: mdl = RerankModel[factory]( - key=llm["api_key"], - model_name=llm["llm_name"], + key=llm["api_key"], + model_name=llm["llm_name"], base_url=llm["api_base"] ) try: @@ -246,8 +249,8 @@ def add_llm(): e) elif llm["model_type"] == LLMType.IMAGE2TEXT.value: mdl = CvModel[factory]( - key=llm["api_key"], - model_name=llm["llm_name"], + key=llm["api_key"], + model_name=llm["llm_name"], base_url=llm["api_base"] ) try: @@ -282,7 +285,8 @@ def add_llm(): return get_data_error_result(message=msg) if not TenantLLMService.filter_update( - [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory, TenantLLM.llm_name == llm["llm_name"]], llm): + [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory, + TenantLLM.llm_name == llm["llm_name"]], llm): TenantLLMService.save(**llm) return get_json_result(data=True) @@ -294,7 +298,8 @@ def add_llm(): def delete_llm(): req = request.json TenantLLMService.filter_delete( - [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]]) + [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], + TenantLLM.llm_name == req["llm_name"]]) return get_json_result(data=True) @@ -304,7 +309,7 @@ def delete_llm(): def delete_factory(): req = request.json TenantLLMService.filter_delete( - [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"]]) + [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"]]) return get_json_result(data=True) @@ -332,8 +337,8 @@ def my_llms(): @manager.route('/list', methods=['GET']) @login_required def list_app(): - self_deploied = ["Youdao","FastEmbed", "BAAI", "Ollama", "Xinference", "LocalAI", "LM-Studio"] - weighted = ["Youdao","FastEmbed", "BAAI"] if settings.LIGHTEN != 0 else [] + self_deploied = ["Youdao", "FastEmbed", "BAAI", "Ollama", "Xinference", "LocalAI", "LM-Studio"] + weighted = ["Youdao", "FastEmbed", "BAAI"] if settings.LIGHTEN != 0 else [] model_type = request.args.get("model_type") try: objs = TenantLLMService.query(tenant_id=current_user.id) @@ -344,15 +349,15 @@ def list_app(): for m in llms: m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in self_deploied - llm_set = set([m["llm_name"]+"@"+m["fid"] for m in llms]) + llm_set = set([m["llm_name"] + "@" + m["fid"] for m in llms]) for o in objs: - if not o.api_key:continue - if o.llm_name+"@"+o.llm_factory in llm_set:continue + if not o.api_key: continue + if o.llm_name + "@" + o.llm_factory in llm_set: continue llms.append({"llm_name": o.llm_name, "model_type": o.model_type, "fid": o.llm_factory, "available": True}) res = {} for m in llms: - if model_type and m["model_type"].find(model_type)<0: + if model_type and m["model_type"].find(model_type) < 0: continue if m["fid"] not in res: res[m["fid"]] = [] @@ -360,4 +365,4 @@ def list_app(): return get_json_result(data=res) except Exception as e: - return server_error_response(e) + return server_error_response(e) \ No newline at end of file diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index b6641b107..e6380607b 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -13,21 +13,24 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import re import json from functools import partial from uuid import uuid4 - +from api.db import LLMType from flask import request, Response - +from api.db.services.dialog_service import ask from agent.canvas import Canvas from api.db import StatusEnum from api.db.db_models import API4Conversation from api.db.services.api_service import API4ConversationService from api.db.services.canvas_service import UserCanvasService from api.db.services.dialog_service import DialogService, ConversationService, chat +from api.db.services.knowledgebase_service import KnowledgebaseService from api.utils import get_uuid from api.utils.api_utils import get_error_data_result from api.utils.api_utils import get_result, token_required +from api.db.services.llm_service import LLMBundle @manager.route('/chats//sessions', methods=['POST']) @@ -342,7 +345,7 @@ def agent_completion(tenant_id, agent_id): yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n" - yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" + yield "data:" + json.dumps({"code": 0, "data": True}, ensure_ascii=False) + "\n\n" resp = Response(sse(), mimetype="text/event-stream") resp.headers.add_header("Cache-control", "no-cache") @@ -366,7 +369,7 @@ def agent_completion(tenant_id, agent_id): @manager.route('/chats//sessions', methods=['GET']) @token_required -def list(chat_id,tenant_id): +def list_session(chat_id,tenant_id): if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value): return get_error_data_result(message=f"You don't own the assistant {chat_id}.") id = request.args.get("id") @@ -441,4 +444,80 @@ def delete(tenant_id,chat_id): if not conv: return get_error_data_result(message="The chat doesn't own the session") ConversationService.delete_by_id(id) - return get_result() \ No newline at end of file + return get_result() + +@manager.route('/sessions/ask', methods=['POST']) +@token_required +def ask_about(tenant_id): + req = request.json + if not req.get("question"): + return get_error_data_result("`question` is required.") + if not req.get("dataset_ids"): + return get_error_data_result("`dataset_ids` is required.") + if not isinstance(req.get("dataset_ids"),list): + return get_error_data_result("`dataset_ids` should be a list.") + req["kb_ids"]=req.pop("dataset_ids") + for kb_id in req["kb_ids"]: + if not KnowledgebaseService.accessible(kb_id,tenant_id): + return get_error_data_result(f"You don't own the dataset {kb_id}.") + kbs = KnowledgebaseService.query(id=kb_id) + kb = kbs[0] + if kb.chunk_num == 0: + return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file") + uid = tenant_id + def stream(): + nonlocal req, uid + try: + for ans in ask(req["question"], req["kb_ids"], uid): + yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" + except Exception as e: + yield "data:" + json.dumps({"code": 500, "message": str(e), + "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, + ensure_ascii=False) + "\n\n" + yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" + + resp = Response(stream(), mimetype="text/event-stream") + resp.headers.add_header("Cache-control", "no-cache") + resp.headers.add_header("Connection", "keep-alive") + resp.headers.add_header("X-Accel-Buffering", "no") + resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") + return resp + + +@manager.route('/sessions/related_questions', methods=['POST']) +@token_required +def related_questions(tenant_id): + req = request.json + if not req.get("question"): + return get_error_data_result("`question` is required.") + question = req["question"] + chat_mdl = LLMBundle(tenant_id, LLMType.CHAT) + prompt = """ +Objective: To generate search terms related to the user's search keywords, helping users find more valuable information. +Instructions: + - Based on the keywords provided by the user, generate 5-10 related search terms. + - Each search term should be directly or indirectly related to the keyword, guiding the user to find more valuable information. + - Use common, general terms as much as possible, avoiding obscure words or technical jargon. + - Keep the term length between 2-4 words, concise and clear. + - DO NOT translate, use the language of the original keywords. + +### Example: +Keywords: Chinese football +Related search terms: +1. Current status of Chinese football +2. Reform of Chinese football +3. Youth training of Chinese football +4. Chinese football in the Asian Cup +5. Chinese football in the World Cup + +Reason: + - When searching, users often only use one or two keywords, making it difficult to fully express their information needs. + - Generating related search terms can help users dig deeper into relevant information and improve search efficiency. + - At the same time, related terms can also help search engines better understand user needs and return more accurate search results. + +""" + ans = chat_mdl.chat(prompt, [{"role": "user", "content": f""" +Keywords: {question} +Related search terms: + """}], {"temperature": 0.9}) + return get_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)]) diff --git a/api/db/db_models.py b/api/db/db_models.py index b6975afb4..f57309f0a 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -17,6 +17,7 @@ import logging import inspect import os import sys +import typing import operator from enum import Enum from functools import wraps @@ -29,10 +30,13 @@ from peewee import ( Field, Model, Metadata ) from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase + + from api.db import SerializedType, ParserType from api import settings from api import utils + def singleton(cls, *args, **kw): instances = {} @@ -120,13 +124,13 @@ class SerializedField(LongTextField): f"the serialized type {self._serialized_type} is not supported") -def is_continuous_field(cls: type) -> bool: +def is_continuous_field(cls: typing.Type) -> bool: if cls in CONTINUOUS_FIELD_TYPE: return True for p in cls.__bases__: if p in CONTINUOUS_FIELD_TYPE: return True - elif p is not Field and p is not object: + elif p != Field and p != object: if is_continuous_field(p): return True else: @@ -158,7 +162,7 @@ class BaseModel(Model): def to_dict(self): return self.__dict__['__data__'] - def to_human_model_dict(self, only_primary_with: list | None = None): + def to_human_model_dict(self, only_primary_with: list = None): model_dict = self.__dict__['__data__'] if not only_primary_with: @@ -268,6 +272,7 @@ class JsonSerializedField(SerializedField): super(JsonSerializedField, self).__init__(serialized_type=SerializedType.JSON, object_hook=object_hook, object_pairs_hook=object_pairs_hook, **kwargs) + class PooledDatabase(Enum): MYSQL = PooledMySQLDatabase POSTGRES = PooledPostgresqlDatabase @@ -286,6 +291,7 @@ class BaseDataBase: self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config) logging.info('init database on cluster mode successfully') + class PostgresDatabaseLock: def __init__(self, lock_name, timeout=10, db=None): self.lock_name = lock_name @@ -330,6 +336,7 @@ class PostgresDatabaseLock: return magic + class MysqlDatabaseLock: def __init__(self, lock_name, timeout=10, db=None): self.lock_name = lock_name @@ -644,7 +651,7 @@ class TenantLLM(DataBaseModel): index=True) api_key = CharField(max_length=1024, null=True, help_text="API KEY", index=True) api_base = CharField(max_length=255, null=True, help_text="API Base") - + max_tokens = IntegerField(default=8192, index=True) used_tokens = IntegerField(default=0, index=True) def __str__(self): @@ -875,8 +882,10 @@ class Dialog(DataBaseModel): default="simple", help_text="simple|advanced", index=True) - prompt_config = JSONField(null=False, default={"system": "", "prologue": "Hi! I'm your assistant, what can I do for you?", - "parameters": [], "empty_response": "Sorry! No relevant content was found in the knowledge base!"}) + prompt_config = JSONField(null=False, + default={"system": "", "prologue": "Hi! I'm your assistant, what can I do for you?", + "parameters": [], + "empty_response": "Sorry! No relevant content was found in the knowledge base!"}) similarity_threshold = FloatField(default=0.2) vector_similarity_weight = FloatField(default=0.3) @@ -890,7 +899,7 @@ class Dialog(DataBaseModel): null=False, default="1", help_text="it needs to insert reference index into answer or not") - + rerank_id = CharField( max_length=128, null=False, @@ -1025,8 +1034,8 @@ def migrate_db(): pass try: migrate( - migrator.add_column("tenant","tts_id", - CharField(max_length=256,null=True,help_text="default tts model ID",index=True)) + migrator.add_column("tenant", "tts_id", + CharField(max_length=256, null=True, help_text="default tts model ID", index=True)) ) except Exception: pass @@ -1055,4 +1064,9 @@ def migrate_db(): ) except Exception: pass - + try: + migrate( + migrator.add_column("tenant_llm","max_tokens",IntegerField(default=8192,index=True)) + ) + except Exception: + pass diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 813aa04db..9f9df1e45 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -567,7 +567,7 @@ class TogetherAIEmbed(OllamaEmbed): def __init__(self, key, model_name, base_url="https://api.together.xyz/v1"): if not base_url: base_url = "https://api.together.xyz/v1" - super().__init__(key, model_name, base_url) + super().__init__(key, model_name, base_url=base_url) class PerfXCloudEmbed(OpenAIEmbed): diff --git a/web/src/interfaces/request/llm.ts b/web/src/interfaces/request/llm.ts index 309fd45d0..05f8f470e 100644 --- a/web/src/interfaces/request/llm.ts +++ b/web/src/interfaces/request/llm.ts @@ -4,6 +4,7 @@ export interface IAddLlmRequestBody { model_type: string; api_base?: string; // chat|embedding|speech2text|image2text api_key: string; + max_tokens: number; } export interface IDeleteLlmRequestBody { diff --git a/web/src/locales/en.ts b/web/src/locales/en.ts index 97de86c8a..e4543cdd6 100644 --- a/web/src/locales/en.ts +++ b/web/src/locales/en.ts @@ -393,6 +393,8 @@ The above is the content you need to summarize.`, maxTokensMessage: 'Max Tokens is required', maxTokensTip: 'This sets the maximum length of the model’s output, measured in the number of tokens (words or pieces of words).', + maxTokensInvalidMessage: 'Please enter a valid number for Max Tokens.', + maxTokensMinMessage: 'Max Tokens cannot be less than 0.', quote: 'Show Quote', quoteTip: 'Should the source of the original text be displayed?', selfRag: 'Self-RAG', @@ -441,6 +443,12 @@ The above is the content you need to summarize.`, setting: { profile: 'Profile', profileDescription: 'Update your photo and personal details here.', + maxTokens: 'Max Tokens', + maxTokensMessage: 'Max Tokens is required', + maxTokensTip: + 'This sets the maximum length of the model’s output, measured in the number of tokens (words or pieces of words).', + maxTokensInvalidMessage: 'Please enter a valid number for Max Tokens.', + maxTokensMinMessage: 'Max Tokens cannot be less than 0.', password: 'Password', passwordDescription: 'Please enter your current password to change your password.', diff --git a/web/src/locales/es.ts b/web/src/locales/es.ts index 5785cc467..39339faed 100644 --- a/web/src/locales/es.ts +++ b/web/src/locales/es.ts @@ -231,6 +231,8 @@ export default { maxTokensMessage: 'El máximo de tokens es obligatorio', maxTokensTip: 'Esto establece la longitud máxima de la salida del modelo, medida en el número de tokens (palabras o piezas de palabras).', + maxTokensInvalidMessage: 'Por favor, ingresa un número válido para Max Tokens.', + maxTokensMinMessage: 'Max Tokens no puede ser menor que 0.', quote: 'Mostrar cita', quoteTip: '¿Debe mostrarse la fuente del texto original?', selfRag: 'Self-RAG', @@ -278,6 +280,12 @@ export default { setting: { profile: 'Perfil', profileDescription: 'Actualiza tu foto y tus datos personales aquí.', + maxTokens: 'Máximo de tokens', + maxTokensMessage: 'El máximo de tokens es obligatorio', + maxTokensTip: + 'Esto establece la longitud máxima de la salida del modelo, medida en el número de tokens (palabras o piezas de palabras).', + maxTokensInvalidMessage: 'Por favor, ingresa un número válido para Max Tokens.', + maxTokensMinMessage: 'Max Tokens no puede ser menor que 0.', password: 'Contraseña', passwordDescription: 'Por favor ingresa tu contraseña actual para cambiarla.', diff --git a/web/src/locales/id.ts b/web/src/locales/id.ts index 5375704ad..9179018ff 100644 --- a/web/src/locales/id.ts +++ b/web/src/locales/id.ts @@ -401,6 +401,8 @@ export default { maxTokensMessage: 'Token Maksimum diperlukan', maxTokensTip: 'Ini menetapkan panjang maksimum keluaran model, diukur dalam jumlah token (kata atau potongan kata).', + maxTokensInvalidMessage: 'Silakan masukkan angka yang valid untuk Max Tokens.', + maxTokensMinMessage: 'Max Tokens tidak boleh kurang dari 0.', quote: 'Tampilkan Kutipan', quoteTip: 'Haruskah sumber teks asli ditampilkan?', selfRag: 'Self-RAG', @@ -450,6 +452,12 @@ export default { setting: { profile: 'Profil', profileDescription: 'Perbarui foto dan detail pribadi Anda di sini.', + maxTokens: 'Token Maksimum', + maxTokensMessage: 'Token Maksimum diperlukan', + maxTokensTip: + 'Ini menetapkan panjang maksimum keluaran model, diukur dalam jumlah token (kata atau potongan kata).', + maxTokensInvalidMessage: 'Silakan masukkan angka yang valid untuk Max Tokens.', + maxTokensMinMessage: 'Max Tokens tidak boleh kurang dari 0.', password: 'Kata Sandi', passwordDescription: 'Silakan masukkan kata sandi Anda saat ini untuk mengubah kata sandi Anda.', diff --git a/web/src/locales/zh-traditional.ts b/web/src/locales/zh-traditional.ts index 7dc518226..305815faf 100644 --- a/web/src/locales/zh-traditional.ts +++ b/web/src/locales/zh-traditional.ts @@ -376,6 +376,8 @@ export default { maxTokensMessage: '最大token數是必填項', maxTokensTip: '這設置了模型輸出的最大長度,以標記(單詞或單詞片段)的數量來衡量。', + maxTokensInvalidMessage: '請輸入有效的最大標記數。', + maxTokensMinMessage: '最大標記數不能小於 0。', quote: '顯示引文', quoteTip: '是否應該顯示原文出處?', selfRag: 'Self-RAG', @@ -422,6 +424,12 @@ export default { setting: { profile: '概述', profileDescription: '在此更新您的照片和個人詳細信息。', + maxTokens: '最大token數', + maxTokensMessage: '最大token數是必填項', + maxTokensTip: + '這設置了模型輸出的最大長度,以標記(單詞或單詞片段)的數量來衡量。', + maxTokensInvalidMessage: '請輸入有效的最大標記數。', + maxTokensMinMessage: '最大標記數不能小於 0。', password: '密碼', passwordDescription: '請輸入您當前的密碼以更改您的密碼。', model: '模型提供商', diff --git a/web/src/locales/zh.ts b/web/src/locales/zh.ts index 086c151f2..fe6a543c2 100644 --- a/web/src/locales/zh.ts +++ b/web/src/locales/zh.ts @@ -393,6 +393,8 @@ export default { maxTokensMessage: '最大token数是必填项', maxTokensTip: '这设置了模型输出的最大长度,以标记(单词或单词片段)的数量来衡量。', + maxTokensInvalidMessage: '请输入有效的最大令牌数。', + maxTokensMinMessage: '最大令牌数不能小于 0。', quote: '显示引文', quoteTip: '是否应该显示原文出处?', selfRag: 'Self-RAG', @@ -439,6 +441,12 @@ export default { setting: { profile: '概要', profileDescription: '在此更新您的照片和个人详细信息。', + maxTokens: '最大token数', + maxTokensMessage: '最大token数是必填项', + maxTokensTip: + '这设置了模型输出的最大长度,以标记(单词或单词片段)的数量来衡量。', + maxTokensInvalidMessage: '请输入有效的最大令牌数。', + maxTokensMinMessage: '最大令牌数不能小于 0。', password: '密码', passwordDescription: '请输入您当前的密码以更改您的密码。', model: '模型提供商', diff --git a/web/src/pages/user-setting/setting-model/Tencent-modal/index.tsx b/web/src/pages/user-setting/setting-model/Tencent-modal/index.tsx index b278e8b8a..71d048a2c 100644 --- a/web/src/pages/user-setting/setting-model/Tencent-modal/index.tsx +++ b/web/src/pages/user-setting/setting-model/Tencent-modal/index.tsx @@ -1,7 +1,7 @@ import { useTranslate } from '@/hooks/common-hooks'; import { IModalProps } from '@/interfaces/common'; import { IAddLlmRequestBody } from '@/interfaces/request/llm'; -import { Flex, Form, Input, Modal, Select, Space } from 'antd'; +import { Flex, Form, Input, Modal, Select, Space, InputNumber } from 'antd'; import omit from 'lodash/omit'; type FieldType = IAddLlmRequestBody & { @@ -30,6 +30,7 @@ const TencentCloudModal = ({ ...omit(values), model_type: modelType, llm_factory: llmFactory, + max_tokens:16000, }; console.info(data); diff --git a/web/src/pages/user-setting/setting-model/azure-openai-modal/index.tsx b/web/src/pages/user-setting/setting-model/azure-openai-modal/index.tsx index f9fab8ab2..d58939d26 100644 --- a/web/src/pages/user-setting/setting-model/azure-openai-modal/index.tsx +++ b/web/src/pages/user-setting/setting-model/azure-openai-modal/index.tsx @@ -1,7 +1,7 @@ import { useTranslate } from '@/hooks/common-hooks'; import { IModalProps } from '@/interfaces/common'; import { IAddLlmRequestBody } from '@/interfaces/request/llm'; -import { Form, Input, Modal, Select, Switch } from 'antd'; +import { Form, Input, Modal, Select, Switch, InputNumber } from 'antd'; import omit from 'lodash/omit'; type FieldType = IAddLlmRequestBody & { @@ -33,6 +33,7 @@ const AzureOpenAIModal = ({ ...omit(values, ['vision']), model_type: modelType, llm_factory: llmFactory, + max_tokens:values.max_tokens, }; console.info(data); @@ -107,6 +108,31 @@ const AzureOpenAIModal = ({ > + + label={t('maxTokens')} + name="max_tokens" + rules={[ + { required: true, message: t('maxTokensMessage') }, + { + type: 'number', + message: t('maxTokensInvalidMessage'), + }, + ({ getFieldValue }) => ({ + validator(_, value) { + if (value < 0) { + return Promise.reject(new Error(t('maxTokensMinMessage'))); + } + return Promise.resolve(); + }, + }), + ]} + > + + + {({ getFieldValue }) => getFieldValue('model_type') === 'chat' && ( diff --git a/web/src/pages/user-setting/setting-model/bedrock-modal/index.tsx b/web/src/pages/user-setting/setting-model/bedrock-modal/index.tsx index 4a137056b..be9b4c3af 100644 --- a/web/src/pages/user-setting/setting-model/bedrock-modal/index.tsx +++ b/web/src/pages/user-setting/setting-model/bedrock-modal/index.tsx @@ -1,7 +1,7 @@ import { useTranslate } from '@/hooks/common-hooks'; import { IModalProps } from '@/interfaces/common'; import { IAddLlmRequestBody } from '@/interfaces/request/llm'; -import { Flex, Form, Input, Modal, Select, Space } from 'antd'; +import { Flex, Form, Input, Modal, Select, Space, InputNumber } from 'antd'; import { useMemo } from 'react'; import { BedrockRegionList } from '../constant'; @@ -34,6 +34,7 @@ const BedrockModal = ({ const data = { ...values, llm_factory: llmFactory, + max_tokens:values.max_tokens, }; onOk?.(data); @@ -111,6 +112,31 @@ const BedrockModal = ({ allowClear > + + label={t('maxTokens')} + name="max_tokens" + rules={[ + { required: true, message: t('maxTokensMessage') }, + { + type: 'number', + message: t('maxTokensInvalidMessage'), + }, + ({ getFieldValue }) => ({ + validator(_, value) { + if (value < 0) { + return Promise.reject(new Error(t('maxTokensMinMessage'))); + } + return Promise.resolve(); + }, + }), + ]} + > + + + ); diff --git a/web/src/pages/user-setting/setting-model/fish-audio-modal/index.tsx b/web/src/pages/user-setting/setting-model/fish-audio-modal/index.tsx index dde456496..a4b85342e 100644 --- a/web/src/pages/user-setting/setting-model/fish-audio-modal/index.tsx +++ b/web/src/pages/user-setting/setting-model/fish-audio-modal/index.tsx @@ -1,7 +1,7 @@ import { useTranslate } from '@/hooks/common-hooks'; import { IModalProps } from '@/interfaces/common'; import { IAddLlmRequestBody } from '@/interfaces/request/llm'; -import { Flex, Form, Input, Modal, Select, Space } from 'antd'; +import { Flex, Form, Input, Modal, Select, Space, InputNumber } from 'antd'; import omit from 'lodash/omit'; type FieldType = IAddLlmRequestBody & { @@ -30,6 +30,7 @@ const FishAudioModal = ({ ...omit(values), model_type: modelType, llm_factory: llmFactory, + max_tokens:values.max_tokens, }; console.info(data); @@ -93,6 +94,31 @@ const FishAudioModal = ({ > + + label={t('maxTokens')} + name="max_tokens" + rules={[ + { required: true, message: t('maxTokensMessage') }, + { + type: 'number', + message: t('maxTokensInvalidMessage'), + }, + ({ getFieldValue }) => ({ + validator(_, value) { + if (value < 0) { + return Promise.reject(new Error(t('maxTokensMinMessage'))); + } + return Promise.resolve(); + }, + }), + ]} + > + + + ); diff --git a/web/src/pages/user-setting/setting-model/google-modal/index.tsx b/web/src/pages/user-setting/setting-model/google-modal/index.tsx index d1d17c581..78d0ab697 100644 --- a/web/src/pages/user-setting/setting-model/google-modal/index.tsx +++ b/web/src/pages/user-setting/setting-model/google-modal/index.tsx @@ -1,7 +1,7 @@ import { useTranslate } from '@/hooks/common-hooks'; import { IModalProps } from '@/interfaces/common'; import { IAddLlmRequestBody } from '@/interfaces/request/llm'; -import { Form, Input, Modal, Select } from 'antd'; +import { Form, Input, Modal, Select, InputNumber } from 'antd'; type FieldType = IAddLlmRequestBody & { google_project_id: string; @@ -27,6 +27,7 @@ const GoogleModal = ({ const data = { ...values, llm_factory: llmFactory, + max_tokens:values.max_tokens, }; onOk?.(data); @@ -87,6 +88,31 @@ const GoogleModal = ({ > + + label={t('maxTokens')} + name="max_tokens" + rules={[ + { required: true, message: t('maxTokensMessage') }, + { + type: 'number', + message: t('maxTokensInvalidMessage'), + }, + ({ getFieldValue }) => ({ + validator(_, value) { + if (value < 0) { + return Promise.reject(new Error(t('maxTokensMinMessage'))); + } + return Promise.resolve(); + }, + }), + ]} + > + + + ); diff --git a/web/src/pages/user-setting/setting-model/hunyuan-modal/index.tsx b/web/src/pages/user-setting/setting-model/hunyuan-modal/index.tsx index baecfdd65..a22585477 100644 --- a/web/src/pages/user-setting/setting-model/hunyuan-modal/index.tsx +++ b/web/src/pages/user-setting/setting-model/hunyuan-modal/index.tsx @@ -1,7 +1,7 @@ import { useTranslate } from '@/hooks/common-hooks'; import { IModalProps } from '@/interfaces/common'; import { IAddLlmRequestBody } from '@/interfaces/request/llm'; -import { Form, Input, Modal, Select } from 'antd'; +import { Form, Input, Modal, Select} from 'antd'; import omit from 'lodash/omit'; type FieldType = IAddLlmRequestBody & { diff --git a/web/src/pages/user-setting/setting-model/index.tsx b/web/src/pages/user-setting/setting-model/index.tsx index c25d5983a..5e05c932e 100644 --- a/web/src/pages/user-setting/setting-model/index.tsx +++ b/web/src/pages/user-setting/setting-model/index.tsx @@ -402,7 +402,7 @@ const UserSettingModel = () => { hideModal={hideTencentCloudAddingModal} onOk={onTencentCloudAddingOk} loading={TencentCloudAddingLoading} - llmFactory={'Tencent TencentCloud'} + llmFactory={'Tencent Cloud'} > + + label={t('maxTokens')} + name="max_tokens" + rules={[ + { required: true, message: t('maxTokensMessage') }, + { + type: 'number', + message: t('maxTokensInvalidMessage'), + }, + ({ getFieldValue }) => ({ + validator(_, value) { + if (value < 0) { + return Promise.reject(new Error(t('maxTokensMinMessage'))); + } + return Promise.resolve(); + }, + }), + ]} + > + + + {({ getFieldValue }) => getFieldValue('model_type') === 'chat' && ( diff --git a/web/src/pages/user-setting/setting-model/spark-modal/index.tsx b/web/src/pages/user-setting/setting-model/spark-modal/index.tsx index 59be63301..300ec299a 100644 --- a/web/src/pages/user-setting/setting-model/spark-modal/index.tsx +++ b/web/src/pages/user-setting/setting-model/spark-modal/index.tsx @@ -1,7 +1,7 @@ import { useTranslate } from '@/hooks/common-hooks'; import { IModalProps } from '@/interfaces/common'; import { IAddLlmRequestBody } from '@/interfaces/request/llm'; -import { Form, Input, Modal, Select } from 'antd'; +import { Form, Input, Modal, Select, InputNumber } from 'antd'; import omit from 'lodash/omit'; type FieldType = IAddLlmRequestBody & { @@ -36,6 +36,7 @@ const SparkModal = ({ ...omit(values, ['vision']), model_type: modelType, llm_factory: llmFactory, + max_tokens:values.max_tokens, }; console.info(data); @@ -128,6 +129,31 @@ const SparkModal = ({ ) } + + label={t('maxTokens')} + name="max_tokens" + rules={[ + { required: true, message: t('maxTokensMessage') }, + { + type: 'number', + message: t('maxTokensInvalidMessage'), + }, + ({ getFieldValue }) => ({ + validator(_, value) { + if (value < 0) { + return Promise.reject(new Error(t('maxTokensMinMessage'))); + } + return Promise.resolve(); + }, + }), + ]} + > + + + ); diff --git a/web/src/pages/user-setting/setting-model/volcengine-modal/index.tsx b/web/src/pages/user-setting/setting-model/volcengine-modal/index.tsx index 181fc0e1e..c3c60442a 100644 --- a/web/src/pages/user-setting/setting-model/volcengine-modal/index.tsx +++ b/web/src/pages/user-setting/setting-model/volcengine-modal/index.tsx @@ -1,7 +1,7 @@ import { useTranslate } from '@/hooks/common-hooks'; import { IModalProps } from '@/interfaces/common'; import { IAddLlmRequestBody } from '@/interfaces/request/llm'; -import { Flex, Form, Input, Modal, Select, Space, Switch } from 'antd'; +import { Flex, Form, Input, Modal, Select, Space, Switch, InputNumber } from 'antd'; import omit from 'lodash/omit'; type FieldType = IAddLlmRequestBody & { @@ -36,6 +36,7 @@ const VolcEngineModal = ({ ...omit(values, ['vision']), model_type: modelType, llm_factory: llmFactory, + max_tokens:values.max_tokens, }; console.info(data); @@ -103,19 +104,31 @@ const VolcEngineModal = ({ > - - {({ getFieldValue }) => - getFieldValue('model_type') === 'chat' && ( - - - - ) - } + + label={t('maxTokens')} + name="max_tokens" + rules={[ + { required: true, message: t('maxTokensMessage') }, + { + type: 'number', + message: t('maxTokensInvalidMessage'), + }, + ({ getFieldValue }) => ({ + validator(_, value) { + if (value < 0) { + return Promise.reject(new Error(t('maxTokensMinMessage'))); + } + return Promise.resolve(); + }, + }), + ]} + > + + ); diff --git a/web/src/pages/user-setting/setting-model/yiyan-modal/index.tsx b/web/src/pages/user-setting/setting-model/yiyan-modal/index.tsx index 8c99bb29b..74a310d05 100644 --- a/web/src/pages/user-setting/setting-model/yiyan-modal/index.tsx +++ b/web/src/pages/user-setting/setting-model/yiyan-modal/index.tsx @@ -1,7 +1,7 @@ import { useTranslate } from '@/hooks/common-hooks'; import { IModalProps } from '@/interfaces/common'; import { IAddLlmRequestBody } from '@/interfaces/request/llm'; -import { Form, Input, Modal, Select } from 'antd'; +import { Form, Input, Modal, Select, InputNumber } from 'antd'; import omit from 'lodash/omit'; type FieldType = IAddLlmRequestBody & { @@ -34,6 +34,7 @@ const YiyanModal = ({ ...omit(values, ['vision']), model_type: modelType, llm_factory: llmFactory, + max_tokens:values.max_tokens, }; console.info(data); @@ -89,6 +90,30 @@ const YiyanModal = ({ > + + label={t('maxTokens')} + name="max_tokens" + rules={[ + { required: true, message: t('maxTokensMessage') }, + { + type: 'number', + message: t('maxTokensInvalidMessage'), + }, + ({ getFieldValue }) => ({ + validator(_, value) { + if (value < 0) { + return Promise.reject(new Error(t('maxTokensMinMessage'))); + } + return Promise.resolve(); + }, + }), + ]} + > + + );