Fix: Normalize embedding model ID comparison across datasets (#5169)

Modify embedding model ID comparison to remove vendor suffixes, ensuring
consistent model identification when working with multiple knowledge
bases. This change affects dialog creation, chat operations, and
document retrieval test functions.

### What problem does this PR solve?

resolve this bug: https://github.com/infiniflow/ragflow/issues/5166

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Co-authored-by: wenju.li <wenju.li@deepctr.cn>
This commit is contained in:
liwenju0 2025-02-20 12:40:59 +08:00 committed by GitHub
parent ed943b1b5b
commit f298e55ded
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 8 additions and 4 deletions

View File

@ -18,6 +18,7 @@ from flask import request
from flask_login import login_required, current_user from flask_login import login_required, current_user
from api.db.services.dialog_service import DialogService from api.db.services.dialog_service import DialogService
from api.db import StatusEnum from api.db import StatusEnum
from api.db.services.llm_service import TenantLLMService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.user_service import TenantService, UserTenantService from api.db.services.user_service import TenantService, UserTenantService
from api import settings from api import settings
@ -75,7 +76,8 @@ def set_dialog():
if not e: if not e:
return get_data_error_result(message="Tenant not found!") return get_data_error_result(message="Tenant not found!")
kbs = KnowledgebaseService.get_by_ids(req.get("kb_ids")) kbs = KnowledgebaseService.get_by_ids(req.get("kb_ids"))
embd_count = len(set([kb.embd_id for kb in kbs])) embd_ids = [TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs] # remove vendor suffix for comparison
embd_count = len(set(embd_ids))
if embd_count != 1: if embd_count != 1:
return get_data_error_result(message=f'Datasets use different embedding models: {[kb.embd_id for kb in kbs]}"') return get_data_error_result(message=f'Datasets use different embedding models: {[kb.embd_id for kb in kbs]}"')

View File

@ -41,7 +41,8 @@ def create(tenant_id):
if kb.chunk_num == 0: if kb.chunk_num == 0:
return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file") return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
kbs = KnowledgebaseService.get_by_ids(ids) kbs = KnowledgebaseService.get_by_ids(ids)
embd_count = list(set([kb.embd_id for kb in kbs])) embd_ids = [TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs] # remove vendor suffix for comparison
embd_count = list(set(embd_ids))
if len(embd_count) != 1: if len(embd_count) != 1:
return get_result(message='Datasets use different embedding models."', return get_result(message='Datasets use different embedding models."',
code=settings.RetCode.AUTHENTICATION_ERROR) code=settings.RetCode.AUTHENTICATION_ERROR)
@ -176,7 +177,8 @@ def update(tenant_id, chat_id):
if kb.chunk_num == 0: if kb.chunk_num == 0:
return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file") return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
kbs = KnowledgebaseService.get_by_ids(ids) kbs = KnowledgebaseService.get_by_ids(ids)
embd_count = list(set([kb.embd_id for kb in kbs])) embd_ids = [TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs] # remove vendor suffix for comparison
embd_count = list(set(embd_ids))
if len(embd_count) != 1: if len(embd_count) != 1:
return get_result( return get_result(
message='Datasets use different embedding models."', message='Datasets use different embedding models."',

View File

@ -1305,7 +1305,7 @@ def retrieval_test(tenant_id):
if not KnowledgebaseService.accessible(kb_id=id, user_id=tenant_id): if not KnowledgebaseService.accessible(kb_id=id, user_id=tenant_id):
return get_error_data_result(f"You don't own the dataset {id}.") return get_error_data_result(f"You don't own the dataset {id}.")
kbs = KnowledgebaseService.get_by_ids(kb_ids) kbs = KnowledgebaseService.get_by_ids(kb_ids)
embd_nms = list(set([kb.embd_id for kb in kbs])) embd_nms = list(set([TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs])) # remove vendor suffix for comparison
if len(embd_nms) != 1: if len(embd_nms) != 1:
return get_result( return get_result(
message='Datasets use different embedding models."', message='Datasets use different embedding models."',