Feat: change default models (#7777)

### What problem does this PR solve?

change default models to buildin models
https://github.com/infiniflow/ragflow/issues/7774

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
liu an 2025-05-22 11:59:12 +08:00 committed by Yingfeng Zhang
parent 42f4d4dbc8
commit e166f132b3
7 changed files with 221 additions and 210 deletions

View File

@ -16,6 +16,7 @@
import logging import logging
from flask import request from flask import request
from api import settings from api import settings
from api.db import StatusEnum from api.db import StatusEnum
from api.db.services.dialog_service import DialogService from api.db.services.dialog_service import DialogService
@ -23,11 +24,10 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import TenantLLMService from api.db.services.llm_service import TenantLLMService
from api.db.services.user_service import TenantService from api.db.services.user_service import TenantService
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.api_utils import get_error_data_result, token_required, get_result, check_duplicate_ids from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required
@manager.route("/chats", methods=["POST"]) # noqa: F821
@manager.route('/chats', methods=['POST']) # noqa: F821
@token_required @token_required
def create(tenant_id): def create(tenant_id):
req = request.json req = request.json
@ -45,15 +45,16 @@ def create(tenant_id):
embd_ids = [TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs] # remove vendor suffix for comparison embd_ids = [TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs] # remove vendor suffix for comparison
embd_count = list(set(embd_ids)) embd_count = list(set(embd_ids))
if len(embd_count) > 1: if len(embd_count) > 1:
return get_result(message='Datasets use different embedding models."', return get_result(message='Datasets use different embedding models."', code=settings.RetCode.AUTHENTICATION_ERROR)
code=settings.RetCode.AUTHENTICATION_ERROR)
req["kb_ids"] = ids req["kb_ids"] = ids
# llm # llm
llm = req.get("llm") llm = req.get("llm")
if llm: if llm:
if "model_name" in llm: if "model_name" in llm:
req["llm_id"] = llm.pop("model_name") req["llm_id"] = llm.pop("model_name")
if not TenantLLMService.query(tenant_id=tenant_id, llm_name=req["llm_id"], model_type="chat"): if req.get("llm_id") is not None:
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(req["llm_id"])
if not TenantLLMService.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, model_type="chat"):
return get_error_data_result(f"`model_name` {req.get('llm_id')} doesn't exist") return get_error_data_result(f"`model_name` {req.get('llm_id')} doesn't exist")
req["llm_setting"] = req.pop("llm") req["llm_setting"] = req.pop("llm")
e, tenant = TenantService.get_by_id(tenant_id) e, tenant = TenantService.get_by_id(tenant_id)
@ -61,13 +62,8 @@ def create(tenant_id):
return get_error_data_result(message="Tenant not found!") return get_error_data_result(message="Tenant not found!")
# prompt # prompt
prompt = req.get("prompt") prompt = req.get("prompt")
key_mapping = {"parameters": "variables", key_mapping = {"parameters": "variables", "prologue": "opener", "quote": "show_quote", "system": "prompt", "rerank_id": "rerank_model", "vector_similarity_weight": "keywords_similarity_weight"}
"prologue": "opener", key_list = ["similarity_threshold", "vector_similarity_weight", "top_n", "rerank_id", "top_k"]
"quote": "show_quote",
"system": "prompt",
"rerank_id": "rerank_model",
"vector_similarity_weight": "keywords_similarity_weight"}
key_list = ["similarity_threshold", "vector_similarity_weight", "top_n", "rerank_id","top_k"]
if prompt: if prompt:
for new_key, old_key in key_mapping.items(): for new_key, old_key in key_mapping.items():
if old_key in prompt: if old_key in prompt:
@ -85,9 +81,7 @@ def create(tenant_id):
req["rerank_id"] = req.get("rerank_id", "") req["rerank_id"] = req.get("rerank_id", "")
if req.get("rerank_id"): if req.get("rerank_id"):
value_rerank_model = ["BAAI/bge-reranker-v2-m3", "maidalun1020/bce-reranker-base_v1"] value_rerank_model = ["BAAI/bge-reranker-v2-m3", "maidalun1020/bce-reranker-base_v1"]
if req["rerank_id"] not in value_rerank_model and not TenantLLMService.query(tenant_id=tenant_id, if req["rerank_id"] not in value_rerank_model and not TenantLLMService.query(tenant_id=tenant_id, llm_name=req.get("rerank_id"), model_type="rerank"):
llm_name=req.get("rerank_id"),
model_type="rerank"):
return get_error_data_result(f"`rerank_model` {req.get('rerank_id')} doesn't exist") return get_error_data_result(f"`rerank_model` {req.get('rerank_id')} doesn't exist")
if not req.get("llm_id"): if not req.get("llm_id"):
req["llm_id"] = tenant.llm_id req["llm_id"] = tenant.llm_id
@ -106,27 +100,24 @@ def create(tenant_id):
{knowledge} {knowledge}
The above is the knowledge base.""", The above is the knowledge base.""",
"prologue": "Hi! I'm your assistant, what can I do for you?", "prologue": "Hi! I'm your assistant, what can I do for you?",
"parameters": [ "parameters": [{"key": "knowledge", "optional": False}],
{"key": "knowledge", "optional": False}
],
"empty_response": "Sorry! No relevant content was found in the knowledge base!", "empty_response": "Sorry! No relevant content was found in the knowledge base!",
"quote": True, "quote": True,
"tts": False, "tts": False,
"refine_multiturn": True "refine_multiturn": True,
} }
key_list_2 = ["system", "prologue", "parameters", "empty_response", "quote", "tts", "refine_multiturn"] key_list_2 = ["system", "prologue", "parameters", "empty_response", "quote", "tts", "refine_multiturn"]
if "prompt_config" not in req: if "prompt_config" not in req:
req['prompt_config'] = {} req["prompt_config"] = {}
for key in key_list_2: for key in key_list_2:
temp = req['prompt_config'].get(key) temp = req["prompt_config"].get(key)
if (not temp and key == 'system') or (key not in req["prompt_config"]): if (not temp and key == "system") or (key not in req["prompt_config"]):
req['prompt_config'][key] = default_prompt[key] req["prompt_config"][key] = default_prompt[key]
for p in req['prompt_config']["parameters"]: for p in req["prompt_config"]["parameters"]:
if p["optional"]: if p["optional"]:
continue continue
if req['prompt_config']["system"].find("{%s}" % p["key"]) < 0: if req["prompt_config"]["system"].find("{%s}" % p["key"]) < 0:
return get_error_data_result( return get_error_data_result(message="Parameter '{}' is not used".format(p["key"]))
message="Parameter '{}' is not used".format(p["key"]))
# save # save
if not DialogService.save(**req): if not DialogService.save(**req):
return get_error_data_result(message="Fail to new a chat!") return get_error_data_result(message="Fail to new a chat!")
@ -141,10 +132,7 @@ def create(tenant_id):
renamed_dict[new_key] = value renamed_dict[new_key] = value
res["prompt"] = renamed_dict res["prompt"] = renamed_dict
del res["prompt_config"] del res["prompt_config"]
new_dict = {"similarity_threshold": res["similarity_threshold"], new_dict = {"similarity_threshold": res["similarity_threshold"], "keywords_similarity_weight": 1 - res["vector_similarity_weight"], "top_n": res["top_n"], "rerank_model": res["rerank_id"]}
"keywords_similarity_weight": 1-res["vector_similarity_weight"],
"top_n": res["top_n"],
"rerank_model": res['rerank_id']}
res["prompt"].update(new_dict) res["prompt"].update(new_dict)
for key in key_list: for key in key_list:
del res[key] del res[key]
@ -156,11 +144,11 @@ def create(tenant_id):
return get_result(data=res) return get_result(data=res)
@manager.route('/chats/<chat_id>', methods=['PUT']) # noqa: F821 @manager.route("/chats/<chat_id>", methods=["PUT"]) # noqa: F821
@token_required @token_required
def update(tenant_id, chat_id): def update(tenant_id, chat_id):
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value): if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
return get_error_data_result(message='You do not own the chat') return get_error_data_result(message="You do not own the chat")
req = request.json req = request.json
ids = req.get("dataset_ids") ids = req.get("dataset_ids")
if "show_quotation" in req: if "show_quotation" in req:
@ -179,9 +167,7 @@ def update(tenant_id, chat_id):
embd_ids = [TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs] # remove vendor suffix for comparison embd_ids = [TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs] # remove vendor suffix for comparison
embd_count = list(set(embd_ids)) embd_count = list(set(embd_ids))
if len(embd_count) != 1: if len(embd_count) != 1:
return get_result( return get_result(message='Datasets use different embedding models."', code=settings.RetCode.AUTHENTICATION_ERROR)
message='Datasets use different embedding models."',
code=settings.RetCode.AUTHENTICATION_ERROR)
req["kb_ids"] = ids req["kb_ids"] = ids
llm = req.get("llm") llm = req.get("llm")
if llm: if llm:
@ -195,13 +181,8 @@ def update(tenant_id, chat_id):
return get_error_data_result(message="Tenant not found!") return get_error_data_result(message="Tenant not found!")
# prompt # prompt
prompt = req.get("prompt") prompt = req.get("prompt")
key_mapping = {"parameters": "variables", key_mapping = {"parameters": "variables", "prologue": "opener", "quote": "show_quote", "system": "prompt", "rerank_id": "rerank_model", "vector_similarity_weight": "keywords_similarity_weight"}
"prologue": "opener", key_list = ["similarity_threshold", "vector_similarity_weight", "top_n", "rerank_id", "top_k"]
"quote": "show_quote",
"system": "prompt",
"rerank_id": "rerank_model",
"vector_similarity_weight": "keywords_similarity_weight"}
key_list = ["similarity_threshold", "vector_similarity_weight", "top_n", "rerank_id","top_k"]
if prompt: if prompt:
for new_key, old_key in key_mapping.items(): for new_key, old_key in key_mapping.items():
if old_key in prompt: if old_key in prompt:
@ -214,16 +195,12 @@ def update(tenant_id, chat_id):
res = res.to_json() res = res.to_json()
if req.get("rerank_id"): if req.get("rerank_id"):
value_rerank_model = ["BAAI/bge-reranker-v2-m3", "maidalun1020/bce-reranker-base_v1"] value_rerank_model = ["BAAI/bge-reranker-v2-m3", "maidalun1020/bce-reranker-base_v1"]
if req["rerank_id"] not in value_rerank_model and not TenantLLMService.query(tenant_id=tenant_id, if req["rerank_id"] not in value_rerank_model and not TenantLLMService.query(tenant_id=tenant_id, llm_name=req.get("rerank_id"), model_type="rerank"):
llm_name=req.get("rerank_id"),
model_type="rerank"):
return get_error_data_result(f"`rerank_model` {req.get('rerank_id')} doesn't exist") return get_error_data_result(f"`rerank_model` {req.get('rerank_id')} doesn't exist")
if "name" in req: if "name" in req:
if not req.get("name"): if not req.get("name"):
return get_error_data_result(message="`name` cannot be empty.") return get_error_data_result(message="`name` cannot be empty.")
if req["name"].lower() != res["name"].lower() \ if req["name"].lower() != res["name"].lower() and len(DialogService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)) > 0:
and len(
DialogService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)) > 0:
return get_error_data_result(message="Duplicated chat name in updating chat.") return get_error_data_result(message="Duplicated chat name in updating chat.")
if "prompt_config" in req: if "prompt_config" in req:
res["prompt_config"].update(req["prompt_config"]) res["prompt_config"].update(req["prompt_config"])
@ -246,7 +223,7 @@ def update(tenant_id, chat_id):
return get_result() return get_result()
@manager.route('/chats', methods=['DELETE']) # noqa: F821 @manager.route("/chats", methods=["DELETE"]) # noqa: F821
@token_required @token_required
def delete(tenant_id): def delete(tenant_id):
errors = [] errors = []
@ -276,27 +253,20 @@ def delete(tenant_id):
if errors: if errors:
if success_count > 0: if success_count > 0:
return get_result( return get_result(data={"success_count": success_count, "errors": errors}, message=f"Partially deleted {success_count} chats with {len(errors)} errors")
data={"success_count": success_count, "errors": errors},
message=f"Partially deleted {success_count} chats with {len(errors)} errors"
)
else: else:
return get_error_data_result(message="; ".join(errors)) return get_error_data_result(message="; ".join(errors))
if duplicate_messages: if duplicate_messages:
if success_count > 0: if success_count > 0:
return get_result( return get_result(message=f"Partially deleted {success_count} chats with {len(duplicate_messages)} errors", data={"success_count": success_count, "errors": duplicate_messages})
message=f"Partially deleted {success_count} chats with {len(duplicate_messages)} errors",
data={"success_count": success_count, "errors": duplicate_messages}
)
else: else:
return get_error_data_result(message=";".join(duplicate_messages)) return get_error_data_result(message=";".join(duplicate_messages))
return get_result() return get_result()
@manager.route("/chats", methods=["GET"]) # noqa: F821
@manager.route('/chats', methods=['GET']) # noqa: F821
@token_required @token_required
def list_chat(tenant_id): def list_chat(tenant_id):
id = request.args.get("id") id = request.args.get("id")
@ -316,13 +286,15 @@ def list_chat(tenant_id):
if not chats: if not chats:
return get_result(data=[]) return get_result(data=[])
list_assts = [] list_assts = []
key_mapping = {"parameters": "variables", key_mapping = {
"parameters": "variables",
"prologue": "opener", "prologue": "opener",
"quote": "show_quote", "quote": "show_quote",
"system": "prompt", "system": "prompt",
"rerank_id": "rerank_model", "rerank_id": "rerank_model",
"vector_similarity_weight": "keywords_similarity_weight", "vector_similarity_weight": "keywords_similarity_weight",
"do_refer": "show_quotation"} "do_refer": "show_quotation",
}
key_list = ["similarity_threshold", "vector_similarity_weight", "top_n", "rerank_id"] key_list = ["similarity_threshold", "vector_similarity_weight", "top_n", "rerank_id"]
for res in chats: for res in chats:
renamed_dict = {} renamed_dict = {}
@ -331,10 +303,7 @@ def list_chat(tenant_id):
renamed_dict[new_key] = value renamed_dict[new_key] = value
res["prompt"] = renamed_dict res["prompt"] = renamed_dict
del res["prompt_config"] del res["prompt_config"]
new_dict = {"similarity_threshold": res["similarity_threshold"], new_dict = {"similarity_threshold": res["similarity_threshold"], "keywords_similarity_weight": 1 - res["vector_similarity_weight"], "top_n": res["top_n"], "rerank_model": res["rerank_id"]}
"keywords_similarity_weight": 1-res["vector_similarity_weight"],
"top_n": res["top_n"],
"rerank_model": res['rerank_id']}
res["prompt"].update(new_dict) res["prompt"].update(new_dict)
for key in key_list: for key in key_list:
del res[key] del res[key]

