From f298e55dede2ae565e912062be1196dccd14e5d5 Mon Sep 17 00:00:00 2001 From: liwenju0 Date: Thu, 20 Feb 2025 12:40:59 +0800 Subject: [PATCH] Fix: Normalize embedding model ID comparison across datasets (#5169) Modify embedding model ID comparison to remove vendor suffixes, ensuring consistent model identification when working with multiple knowledge bases. This change affects dialog creation, chat operations, and document retrieval test functions. ### What problem does this PR solve? resolve this bug: https://github.com/infiniflow/ragflow/issues/5166 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --------- Co-authored-by: wenju.li --- api/apps/dialog_app.py | 4 +++- api/apps/sdk/chat.py | 6 ++++-- api/apps/sdk/doc.py | 2 +- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/api/apps/dialog_app.py b/api/apps/dialog_app.py index 822bcd848..a77e355b5 100644 --- a/api/apps/dialog_app.py +++ b/api/apps/dialog_app.py @@ -18,6 +18,7 @@ from flask import request from flask_login import login_required, current_user from api.db.services.dialog_service import DialogService from api.db import StatusEnum +from api.db.services.llm_service import TenantLLMService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.user_service import TenantService, UserTenantService from api import settings @@ -75,7 +76,8 @@ def set_dialog(): if not e: return get_data_error_result(message="Tenant not found!") kbs = KnowledgebaseService.get_by_ids(req.get("kb_ids")) - embd_count = len(set([kb.embd_id for kb in kbs])) + embd_ids = [TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs] # remove vendor suffix for comparison + embd_count = len(set(embd_ids)) if embd_count != 1: return get_data_error_result(message=f'Datasets use different embedding models: {[kb.embd_id for kb in kbs]}"') diff --git a/api/apps/sdk/chat.py b/api/apps/sdk/chat.py index 49d79432d..cbd7fdb1f 100644 --- a/api/apps/sdk/chat.py +++ b/api/apps/sdk/chat.py @@ -41,7 +41,8 @@ def create(tenant_id): if kb.chunk_num == 0: return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file") kbs = KnowledgebaseService.get_by_ids(ids) - embd_count = list(set([kb.embd_id for kb in kbs])) + embd_ids = [TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs] # remove vendor suffix for comparison + embd_count = list(set(embd_ids)) if len(embd_count) != 1: return get_result(message='Datasets use different embedding models."', code=settings.RetCode.AUTHENTICATION_ERROR) @@ -176,7 +177,8 @@ def update(tenant_id, chat_id): if kb.chunk_num == 0: return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file") kbs = KnowledgebaseService.get_by_ids(ids) - embd_count = list(set([kb.embd_id for kb in kbs])) + embd_ids = [TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs] # remove vendor suffix for comparison + embd_count = list(set(embd_ids)) if len(embd_count) != 1: return get_result( message='Datasets use different embedding models."', diff --git a/api/apps/sdk/doc.py b/api/apps/sdk/doc.py index b87926335..98e2c4227 100644 --- a/api/apps/sdk/doc.py +++ b/api/apps/sdk/doc.py @@ -1305,7 +1305,7 @@ def retrieval_test(tenant_id): if not KnowledgebaseService.accessible(kb_id=id, user_id=tenant_id): return get_error_data_result(f"You don't own the dataset {id}.") kbs = KnowledgebaseService.get_by_ids(kb_ids) - embd_nms = list(set([kb.embd_id for kb in kbs])) + embd_nms = list(set([TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs])) # remove vendor suffix for comparison if len(embd_nms) != 1: return get_result( message='Datasets use different embedding models."',