From 87763ef0a0c9a3ebd9f95d7143b97a234dc78fdf Mon Sep 17 00:00:00 2001 From: liu an Date: Tue, 11 Mar 2025 18:55:11 +0800 Subject: [PATCH] TEST: Added test cases for Update Dataset HTTP API (#5924) ### What problem does this PR solve? cover dataset update endpoints ### Type of change - [x] Add test cases --- .../test_update_dataset.py | 334 ++++++++++++++++++ 1 file changed, 334 insertions(+) create mode 100644 sdk/python/test/test_http_api/test_dataset_mangement/test_update_dataset.py diff --git a/sdk/python/test/test_http_api/test_dataset_mangement/test_update_dataset.py b/sdk/python/test/test_http_api/test_dataset_mangement/test_update_dataset.py new file mode 100644 index 000000000..cb05bb8e7 --- /dev/null +++ b/sdk/python/test/test_http_api/test_dataset_mangement/test_update_dataset.py @@ -0,0 +1,334 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import base64 +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path + +import pytest +from common import ( + DATASET_NAME_LIMIT, + INVALID_API_TOKEN, + create_datasets, + list_dataset, + update_dataset, +) +from libs.auth import RAGFlowHttpApiAuth + +# TODO: Missing scenario for updating embedding_model with chunk_count != 0 + + +class TestAuthorization: + @pytest.mark.parametrize( + "auth, expected_code, expected_message", + [ + (None, 0, "`Authorization` can't be empty"), + ( + RAGFlowHttpApiAuth(INVALID_API_TOKEN), + 109, + "Authentication error: API key is invalid!", + ), + ], + ) + def test_invalid_auth( + self, get_http_api_auth, auth, expected_code, expected_message + ): + ids = create_datasets(get_http_api_auth, 1) + res = update_dataset(auth, ids[0], {"name": "new_name"}) + assert res["code"] == expected_code + assert res["message"] == expected_message + + +class TestDatasetUpdate: + @pytest.mark.parametrize( + "name, expected_code, expected_message", + [ + ("valid_name", 0, ""), + ( + "a" * (DATASET_NAME_LIMIT + 1), + 102, + "Dataset name should not be longer than 128 characters.", + ), + (0, 100, """AttributeError("\'int\' object has no attribute \'strip\'")"""), + ( + None, + 100, + """AttributeError("\'NoneType\' object has no attribute \'strip\'")""", + ), + pytest.param("", 102, "", marks=pytest.mark.xfail(reason="issue#5915")), + ("dataset_1", 102, "Duplicated dataset name in updating dataset."), + ("DATASET_1", 102, "Duplicated dataset name in updating dataset."), + ], + ) + def test_name(self, get_http_api_auth, name, expected_code, expected_message): + ids = create_datasets(get_http_api_auth, 2) + res = update_dataset(get_http_api_auth, ids[0], {"name": name}) + assert res["code"] == expected_code + if expected_code == 0: + res = list_dataset(get_http_api_auth, {"id": ids[0]}) + assert res["data"][0]["name"] == name + else: + assert res["message"] == expected_message + + @pytest.mark.parametrize( + "embedding_model, expected_code, expected_message", + [ + ("BAAI/bge-large-zh-v1.5", 0, ""), + ("BAAI/bge-base-en-v1.5", 0, ""), + ("BAAI/bge-large-en-v1.5", 0, ""), + ("BAAI/bge-small-en-v1.5", 0, ""), + ("BAAI/bge-small-zh-v1.5", 0, ""), + ("jinaai/jina-embeddings-v2-base-en", 0, ""), + ("jinaai/jina-embeddings-v2-small-en", 0, ""), + ("nomic-ai/nomic-embed-text-v1.5", 0, ""), + ("sentence-transformers/all-MiniLM-L6-v2", 0, ""), + ("text-embedding-v2", 0, ""), + ("text-embedding-v3", 0, ""), + ("maidalun1020/bce-embedding-base_v1", 0, ""), + ( + "other_embedding_model", + 102, + "`embedding_model` other_embedding_model doesn't exist", + ), + (None, 102, "`embedding_model` can't be empty"), + ], + ) + def test_embedding_model( + self, get_http_api_auth, embedding_model, expected_code, expected_message + ): + ids = create_datasets(get_http_api_auth, 1) + res = update_dataset( + get_http_api_auth, ids[0], {"embedding_model": embedding_model} + ) + assert res["code"] == expected_code + if expected_code == 0: + res = list_dataset(get_http_api_auth, {"id": ids[0]}) + assert res["data"][0]["embedding_model"] == embedding_model + else: + assert res["message"] == expected_message + + @pytest.mark.parametrize( + "chunk_method, expected_code, expected_message", + [ + ("naive", 0, ""), + ("manual", 0, ""), + ("qa", 0, ""), + ("table", 0, ""), + ("paper", 0, ""), + ("book", 0, ""), + ("laws", 0, ""), + ("presentation", 0, ""), + ("picture", 0, ""), + ("one", 0, ""), + ("knowledge_graph", 0, ""), + ("email", 0, ""), + ("tag", 0, ""), + pytest.param( + "", + 0, + "", + marks=pytest.mark.xfail(reason="issue#5920"), + ), + ( + "other_chunk_method", + 102, + "'other_chunk_method' is not in ['naive', 'manual', 'qa', 'table'," + " 'paper', 'book', 'laws', 'presentation', 'picture', 'one', " + "'knowledge_graph', 'email', 'tag']", + ), + pytest.param( + None, + 0, + "", + marks=pytest.mark.xfail(reason="issue#5920"), + ), + ], + ) + def test_chunk_method( + self, get_http_api_auth, chunk_method, expected_code, expected_message + ): + ids = create_datasets(get_http_api_auth, 1) + res = update_dataset(get_http_api_auth, ids[0], {"chunk_method": chunk_method}) + assert res["code"] == expected_code + if expected_code == 0: + res = list_dataset(get_http_api_auth, {"id": ids[0]}) + if chunk_method != "": + assert res["data"][0]["chunk_method"] == chunk_method + else: + assert res["data"][0]["chunk_method"] == "naive" + else: + assert res["message"] == expected_message + + def test_avatar(self, get_http_api_auth, request): + def encode_avatar(image_path): + with Path.open(image_path, "rb") as file: + binary_data = file.read() + base64_encoded = base64.b64encode(binary_data).decode("utf-8") + return base64_encoded + + ids = create_datasets(get_http_api_auth, 1) + payload = { + "avatar": encode_avatar( + Path(request.config.rootdir) / "test/data/logo.svg" + ), + } + res = update_dataset(get_http_api_auth, ids[0], payload) + assert res["code"] == 0 + + def test_description(self, get_http_api_auth): + ids = create_datasets(get_http_api_auth, 1) + payload = {"description": "description"} + res = update_dataset(get_http_api_auth, ids[0], payload) + assert res["code"] == 0 + + res = list_dataset(get_http_api_auth, {"id": ids[0]}) + assert res["data"][0]["description"] == "description" + + def test_pagerank(self, get_http_api_auth): + ids = create_datasets(get_http_api_auth, 1) + payload = {"pagerank": 1} + res = update_dataset(get_http_api_auth, ids[0], payload) + assert res["code"] == 0 + + res = list_dataset(get_http_api_auth, {"id": ids[0]}) + assert res["data"][0]["pagerank"] == 1 + + def test_similarity_threshold(self, get_http_api_auth): + ids = create_datasets(get_http_api_auth, 1) + payload = {"similarity_threshold": 1} + res = update_dataset(get_http_api_auth, ids[0], payload) + assert res["code"] == 0 + + res = list_dataset(get_http_api_auth, {"id": ids[0]}) + assert res["data"][0]["similarity_threshold"] == 1 + + @pytest.mark.parametrize( + "permission, expected_code", + [ + ("me", 0), + ("team", 0), + pytest.param("", 0, marks=pytest.mark.xfail(reason="issue#5920")), + ("ME", 102), + ("TEAM", 102), + ("other_permission", 102), + ], + ) + def test_permission(self, get_http_api_auth, permission, expected_code): + ids = create_datasets(get_http_api_auth, 1) + payload = {"permission": permission} + res = update_dataset(get_http_api_auth, ids[0], payload) + assert res["code"] == expected_code + + res = list_dataset(get_http_api_auth, {"id": ids[0]}) + if expected_code == 0 and permission != "": + assert res["data"][0]["permission"] == permission + if permission == "": + assert res["data"][0]["permission"] == "me" + + def test_vector_similarity_weight(self, get_http_api_auth): + ids = create_datasets(get_http_api_auth, 1) + payload = {"vector_similarity_weight": 1} + res = update_dataset(get_http_api_auth, ids[0], payload) + assert res["code"] == 0 + + res = list_dataset(get_http_api_auth, {"id": ids[0]}) + assert res["data"][0]["vector_similarity_weight"] == 1 + + def test_invalid_dataset_id(self, get_http_api_auth): + create_datasets(get_http_api_auth, 1) + res = update_dataset( + get_http_api_auth, "invalid_dataset_id", {"name": "invalid_dataset_id"} + ) + assert res["code"] == 102 + assert res["message"] == "You don't own the dataset" + + @pytest.mark.parametrize( + "payload, expected_code, expected_message", + [ + ({"chunk_count": 1}, 102, "Can't change `chunk_count`."), + pytest.param( + {"create_date": "Tue, 11 Mar 2025 13:37:23 GMT"}, + 102, + "", + marks=pytest.mark.xfail(reason="issue#5923"), + ), + pytest.param( + {"create_time": 1741671443322}, + 102, + "", + marks=pytest.mark.xfail(reason="issue#5923"), + ), + pytest.param( + {"created_by": "aa"}, + 102, + "", + marks=pytest.mark.xfail(reason="issue#5923"), + ), + ({"document_count": 1}, 102, "Can't change `document_count`."), + ({"id": "id"}, 102, "The input parameters are invalid."), + pytest.param( + {"status": "1"}, 102, "", marks=pytest.mark.xfail(reason="issue#5923") + ), + ( + {"tenant_id": "e57c1966f99211efb41e9e45646e0111"}, + 102, + "Can't change `tenant_id`.", + ), + pytest.param( + {"token_num": 1}, 102, "", marks=pytest.mark.xfail(reason="issue#5923") + ), + pytest.param( + {"update_date": "Tue, 11 Mar 2025 13:37:23 GMT"}, + 102, + "", + marks=pytest.mark.xfail(reason="issue#5923"), + ), + pytest.param( + {"update_time": 1741671443339}, + 102, + "", + marks=pytest.mark.xfail(reason="issue#5923"), + ), + pytest.param( + {"unknown_field": 0}, + 100, + "", + marks=pytest.mark.xfail(reason="issue#5923"), + ), + ], + ) + def test_modify_unsupported_field( + self, get_http_api_auth, payload, expected_code, expected_message + ): + ids = create_datasets(get_http_api_auth, 1) + res = update_dataset(get_http_api_auth, ids[0], payload) + assert res["code"] == expected_code + assert res["message"] == expected_message + + def test_concurrent_update(self, get_http_api_auth): + ids = create_datasets(get_http_api_auth, 1) + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [ + executor.submit( + update_dataset, get_http_api_auth, ids[0], {"name": f"dataset_{i}"} + ) + for i in range(100) + ] + responses = [f.result() for f in futures] + assert all(r["code"] == 0 for r in responses) + + res = list_dataset(get_http_api_auth, {"id": ids[0]}) + assert res["data"][0]["name"] == "dataset_99"