View File

@ -13,36 +13,37 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import logging
import json import json
import logging
import re import re
from datetime import datetime from datetime import datetime
from flask import request, session, redirect from flask import redirect, request, session
from werkzeug.security import generate_password_hash, check_password_hash from flask_login import current_user, login_required, login_user, logout_user
from flask_login import login_required, current_user, login_user, logout_user from werkzeug.security import check_password_hash, generate_password_hash
from api import settings
from api.apps.auth import get_auth_client
from api.db import FileType, UserTenantRole
from api.db.db_models import TenantLLM from api.db.db_models import TenantLLM
from api.db.services.llm_service import TenantLLMService, LLMService from api.db.services.file_service import FileService
from api.utils.api_utils import ( from api.db.services.llm_service import LLMService, TenantLLMService
server_error_response, from api.db.services.user_service import TenantService, UserService, UserTenantService
validate_request,
get_data_error_result,
)
from api.utils import ( from api.utils import (
get_uuid,
get_format_time,
decrypt,
download_img,
current_timestamp, current_timestamp,
datetime_format, datetime_format,
decrypt,
download_img,
get_format_time,
get_uuid,
)
from api.utils.api_utils import (
construct_response,
get_data_error_result,
get_json_result,
server_error_response,
validate_request,
) )
from api.db import UserTenantRole, FileType
from api import settings
from api.db.services.user_service import UserService, TenantService, UserTenantService
from api.db.services.file_service import FileService
from api.utils.api_utils import get_json_result, construct_response
from api.apps.auth import get_auth_client
@manager.route("/login", methods=["POST", "GET"]) # noqa: F821 @manager.route("/login", methods=["POST", "GET"]) # noqa: F821
@ -77,9 +78,7 @@ def login():
type: object type: object
""" """
if not request.json: if not request.json:
return get_json_result( return get_json_result(data=False, code=settings.RetCode.AUTHENTICATION_ERROR, message="Unauthorized!")
data=False, code=settings.RetCode.AUTHENTICATION_ERROR, message="Unauthorized!"
)
email = request.json.get("email", "") email = request.json.get("email", "")
users = UserService.query(email=email) users = UserService.query(email=email)
@ -94,9 +93,7 @@ def login():
try: try:
password = decrypt(password) password = decrypt(password)
except BaseException: except BaseException:
return get_json_result( return get_json_result(data=False, code=settings.RetCode.SERVER_ERROR, message="Fail to crypt password")
data=False, code=settings.RetCode.SERVER_ERROR, message="Fail to crypt password"
)
user = UserService.query_user(email, password) user = UserService.query_user(email, password)
if user: if user:
@ -124,19 +121,17 @@ def get_login_channels():
try: try:
channels = [] channels = []
for channel, config in settings.OAUTH_CONFIG.items(): for channel, config in settings.OAUTH_CONFIG.items():
channels.append({ channels.append(
{
"channel": channel, "channel": channel,
"display_name": config.get("display_name", channel.title()), "display_name": config.get("display_name", channel.title()),
"icon": config.get("icon", "sso"), "icon": config.get("icon", "sso"),
}) }
)
return get_json_result(data=channels) return get_json_result(data=channels)
except Exception as e: except Exception as e:
logging.exception(e) logging.exception(e)
return get_json_result( return get_json_result(data=[], message=f"Load channels failure, error: {str(e)}", code=settings.RetCode.EXCEPTION_ERROR)
data=[],
message=f"Load channels failure, error: {str(e)}",
code=settings.RetCode.EXCEPTION_ERROR
)
@manager.route("/login/<channel>", methods=["GET"]) # noqa: F821 @manager.route("/login/<channel>", methods=["GET"]) # noqa: F821
@ -434,9 +429,7 @@ def user_info_from_feishu(access_token):
"Content-Type": "application/json; charset=utf-8", "Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {access_token}", "Authorization": f"Bearer {access_token}",
} }
res = requests.get( res = requests.get("https://open.feishu.cn/open-apis/authen/v1/user_info", headers=headers)
"https://open.feishu.cn/open-apis/authen/v1/user_info", headers=headers
)
user_info = res.json()["data"] user_info = res.json()["data"]
user_info["email"] = None if user_info.get("email") == "" else user_info["email"] user_info["email"] = None if user_info.get("email") == "" else user_info["email"]
return user_info return user_info
@ -446,17 +439,13 @@ def user_info_from_github(access_token):
import requests import requests
headers = {"Accept": "application/json", "Authorization": f"token {access_token}"} headers = {"Accept": "application/json", "Authorization": f"token {access_token}"}
res = requests.get( res = requests.get(f"https://api.github.com/user?access_token={access_token}", headers=headers)
f"https://api.github.com/user?access_token={access_token}", headers=headers
)
user_info = res.json() user_info = res.json()
email_info = requests.get( email_info = requests.get(
f"https://api.github.com/user/emails?access_token={access_token}", f"https://api.github.com/user/emails?access_token={access_token}",
headers=headers, headers=headers,
).json() ).json()
user_info["email"] = next( user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"]
(email for email in email_info if email["primary"]), None
)["email"]
return user_info return user_info
@ -516,9 +505,7 @@ def setting_user():
request_data = request.json request_data = request.json
if request_data.get("password"): if request_data.get("password"):
new_password = request_data.get("new_password") new_password = request_data.get("new_password")
if not check_password_hash( if not check_password_hash(current_user.password, decrypt(request_data["password"])):
current_user.password, decrypt(request_data["password"])
):
return get_json_result( return get_json_result(
data=False, data=False,
code=settings.RetCode.AUTHENTICATION_ERROR, code=settings.RetCode.AUTHENTICATION_ERROR,
@ -549,9 +536,7 @@ def setting_user():
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
logging.exception(e) logging.exception(e)
return get_json_result( return get_json_result(data=False, message="Update failure!", code=settings.RetCode.EXCEPTION_ERROR)
data=False, message="Update failure!", code=settings.RetCode.EXCEPTION_ERROR
)
@manager.route("/info", methods=["GET"]) # noqa: F821 @manager.route("/info", methods=["GET"]) # noqa: F821
@ -643,7 +628,21 @@ def user_register(user_id, user):
"model_type": llm.model_type, "model_type": llm.model_type,
"api_key": settings.API_KEY, "api_key": settings.API_KEY,
"api_base": settings.LLM_BASE_URL, "api_base": settings.LLM_BASE_URL,
"max_tokens": llm.max_tokens if llm.max_tokens else 8192 "max_tokens": llm.max_tokens if llm.max_tokens else 8192,
}
)
if settings.LIGHTEN != 1:
for buildin_embedding_model in settings.BUILTIN_EMBEDDING_MODELS:
mdlnm, fid = TenantLLMService.split_model_name_and_factory(buildin_embedding_model)
tenant_llm.append(
{
"tenant_id": user_id,
"llm_factory": fid,
"llm_name": mdlnm,
"model_type": "embedding",
"api_key": "",
"api_base": "",
"max_tokens": 1024 if buildin_embedding_model == "BAAI/bge-large-zh-v1.5@BAAI" else 512,
} }
) )

