mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-04-22 14:10:01 +08:00
### What problem does this PR solve? Fix: The max tokens defined by the tenant are not used (#4297) (#2817) ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --------- Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
parent
3e0bc9e36b
commit
00c7ddbc9b
@ -29,7 +29,7 @@ from api.db.db_models import Dialog, DB
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
|
||||
from api.db.services.llm_service import TenantLLMService, LLMBundle
|
||||
from api import settings
|
||||
from graphrag.utils import get_tags_from_cache, set_tags_to_cache
|
||||
from rag.app.resume import forbidden_select_fields4resume
|
||||
@ -172,21 +172,12 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
|
||||
chat_start_ts = timer()
|
||||
|
||||
# Get llm model name and model provider name
|
||||
llm_id, model_provider = TenantLLMService.split_model_name_and_factory(dialog.llm_id)
|
||||
|
||||
# Get llm model instance by model and provide name
|
||||
llm = LLMService.query(llm_name=llm_id) if not model_provider else LLMService.query(llm_name=llm_id, fid=model_provider)
|
||||
|
||||
if not llm:
|
||||
# Model name is provided by tenant, but not system built-in
|
||||
llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id) if not model_provider else \
|
||||
TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id, llm_factory=model_provider)
|
||||
if not llm:
|
||||
raise LookupError("LLM(%s) not found" % dialog.llm_id)
|
||||
max_tokens = 8192
|
||||
if llm_id2llm_type(dialog.llm_id) == "image2text":
|
||||
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
||||
else:
|
||||
max_tokens = llm[0].max_tokens
|
||||
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
||||
|
||||
max_tokens = llm_model_config.get("max_tokens", 8192)
|
||||
|
||||
check_llm_ts = timer()
|
||||
|
||||
|
@ -86,8 +86,7 @@ class TenantLLMService(CommonService):
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def model_instance(cls, tenant_id, llm_type,
|
||||
llm_name=None, lang="Chinese"):
|
||||
def get_model_config(cls, tenant_id, llm_type, llm_name=None):
|
||||
e, tenant = TenantService.get_by_id(tenant_id)
|
||||
if not e:
|
||||
raise LookupError("Tenant not found")
|
||||
@ -124,7 +123,13 @@ class TenantLLMService(CommonService):
|
||||
if not mdlnm:
|
||||
raise LookupError(f"Type of {llm_type} model is not set.")
|
||||
raise LookupError("Model({}) not authorized".format(mdlnm))
|
||||
return model_config
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def model_instance(cls, tenant_id, llm_type,
|
||||
llm_name=None, lang="Chinese"):
|
||||
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
|
||||
if llm_type == LLMType.EMBEDDING.value:
|
||||
if model_config["llm_factory"] not in EmbeddingModel:
|
||||
return
|
||||
@ -228,10 +233,8 @@ class LLMBundle(object):
|
||||
tenant_id, llm_type, llm_name, lang=lang)
|
||||
assert self.mdl, "Can't find model for {}/{}/{}".format(
|
||||
tenant_id, llm_type, llm_name)
|
||||
self.max_length = 8192
|
||||
for lm in LLMService.query(llm_name=llm_name):
|
||||
self.max_length = lm.max_tokens
|
||||
break
|
||||
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
|
||||
self.max_length = model_config.get("max_tokens", 8192)
|
||||
|
||||
def encode(self, texts: list):
|
||||
embeddings, used_tokens = self.mdl.encode(texts)
|
||||
|
Loading…
x
Reference in New Issue
Block a user