mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-14 04:55:54 +08:00
Fix: Normalize embedding model ID comparison across datasets (#5169)
Modify embedding model ID comparison to remove vendor suffixes, ensuring consistent model identification when working with multiple knowledge bases. This change affects dialog creation, chat operations, and document retrieval test functions. ### What problem does this PR solve? resolve this bug: https://github.com/infiniflow/ragflow/issues/5166 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --------- Co-authored-by: wenju.li <wenju.li@deepctr.cn>
This commit is contained in:
parent
ed943b1b5b
commit
f298e55ded
@ -18,6 +18,7 @@ from flask import request
|
|||||||
from flask_login import login_required, current_user
|
from flask_login import login_required, current_user
|
||||||
from api.db.services.dialog_service import DialogService
|
from api.db.services.dialog_service import DialogService
|
||||||
from api.db import StatusEnum
|
from api.db import StatusEnum
|
||||||
|
from api.db.services.llm_service import TenantLLMService
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.user_service import TenantService, UserTenantService
|
from api.db.services.user_service import TenantService, UserTenantService
|
||||||
from api import settings
|
from api import settings
|
||||||
@ -75,7 +76,8 @@ def set_dialog():
|
|||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Tenant not found!")
|
return get_data_error_result(message="Tenant not found!")
|
||||||
kbs = KnowledgebaseService.get_by_ids(req.get("kb_ids"))
|
kbs = KnowledgebaseService.get_by_ids(req.get("kb_ids"))
|
||||||
embd_count = len(set([kb.embd_id for kb in kbs]))
|
embd_ids = [TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs] # remove vendor suffix for comparison
|
||||||
|
embd_count = len(set(embd_ids))
|
||||||
if embd_count != 1:
|
if embd_count != 1:
|
||||||
return get_data_error_result(message=f'Datasets use different embedding models: {[kb.embd_id for kb in kbs]}"')
|
return get_data_error_result(message=f'Datasets use different embedding models: {[kb.embd_id for kb in kbs]}"')
|
||||||
|
|
||||||
|
@ -41,7 +41,8 @@ def create(tenant_id):
|
|||||||
if kb.chunk_num == 0:
|
if kb.chunk_num == 0:
|
||||||
return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
|
return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
|
||||||
kbs = KnowledgebaseService.get_by_ids(ids)
|
kbs = KnowledgebaseService.get_by_ids(ids)
|
||||||
embd_count = list(set([kb.embd_id for kb in kbs]))
|
embd_ids = [TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs] # remove vendor suffix for comparison
|
||||||
|
embd_count = list(set(embd_ids))
|
||||||
if len(embd_count) != 1:
|
if len(embd_count) != 1:
|
||||||
return get_result(message='Datasets use different embedding models."',
|
return get_result(message='Datasets use different embedding models."',
|
||||||
code=settings.RetCode.AUTHENTICATION_ERROR)
|
code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||||
@ -176,7 +177,8 @@ def update(tenant_id, chat_id):
|
|||||||
if kb.chunk_num == 0:
|
if kb.chunk_num == 0:
|
||||||
return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
|
return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
|
||||||
kbs = KnowledgebaseService.get_by_ids(ids)
|
kbs = KnowledgebaseService.get_by_ids(ids)
|
||||||
embd_count = list(set([kb.embd_id for kb in kbs]))
|
embd_ids = [TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs] # remove vendor suffix for comparison
|
||||||
|
embd_count = list(set(embd_ids))
|
||||||
if len(embd_count) != 1:
|
if len(embd_count) != 1:
|
||||||
return get_result(
|
return get_result(
|
||||||
message='Datasets use different embedding models."',
|
message='Datasets use different embedding models."',
|
||||||
|
@ -1305,7 +1305,7 @@ def retrieval_test(tenant_id):
|
|||||||
if not KnowledgebaseService.accessible(kb_id=id, user_id=tenant_id):
|
if not KnowledgebaseService.accessible(kb_id=id, user_id=tenant_id):
|
||||||
return get_error_data_result(f"You don't own the dataset {id}.")
|
return get_error_data_result(f"You don't own the dataset {id}.")
|
||||||
kbs = KnowledgebaseService.get_by_ids(kb_ids)
|
kbs = KnowledgebaseService.get_by_ids(kb_ids)
|
||||||
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
embd_nms = list(set([TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs])) # remove vendor suffix for comparison
|
||||||
if len(embd_nms) != 1:
|
if len(embd_nms) != 1:
|
||||||
return get_result(
|
return get_result(
|
||||||
message='Datasets use different embedding models."',
|
message='Datasets use different embedding models."',
|
||||||
|
Loading…
x
Reference in New Issue
Block a user