View File

@ -81,7 +81,7 @@ def init_settings():
DATABASE = decrypt_database_config(name=DATABASE_TYPE) DATABASE = decrypt_database_config(name=DATABASE_TYPE)
LLM = get_base_config("user_default_llm", {}) LLM = get_base_config("user_default_llm", {})
LLM_DEFAULT_MODELS = LLM.get("default_models", {}) LLM_DEFAULT_MODELS = LLM.get("default_models", {})
LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen") LLM_FACTORY = LLM.get("factory")
LLM_BASE_URL = LLM.get("base_url") LLM_BASE_URL = LLM.get("base_url")
try: try:
REGISTER_ENABLED = int(os.environ.get("REGISTER_ENABLED", "1")) REGISTER_ENABLED = int(os.environ.get("REGISTER_ENABLED", "1"))

View File

@ -567,7 +567,7 @@
{ {
"name": "Youdao", "name": "Youdao",
"logo": "", "logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", "tags": "TEXT EMBEDDING",
"status": "1", "status": "1",
"llm": [ "llm": [
{ {
@ -755,7 +755,7 @@
{ {
"name": "BAAI", "name": "BAAI",
"logo": "", "logo": "",
"tags": "TEXT EMBEDDING, TEXT RE-RANK", "tags": "TEXT EMBEDDING",
"status": "1", "status": "1",
"llm": [ "llm": [
{ {

View File

@ -20,7 +20,9 @@ import pytest
import requests import requests
HOST_ADDRESS = os.getenv("HOST_ADDRESS", "http://127.0.0.1:9380") HOST_ADDRESS = os.getenv("HOST_ADDRESS", "http://127.0.0.1:9380")
ZHIPU_AI_API_KEY = os.getenv("ZHIPU_AI_API_KEY", "ca148e43209c40109e2bc2f56281dd11.BltyA2N1B043B7Ra")
if ZHIPU_AI_API_KEY is None:
pytest.exit("Error: Environment variable ZHIPU_AI_API_KEY must be set")
# def generate_random_email(): # def generate_random_email():
# return 'user_' + ''.join(random.choices(string.ascii_lowercase + string.digits, k=8))+'@1.com' # return 'user_' + ''.join(random.choices(string.ascii_lowercase + string.digits, k=8))+'@1.com'
@ -87,3 +89,64 @@ def get_auth():
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def get_email(): def get_email():
return EMAIL return EMAIL
def get_my_llms(auth, name):
url = HOST_ADDRESS + "/v1/llm/my_llms"
authorization = {"Authorization": auth}
response = requests.get(url=url, headers=authorization)
res = response.json()
if res.get("code") != 0:
raise Exception(res.get("message"))
if name in res.get("data"):
return True
return False
def add_models(auth):
url = HOST_ADDRESS + "/v1/llm/set_api_key"
authorization = {"Authorization": auth}
models_info = {
"ZHIPU-AI": {"llm_factory": "ZHIPU-AI", "api_key": ZHIPU_AI_API_KEY},
}
for name, model_info in models_info.items():
if not get_my_llms(auth, name):
response = requests.post(url=url, headers=authorization, json=model_info)
res = response.json()
if res.get("code") != 0:
pytest.exit(f"Critical error in add_models: {res.get('message')}")
def get_tenant_info(auth):
url = HOST_ADDRESS + "/v1/user/tenant_info"
authorization = {"Authorization": auth}
response = requests.get(url=url, headers=authorization)
res = response.json()
if res.get("code") != 0:
raise Exception(res.get("message"))
return res["data"].get("tenant_id")
@pytest.fixture(scope="session", autouse=True)
def set_tenant_info(get_auth):
auth = get_auth
try:
add_models(auth)
tenant_id = get_tenant_info(auth)
except Exception as e:
pytest.exit(f"Error in set_tenant_info: {str(e)}")
url = HOST_ADDRESS + "/v1/user/set_tenant_info"
authorization = {"Authorization": get_auth}
tenant_info = {
"tenant_id": tenant_id,
"llm_id": "glm-4-flash@ZHIPU-AI",
"embd_id": "BAAI/bge-large-zh-v1.5@BAAI",
"img2txt_id": "glm-4v@ZHIPU-AI",
"asr_id": "",
"tts_id": None,
}
response = requests.post(url=url, headers=authorization, json=tenant_info)
res = response.json()
if res.get("code") != 0:
raise Exception(res.get("message"))

View File

@ -16,7 +16,6 @@
import os import os
import pytest import pytest
import requests
from common import ( from common import (
add_chunk, add_chunk,
batch_create_datasets, batch_create_datasets,
@ -49,9 +48,6 @@ MARKER_EXPRESSIONS = {
"p3": "p1 or p2 or p3", "p3": "p1 or p2 or p3",
} }
HOST_ADDRESS = os.getenv("HOST_ADDRESS", "http://127.0.0.1:9380") HOST_ADDRESS = os.getenv("HOST_ADDRESS", "http://127.0.0.1:9380")
ZHIPU_AI_API_KEY = os.getenv("ZHIPU_AI_API_KEY", "ca148e43209c40109e2bc2f56281dd11.BltyA2N1B043B7Ra")
if ZHIPU_AI_API_KEY is None:
pytest.exit("Error: Environment variable ZHIPU_AI_API_KEY must be set")
def pytest_addoption(parser: pytest.Parser) -> None: def pytest_addoption(parser: pytest.Parser) -> None:
@ -85,67 +81,6 @@ def get_http_api_auth(get_api_key_fixture):
return RAGFlowHttpApiAuth(get_api_key_fixture) return RAGFlowHttpApiAuth(get_api_key_fixture)
def get_my_llms(auth, name):
url = HOST_ADDRESS + "/v1/llm/my_llms"
authorization = {"Authorization": auth}
response = requests.get(url=url, headers=authorization)
res = response.json()
if res.get("code") != 0:
raise Exception(res.get("message"))
if name in res.get("data"):
return True
return False
def add_models(auth):
url = HOST_ADDRESS + "/v1/llm/set_api_key"
authorization = {"Authorization": auth}
models_info = {
"ZHIPU-AI": {"llm_factory": "ZHIPU-AI", "api_key": ZHIPU_AI_API_KEY},
}
for name, model_info in models_info.items():
if not get_my_llms(auth, name):
response = requests.post(url=url, headers=authorization, json=model_info)
res = response.json()
if res.get("code") != 0:
pytest.exit(f"Critical error in add_models: {res.get('message')}")
def get_tenant_info(auth):
url = HOST_ADDRESS + "/v1/user/tenant_info"
authorization = {"Authorization": auth}
response = requests.get(url=url, headers=authorization)
res = response.json()
if res.get("code") != 0:
raise Exception(res.get("message"))
return res["data"].get("tenant_id")
@pytest.fixture(scope="session", autouse=True)
def set_tenant_info(get_auth):
auth = get_auth
try:
add_models(auth)
tenant_id = get_tenant_info(auth)
except Exception as e:
pytest.exit(f"Error in set_tenant_info: {str(e)}")
url = HOST_ADDRESS + "/v1/user/set_tenant_info"
authorization = {"Authorization": get_auth}
tenant_info = {
"tenant_id": tenant_id,
"llm_id": "glm-4-flash@ZHIPU-AI",
"embd_id": "BAAI/bge-large-zh-v1.5@BAAI",
"img2txt_id": "glm-4v@ZHIPU-AI",
"asr_id": "",
"tts_id": None,
}
response = requests.post(url=url, headers=authorization, json=tenant_info)
res = response.json()
if res.get("code") != 0:
raise Exception(res.get("message"))
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def clear_datasets(request, get_http_api_auth): def clear_datasets(request, get_http_api_auth):
def cleanup(): def cleanup():

View File

@ -14,8 +14,9 @@
# limitations under the License. # limitations under the License.
# #
from ragflow_sdk import RAGFlow
from common import HOST_ADDRESS from common import HOST_ADDRESS
from ragflow_sdk import RAGFlow
from ragflow_sdk.modules.chat import Chat
def test_create_chat_with_name(get_api_key_fixture): def test_create_chat_with_name(get_api_key_fixture):
@ -31,7 +32,18 @@ def test_create_chat_with_name(get_api_key_fixture):
docs = kb.upload_documents(documents) docs = kb.upload_documents(documents)
for doc in docs: for doc in docs:
doc.add_chunk("This is a test to add chunk") doc.add_chunk("This is a test to add chunk")
rag.create_chat("test_create_chat", dataset_ids=[kb.id]) llm = Chat.LLM(
rag,
{
"model_name": "glm-4-flash@ZHIPU-AI",
"temperature": 0.1,
"top_p": 0.3,
"presence_penalty": 0.4,
"frequency_penalty": 0.7,
"max_tokens": 512,
},
)
rag.create_chat("test_create_chat", dataset_ids=[kb.id], llm=llm)
def test_update_chat_with_name(get_api_key_fixture): def test_update_chat_with_name(get_api_key_fixture):
@ -47,7 +59,18 @@ def test_update_chat_with_name(get_api_key_fixture):
docs = kb.upload_documents(documents) docs = kb.upload_documents(documents)
for doc in docs: for doc in docs:
doc.add_chunk("This is a test to add chunk") doc.add_chunk("This is a test to add chunk")
chat = rag.create_chat("test_update_chat", dataset_ids=[kb.id]) llm = Chat.LLM(
rag,
{
"model_name": "glm-4-flash@ZHIPU-AI",
"temperature": 0.1,
"top_p": 0.3,
"presence_penalty": 0.4,
"frequency_penalty": 0.7,
"max_tokens": 512,
},
)
chat = rag.create_chat("test_update_chat", dataset_ids=[kb.id], llm=llm)
chat.update({"name": "new_chat"}) chat.update({"name": "new_chat"})
@ -64,7 +87,18 @@ def test_delete_chats_with_success(get_api_key_fixture):
docs = kb.upload_documents(documents) docs = kb.upload_documents(documents)
for doc in docs: for doc in docs:
doc.add_chunk("This is a test to add chunk") doc.add_chunk("This is a test to add chunk")
chat = rag.create_chat("test_delete_chat", dataset_ids=[kb.id]) llm = Chat.LLM(
rag,
{
"model_name": "glm-4-flash@ZHIPU-AI",
"temperature": 0.1,
"top_p": 0.3,
"presence_penalty": 0.4,
"frequency_penalty": 0.7,
"max_tokens": 512,
},
)
chat = rag.create_chat("test_delete_chat", dataset_ids=[kb.id], llm=llm)
rag.delete_chats(ids=[chat.id]) rag.delete_chats(ids=[chat.id])
@ -81,6 +115,17 @@ def test_list_chats_with_success(get_api_key_fixture):
docs = kb.upload_documents(documents) docs = kb.upload_documents(documents)
for doc in docs: for doc in docs:
doc.add_chunk("This is a test to add chunk") doc.add_chunk("This is a test to add chunk")
rag.create_chat("test_list_1", dataset_ids=[kb.id]) llm = Chat.LLM(
rag.create_chat("test_list_2", dataset_ids=[kb.id]) rag,
{
"model_name": "glm-4-flash@ZHIPU-AI",
"temperature": 0.1,
"top_p": 0.3,
"presence_penalty": 0.4,
"frequency_penalty": 0.7,
"max_tokens": 512,
},
)
rag.create_chat("test_list_1", dataset_ids=[kb.id], llm=llm)
rag.create_chat("test_list_2", dataset_ids=[kb.id], llm=llm)
rag.list_chats() rag.list_chats()