diff --git a/api/apps/sdk/chat.py b/api/apps/sdk/chat.py index 100e97c78..da6fcc51d 100644 --- a/api/apps/sdk/chat.py +++ b/api/apps/sdk/chat.py @@ -31,9 +31,7 @@ from api.utils.api_utils import get_result @token_required def create(tenant_id): req = request.json - ids = req.get("dataset_ids") - if not ids: - return get_error_data_result(message="`dataset_ids` is required") + ids = [i for i in req.get("dataset_ids", []) if i] for kb_id in ids: kbs = KnowledgebaseService.accessible(kb_id=kb_id, user_id=tenant_id) if not kbs: @@ -42,10 +40,10 @@ def create(tenant_id): kb = kbs[0] if kb.chunk_num == 0: return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file") - kbs = KnowledgebaseService.get_by_ids(ids) + kbs = KnowledgebaseService.get_by_ids(ids) if ids else [] embd_ids = [TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs] # remove vendor suffix for comparison embd_count = list(set(embd_ids)) - if len(embd_count) != 1: + if len(embd_count) > 1: return get_result(message='Datasets use different embedding models."', code=settings.RetCode.AUTHENTICATION_ERROR) req["kb_ids"] = ids