From ec1659cba091ce1dcf7a374a9749e574b87aa991 Mon Sep 17 00:00:00 2001 From: crazywoola <100913391+crazywoola@users.noreply.github.com> Date: Fri, 19 Jan 2024 20:12:04 +0800 Subject: [PATCH] fix: saving error in empty dataset (#2098) --- api/controllers/console/datasets/datasets.py | 14 +++++++++----- api/controllers/service_api/dataset/dataset.py | 3 ++- api/models/dataset.py | 2 +- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 7be8e87ce0..8d315460d3 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -19,7 +19,7 @@ from flask import current_app, request from flask_login import current_user from flask_restful import Resource, marshal, marshal_with, reqparse from libs.login import login_required -from models.dataset import Document, DocumentSegment +from models.dataset import Dataset, Document, DocumentSegment from models.model import ApiToken, UploadFile from services.dataset_service import DatasetService, DocumentService from werkzeug.exceptions import Forbidden, NotFound @@ -97,7 +97,8 @@ class DatasetListApi(Resource): help='type is required. Name must be between 1 to 40 characters.', type=_validate_name) parser.add_argument('indexing_technique', type=str, location='json', - choices=('high_quality', 'economy'), + choices=Dataset.INDEXING_TECHNIQUE_LIST, + nullable=True, help='Invalid indexing technique.') args = parser.parse_args() @@ -177,8 +178,9 @@ class DatasetApi(Resource): location='json', store_missing=False, type=_validate_description_length) parser.add_argument('indexing_technique', type=str, location='json', - choices=('high_quality', 'economy'), - help='Invalid indexing technique.') + choices=Dataset.INDEXING_TECHNIQUE_LIST, + nullable=True, + help='Invalid indexing technique.') parser.add_argument('permission', type=str, location='json', choices=( 'only_me', 'all_team_members'), help='Invalid permission.') parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.') @@ -256,7 +258,9 @@ class DatasetIndexingEstimateApi(Resource): parser = reqparse.RequestParser() parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json') parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') - parser.add_argument('indexing_technique', type=str, required=True, nullable=True, location='json') + parser.add_argument('indexing_technique', type=str, required=True, + choices=Dataset.INDEXING_TECHNIQUE_LIST, + nullable=True, location='json') parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json') parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 6028d7c341..6827d47dfc 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -1,3 +1,4 @@ +from models.dataset import Dataset import services.dataset_service from controllers.service_api import api from controllers.service_api.dataset.error import DatasetNameDuplicateError @@ -68,7 +69,7 @@ class DatasetApi(DatasetApiResource): help='type is required. Name must be between 1 to 40 characters.', type=_validate_name) parser.add_argument('indexing_technique', type=str, location='json', - choices=('high_quality', 'economy'), + choices=Dataset.INDEXING_TECHNIQUE_LIST, help='Invalid indexing technique.') args = parser.parse_args() diff --git a/api/models/dataset.py b/api/models/dataset.py index 5e67b2b8b8..b9f8eacca6 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -17,7 +17,7 @@ class Dataset(db.Model): db.Index('retrieval_model_idx', "retrieval_model", postgresql_using='gin') ) - INDEXING_TECHNIQUE_LIST = ['high_quality', 'economy'] + INDEXING_TECHNIQUE_LIST = ['high_quality', 'economy', None] id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) tenant_id = db.Column(UUID, nullable=False)