Fix @ in model name issue. (#3821)

### What problem does this PR solve?

#3814

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Kevin Hu 2024-12-03 12:41:39 +08:00 committed by GitHub
parent e66addc82d
commit 7543047de3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 30 additions and 15 deletions

View File

@ -120,7 +120,7 @@ def message_fit_in(msg, max_length=4000):
def llm_id2llm_type(llm_id):
llm_id = llm_id.split("@")[0]
llm_id, _ = TenantLLMService.split_model_name_and_factory(llm_id)
fnm = os.path.join(get_project_base_directory(), "conf")
llm_factories = json.load(open(os.path.join(fnm, "llm_factories.json"), "r"))
for llm_factory in llm_factories["factory_llm_infos"]:
@ -132,11 +132,7 @@ def llm_id2llm_type(llm_id):
def chat(dialog, messages, stream=True, **kwargs):
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
st = timer()
tmp = dialog.llm_id.split("@")
fid = None
llm_id = tmp[0]
if len(tmp)>1: fid = tmp[1]
llm_id, fid = TenantLLMService.split_model_name_and_factory(dialog.llm_id)
llm = LLMService.query(llm_name=llm_id) if not fid else LLMService.query(llm_name=llm_id, fid=fid)
if not llm:
llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id) if not fid else \

View File

@ -13,8 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import logging
import os
from api.db.services.user_service import TenantService
from api.utils.file_utils import get_project_base_directory
from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel, Seq2txtModel, TTSModel
from api.db import LLMType
from api.db.db_models import DB
@ -36,11 +40,11 @@ class TenantLLMService(CommonService):
@classmethod
@DB.connection_context()
def get_api_key(cls, tenant_id, model_name):
arr = model_name.split("@")
if len(arr) < 2:
objs = cls.query(tenant_id=tenant_id, llm_name=model_name)
mdlnm, fid = TenantLLMService.split_model_name_and_factory(model_name)
if not fid:
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm)
else:
objs = cls.query(tenant_id=tenant_id, llm_name=arr[0], llm_factory=arr[1])
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
if not objs:
return
return objs[0]
@ -61,6 +65,23 @@ class TenantLLMService(CommonService):
return list(objs)
@staticmethod
def split_model_name_and_factory(model_name):
arr = model_name.split("@")
if len(arr) < 2:
return model_name, None
if len(arr) > 2:
return "@".join(arr[0:-1]), arr[-1]
try:
fact = json.load(open(os.path.join(get_project_base_directory(), "conf/llm_factories.json"), "r"))["factory_llm_infos"]
fact = set([f["name"] for f in fact])
if arr[-1] not in fact:
return model_name, None
return arr[0], arr[-1]
except Exception as e:
logging.exception(f"TenantLLMService.split_model_name_and_factory got exception: {e}")
return model_name, None
@classmethod
@DB.connection_context()
def model_instance(cls, tenant_id, llm_type,
@ -85,9 +106,7 @@ class TenantLLMService(CommonService):
assert False, "LLM type error"
model_config = cls.get_api_key(tenant_id, mdlnm)
tmp = mdlnm.split("@")
fid = None if len(tmp) < 2 else tmp[1]
mdlnm = tmp[0]
mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm)
if model_config: model_config = model_config.to_dict()
if not model_config:
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
@ -168,7 +187,7 @@ class TenantLLMService(CommonService):
else:
assert False, "LLM type error"
llm_name = mdlnm.split("@")[0] if "@" in mdlnm else mdlnm
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(mdlnm)
num = 0
try:
@ -179,7 +198,7 @@ class TenantLLMService(CommonService):
.where(cls.model.tenant_id == tenant_id, cls.model.llm_factory == tenant_llm.llm_factory, cls.model.llm_name == llm_name)\
.execute()
else:
llm_factory = mdlnm.split("@")[1] if "@" in mdlnm else mdlnm
if not llm_factory: llm_factory = mdlnm
num = cls.model.create(tenant_id=tenant_id, llm_factory=llm_factory, llm_name=llm_name, used_tokens=used_tokens)
except Exception:
logging.exception("TenantLLMService.increase_usage got exception")