mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-06-04 11:24:00 +08:00
fix duplicated llm name betweeen different suppliers (#2477)
### What problem does this PR solve? #2465 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
parent
2484e26cb5
commit
01acc3fd5a
@ -27,7 +27,7 @@ from rag.utils.es_conn import ELASTICSEARCH
|
|||||||
from rag.utils import rmSpace
|
from rag.utils import rmSpace
|
||||||
from api.db import LLMType, ParserType
|
from api.db import LLMType, ParserType
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.llm_service import TenantLLMService
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.db.services.user_service import UserTenantService
|
from api.db.services.user_service import UserTenantService
|
||||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
@ -141,8 +141,7 @@ def set():
|
|||||||
return get_data_error_result(retmsg="Tenant not found!")
|
return get_data_error_result(retmsg="Tenant not found!")
|
||||||
|
|
||||||
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
||||||
embd_mdl = TenantLLMService.model_instance(
|
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id)
|
||||||
tenant_id, LLMType.EMBEDDING.value, embd_id)
|
|
||||||
|
|
||||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||||
if not e:
|
if not e:
|
||||||
@ -235,8 +234,7 @@ def create():
|
|||||||
return get_data_error_result(retmsg="Tenant not found!")
|
return get_data_error_result(retmsg="Tenant not found!")
|
||||||
|
|
||||||
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
||||||
embd_mdl = TenantLLMService.model_instance(
|
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)
|
||||||
tenant_id, LLMType.EMBEDDING.value, embd_id)
|
|
||||||
|
|
||||||
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
|
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
|
||||||
v = 0.1 * v[0] + 0.9 * v[1]
|
v = 0.1 * v[0] + 0.9 * v[1]
|
||||||
@ -281,16 +279,14 @@ def retrieval_test():
|
|||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(retmsg="Knowledgebase not found!")
|
return get_data_error_result(retmsg="Knowledgebase not found!")
|
||||||
|
|
||||||
embd_mdl = TenantLLMService.model_instance(
|
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||||
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
|
||||||
|
|
||||||
rerank_mdl = None
|
rerank_mdl = None
|
||||||
if req.get("rerank_id"):
|
if req.get("rerank_id"):
|
||||||
rerank_mdl = TenantLLMService.model_instance(
|
rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
|
||||||
kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
|
|
||||||
|
|
||||||
if req.get("keyword", False):
|
if req.get("keyword", False):
|
||||||
chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT)
|
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||||
question += keyword_extraction(chat_mdl, question)
|
question += keyword_extraction(chat_mdl, question)
|
||||||
|
|
||||||
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
|
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
|
||||||
|
@ -78,6 +78,7 @@ def message_fit_in(msg, max_length=4000):
|
|||||||
|
|
||||||
|
|
||||||
def llm_id2llm_type(llm_id):
|
def llm_id2llm_type(llm_id):
|
||||||
|
llm_id = llm_id.split("@")[0]
|
||||||
fnm = os.path.join(get_project_base_directory(), "conf")
|
fnm = os.path.join(get_project_base_directory(), "conf")
|
||||||
llm_factories = json.load(open(os.path.join(fnm, "llm_factories.json"), "r"))
|
llm_factories = json.load(open(os.path.join(fnm, "llm_factories.json"), "r"))
|
||||||
for llm_factory in llm_factories["factory_llm_infos"]:
|
for llm_factory in llm_factories["factory_llm_infos"]:
|
||||||
@ -89,9 +90,15 @@ def llm_id2llm_type(llm_id):
|
|||||||
def chat(dialog, messages, stream=True, **kwargs):
|
def chat(dialog, messages, stream=True, **kwargs):
|
||||||
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
||||||
st = timer()
|
st = timer()
|
||||||
llm = LLMService.query(llm_name=dialog.llm_id)
|
tmp = dialog.llm_id.split("@")
|
||||||
|
fid = None
|
||||||
|
llm_id = tmp[0]
|
||||||
|
if len(tmp)>1: fid = tmp[1]
|
||||||
|
|
||||||
|
llm = LLMService.query(llm_name=llm_id) if not fid else LLMService.query(llm_name=llm_id, fid=fid)
|
||||||
if not llm:
|
if not llm:
|
||||||
llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=dialog.llm_id)
|
llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id) if not fid else \
|
||||||
|
TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id, llm_factory=fid)
|
||||||
if not llm:
|
if not llm:
|
||||||
raise LookupError("LLM(%s) not found" % dialog.llm_id)
|
raise LookupError("LLM(%s) not found" % dialog.llm_id)
|
||||||
max_tokens = 8192
|
max_tokens = 8192
|
||||||
|
@ -17,7 +17,7 @@ from api.db.services.user_service import TenantService
|
|||||||
from api.settings import database_logger
|
from api.settings import database_logger
|
||||||
from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel, Seq2txtModel, TTSModel
|
from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel, Seq2txtModel, TTSModel
|
||||||
from api.db import LLMType
|
from api.db import LLMType
|
||||||
from api.db.db_models import DB, UserTenant
|
from api.db.db_models import DB
|
||||||
from api.db.db_models import LLMFactories, LLM, TenantLLM
|
from api.db.db_models import LLMFactories, LLM, TenantLLM
|
||||||
from api.db.services.common_service import CommonService
|
from api.db.services.common_service import CommonService
|
||||||
|
|
||||||
@ -36,7 +36,11 @@ class TenantLLMService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_api_key(cls, tenant_id, model_name):
|
def get_api_key(cls, tenant_id, model_name):
|
||||||
|
arr = model_name.split("@")
|
||||||
|
if len(arr) < 2:
|
||||||
objs = cls.query(tenant_id=tenant_id, llm_name=model_name)
|
objs = cls.query(tenant_id=tenant_id, llm_name=model_name)
|
||||||
|
else:
|
||||||
|
objs = cls.query(tenant_id=tenant_id, llm_name=arr[0], llm_factory=arr[1])
|
||||||
if not objs:
|
if not objs:
|
||||||
return
|
return
|
||||||
return objs[0]
|
return objs[0]
|
||||||
@ -81,14 +85,17 @@ class TenantLLMService(CommonService):
|
|||||||
assert False, "LLM type error"
|
assert False, "LLM type error"
|
||||||
|
|
||||||
model_config = cls.get_api_key(tenant_id, mdlnm)
|
model_config = cls.get_api_key(tenant_id, mdlnm)
|
||||||
|
tmp = mdlnm.split("@")
|
||||||
|
fid = None if len(tmp) < 2 else tmp[1]
|
||||||
|
mdlnm = tmp[0]
|
||||||
if model_config: model_config = model_config.to_dict()
|
if model_config: model_config = model_config.to_dict()
|
||||||
if not model_config:
|
if not model_config:
|
||||||
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
|
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
|
||||||
llm = LLMService.query(llm_name=llm_name if llm_name else mdlnm)
|
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
|
||||||
if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]:
|
if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]:
|
||||||
model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": llm_name if llm_name else mdlnm, "api_base": ""}
|
model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": mdlnm, "api_base": ""}
|
||||||
if not model_config:
|
if not model_config:
|
||||||
if llm_name == "flag-embedding":
|
if mdlnm == "flag-embedding":
|
||||||
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "",
|
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "",
|
||||||
"llm_name": llm_name, "api_base": ""}
|
"llm_name": llm_name, "api_base": ""}
|
||||||
else:
|
else:
|
||||||
|
@ -76,7 +76,7 @@ class Docx(DocxParser):
|
|||||||
if last_image:
|
if last_image:
|
||||||
image_list.insert(0, last_image)
|
image_list.insert(0, last_image)
|
||||||
last_image = None
|
last_image = None
|
||||||
lines.append((self.__clean(p.text), image_list, p.style.name))
|
lines.append((self.__clean(p.text), image_list, p.style.name if p.style else ""))
|
||||||
else:
|
else:
|
||||||
if current_image := self.get_picture(self.doc, p):
|
if current_image := self.get_picture(self.doc, p):
|
||||||
if lines:
|
if lines:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user