diff --git a/api/apps/dialog_app.py b/api/apps/dialog_app.py index 822bcd848..a77e355b5 100644 --- a/api/apps/dialog_app.py +++ b/api/apps/dialog_app.py @@ -18,6 +18,7 @@ from flask import request from flask_login import login_required, current_user from api.db.services.dialog_service import DialogService from api.db import StatusEnum +from api.db.services.llm_service import TenantLLMService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.user_service import TenantService, UserTenantService from api import settings @@ -75,7 +76,8 @@ def set_dialog(): if not e: return get_data_error_result(message="Tenant not found!") kbs = KnowledgebaseService.get_by_ids(req.get("kb_ids")) - embd_count = len(set([kb.embd_id for kb in kbs])) + embd_ids = [TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs] # remove vendor suffix for comparison + embd_count = len(set(embd_ids)) if embd_count != 1: return get_data_error_result(message=f'Datasets use different embedding models: {[kb.embd_id for kb in kbs]}"') diff --git a/api/apps/sdk/chat.py b/api/apps/sdk/chat.py index 49d79432d..cbd7fdb1f 100644 --- a/api/apps/sdk/chat.py +++ b/api/apps/sdk/chat.py @@ -41,7 +41,8 @@ def create(tenant_id): if kb.chunk_num == 0: return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file") kbs = KnowledgebaseService.get_by_ids(ids) - embd_count = list(set([kb.embd_id for kb in kbs])) + embd_ids = [TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs] # remove vendor suffix for comparison + embd_count = list(set(embd_ids)) if len(embd_count) != 1: return get_result(message='Datasets use different embedding models."', code=settings.RetCode.AUTHENTICATION_ERROR) @@ -176,7 +177,8 @@ def update(tenant_id, chat_id): if kb.chunk_num == 0: return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file") kbs = KnowledgebaseService.get_by_ids(ids) - embd_count = list(set([kb.embd_id for kb in kbs])) + embd_ids = [TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs] # remove vendor suffix for comparison + embd_count = list(set(embd_ids)) if len(embd_count) != 1: return get_result( message='Datasets use different embedding models."', diff --git a/api/apps/sdk/doc.py b/api/apps/sdk/doc.py index b87926335..98e2c4227 100644 --- a/api/apps/sdk/doc.py +++ b/api/apps/sdk/doc.py @@ -1305,7 +1305,7 @@ def retrieval_test(tenant_id): if not KnowledgebaseService.accessible(kb_id=id, user_id=tenant_id): return get_error_data_result(f"You don't own the dataset {id}.") kbs = KnowledgebaseService.get_by_ids(kb_ids) - embd_nms = list(set([kb.embd_id for kb in kbs])) + embd_nms = list(set([TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs])) # remove vendor suffix for comparison if len(embd_nms) != 1: return get_result( message='Datasets use different embedding models."',