From 7543047de3f2173e5d47a24f48ef909d4277ebb6 Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Tue, 3 Dec 2024 12:41:39 +0800 Subject: [PATCH] Fix @ in model name issue. (#3821) ### What problem does this PR solve? #3814 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- api/db/services/dialog_service.py | 8 ++----- api/db/services/llm_service.py | 37 +++++++++++++++++++++++-------- 2 files changed, 30 insertions(+), 15 deletions(-) diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 0463cbb6c..9bdd53444 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -120,7 +120,7 @@ def message_fit_in(msg, max_length=4000): def llm_id2llm_type(llm_id): - llm_id = llm_id.split("@")[0] + llm_id, _ = TenantLLMService.split_model_name_and_factory(llm_id) fnm = os.path.join(get_project_base_directory(), "conf") llm_factories = json.load(open(os.path.join(fnm, "llm_factories.json"), "r")) for llm_factory in llm_factories["factory_llm_infos"]: @@ -132,11 +132,7 @@ def llm_id2llm_type(llm_id): def chat(dialog, messages, stream=True, **kwargs): assert messages[-1]["role"] == "user", "The last content of this conversation is not from user." st = timer() - tmp = dialog.llm_id.split("@") - fid = None - llm_id = tmp[0] - if len(tmp)>1: fid = tmp[1] - + llm_id, fid = TenantLLMService.split_model_name_and_factory(dialog.llm_id) llm = LLMService.query(llm_name=llm_id) if not fid else LLMService.query(llm_name=llm_id, fid=fid) if not llm: llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id) if not fid else \ diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index 90dfb0263..e7bdc455a 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -13,8 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import json import logging +import os + from api.db.services.user_service import TenantService +from api.utils.file_utils import get_project_base_directory from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel, Seq2txtModel, TTSModel from api.db import LLMType from api.db.db_models import DB @@ -36,11 +40,11 @@ class TenantLLMService(CommonService): @classmethod @DB.connection_context() def get_api_key(cls, tenant_id, model_name): - arr = model_name.split("@") - if len(arr) < 2: - objs = cls.query(tenant_id=tenant_id, llm_name=model_name) + mdlnm, fid = TenantLLMService.split_model_name_and_factory(model_name) + if not fid: + objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm) else: - objs = cls.query(tenant_id=tenant_id, llm_name=arr[0], llm_factory=arr[1]) + objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid) if not objs: return return objs[0] @@ -61,6 +65,23 @@ class TenantLLMService(CommonService): return list(objs) + @staticmethod + def split_model_name_and_factory(model_name): + arr = model_name.split("@") + if len(arr) < 2: + return model_name, None + if len(arr) > 2: + return "@".join(arr[0:-1]), arr[-1] + try: + fact = json.load(open(os.path.join(get_project_base_directory(), "conf/llm_factories.json"), "r"))["factory_llm_infos"] + fact = set([f["name"] for f in fact]) + if arr[-1] not in fact: + return model_name, None + return arr[0], arr[-1] + except Exception as e: + logging.exception(f"TenantLLMService.split_model_name_and_factory got exception: {e}") + return model_name, None + @classmethod @DB.connection_context() def model_instance(cls, tenant_id, llm_type, @@ -85,9 +106,7 @@ class TenantLLMService(CommonService): assert False, "LLM type error" model_config = cls.get_api_key(tenant_id, mdlnm) - tmp = mdlnm.split("@") - fid = None if len(tmp) < 2 else tmp[1] - mdlnm = tmp[0] + mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm) if model_config: model_config = model_config.to_dict() if not model_config: if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]: @@ -168,7 +187,7 @@ class TenantLLMService(CommonService): else: assert False, "LLM type error" - llm_name = mdlnm.split("@")[0] if "@" in mdlnm else mdlnm + llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(mdlnm) num = 0 try: @@ -179,7 +198,7 @@ class TenantLLMService(CommonService): .where(cls.model.tenant_id == tenant_id, cls.model.llm_factory == tenant_llm.llm_factory, cls.model.llm_name == llm_name)\ .execute() else: - llm_factory = mdlnm.split("@")[1] if "@" in mdlnm else mdlnm + if not llm_factory: llm_factory = mdlnm num = cls.model.create(tenant_id=tenant_id, llm_factory=llm_factory, llm_name=llm_name, used_tokens=used_tokens) except Exception: logging.exception("TenantLLMService.increase_usage got exception")