mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-11 16:09:03 +08:00
Fix @ in model name issue. (#3821)
### What problem does this PR solve? #3814 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
parent
e66addc82d
commit
7543047de3
@ -120,7 +120,7 @@ def message_fit_in(msg, max_length=4000):
|
|||||||
|
|
||||||
|
|
||||||
def llm_id2llm_type(llm_id):
|
def llm_id2llm_type(llm_id):
|
||||||
llm_id = llm_id.split("@")[0]
|
llm_id, _ = TenantLLMService.split_model_name_and_factory(llm_id)
|
||||||
fnm = os.path.join(get_project_base_directory(), "conf")
|
fnm = os.path.join(get_project_base_directory(), "conf")
|
||||||
llm_factories = json.load(open(os.path.join(fnm, "llm_factories.json"), "r"))
|
llm_factories = json.load(open(os.path.join(fnm, "llm_factories.json"), "r"))
|
||||||
for llm_factory in llm_factories["factory_llm_infos"]:
|
for llm_factory in llm_factories["factory_llm_infos"]:
|
||||||
@ -132,11 +132,7 @@ def llm_id2llm_type(llm_id):
|
|||||||
def chat(dialog, messages, stream=True, **kwargs):
|
def chat(dialog, messages, stream=True, **kwargs):
|
||||||
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
||||||
st = timer()
|
st = timer()
|
||||||
tmp = dialog.llm_id.split("@")
|
llm_id, fid = TenantLLMService.split_model_name_and_factory(dialog.llm_id)
|
||||||
fid = None
|
|
||||||
llm_id = tmp[0]
|
|
||||||
if len(tmp)>1: fid = tmp[1]
|
|
||||||
|
|
||||||
llm = LLMService.query(llm_name=llm_id) if not fid else LLMService.query(llm_name=llm_id, fid=fid)
|
llm = LLMService.query(llm_name=llm_id) if not fid else LLMService.query(llm_name=llm_id, fid=fid)
|
||||||
if not llm:
|
if not llm:
|
||||||
llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id) if not fid else \
|
llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id) if not fid else \
|
||||||
|
@ -13,8 +13,12 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
from api.db.services.user_service import TenantService
|
from api.db.services.user_service import TenantService
|
||||||
|
from api.utils.file_utils import get_project_base_directory
|
||||||
from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel, Seq2txtModel, TTSModel
|
from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel, Seq2txtModel, TTSModel
|
||||||
from api.db import LLMType
|
from api.db import LLMType
|
||||||
from api.db.db_models import DB
|
from api.db.db_models import DB
|
||||||
@ -36,11 +40,11 @@ class TenantLLMService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_api_key(cls, tenant_id, model_name):
|
def get_api_key(cls, tenant_id, model_name):
|
||||||
arr = model_name.split("@")
|
mdlnm, fid = TenantLLMService.split_model_name_and_factory(model_name)
|
||||||
if len(arr) < 2:
|
if not fid:
|
||||||
objs = cls.query(tenant_id=tenant_id, llm_name=model_name)
|
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm)
|
||||||
else:
|
else:
|
||||||
objs = cls.query(tenant_id=tenant_id, llm_name=arr[0], llm_factory=arr[1])
|
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
|
||||||
if not objs:
|
if not objs:
|
||||||
return
|
return
|
||||||
return objs[0]
|
return objs[0]
|
||||||
@ -61,6 +65,23 @@ class TenantLLMService(CommonService):
|
|||||||
|
|
||||||
return list(objs)
|
return list(objs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def split_model_name_and_factory(model_name):
|
||||||
|
arr = model_name.split("@")
|
||||||
|
if len(arr) < 2:
|
||||||
|
return model_name, None
|
||||||
|
if len(arr) > 2:
|
||||||
|
return "@".join(arr[0:-1]), arr[-1]
|
||||||
|
try:
|
||||||
|
fact = json.load(open(os.path.join(get_project_base_directory(), "conf/llm_factories.json"), "r"))["factory_llm_infos"]
|
||||||
|
fact = set([f["name"] for f in fact])
|
||||||
|
if arr[-1] not in fact:
|
||||||
|
return model_name, None
|
||||||
|
return arr[0], arr[-1]
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(f"TenantLLMService.split_model_name_and_factory got exception: {e}")
|
||||||
|
return model_name, None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def model_instance(cls, tenant_id, llm_type,
|
def model_instance(cls, tenant_id, llm_type,
|
||||||
@ -85,9 +106,7 @@ class TenantLLMService(CommonService):
|
|||||||
assert False, "LLM type error"
|
assert False, "LLM type error"
|
||||||
|
|
||||||
model_config = cls.get_api_key(tenant_id, mdlnm)
|
model_config = cls.get_api_key(tenant_id, mdlnm)
|
||||||
tmp = mdlnm.split("@")
|
mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm)
|
||||||
fid = None if len(tmp) < 2 else tmp[1]
|
|
||||||
mdlnm = tmp[0]
|
|
||||||
if model_config: model_config = model_config.to_dict()
|
if model_config: model_config = model_config.to_dict()
|
||||||
if not model_config:
|
if not model_config:
|
||||||
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
|
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
|
||||||
@ -168,7 +187,7 @@ class TenantLLMService(CommonService):
|
|||||||
else:
|
else:
|
||||||
assert False, "LLM type error"
|
assert False, "LLM type error"
|
||||||
|
|
||||||
llm_name = mdlnm.split("@")[0] if "@" in mdlnm else mdlnm
|
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(mdlnm)
|
||||||
|
|
||||||
num = 0
|
num = 0
|
||||||
try:
|
try:
|
||||||
@ -179,7 +198,7 @@ class TenantLLMService(CommonService):
|
|||||||
.where(cls.model.tenant_id == tenant_id, cls.model.llm_factory == tenant_llm.llm_factory, cls.model.llm_name == llm_name)\
|
.where(cls.model.tenant_id == tenant_id, cls.model.llm_factory == tenant_llm.llm_factory, cls.model.llm_name == llm_name)\
|
||||||
.execute()
|
.execute()
|
||||||
else:
|
else:
|
||||||
llm_factory = mdlnm.split("@")[1] if "@" in mdlnm else mdlnm
|
if not llm_factory: llm_factory = mdlnm
|
||||||
num = cls.model.create(tenant_id=tenant_id, llm_factory=llm_factory, llm_name=llm_name, used_tokens=used_tokens)
|
num = cls.model.create(tenant_id=tenant_id, llm_factory=llm_factory, llm_name=llm_name, used_tokens=used_tokens)
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.exception("TenantLLMService.increase_usage got exception")
|
logging.exception("TenantLLMService.increase_usage got exception")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user