diff --git a/api/apps/sdk/chat.py b/api/apps/sdk/chat.py index 18e1cf83f..24d89db85 100644 --- a/api/apps/sdk/chat.py +++ b/api/apps/sdk/chat.py @@ -16,6 +16,7 @@ import logging from flask import request + from api import settings from api.db import StatusEnum from api.db.services.dialog_service import DialogService @@ -23,15 +24,14 @@ from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import TenantLLMService from api.db.services.user_service import TenantService from api.utils import get_uuid -from api.utils.api_utils import get_error_data_result, token_required, get_result, check_duplicate_ids +from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required - -@manager.route('/chats', methods=['POST']) # noqa: F821 +@manager.route("/chats", methods=["POST"]) # noqa: F821 @token_required def create(tenant_id): req = request.json - ids = [i for i in req.get("dataset_ids", []) if i] + ids = [i for i in req.get("dataset_ids", []) if i] for kb_id in ids: kbs = KnowledgebaseService.accessible(kb_id=kb_id, user_id=tenant_id) if not kbs: @@ -40,34 +40,30 @@ def create(tenant_id): kb = kbs[0] if kb.chunk_num == 0: return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file") - + kbs = KnowledgebaseService.get_by_ids(ids) if ids else [] embd_ids = [TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs] # remove vendor suffix for comparison embd_count = list(set(embd_ids)) if len(embd_count) > 1: - return get_result(message='Datasets use different embedding models."', - code=settings.RetCode.AUTHENTICATION_ERROR) + return get_result(message='Datasets use different embedding models."', code=settings.RetCode.AUTHENTICATION_ERROR) req["kb_ids"] = ids # llm llm = req.get("llm") if llm: if "model_name" in llm: req["llm_id"] = llm.pop("model_name") - if not TenantLLMService.query(tenant_id=tenant_id, llm_name=req["llm_id"], model_type="chat"): - return get_error_data_result(f"`model_name` {req.get('llm_id')} doesn't exist") + if req.get("llm_id") is not None: + llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(req["llm_id"]) + if not TenantLLMService.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, model_type="chat"): + return get_error_data_result(f"`model_name` {req.get('llm_id')} doesn't exist") req["llm_setting"] = req.pop("llm") e, tenant = TenantService.get_by_id(tenant_id) if not e: return get_error_data_result(message="Tenant not found!") # prompt prompt = req.get("prompt") - key_mapping = {"parameters": "variables", - "prologue": "opener", - "quote": "show_quote", - "system": "prompt", - "rerank_id": "rerank_model", - "vector_similarity_weight": "keywords_similarity_weight"} - key_list = ["similarity_threshold", "vector_similarity_weight", "top_n", "rerank_id","top_k"] + key_mapping = {"parameters": "variables", "prologue": "opener", "quote": "show_quote", "system": "prompt", "rerank_id": "rerank_model", "vector_similarity_weight": "keywords_similarity_weight"} + key_list = ["similarity_threshold", "vector_similarity_weight", "top_n", "rerank_id", "top_k"] if prompt: for new_key, old_key in key_mapping.items(): if old_key in prompt: @@ -85,9 +81,7 @@ def create(tenant_id): req["rerank_id"] = req.get("rerank_id", "") if req.get("rerank_id"): value_rerank_model = ["BAAI/bge-reranker-v2-m3", "maidalun1020/bce-reranker-base_v1"] - if req["rerank_id"] not in value_rerank_model and not TenantLLMService.query(tenant_id=tenant_id, - llm_name=req.get("rerank_id"), - model_type="rerank"): + if req["rerank_id"] not in value_rerank_model and not TenantLLMService.query(tenant_id=tenant_id, llm_name=req.get("rerank_id"), model_type="rerank"): return get_error_data_result(f"`rerank_model` {req.get('rerank_id')} doesn't exist") if not req.get("llm_id"): req["llm_id"] = tenant.llm_id @@ -106,27 +100,24 @@ def create(tenant_id): {knowledge} The above is the knowledge base.""", "prologue": "Hi! I'm your assistant, what can I do for you?", - "parameters": [ - {"key": "knowledge", "optional": False} - ], + "parameters": [{"key": "knowledge", "optional": False}], "empty_response": "Sorry! No relevant content was found in the knowledge base!", "quote": True, "tts": False, - "refine_multiturn": True + "refine_multiturn": True, } key_list_2 = ["system", "prologue", "parameters", "empty_response", "quote", "tts", "refine_multiturn"] if "prompt_config" not in req: - req['prompt_config'] = {} + req["prompt_config"] = {} for key in key_list_2: - temp = req['prompt_config'].get(key) - if (not temp and key == 'system') or (key not in req["prompt_config"]): - req['prompt_config'][key] = default_prompt[key] - for p in req['prompt_config']["parameters"]: + temp = req["prompt_config"].get(key) + if (not temp and key == "system") or (key not in req["prompt_config"]): + req["prompt_config"][key] = default_prompt[key] + for p in req["prompt_config"]["parameters"]: if p["optional"]: continue - if req['prompt_config']["system"].find("{%s}" % p["key"]) < 0: - return get_error_data_result( - message="Parameter '{}' is not used".format(p["key"])) + if req["prompt_config"]["system"].find("{%s}" % p["key"]) < 0: + return get_error_data_result(message="Parameter '{}' is not used".format(p["key"])) # save if not DialogService.save(**req): return get_error_data_result(message="Fail to new a chat!") @@ -141,10 +132,7 @@ def create(tenant_id): renamed_dict[new_key] = value res["prompt"] = renamed_dict del res["prompt_config"] - new_dict = {"similarity_threshold": res["similarity_threshold"], - "keywords_similarity_weight": 1-res["vector_similarity_weight"], - "top_n": res["top_n"], - "rerank_model": res['rerank_id']} + new_dict = {"similarity_threshold": res["similarity_threshold"], "keywords_similarity_weight": 1 - res["vector_similarity_weight"], "top_n": res["top_n"], "rerank_model": res["rerank_id"]} res["prompt"].update(new_dict) for key in key_list: del res[key] @@ -156,11 +144,11 @@ def create(tenant_id): return get_result(data=res) -@manager.route('/chats/', methods=['PUT']) # noqa: F821 +@manager.route("/chats/", methods=["PUT"]) # noqa: F821 @token_required def update(tenant_id, chat_id): if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value): - return get_error_data_result(message='You do not own the chat') + return get_error_data_result(message="You do not own the chat") req = request.json ids = req.get("dataset_ids") if "show_quotation" in req: @@ -174,14 +162,12 @@ def update(tenant_id, chat_id): kb = kbs[0] if kb.chunk_num == 0: return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file") - + kbs = KnowledgebaseService.get_by_ids(ids) embd_ids = [TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs] # remove vendor suffix for comparison embd_count = list(set(embd_ids)) if len(embd_count) != 1: - return get_result( - message='Datasets use different embedding models."', - code=settings.RetCode.AUTHENTICATION_ERROR) + return get_result(message='Datasets use different embedding models."', code=settings.RetCode.AUTHENTICATION_ERROR) req["kb_ids"] = ids llm = req.get("llm") if llm: @@ -195,13 +181,8 @@ def update(tenant_id, chat_id): return get_error_data_result(message="Tenant not found!") # prompt prompt = req.get("prompt") - key_mapping = {"parameters": "variables", - "prologue": "opener", - "quote": "show_quote", - "system": "prompt", - "rerank_id": "rerank_model", - "vector_similarity_weight": "keywords_similarity_weight"} - key_list = ["similarity_threshold", "vector_similarity_weight", "top_n", "rerank_id","top_k"] + key_mapping = {"parameters": "variables", "prologue": "opener", "quote": "show_quote", "system": "prompt", "rerank_id": "rerank_model", "vector_similarity_weight": "keywords_similarity_weight"} + key_list = ["similarity_threshold", "vector_similarity_weight", "top_n", "rerank_id", "top_k"] if prompt: for new_key, old_key in key_mapping.items(): if old_key in prompt: @@ -214,16 +195,12 @@ def update(tenant_id, chat_id): res = res.to_json() if req.get("rerank_id"): value_rerank_model = ["BAAI/bge-reranker-v2-m3", "maidalun1020/bce-reranker-base_v1"] - if req["rerank_id"] not in value_rerank_model and not TenantLLMService.query(tenant_id=tenant_id, - llm_name=req.get("rerank_id"), - model_type="rerank"): + if req["rerank_id"] not in value_rerank_model and not TenantLLMService.query(tenant_id=tenant_id, llm_name=req.get("rerank_id"), model_type="rerank"): return get_error_data_result(f"`rerank_model` {req.get('rerank_id')} doesn't exist") if "name" in req: if not req.get("name"): return get_error_data_result(message="`name` cannot be empty.") - if req["name"].lower() != res["name"].lower() \ - and len( - DialogService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)) > 0: + if req["name"].lower() != res["name"].lower() and len(DialogService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)) > 0: return get_error_data_result(message="Duplicated chat name in updating chat.") if "prompt_config" in req: res["prompt_config"].update(req["prompt_config"]) @@ -246,7 +223,7 @@ def update(tenant_id, chat_id): return get_result() -@manager.route('/chats', methods=['DELETE']) # noqa: F821 +@manager.route("/chats", methods=["DELETE"]) # noqa: F821 @token_required def delete(tenant_id): errors = [] @@ -273,30 +250,23 @@ def delete(tenant_id): temp_dict = {"status": StatusEnum.INVALID.value} DialogService.update_by_id(id, temp_dict) success_count += 1 - + if errors: if success_count > 0: - return get_result( - data={"success_count": success_count, "errors": errors}, - message=f"Partially deleted {success_count} chats with {len(errors)} errors" - ) + return get_result(data={"success_count": success_count, "errors": errors}, message=f"Partially deleted {success_count} chats with {len(errors)} errors") else: return get_error_data_result(message="; ".join(errors)) - + if duplicate_messages: if success_count > 0: - return get_result( - message=f"Partially deleted {success_count} chats with {len(duplicate_messages)} errors", - data={"success_count": success_count, "errors": duplicate_messages} - ) + return get_result(message=f"Partially deleted {success_count} chats with {len(duplicate_messages)} errors", data={"success_count": success_count, "errors": duplicate_messages}) else: return get_error_data_result(message=";".join(duplicate_messages)) - + return get_result() - -@manager.route('/chats', methods=['GET']) # noqa: F821 +@manager.route("/chats", methods=["GET"]) # noqa: F821 @token_required def list_chat(tenant_id): id = request.args.get("id") @@ -316,13 +286,15 @@ def list_chat(tenant_id): if not chats: return get_result(data=[]) list_assts = [] - key_mapping = {"parameters": "variables", - "prologue": "opener", - "quote": "show_quote", - "system": "prompt", - "rerank_id": "rerank_model", - "vector_similarity_weight": "keywords_similarity_weight", - "do_refer": "show_quotation"} + key_mapping = { + "parameters": "variables", + "prologue": "opener", + "quote": "show_quote", + "system": "prompt", + "rerank_id": "rerank_model", + "vector_similarity_weight": "keywords_similarity_weight", + "do_refer": "show_quotation", + } key_list = ["similarity_threshold", "vector_similarity_weight", "top_n", "rerank_id"] for res in chats: renamed_dict = {} @@ -331,10 +303,7 @@ def list_chat(tenant_id): renamed_dict[new_key] = value res["prompt"] = renamed_dict del res["prompt_config"] - new_dict = {"similarity_threshold": res["similarity_threshold"], - "keywords_similarity_weight": 1-res["vector_similarity_weight"], - "top_n": res["top_n"], - "rerank_model": res['rerank_id']} + new_dict = {"similarity_threshold": res["similarity_threshold"], "keywords_similarity_weight": 1 - res["vector_similarity_weight"], "top_n": res["top_n"], "rerank_model": res["rerank_id"]} res["prompt"].update(new_dict) for key in key_list: del res[key] diff --git a/api/apps/user_app.py b/api/apps/user_app.py index 401d6977d..597f50971 100644 --- a/api/apps/user_app.py +++ b/api/apps/user_app.py @@ -13,36 +13,37 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import logging import json +import logging import re from datetime import datetime -from flask import request, session, redirect -from werkzeug.security import generate_password_hash, check_password_hash -from flask_login import login_required, current_user, login_user, logout_user +from flask import redirect, request, session +from flask_login import current_user, login_required, login_user, logout_user +from werkzeug.security import check_password_hash, generate_password_hash +from api import settings +from api.apps.auth import get_auth_client +from api.db import FileType, UserTenantRole from api.db.db_models import TenantLLM -from api.db.services.llm_service import TenantLLMService, LLMService -from api.utils.api_utils import ( - server_error_response, - validate_request, - get_data_error_result, -) +from api.db.services.file_service import FileService +from api.db.services.llm_service import LLMService, TenantLLMService +from api.db.services.user_service import TenantService, UserService, UserTenantService from api.utils import ( - get_uuid, - get_format_time, - decrypt, - download_img, current_timestamp, datetime_format, + decrypt, + download_img, + get_format_time, + get_uuid, +) +from api.utils.api_utils import ( + construct_response, + get_data_error_result, + get_json_result, + server_error_response, + validate_request, ) -from api.db import UserTenantRole, FileType -from api import settings -from api.db.services.user_service import UserService, TenantService, UserTenantService -from api.db.services.file_service import FileService -from api.utils.api_utils import get_json_result, construct_response -from api.apps.auth import get_auth_client @manager.route("/login", methods=["POST", "GET"]) # noqa: F821 @@ -77,9 +78,7 @@ def login(): type: object """ if not request.json: - return get_json_result( - data=False, code=settings.RetCode.AUTHENTICATION_ERROR, message="Unauthorized!" - ) + return get_json_result(data=False, code=settings.RetCode.AUTHENTICATION_ERROR, message="Unauthorized!") email = request.json.get("email", "") users = UserService.query(email=email) @@ -94,9 +93,7 @@ def login(): try: password = decrypt(password) except BaseException: - return get_json_result( - data=False, code=settings.RetCode.SERVER_ERROR, message="Fail to crypt password" - ) + return get_json_result(data=False, code=settings.RetCode.SERVER_ERROR, message="Fail to crypt password") user = UserService.query_user(email, password) if user: @@ -116,7 +113,7 @@ def login(): ) -@manager.route("/login/channels", methods=["GET"]) # noqa: F821 +@manager.route("/login/channels", methods=["GET"]) # noqa: F821 def get_login_channels(): """ Get all supported authentication channels. @@ -124,22 +121,20 @@ def get_login_channels(): try: channels = [] for channel, config in settings.OAUTH_CONFIG.items(): - channels.append({ - "channel": channel, - "display_name": config.get("display_name", channel.title()), - "icon": config.get("icon", "sso"), - }) + channels.append( + { + "channel": channel, + "display_name": config.get("display_name", channel.title()), + "icon": config.get("icon", "sso"), + } + ) return get_json_result(data=channels) except Exception as e: logging.exception(e) - return get_json_result( - data=[], - message=f"Load channels failure, error: {str(e)}", - code=settings.RetCode.EXCEPTION_ERROR - ) + return get_json_result(data=[], message=f"Load channels failure, error: {str(e)}", code=settings.RetCode.EXCEPTION_ERROR) -@manager.route("/login/", methods=["GET"]) # noqa: F821 +@manager.route("/login/", methods=["GET"]) # noqa: F821 def oauth_login(channel): channel_config = settings.OAUTH_CONFIG.get(channel) if not channel_config: @@ -152,7 +147,7 @@ def oauth_login(channel): return redirect(auth_url) -@manager.route("/oauth/callback/", methods=["GET"]) # noqa: F821 +@manager.route("/oauth/callback/", methods=["GET"]) # noqa: F821 def oauth_callback(channel): """ Handle the OAuth/OIDC callback for various channels dynamically. @@ -190,7 +185,7 @@ def oauth_callback(channel): # Login or register users = UserService.query(email=user_info.email) user_id = get_uuid() - + if not users: try: try: @@ -434,9 +429,7 @@ def user_info_from_feishu(access_token): "Content-Type": "application/json; charset=utf-8", "Authorization": f"Bearer {access_token}", } - res = requests.get( - "https://open.feishu.cn/open-apis/authen/v1/user_info", headers=headers - ) + res = requests.get("https://open.feishu.cn/open-apis/authen/v1/user_info", headers=headers) user_info = res.json()["data"] user_info["email"] = None if user_info.get("email") == "" else user_info["email"] return user_info @@ -446,17 +439,13 @@ def user_info_from_github(access_token): import requests headers = {"Accept": "application/json", "Authorization": f"token {access_token}"} - res = requests.get( - f"https://api.github.com/user?access_token={access_token}", headers=headers - ) + res = requests.get(f"https://api.github.com/user?access_token={access_token}", headers=headers) user_info = res.json() email_info = requests.get( f"https://api.github.com/user/emails?access_token={access_token}", headers=headers, ).json() - user_info["email"] = next( - (email for email in email_info if email["primary"]), None - )["email"] + user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"] return user_info @@ -516,9 +505,7 @@ def setting_user(): request_data = request.json if request_data.get("password"): new_password = request_data.get("new_password") - if not check_password_hash( - current_user.password, decrypt(request_data["password"]) - ): + if not check_password_hash(current_user.password, decrypt(request_data["password"])): return get_json_result( data=False, code=settings.RetCode.AUTHENTICATION_ERROR, @@ -549,9 +536,7 @@ def setting_user(): return get_json_result(data=True) except Exception as e: logging.exception(e) - return get_json_result( - data=False, message="Update failure!", code=settings.RetCode.EXCEPTION_ERROR - ) + return get_json_result(data=False, message="Update failure!", code=settings.RetCode.EXCEPTION_ERROR) @manager.route("/info", methods=["GET"]) # noqa: F821 @@ -643,9 +628,23 @@ def user_register(user_id, user): "model_type": llm.model_type, "api_key": settings.API_KEY, "api_base": settings.LLM_BASE_URL, - "max_tokens": llm.max_tokens if llm.max_tokens else 8192 + "max_tokens": llm.max_tokens if llm.max_tokens else 8192, } ) + if settings.LIGHTEN != 1: + for buildin_embedding_model in settings.BUILTIN_EMBEDDING_MODELS: + mdlnm, fid = TenantLLMService.split_model_name_and_factory(buildin_embedding_model) + tenant_llm.append( + { + "tenant_id": user_id, + "llm_factory": fid, + "llm_name": mdlnm, + "model_type": "embedding", + "api_key": "", + "api_base": "", + "max_tokens": 1024 if buildin_embedding_model == "BAAI/bge-large-zh-v1.5@BAAI" else 512, + } + ) if not UserService.save(**user): return diff --git a/api/settings.py b/api/settings.py index d4ce48079..2d743f904 100644 --- a/api/settings.py +++ b/api/settings.py @@ -81,7 +81,7 @@ def init_settings(): DATABASE = decrypt_database_config(name=DATABASE_TYPE) LLM = get_base_config("user_default_llm", {}) LLM_DEFAULT_MODELS = LLM.get("default_models", {}) - LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen") + LLM_FACTORY = LLM.get("factory") LLM_BASE_URL = LLM.get("base_url") try: REGISTER_ENABLED = int(os.environ.get("REGISTER_ENABLED", "1")) diff --git a/conf/llm_factories.json b/conf/llm_factories.json index 6bc9ae8bd..916ccbad0 100644 --- a/conf/llm_factories.json +++ b/conf/llm_factories.json @@ -567,7 +567,7 @@ { "name": "Youdao", "logo": "", - "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", + "tags": "TEXT EMBEDDING", "status": "1", "llm": [ { @@ -755,7 +755,7 @@ { "name": "BAAI", "logo": "", - "tags": "TEXT EMBEDDING, TEXT RE-RANK", + "tags": "TEXT EMBEDDING", "status": "1", "llm": [ { diff --git a/sdk/python/test/conftest.py b/sdk/python/test/conftest.py index 9cb09bad2..6eae2c2de 100644 --- a/sdk/python/test/conftest.py +++ b/sdk/python/test/conftest.py @@ -20,7 +20,9 @@ import pytest import requests HOST_ADDRESS = os.getenv("HOST_ADDRESS", "http://127.0.0.1:9380") - +ZHIPU_AI_API_KEY = os.getenv("ZHIPU_AI_API_KEY", "ca148e43209c40109e2bc2f56281dd11.BltyA2N1B043B7Ra") +if ZHIPU_AI_API_KEY is None: + pytest.exit("Error: Environment variable ZHIPU_AI_API_KEY must be set") # def generate_random_email(): # return 'user_' + ''.join(random.choices(string.ascii_lowercase + string.digits, k=8))+'@1.com' @@ -87,3 +89,64 @@ def get_auth(): @pytest.fixture(scope="session") def get_email(): return EMAIL + + +def get_my_llms(auth, name): + url = HOST_ADDRESS + "/v1/llm/my_llms" + authorization = {"Authorization": auth} + response = requests.get(url=url, headers=authorization) + res = response.json() + if res.get("code") != 0: + raise Exception(res.get("message")) + if name in res.get("data"): + return True + return False + + +def add_models(auth): + url = HOST_ADDRESS + "/v1/llm/set_api_key" + authorization = {"Authorization": auth} + models_info = { + "ZHIPU-AI": {"llm_factory": "ZHIPU-AI", "api_key": ZHIPU_AI_API_KEY}, + } + + for name, model_info in models_info.items(): + if not get_my_llms(auth, name): + response = requests.post(url=url, headers=authorization, json=model_info) + res = response.json() + if res.get("code") != 0: + pytest.exit(f"Critical error in add_models: {res.get('message')}") + + +def get_tenant_info(auth): + url = HOST_ADDRESS + "/v1/user/tenant_info" + authorization = {"Authorization": auth} + response = requests.get(url=url, headers=authorization) + res = response.json() + if res.get("code") != 0: + raise Exception(res.get("message")) + return res["data"].get("tenant_id") + + +@pytest.fixture(scope="session", autouse=True) +def set_tenant_info(get_auth): + auth = get_auth + try: + add_models(auth) + tenant_id = get_tenant_info(auth) + except Exception as e: + pytest.exit(f"Error in set_tenant_info: {str(e)}") + url = HOST_ADDRESS + "/v1/user/set_tenant_info" + authorization = {"Authorization": get_auth} + tenant_info = { + "tenant_id": tenant_id, + "llm_id": "glm-4-flash@ZHIPU-AI", + "embd_id": "BAAI/bge-large-zh-v1.5@BAAI", + "img2txt_id": "glm-4v@ZHIPU-AI", + "asr_id": "", + "tts_id": None, + } + response = requests.post(url=url, headers=authorization, json=tenant_info) + res = response.json() + if res.get("code") != 0: + raise Exception(res.get("message")) diff --git a/sdk/python/test/test_http_api/conftest.py b/sdk/python/test/test_http_api/conftest.py index f8a7bc4ed..0825113b7 100644 --- a/sdk/python/test/test_http_api/conftest.py +++ b/sdk/python/test/test_http_api/conftest.py @@ -16,7 +16,6 @@ import os import pytest -import requests from common import ( add_chunk, batch_create_datasets, @@ -49,9 +48,6 @@ MARKER_EXPRESSIONS = { "p3": "p1 or p2 or p3", } HOST_ADDRESS = os.getenv("HOST_ADDRESS", "http://127.0.0.1:9380") -ZHIPU_AI_API_KEY = os.getenv("ZHIPU_AI_API_KEY", "ca148e43209c40109e2bc2f56281dd11.BltyA2N1B043B7Ra") -if ZHIPU_AI_API_KEY is None: - pytest.exit("Error: Environment variable ZHIPU_AI_API_KEY must be set") def pytest_addoption(parser: pytest.Parser) -> None: @@ -85,67 +81,6 @@ def get_http_api_auth(get_api_key_fixture): return RAGFlowHttpApiAuth(get_api_key_fixture) -def get_my_llms(auth, name): - url = HOST_ADDRESS + "/v1/llm/my_llms" - authorization = {"Authorization": auth} - response = requests.get(url=url, headers=authorization) - res = response.json() - if res.get("code") != 0: - raise Exception(res.get("message")) - if name in res.get("data"): - return True - return False - - -def add_models(auth): - url = HOST_ADDRESS + "/v1/llm/set_api_key" - authorization = {"Authorization": auth} - models_info = { - "ZHIPU-AI": {"llm_factory": "ZHIPU-AI", "api_key": ZHIPU_AI_API_KEY}, - } - - for name, model_info in models_info.items(): - if not get_my_llms(auth, name): - response = requests.post(url=url, headers=authorization, json=model_info) - res = response.json() - if res.get("code") != 0: - pytest.exit(f"Critical error in add_models: {res.get('message')}") - - -def get_tenant_info(auth): - url = HOST_ADDRESS + "/v1/user/tenant_info" - authorization = {"Authorization": auth} - response = requests.get(url=url, headers=authorization) - res = response.json() - if res.get("code") != 0: - raise Exception(res.get("message")) - return res["data"].get("tenant_id") - - -@pytest.fixture(scope="session", autouse=True) -def set_tenant_info(get_auth): - auth = get_auth - try: - add_models(auth) - tenant_id = get_tenant_info(auth) - except Exception as e: - pytest.exit(f"Error in set_tenant_info: {str(e)}") - url = HOST_ADDRESS + "/v1/user/set_tenant_info" - authorization = {"Authorization": get_auth} - tenant_info = { - "tenant_id": tenant_id, - "llm_id": "glm-4-flash@ZHIPU-AI", - "embd_id": "BAAI/bge-large-zh-v1.5@BAAI", - "img2txt_id": "glm-4v@ZHIPU-AI", - "asr_id": "", - "tts_id": None, - } - response = requests.post(url=url, headers=authorization, json=tenant_info) - res = response.json() - if res.get("code") != 0: - raise Exception(res.get("message")) - - @pytest.fixture(scope="function") def clear_datasets(request, get_http_api_auth): def cleanup(): diff --git a/sdk/python/test/test_sdk_api/t_chat.py b/sdk/python/test/test_sdk_api/t_chat.py index 5b9ed4275..f15b52f31 100644 --- a/sdk/python/test/test_sdk_api/t_chat.py +++ b/sdk/python/test/test_sdk_api/t_chat.py @@ -14,8 +14,9 @@ # limitations under the License. # -from ragflow_sdk import RAGFlow from common import HOST_ADDRESS +from ragflow_sdk import RAGFlow +from ragflow_sdk.modules.chat import Chat def test_create_chat_with_name(get_api_key_fixture): @@ -31,7 +32,18 @@ def test_create_chat_with_name(get_api_key_fixture): docs = kb.upload_documents(documents) for doc in docs: doc.add_chunk("This is a test to add chunk") - rag.create_chat("test_create_chat", dataset_ids=[kb.id]) + llm = Chat.LLM( + rag, + { + "model_name": "glm-4-flash@ZHIPU-AI", + "temperature": 0.1, + "top_p": 0.3, + "presence_penalty": 0.4, + "frequency_penalty": 0.7, + "max_tokens": 512, + }, + ) + rag.create_chat("test_create_chat", dataset_ids=[kb.id], llm=llm) def test_update_chat_with_name(get_api_key_fixture): @@ -47,7 +59,18 @@ def test_update_chat_with_name(get_api_key_fixture): docs = kb.upload_documents(documents) for doc in docs: doc.add_chunk("This is a test to add chunk") - chat = rag.create_chat("test_update_chat", dataset_ids=[kb.id]) + llm = Chat.LLM( + rag, + { + "model_name": "glm-4-flash@ZHIPU-AI", + "temperature": 0.1, + "top_p": 0.3, + "presence_penalty": 0.4, + "frequency_penalty": 0.7, + "max_tokens": 512, + }, + ) + chat = rag.create_chat("test_update_chat", dataset_ids=[kb.id], llm=llm) chat.update({"name": "new_chat"}) @@ -64,7 +87,18 @@ def test_delete_chats_with_success(get_api_key_fixture): docs = kb.upload_documents(documents) for doc in docs: doc.add_chunk("This is a test to add chunk") - chat = rag.create_chat("test_delete_chat", dataset_ids=[kb.id]) + llm = Chat.LLM( + rag, + { + "model_name": "glm-4-flash@ZHIPU-AI", + "temperature": 0.1, + "top_p": 0.3, + "presence_penalty": 0.4, + "frequency_penalty": 0.7, + "max_tokens": 512, + }, + ) + chat = rag.create_chat("test_delete_chat", dataset_ids=[kb.id], llm=llm) rag.delete_chats(ids=[chat.id]) @@ -81,6 +115,17 @@ def test_list_chats_with_success(get_api_key_fixture): docs = kb.upload_documents(documents) for doc in docs: doc.add_chunk("This is a test to add chunk") - rag.create_chat("test_list_1", dataset_ids=[kb.id]) - rag.create_chat("test_list_2", dataset_ids=[kb.id]) + llm = Chat.LLM( + rag, + { + "model_name": "glm-4-flash@ZHIPU-AI", + "temperature": 0.1, + "top_p": 0.3, + "presence_penalty": 0.4, + "frequency_penalty": 0.7, + "max_tokens": 512, + }, + ) + rag.create_chat("test_list_1", dataset_ids=[kb.id], llm=llm) + rag.create_chat("test_list_2", dataset_ids=[kb.id], llm=llm) rag.list_chats()