mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-10 16:59:01 +08:00
Feat: change default models (#7777)
### What problem does this PR solve? change default models to buildin models https://github.com/infiniflow/ragflow/issues/7774 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
parent
42f4d4dbc8
commit
e166f132b3
@ -16,6 +16,7 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
|
||||
from api import settings
|
||||
from api.db import StatusEnum
|
||||
from api.db.services.dialog_service import DialogService
|
||||
@ -23,15 +24,14 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import TenantLLMService
|
||||
from api.db.services.user_service import TenantService
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import get_error_data_result, token_required, get_result, check_duplicate_ids
|
||||
from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required
|
||||
|
||||
|
||||
|
||||
@manager.route('/chats', methods=['POST']) # noqa: F821
|
||||
@manager.route("/chats", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def create(tenant_id):
|
||||
req = request.json
|
||||
ids = [i for i in req.get("dataset_ids", []) if i]
|
||||
ids = [i for i in req.get("dataset_ids", []) if i]
|
||||
for kb_id in ids:
|
||||
kbs = KnowledgebaseService.accessible(kb_id=kb_id, user_id=tenant_id)
|
||||
if not kbs:
|
||||
@ -40,34 +40,30 @@ def create(tenant_id):
|
||||
kb = kbs[0]
|
||||
if kb.chunk_num == 0:
|
||||
return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
|
||||
|
||||
|
||||
kbs = KnowledgebaseService.get_by_ids(ids) if ids else []
|
||||
embd_ids = [TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs] # remove vendor suffix for comparison
|
||||
embd_count = list(set(embd_ids))
|
||||
if len(embd_count) > 1:
|
||||
return get_result(message='Datasets use different embedding models."',
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
return get_result(message='Datasets use different embedding models."', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
req["kb_ids"] = ids
|
||||
# llm
|
||||
llm = req.get("llm")
|
||||
if llm:
|
||||
if "model_name" in llm:
|
||||
req["llm_id"] = llm.pop("model_name")
|
||||
if not TenantLLMService.query(tenant_id=tenant_id, llm_name=req["llm_id"], model_type="chat"):
|
||||
return get_error_data_result(f"`model_name` {req.get('llm_id')} doesn't exist")
|
||||
if req.get("llm_id") is not None:
|
||||
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(req["llm_id"])
|
||||
if not TenantLLMService.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, model_type="chat"):
|
||||
return get_error_data_result(f"`model_name` {req.get('llm_id')} doesn't exist")
|
||||
req["llm_setting"] = req.pop("llm")
|
||||
e, tenant = TenantService.get_by_id(tenant_id)
|
||||
if not e:
|
||||
return get_error_data_result(message="Tenant not found!")
|
||||
# prompt
|
||||
prompt = req.get("prompt")
|
||||
key_mapping = {"parameters": "variables",
|
||||
"prologue": "opener",
|
||||
"quote": "show_quote",
|
||||
"system": "prompt",
|
||||
"rerank_id": "rerank_model",
|
||||
"vector_similarity_weight": "keywords_similarity_weight"}
|
||||
key_list = ["similarity_threshold", "vector_similarity_weight", "top_n", "rerank_id","top_k"]
|
||||
key_mapping = {"parameters": "variables", "prologue": "opener", "quote": "show_quote", "system": "prompt", "rerank_id": "rerank_model", "vector_similarity_weight": "keywords_similarity_weight"}
|
||||
key_list = ["similarity_threshold", "vector_similarity_weight", "top_n", "rerank_id", "top_k"]
|
||||
if prompt:
|
||||
for new_key, old_key in key_mapping.items():
|
||||
if old_key in prompt:
|
||||
@ -85,9 +81,7 @@ def create(tenant_id):
|
||||
req["rerank_id"] = req.get("rerank_id", "")
|
||||
if req.get("rerank_id"):
|
||||
value_rerank_model = ["BAAI/bge-reranker-v2-m3", "maidalun1020/bce-reranker-base_v1"]
|
||||
if req["rerank_id"] not in value_rerank_model and not TenantLLMService.query(tenant_id=tenant_id,
|
||||
llm_name=req.get("rerank_id"),
|
||||
model_type="rerank"):
|
||||
if req["rerank_id"] not in value_rerank_model and not TenantLLMService.query(tenant_id=tenant_id, llm_name=req.get("rerank_id"), model_type="rerank"):
|
||||
return get_error_data_result(f"`rerank_model` {req.get('rerank_id')} doesn't exist")
|
||||
if not req.get("llm_id"):
|
||||
req["llm_id"] = tenant.llm_id
|
||||
@ -106,27 +100,24 @@ def create(tenant_id):
|
||||
{knowledge}
|
||||
The above is the knowledge base.""",
|
||||
"prologue": "Hi! I'm your assistant, what can I do for you?",
|
||||
"parameters": [
|
||||
{"key": "knowledge", "optional": False}
|
||||
],
|
||||
"parameters": [{"key": "knowledge", "optional": False}],
|
||||
"empty_response": "Sorry! No relevant content was found in the knowledge base!",
|
||||
"quote": True,
|
||||
"tts": False,
|
||||
"refine_multiturn": True
|
||||
"refine_multiturn": True,
|
||||
}
|
||||
key_list_2 = ["system", "prologue", "parameters", "empty_response", "quote", "tts", "refine_multiturn"]
|
||||
if "prompt_config" not in req:
|
||||
req['prompt_config'] = {}
|
||||
req["prompt_config"] = {}
|
||||
for key in key_list_2:
|
||||
temp = req['prompt_config'].get(key)
|
||||
if (not temp and key == 'system') or (key not in req["prompt_config"]):
|
||||
req['prompt_config'][key] = default_prompt[key]
|
||||
for p in req['prompt_config']["parameters"]:
|
||||
temp = req["prompt_config"].get(key)
|
||||
if (not temp and key == "system") or (key not in req["prompt_config"]):
|
||||
req["prompt_config"][key] = default_prompt[key]
|
||||
for p in req["prompt_config"]["parameters"]:
|
||||
if p["optional"]:
|
||||
continue
|
||||
if req['prompt_config']["system"].find("{%s}" % p["key"]) < 0:
|
||||
return get_error_data_result(
|
||||
message="Parameter '{}' is not used".format(p["key"]))
|
||||
if req["prompt_config"]["system"].find("{%s}" % p["key"]) < 0:
|
||||
return get_error_data_result(message="Parameter '{}' is not used".format(p["key"]))
|
||||
# save
|
||||
if not DialogService.save(**req):
|
||||
return get_error_data_result(message="Fail to new a chat!")
|
||||
@ -141,10 +132,7 @@ def create(tenant_id):
|
||||
renamed_dict[new_key] = value
|
||||
res["prompt"] = renamed_dict
|
||||
del res["prompt_config"]
|
||||
new_dict = {"similarity_threshold": res["similarity_threshold"],
|
||||
"keywords_similarity_weight": 1-res["vector_similarity_weight"],
|
||||
"top_n": res["top_n"],
|
||||
"rerank_model": res['rerank_id']}
|
||||
new_dict = {"similarity_threshold": res["similarity_threshold"], "keywords_similarity_weight": 1 - res["vector_similarity_weight"], "top_n": res["top_n"], "rerank_model": res["rerank_id"]}
|
||||
res["prompt"].update(new_dict)
|
||||
for key in key_list:
|
||||
del res[key]
|
||||
@ -156,11 +144,11 @@ def create(tenant_id):
|
||||
return get_result(data=res)
|
||||
|
||||
|
||||
@manager.route('/chats/<chat_id>', methods=['PUT']) # noqa: F821
|
||||
@manager.route("/chats/<chat_id>", methods=["PUT"]) # noqa: F821
|
||||
@token_required
|
||||
def update(tenant_id, chat_id):
|
||||
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
|
||||
return get_error_data_result(message='You do not own the chat')
|
||||
return get_error_data_result(message="You do not own the chat")
|
||||
req = request.json
|
||||
ids = req.get("dataset_ids")
|
||||
if "show_quotation" in req:
|
||||
@ -174,14 +162,12 @@ def update(tenant_id, chat_id):
|
||||
kb = kbs[0]
|
||||
if kb.chunk_num == 0:
|
||||
return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
|
||||
|
||||
|
||||
kbs = KnowledgebaseService.get_by_ids(ids)
|
||||
embd_ids = [TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs] # remove vendor suffix for comparison
|
||||
embd_count = list(set(embd_ids))
|
||||
if len(embd_count) != 1:
|
||||
return get_result(
|
||||
message='Datasets use different embedding models."',
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
return get_result(message='Datasets use different embedding models."', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
req["kb_ids"] = ids
|
||||
llm = req.get("llm")
|
||||
if llm:
|
||||
@ -195,13 +181,8 @@ def update(tenant_id, chat_id):
|
||||
return get_error_data_result(message="Tenant not found!")
|
||||
# prompt
|
||||
prompt = req.get("prompt")
|
||||
key_mapping = {"parameters": "variables",
|
||||
"prologue": "opener",
|
||||
"quote": "show_quote",
|
||||
"system": "prompt",
|
||||
"rerank_id": "rerank_model",
|
||||
"vector_similarity_weight": "keywords_similarity_weight"}
|
||||
key_list = ["similarity_threshold", "vector_similarity_weight", "top_n", "rerank_id","top_k"]
|
||||
key_mapping = {"parameters": "variables", "prologue": "opener", "quote": "show_quote", "system": "prompt", "rerank_id": "rerank_model", "vector_similarity_weight": "keywords_similarity_weight"}
|
||||
key_list = ["similarity_threshold", "vector_similarity_weight", "top_n", "rerank_id", "top_k"]
|
||||
if prompt:
|
||||
for new_key, old_key in key_mapping.items():
|
||||
if old_key in prompt:
|
||||
@ -214,16 +195,12 @@ def update(tenant_id, chat_id):
|
||||
res = res.to_json()
|
||||
if req.get("rerank_id"):
|
||||
value_rerank_model = ["BAAI/bge-reranker-v2-m3", "maidalun1020/bce-reranker-base_v1"]
|
||||
if req["rerank_id"] not in value_rerank_model and not TenantLLMService.query(tenant_id=tenant_id,
|
||||
llm_name=req.get("rerank_id"),
|
||||
model_type="rerank"):
|
||||
if req["rerank_id"] not in value_rerank_model and not TenantLLMService.query(tenant_id=tenant_id, llm_name=req.get("rerank_id"), model_type="rerank"):
|
||||
return get_error_data_result(f"`rerank_model` {req.get('rerank_id')} doesn't exist")
|
||||
if "name" in req:
|
||||
if not req.get("name"):
|
||||
return get_error_data_result(message="`name` cannot be empty.")
|
||||
if req["name"].lower() != res["name"].lower() \
|
||||
and len(
|
||||
DialogService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)) > 0:
|
||||
if req["name"].lower() != res["name"].lower() and len(DialogService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)) > 0:
|
||||
return get_error_data_result(message="Duplicated chat name in updating chat.")
|
||||
if "prompt_config" in req:
|
||||
res["prompt_config"].update(req["prompt_config"])
|
||||
@ -246,7 +223,7 @@ def update(tenant_id, chat_id):
|
||||
return get_result()
|
||||
|
||||
|
||||
@manager.route('/chats', methods=['DELETE']) # noqa: F821
|
||||
@manager.route("/chats", methods=["DELETE"]) # noqa: F821
|
||||
@token_required
|
||||
def delete(tenant_id):
|
||||
errors = []
|
||||
@ -273,30 +250,23 @@ def delete(tenant_id):
|
||||
temp_dict = {"status": StatusEnum.INVALID.value}
|
||||
DialogService.update_by_id(id, temp_dict)
|
||||
success_count += 1
|
||||
|
||||
|
||||
if errors:
|
||||
if success_count > 0:
|
||||
return get_result(
|
||||
data={"success_count": success_count, "errors": errors},
|
||||
message=f"Partially deleted {success_count} chats with {len(errors)} errors"
|
||||
)
|
||||
return get_result(data={"success_count": success_count, "errors": errors}, message=f"Partially deleted {success_count} chats with {len(errors)} errors")
|
||||
else:
|
||||
return get_error_data_result(message="; ".join(errors))
|
||||
|
||||
|
||||
if duplicate_messages:
|
||||
if success_count > 0:
|
||||
return get_result(
|
||||
message=f"Partially deleted {success_count} chats with {len(duplicate_messages)} errors",
|
||||
data={"success_count": success_count, "errors": duplicate_messages}
|
||||
)
|
||||
return get_result(message=f"Partially deleted {success_count} chats with {len(duplicate_messages)} errors", data={"success_count": success_count, "errors": duplicate_messages})
|
||||
else:
|
||||
return get_error_data_result(message=";".join(duplicate_messages))
|
||||
|
||||
|
||||
return get_result()
|
||||
|
||||
|
||||
|
||||
@manager.route('/chats', methods=['GET']) # noqa: F821
|
||||
@manager.route("/chats", methods=["GET"]) # noqa: F821
|
||||
@token_required
|
||||
def list_chat(tenant_id):
|
||||
id = request.args.get("id")
|
||||
@ -316,13 +286,15 @@ def list_chat(tenant_id):
|
||||
if not chats:
|
||||
return get_result(data=[])
|
||||
list_assts = []
|
||||
key_mapping = {"parameters": "variables",
|
||||
"prologue": "opener",
|
||||
"quote": "show_quote",
|
||||
"system": "prompt",
|
||||
"rerank_id": "rerank_model",
|
||||
"vector_similarity_weight": "keywords_similarity_weight",
|
||||
"do_refer": "show_quotation"}
|
||||
key_mapping = {
|
||||
"parameters": "variables",
|
||||
"prologue": "opener",
|
||||
"quote": "show_quote",
|
||||
"system": "prompt",
|
||||
"rerank_id": "rerank_model",
|
||||
"vector_similarity_weight": "keywords_similarity_weight",
|
||||
"do_refer": "show_quotation",
|
||||
}
|
||||
key_list = ["similarity_threshold", "vector_similarity_weight", "top_n", "rerank_id"]
|
||||
for res in chats:
|
||||
renamed_dict = {}
|
||||
@ -331,10 +303,7 @@ def list_chat(tenant_id):
|
||||
renamed_dict[new_key] = value
|
||||
res["prompt"] = renamed_dict
|
||||
del res["prompt_config"]
|
||||
new_dict = {"similarity_threshold": res["similarity_threshold"],
|
||||
"keywords_similarity_weight": 1-res["vector_similarity_weight"],
|
||||
"top_n": res["top_n"],
|
||||
"rerank_model": res['rerank_id']}
|
||||
new_dict = {"similarity_threshold": res["similarity_threshold"], "keywords_similarity_weight": 1 - res["vector_similarity_weight"], "top_n": res["top_n"], "rerank_model": res["rerank_id"]}
|
||||
res["prompt"].update(new_dict)
|
||||
for key in key_list:
|
||||
del res[key]
|
||||
|
@ -13,36 +13,37 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime
|
||||
|
||||
from flask import request, session, redirect
|
||||
from werkzeug.security import generate_password_hash, check_password_hash
|
||||
from flask_login import login_required, current_user, login_user, logout_user
|
||||
from flask import redirect, request, session
|
||||
from flask_login import current_user, login_required, login_user, logout_user
|
||||
from werkzeug.security import check_password_hash, generate_password_hash
|
||||
|
||||
from api import settings
|
||||
from api.apps.auth import get_auth_client
|
||||
from api.db import FileType, UserTenantRole
|
||||
from api.db.db_models import TenantLLM
|
||||
from api.db.services.llm_service import TenantLLMService, LLMService
|
||||
from api.utils.api_utils import (
|
||||
server_error_response,
|
||||
validate_request,
|
||||
get_data_error_result,
|
||||
)
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.llm_service import LLMService, TenantLLMService
|
||||
from api.db.services.user_service import TenantService, UserService, UserTenantService
|
||||
from api.utils import (
|
||||
get_uuid,
|
||||
get_format_time,
|
||||
decrypt,
|
||||
download_img,
|
||||
current_timestamp,
|
||||
datetime_format,
|
||||
decrypt,
|
||||
download_img,
|
||||
get_format_time,
|
||||
get_uuid,
|
||||
)
|
||||
from api.utils.api_utils import (
|
||||
construct_response,
|
||||
get_data_error_result,
|
||||
get_json_result,
|
||||
server_error_response,
|
||||
validate_request,
|
||||
)
|
||||
from api.db import UserTenantRole, FileType
|
||||
from api import settings
|
||||
from api.db.services.user_service import UserService, TenantService, UserTenantService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.utils.api_utils import get_json_result, construct_response
|
||||
from api.apps.auth import get_auth_client
|
||||
|
||||
|
||||
@manager.route("/login", methods=["POST", "GET"]) # noqa: F821
|
||||
@ -77,9 +78,7 @@ def login():
|
||||
type: object
|
||||
"""
|
||||
if not request.json:
|
||||
return get_json_result(
|
||||
data=False, code=settings.RetCode.AUTHENTICATION_ERROR, message="Unauthorized!"
|
||||
)
|
||||
return get_json_result(data=False, code=settings.RetCode.AUTHENTICATION_ERROR, message="Unauthorized!")
|
||||
|
||||
email = request.json.get("email", "")
|
||||
users = UserService.query(email=email)
|
||||
@ -94,9 +93,7 @@ def login():
|
||||
try:
|
||||
password = decrypt(password)
|
||||
except BaseException:
|
||||
return get_json_result(
|
||||
data=False, code=settings.RetCode.SERVER_ERROR, message="Fail to crypt password"
|
||||
)
|
||||
return get_json_result(data=False, code=settings.RetCode.SERVER_ERROR, message="Fail to crypt password")
|
||||
|
||||
user = UserService.query_user(email, password)
|
||||
if user:
|
||||
@ -116,7 +113,7 @@ def login():
|
||||
)
|
||||
|
||||
|
||||
@manager.route("/login/channels", methods=["GET"]) # noqa: F821
|
||||
@manager.route("/login/channels", methods=["GET"]) # noqa: F821
|
||||
def get_login_channels():
|
||||
"""
|
||||
Get all supported authentication channels.
|
||||
@ -124,22 +121,20 @@ def get_login_channels():
|
||||
try:
|
||||
channels = []
|
||||
for channel, config in settings.OAUTH_CONFIG.items():
|
||||
channels.append({
|
||||
"channel": channel,
|
||||
"display_name": config.get("display_name", channel.title()),
|
||||
"icon": config.get("icon", "sso"),
|
||||
})
|
||||
channels.append(
|
||||
{
|
||||
"channel": channel,
|
||||
"display_name": config.get("display_name", channel.title()),
|
||||
"icon": config.get("icon", "sso"),
|
||||
}
|
||||
)
|
||||
return get_json_result(data=channels)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return get_json_result(
|
||||
data=[],
|
||||
message=f"Load channels failure, error: {str(e)}",
|
||||
code=settings.RetCode.EXCEPTION_ERROR
|
||||
)
|
||||
return get_json_result(data=[], message=f"Load channels failure, error: {str(e)}", code=settings.RetCode.EXCEPTION_ERROR)
|
||||
|
||||
|
||||
@manager.route("/login/<channel>", methods=["GET"]) # noqa: F821
|
||||
@manager.route("/login/<channel>", methods=["GET"]) # noqa: F821
|
||||
def oauth_login(channel):
|
||||
channel_config = settings.OAUTH_CONFIG.get(channel)
|
||||
if not channel_config:
|
||||
@ -152,7 +147,7 @@ def oauth_login(channel):
|
||||
return redirect(auth_url)
|
||||
|
||||
|
||||
@manager.route("/oauth/callback/<channel>", methods=["GET"]) # noqa: F821
|
||||
@manager.route("/oauth/callback/<channel>", methods=["GET"]) # noqa: F821
|
||||
def oauth_callback(channel):
|
||||
"""
|
||||
Handle the OAuth/OIDC callback for various channels dynamically.
|
||||
@ -190,7 +185,7 @@ def oauth_callback(channel):
|
||||
# Login or register
|
||||
users = UserService.query(email=user_info.email)
|
||||
user_id = get_uuid()
|
||||
|
||||
|
||||
if not users:
|
||||
try:
|
||||
try:
|
||||
@ -434,9 +429,7 @@ def user_info_from_feishu(access_token):
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
res = requests.get(
|
||||
"https://open.feishu.cn/open-apis/authen/v1/user_info", headers=headers
|
||||
)
|
||||
res = requests.get("https://open.feishu.cn/open-apis/authen/v1/user_info", headers=headers)
|
||||
user_info = res.json()["data"]
|
||||
user_info["email"] = None if user_info.get("email") == "" else user_info["email"]
|
||||
return user_info
|
||||
@ -446,17 +439,13 @@ def user_info_from_github(access_token):
|
||||
import requests
|
||||
|
||||
headers = {"Accept": "application/json", "Authorization": f"token {access_token}"}
|
||||
res = requests.get(
|
||||
f"https://api.github.com/user?access_token={access_token}", headers=headers
|
||||
)
|
||||
res = requests.get(f"https://api.github.com/user?access_token={access_token}", headers=headers)
|
||||
user_info = res.json()
|
||||
email_info = requests.get(
|
||||
f"https://api.github.com/user/emails?access_token={access_token}",
|
||||
headers=headers,
|
||||
).json()
|
||||
user_info["email"] = next(
|
||||
(email for email in email_info if email["primary"]), None
|
||||
)["email"]
|
||||
user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"]
|
||||
return user_info
|
||||
|
||||
|
||||
@ -516,9 +505,7 @@ def setting_user():
|
||||
request_data = request.json
|
||||
if request_data.get("password"):
|
||||
new_password = request_data.get("new_password")
|
||||
if not check_password_hash(
|
||||
current_user.password, decrypt(request_data["password"])
|
||||
):
|
||||
if not check_password_hash(current_user.password, decrypt(request_data["password"])):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR,
|
||||
@ -549,9 +536,7 @@ def setting_user():
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return get_json_result(
|
||||
data=False, message="Update failure!", code=settings.RetCode.EXCEPTION_ERROR
|
||||
)
|
||||
return get_json_result(data=False, message="Update failure!", code=settings.RetCode.EXCEPTION_ERROR)
|
||||
|
||||
|
||||
@manager.route("/info", methods=["GET"]) # noqa: F821
|
||||
@ -643,9 +628,23 @@ def user_register(user_id, user):
|
||||
"model_type": llm.model_type,
|
||||
"api_key": settings.API_KEY,
|
||||
"api_base": settings.LLM_BASE_URL,
|
||||
"max_tokens": llm.max_tokens if llm.max_tokens else 8192
|
||||
"max_tokens": llm.max_tokens if llm.max_tokens else 8192,
|
||||
}
|
||||
)
|
||||
if settings.LIGHTEN != 1:
|
||||
for buildin_embedding_model in settings.BUILTIN_EMBEDDING_MODELS:
|
||||
mdlnm, fid = TenantLLMService.split_model_name_and_factory(buildin_embedding_model)
|
||||
tenant_llm.append(
|
||||
{
|
||||
"tenant_id": user_id,
|
||||
"llm_factory": fid,
|
||||
"llm_name": mdlnm,
|
||||
"model_type": "embedding",
|
||||
"api_key": "",
|
||||
"api_base": "",
|
||||
"max_tokens": 1024 if buildin_embedding_model == "BAAI/bge-large-zh-v1.5@BAAI" else 512,
|
||||
}
|
||||
)
|
||||
|
||||
if not UserService.save(**user):
|
||||
return
|
||||
|
@ -81,7 +81,7 @@ def init_settings():
|
||||
DATABASE = decrypt_database_config(name=DATABASE_TYPE)
|
||||
LLM = get_base_config("user_default_llm", {})
|
||||
LLM_DEFAULT_MODELS = LLM.get("default_models", {})
|
||||
LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen")
|
||||
LLM_FACTORY = LLM.get("factory")
|
||||
LLM_BASE_URL = LLM.get("base_url")
|
||||
try:
|
||||
REGISTER_ENABLED = int(os.environ.get("REGISTER_ENABLED", "1"))
|
||||
|
@ -567,7 +567,7 @@
|
||||
{
|
||||
"name": "Youdao",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
||||
"tags": "TEXT EMBEDDING",
|
||||
"status": "1",
|
||||
"llm": [
|
||||
{
|
||||
@ -755,7 +755,7 @@
|
||||
{
|
||||
"name": "BAAI",
|
||||
"logo": "",
|
||||
"tags": "TEXT EMBEDDING, TEXT RE-RANK",
|
||||
"tags": "TEXT EMBEDDING",
|
||||
"status": "1",
|
||||
"llm": [
|
||||
{
|
||||
|
@ -20,7 +20,9 @@ import pytest
|
||||
import requests
|
||||
|
||||
HOST_ADDRESS = os.getenv("HOST_ADDRESS", "http://127.0.0.1:9380")
|
||||
|
||||
ZHIPU_AI_API_KEY = os.getenv("ZHIPU_AI_API_KEY", "ca148e43209c40109e2bc2f56281dd11.BltyA2N1B043B7Ra")
|
||||
if ZHIPU_AI_API_KEY is None:
|
||||
pytest.exit("Error: Environment variable ZHIPU_AI_API_KEY must be set")
|
||||
|
||||
# def generate_random_email():
|
||||
# return 'user_' + ''.join(random.choices(string.ascii_lowercase + string.digits, k=8))+'@1.com'
|
||||
@ -87,3 +89,64 @@ def get_auth():
|
||||
@pytest.fixture(scope="session")
|
||||
def get_email():
|
||||
return EMAIL
|
||||
|
||||
|
||||
def get_my_llms(auth, name):
|
||||
url = HOST_ADDRESS + "/v1/llm/my_llms"
|
||||
authorization = {"Authorization": auth}
|
||||
response = requests.get(url=url, headers=authorization)
|
||||
res = response.json()
|
||||
if res.get("code") != 0:
|
||||
raise Exception(res.get("message"))
|
||||
if name in res.get("data"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def add_models(auth):
|
||||
url = HOST_ADDRESS + "/v1/llm/set_api_key"
|
||||
authorization = {"Authorization": auth}
|
||||
models_info = {
|
||||
"ZHIPU-AI": {"llm_factory": "ZHIPU-AI", "api_key": ZHIPU_AI_API_KEY},
|
||||
}
|
||||
|
||||
for name, model_info in models_info.items():
|
||||
if not get_my_llms(auth, name):
|
||||
response = requests.post(url=url, headers=authorization, json=model_info)
|
||||
res = response.json()
|
||||
if res.get("code") != 0:
|
||||
pytest.exit(f"Critical error in add_models: {res.get('message')}")
|
||||
|
||||
|
||||
def get_tenant_info(auth):
|
||||
url = HOST_ADDRESS + "/v1/user/tenant_info"
|
||||
authorization = {"Authorization": auth}
|
||||
response = requests.get(url=url, headers=authorization)
|
||||
res = response.json()
|
||||
if res.get("code") != 0:
|
||||
raise Exception(res.get("message"))
|
||||
return res["data"].get("tenant_id")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def set_tenant_info(get_auth):
|
||||
auth = get_auth
|
||||
try:
|
||||
add_models(auth)
|
||||
tenant_id = get_tenant_info(auth)
|
||||
except Exception as e:
|
||||
pytest.exit(f"Error in set_tenant_info: {str(e)}")
|
||||
url = HOST_ADDRESS + "/v1/user/set_tenant_info"
|
||||
authorization = {"Authorization": get_auth}
|
||||
tenant_info = {
|
||||
"tenant_id": tenant_id,
|
||||
"llm_id": "glm-4-flash@ZHIPU-AI",
|
||||
"embd_id": "BAAI/bge-large-zh-v1.5@BAAI",
|
||||
"img2txt_id": "glm-4v@ZHIPU-AI",
|
||||
"asr_id": "",
|
||||
"tts_id": None,
|
||||
}
|
||||
response = requests.post(url=url, headers=authorization, json=tenant_info)
|
||||
res = response.json()
|
||||
if res.get("code") != 0:
|
||||
raise Exception(res.get("message"))
|
||||
|
@ -16,7 +16,6 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from common import (
|
||||
add_chunk,
|
||||
batch_create_datasets,
|
||||
@ -49,9 +48,6 @@ MARKER_EXPRESSIONS = {
|
||||
"p3": "p1 or p2 or p3",
|
||||
}
|
||||
HOST_ADDRESS = os.getenv("HOST_ADDRESS", "http://127.0.0.1:9380")
|
||||
ZHIPU_AI_API_KEY = os.getenv("ZHIPU_AI_API_KEY", "ca148e43209c40109e2bc2f56281dd11.BltyA2N1B043B7Ra")
|
||||
if ZHIPU_AI_API_KEY is None:
|
||||
pytest.exit("Error: Environment variable ZHIPU_AI_API_KEY must be set")
|
||||
|
||||
|
||||
def pytest_addoption(parser: pytest.Parser) -> None:
|
||||
@ -85,67 +81,6 @@ def get_http_api_auth(get_api_key_fixture):
|
||||
return RAGFlowHttpApiAuth(get_api_key_fixture)
|
||||
|
||||
|
||||
def get_my_llms(auth, name):
|
||||
url = HOST_ADDRESS + "/v1/llm/my_llms"
|
||||
authorization = {"Authorization": auth}
|
||||
response = requests.get(url=url, headers=authorization)
|
||||
res = response.json()
|
||||
if res.get("code") != 0:
|
||||
raise Exception(res.get("message"))
|
||||
if name in res.get("data"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def add_models(auth):
|
||||
url = HOST_ADDRESS + "/v1/llm/set_api_key"
|
||||
authorization = {"Authorization": auth}
|
||||
models_info = {
|
||||
"ZHIPU-AI": {"llm_factory": "ZHIPU-AI", "api_key": ZHIPU_AI_API_KEY},
|
||||
}
|
||||
|
||||
for name, model_info in models_info.items():
|
||||
if not get_my_llms(auth, name):
|
||||
response = requests.post(url=url, headers=authorization, json=model_info)
|
||||
res = response.json()
|
||||
if res.get("code") != 0:
|
||||
pytest.exit(f"Critical error in add_models: {res.get('message')}")
|
||||
|
||||
|
||||
def get_tenant_info(auth):
|
||||
url = HOST_ADDRESS + "/v1/user/tenant_info"
|
||||
authorization = {"Authorization": auth}
|
||||
response = requests.get(url=url, headers=authorization)
|
||||
res = response.json()
|
||||
if res.get("code") != 0:
|
||||
raise Exception(res.get("message"))
|
||||
return res["data"].get("tenant_id")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def set_tenant_info(get_auth):
|
||||
auth = get_auth
|
||||
try:
|
||||
add_models(auth)
|
||||
tenant_id = get_tenant_info(auth)
|
||||
except Exception as e:
|
||||
pytest.exit(f"Error in set_tenant_info: {str(e)}")
|
||||
url = HOST_ADDRESS + "/v1/user/set_tenant_info"
|
||||
authorization = {"Authorization": get_auth}
|
||||
tenant_info = {
|
||||
"tenant_id": tenant_id,
|
||||
"llm_id": "glm-4-flash@ZHIPU-AI",
|
||||
"embd_id": "BAAI/bge-large-zh-v1.5@BAAI",
|
||||
"img2txt_id": "glm-4v@ZHIPU-AI",
|
||||
"asr_id": "",
|
||||
"tts_id": None,
|
||||
}
|
||||
response = requests.post(url=url, headers=authorization, json=tenant_info)
|
||||
res = response.json()
|
||||
if res.get("code") != 0:
|
||||
raise Exception(res.get("message"))
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def clear_datasets(request, get_http_api_auth):
|
||||
def cleanup():
|
||||
|
@ -14,8 +14,9 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from ragflow_sdk import RAGFlow
|
||||
from common import HOST_ADDRESS
|
||||
from ragflow_sdk import RAGFlow
|
||||
from ragflow_sdk.modules.chat import Chat
|
||||
|
||||
|
||||
def test_create_chat_with_name(get_api_key_fixture):
|
||||
@ -31,7 +32,18 @@ def test_create_chat_with_name(get_api_key_fixture):
|
||||
docs = kb.upload_documents(documents)
|
||||
for doc in docs:
|
||||
doc.add_chunk("This is a test to add chunk")
|
||||
rag.create_chat("test_create_chat", dataset_ids=[kb.id])
|
||||
llm = Chat.LLM(
|
||||
rag,
|
||||
{
|
||||
"model_name": "glm-4-flash@ZHIPU-AI",
|
||||
"temperature": 0.1,
|
||||
"top_p": 0.3,
|
||||
"presence_penalty": 0.4,
|
||||
"frequency_penalty": 0.7,
|
||||
"max_tokens": 512,
|
||||
},
|
||||
)
|
||||
rag.create_chat("test_create_chat", dataset_ids=[kb.id], llm=llm)
|
||||
|
||||
|
||||
def test_update_chat_with_name(get_api_key_fixture):
|
||||
@ -47,7 +59,18 @@ def test_update_chat_with_name(get_api_key_fixture):
|
||||
docs = kb.upload_documents(documents)
|
||||
for doc in docs:
|
||||
doc.add_chunk("This is a test to add chunk")
|
||||
chat = rag.create_chat("test_update_chat", dataset_ids=[kb.id])
|
||||
llm = Chat.LLM(
|
||||
rag,
|
||||
{
|
||||
"model_name": "glm-4-flash@ZHIPU-AI",
|
||||
"temperature": 0.1,
|
||||
"top_p": 0.3,
|
||||
"presence_penalty": 0.4,
|
||||
"frequency_penalty": 0.7,
|
||||
"max_tokens": 512,
|
||||
},
|
||||
)
|
||||
chat = rag.create_chat("test_update_chat", dataset_ids=[kb.id], llm=llm)
|
||||
chat.update({"name": "new_chat"})
|
||||
|
||||
|
||||
@ -64,7 +87,18 @@ def test_delete_chats_with_success(get_api_key_fixture):
|
||||
docs = kb.upload_documents(documents)
|
||||
for doc in docs:
|
||||
doc.add_chunk("This is a test to add chunk")
|
||||
chat = rag.create_chat("test_delete_chat", dataset_ids=[kb.id])
|
||||
llm = Chat.LLM(
|
||||
rag,
|
||||
{
|
||||
"model_name": "glm-4-flash@ZHIPU-AI",
|
||||
"temperature": 0.1,
|
||||
"top_p": 0.3,
|
||||
"presence_penalty": 0.4,
|
||||
"frequency_penalty": 0.7,
|
||||
"max_tokens": 512,
|
||||
},
|
||||
)
|
||||
chat = rag.create_chat("test_delete_chat", dataset_ids=[kb.id], llm=llm)
|
||||
rag.delete_chats(ids=[chat.id])
|
||||
|
||||
|
||||
@ -81,6 +115,17 @@ def test_list_chats_with_success(get_api_key_fixture):
|
||||
docs = kb.upload_documents(documents)
|
||||
for doc in docs:
|
||||
doc.add_chunk("This is a test to add chunk")
|
||||
rag.create_chat("test_list_1", dataset_ids=[kb.id])
|
||||
rag.create_chat("test_list_2", dataset_ids=[kb.id])
|
||||
llm = Chat.LLM(
|
||||
rag,
|
||||
{
|
||||
"model_name": "glm-4-flash@ZHIPU-AI",
|
||||
"temperature": 0.1,
|
||||
"top_p": 0.3,
|
||||
"presence_penalty": 0.4,
|
||||
"frequency_penalty": 0.7,
|
||||
"max_tokens": 512,
|
||||
},
|
||||
)
|
||||
rag.create_chat("test_list_1", dataset_ids=[kb.id], llm=llm)
|
||||
rag.create_chat("test_list_2", dataset_ids=[kb.id], llm=llm)
|
||||
rag.list_chats()
|
||||
|
Loading…
x
Reference in New Issue
Block a user