mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-12 04:59:01 +08:00
Fix @ in model name issue. (#3821)
### What problem does this PR solve? #3814 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
parent
e66addc82d
commit
7543047de3
@ -120,7 +120,7 @@ def message_fit_in(msg, max_length=4000):
|
||||
|
||||
|
||||
def llm_id2llm_type(llm_id):
|
||||
llm_id = llm_id.split("@")[0]
|
||||
llm_id, _ = TenantLLMService.split_model_name_and_factory(llm_id)
|
||||
fnm = os.path.join(get_project_base_directory(), "conf")
|
||||
llm_factories = json.load(open(os.path.join(fnm, "llm_factories.json"), "r"))
|
||||
for llm_factory in llm_factories["factory_llm_infos"]:
|
||||
@ -132,11 +132,7 @@ def llm_id2llm_type(llm_id):
|
||||
def chat(dialog, messages, stream=True, **kwargs):
|
||||
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
||||
st = timer()
|
||||
tmp = dialog.llm_id.split("@")
|
||||
fid = None
|
||||
llm_id = tmp[0]
|
||||
if len(tmp)>1: fid = tmp[1]
|
||||
|
||||
llm_id, fid = TenantLLMService.split_model_name_and_factory(dialog.llm_id)
|
||||
llm = LLMService.query(llm_name=llm_id) if not fid else LLMService.query(llm_name=llm_id, fid=fid)
|
||||
if not llm:
|
||||
llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id) if not fid else \
|
||||
|
@ -13,8 +13,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
from api.db.services.user_service import TenantService
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel, Seq2txtModel, TTSModel
|
||||
from api.db import LLMType
|
||||
from api.db.db_models import DB
|
||||
@ -36,11 +40,11 @@ class TenantLLMService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_api_key(cls, tenant_id, model_name):
|
||||
arr = model_name.split("@")
|
||||
if len(arr) < 2:
|
||||
objs = cls.query(tenant_id=tenant_id, llm_name=model_name)
|
||||
mdlnm, fid = TenantLLMService.split_model_name_and_factory(model_name)
|
||||
if not fid:
|
||||
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm)
|
||||
else:
|
||||
objs = cls.query(tenant_id=tenant_id, llm_name=arr[0], llm_factory=arr[1])
|
||||
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
|
||||
if not objs:
|
||||
return
|
||||
return objs[0]
|
||||
@ -61,6 +65,23 @@ class TenantLLMService(CommonService):
|
||||
|
||||
return list(objs)
|
||||
|
||||
@staticmethod
|
||||
def split_model_name_and_factory(model_name):
|
||||
arr = model_name.split("@")
|
||||
if len(arr) < 2:
|
||||
return model_name, None
|
||||
if len(arr) > 2:
|
||||
return "@".join(arr[0:-1]), arr[-1]
|
||||
try:
|
||||
fact = json.load(open(os.path.join(get_project_base_directory(), "conf/llm_factories.json"), "r"))["factory_llm_infos"]
|
||||
fact = set([f["name"] for f in fact])
|
||||
if arr[-1] not in fact:
|
||||
return model_name, None
|
||||
return arr[0], arr[-1]
|
||||
except Exception as e:
|
||||
logging.exception(f"TenantLLMService.split_model_name_and_factory got exception: {e}")
|
||||
return model_name, None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def model_instance(cls, tenant_id, llm_type,
|
||||
@ -85,9 +106,7 @@ class TenantLLMService(CommonService):
|
||||
assert False, "LLM type error"
|
||||
|
||||
model_config = cls.get_api_key(tenant_id, mdlnm)
|
||||
tmp = mdlnm.split("@")
|
||||
fid = None if len(tmp) < 2 else tmp[1]
|
||||
mdlnm = tmp[0]
|
||||
mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm)
|
||||
if model_config: model_config = model_config.to_dict()
|
||||
if not model_config:
|
||||
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
|
||||
@ -168,7 +187,7 @@ class TenantLLMService(CommonService):
|
||||
else:
|
||||
assert False, "LLM type error"
|
||||
|
||||
llm_name = mdlnm.split("@")[0] if "@" in mdlnm else mdlnm
|
||||
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(mdlnm)
|
||||
|
||||
num = 0
|
||||
try:
|
||||
@ -179,7 +198,7 @@ class TenantLLMService(CommonService):
|
||||
.where(cls.model.tenant_id == tenant_id, cls.model.llm_factory == tenant_llm.llm_factory, cls.model.llm_name == llm_name)\
|
||||
.execute()
|
||||
else:
|
||||
llm_factory = mdlnm.split("@")[1] if "@" in mdlnm else mdlnm
|
||||
if not llm_factory: llm_factory = mdlnm
|
||||
num = cls.model.create(tenant_id=tenant_id, llm_factory=llm_factory, llm_name=llm_name, used_tokens=used_tokens)
|
||||
except Exception:
|
||||
logging.exception("TenantLLMService.increase_usage got exception")
|
||||
|
Loading…
x
Reference in New Issue
Block a user