mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-14 20:15:55 +08:00
TEST: Added test cases for Update Dataset HTTP API (#5924)
### What problem does this PR solve? cover dataset update endpoints ### Type of change - [x] Add test cases
This commit is contained in:
parent
939e668096
commit
87763ef0a0
@ -0,0 +1,334 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import base64
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from common import (
|
||||
DATASET_NAME_LIMIT,
|
||||
INVALID_API_TOKEN,
|
||||
create_datasets,
|
||||
list_dataset,
|
||||
update_dataset,
|
||||
)
|
||||
from libs.auth import RAGFlowHttpApiAuth
|
||||
|
||||
# TODO: Missing scenario for updating embedding_model with chunk_count != 0
|
||||
|
||||
|
||||
class TestAuthorization:
|
||||
@pytest.mark.parametrize(
|
||||
"auth, expected_code, expected_message",
|
||||
[
|
||||
(None, 0, "`Authorization` can't be empty"),
|
||||
(
|
||||
RAGFlowHttpApiAuth(INVALID_API_TOKEN),
|
||||
109,
|
||||
"Authentication error: API key is invalid!",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_invalid_auth(
|
||||
self, get_http_api_auth, auth, expected_code, expected_message
|
||||
):
|
||||
ids = create_datasets(get_http_api_auth, 1)
|
||||
res = update_dataset(auth, ids[0], {"name": "new_name"})
|
||||
assert res["code"] == expected_code
|
||||
assert res["message"] == expected_message
|
||||
|
||||
|
||||
class TestDatasetUpdate:
|
||||
@pytest.mark.parametrize(
|
||||
"name, expected_code, expected_message",
|
||||
[
|
||||
("valid_name", 0, ""),
|
||||
(
|
||||
"a" * (DATASET_NAME_LIMIT + 1),
|
||||
102,
|
||||
"Dataset name should not be longer than 128 characters.",
|
||||
),
|
||||
(0, 100, """AttributeError("\'int\' object has no attribute \'strip\'")"""),
|
||||
(
|
||||
None,
|
||||
100,
|
||||
"""AttributeError("\'NoneType\' object has no attribute \'strip\'")""",
|
||||
),
|
||||
pytest.param("", 102, "", marks=pytest.mark.xfail(reason="issue#5915")),
|
||||
("dataset_1", 102, "Duplicated dataset name in updating dataset."),
|
||||
("DATASET_1", 102, "Duplicated dataset name in updating dataset."),
|
||||
],
|
||||
)
|
||||
def test_name(self, get_http_api_auth, name, expected_code, expected_message):
|
||||
ids = create_datasets(get_http_api_auth, 2)
|
||||
res = update_dataset(get_http_api_auth, ids[0], {"name": name})
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
res = list_dataset(get_http_api_auth, {"id": ids[0]})
|
||||
assert res["data"][0]["name"] == name
|
||||
else:
|
||||
assert res["message"] == expected_message
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"embedding_model, expected_code, expected_message",
|
||||
[
|
||||
("BAAI/bge-large-zh-v1.5", 0, ""),
|
||||
("BAAI/bge-base-en-v1.5", 0, ""),
|
||||
("BAAI/bge-large-en-v1.5", 0, ""),
|
||||
("BAAI/bge-small-en-v1.5", 0, ""),
|
||||
("BAAI/bge-small-zh-v1.5", 0, ""),
|
||||
("jinaai/jina-embeddings-v2-base-en", 0, ""),
|
||||
("jinaai/jina-embeddings-v2-small-en", 0, ""),
|
||||
("nomic-ai/nomic-embed-text-v1.5", 0, ""),
|
||||
("sentence-transformers/all-MiniLM-L6-v2", 0, ""),
|
||||
("text-embedding-v2", 0, ""),
|
||||
("text-embedding-v3", 0, ""),
|
||||
("maidalun1020/bce-embedding-base_v1", 0, ""),
|
||||
(
|
||||
"other_embedding_model",
|
||||
102,
|
||||
"`embedding_model` other_embedding_model doesn't exist",
|
||||
),
|
||||
(None, 102, "`embedding_model` can't be empty"),
|
||||
],
|
||||
)
|
||||
def test_embedding_model(
|
||||
self, get_http_api_auth, embedding_model, expected_code, expected_message
|
||||
):
|
||||
ids = create_datasets(get_http_api_auth, 1)
|
||||
res = update_dataset(
|
||||
get_http_api_auth, ids[0], {"embedding_model": embedding_model}
|
||||
)
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
res = list_dataset(get_http_api_auth, {"id": ids[0]})
|
||||
assert res["data"][0]["embedding_model"] == embedding_model
|
||||
else:
|
||||
assert res["message"] == expected_message
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"chunk_method, expected_code, expected_message",
|
||||
[
|
||||
("naive", 0, ""),
|
||||
("manual", 0, ""),
|
||||
("qa", 0, ""),
|
||||
("table", 0, ""),
|
||||
("paper", 0, ""),
|
||||
("book", 0, ""),
|
||||
("laws", 0, ""),
|
||||
("presentation", 0, ""),
|
||||
("picture", 0, ""),
|
||||
("one", 0, ""),
|
||||
("knowledge_graph", 0, ""),
|
||||
("email", 0, ""),
|
||||
("tag", 0, ""),
|
||||
pytest.param(
|
||||
"",
|
||||
0,
|
||||
"",
|
||||
marks=pytest.mark.xfail(reason="issue#5920"),
|
||||
),
|
||||
(
|
||||
"other_chunk_method",
|
||||
102,
|
||||
"'other_chunk_method' is not in ['naive', 'manual', 'qa', 'table',"
|
||||
" 'paper', 'book', 'laws', 'presentation', 'picture', 'one', "
|
||||
"'knowledge_graph', 'email', 'tag']",
|
||||
),
|
||||
pytest.param(
|
||||
None,
|
||||
0,
|
||||
"",
|
||||
marks=pytest.mark.xfail(reason="issue#5920"),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_chunk_method(
|
||||
self, get_http_api_auth, chunk_method, expected_code, expected_message
|
||||
):
|
||||
ids = create_datasets(get_http_api_auth, 1)
|
||||
res = update_dataset(get_http_api_auth, ids[0], {"chunk_method": chunk_method})
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
res = list_dataset(get_http_api_auth, {"id": ids[0]})
|
||||
if chunk_method != "":
|
||||
assert res["data"][0]["chunk_method"] == chunk_method
|
||||
else:
|
||||
assert res["data"][0]["chunk_method"] == "naive"
|
||||
else:
|
||||
assert res["message"] == expected_message
|
||||
|
||||
def test_avatar(self, get_http_api_auth, request):
|
||||
def encode_avatar(image_path):
|
||||
with Path.open(image_path, "rb") as file:
|
||||
binary_data = file.read()
|
||||
base64_encoded = base64.b64encode(binary_data).decode("utf-8")
|
||||
return base64_encoded
|
||||
|
||||
ids = create_datasets(get_http_api_auth, 1)
|
||||
payload = {
|
||||
"avatar": encode_avatar(
|
||||
Path(request.config.rootdir) / "test/data/logo.svg"
|
||||
),
|
||||
}
|
||||
res = update_dataset(get_http_api_auth, ids[0], payload)
|
||||
assert res["code"] == 0
|
||||
|
||||
def test_description(self, get_http_api_auth):
|
||||
ids = create_datasets(get_http_api_auth, 1)
|
||||
payload = {"description": "description"}
|
||||
res = update_dataset(get_http_api_auth, ids[0], payload)
|
||||
assert res["code"] == 0
|
||||
|
||||
res = list_dataset(get_http_api_auth, {"id": ids[0]})
|
||||
assert res["data"][0]["description"] == "description"
|
||||
|
||||
def test_pagerank(self, get_http_api_auth):
|
||||
ids = create_datasets(get_http_api_auth, 1)
|
||||
payload = {"pagerank": 1}
|
||||
res = update_dataset(get_http_api_auth, ids[0], payload)
|
||||
assert res["code"] == 0
|
||||
|
||||
res = list_dataset(get_http_api_auth, {"id": ids[0]})
|
||||
assert res["data"][0]["pagerank"] == 1
|
||||
|
||||
def test_similarity_threshold(self, get_http_api_auth):
|
||||
ids = create_datasets(get_http_api_auth, 1)
|
||||
payload = {"similarity_threshold": 1}
|
||||
res = update_dataset(get_http_api_auth, ids[0], payload)
|
||||
assert res["code"] == 0
|
||||
|
||||
res = list_dataset(get_http_api_auth, {"id": ids[0]})
|
||||
assert res["data"][0]["similarity_threshold"] == 1
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"permission, expected_code",
|
||||
[
|
||||
("me", 0),
|
||||
("team", 0),
|
||||
pytest.param("", 0, marks=pytest.mark.xfail(reason="issue#5920")),
|
||||
("ME", 102),
|
||||
("TEAM", 102),
|
||||
("other_permission", 102),
|
||||
],
|
||||
)
|
||||
def test_permission(self, get_http_api_auth, permission, expected_code):
|
||||
ids = create_datasets(get_http_api_auth, 1)
|
||||
payload = {"permission": permission}
|
||||
res = update_dataset(get_http_api_auth, ids[0], payload)
|
||||
assert res["code"] == expected_code
|
||||
|
||||
res = list_dataset(get_http_api_auth, {"id": ids[0]})
|
||||
if expected_code == 0 and permission != "":
|
||||
assert res["data"][0]["permission"] == permission
|
||||
if permission == "":
|
||||
assert res["data"][0]["permission"] == "me"
|
||||
|
||||
def test_vector_similarity_weight(self, get_http_api_auth):
|
||||
ids = create_datasets(get_http_api_auth, 1)
|
||||
payload = {"vector_similarity_weight": 1}
|
||||
res = update_dataset(get_http_api_auth, ids[0], payload)
|
||||
assert res["code"] == 0
|
||||
|
||||
res = list_dataset(get_http_api_auth, {"id": ids[0]})
|
||||
assert res["data"][0]["vector_similarity_weight"] == 1
|
||||
|
||||
def test_invalid_dataset_id(self, get_http_api_auth):
|
||||
create_datasets(get_http_api_auth, 1)
|
||||
res = update_dataset(
|
||||
get_http_api_auth, "invalid_dataset_id", {"name": "invalid_dataset_id"}
|
||||
)
|
||||
assert res["code"] == 102
|
||||
assert res["message"] == "You don't own the dataset"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_code, expected_message",
|
||||
[
|
||||
({"chunk_count": 1}, 102, "Can't change `chunk_count`."),
|
||||
pytest.param(
|
||||
{"create_date": "Tue, 11 Mar 2025 13:37:23 GMT"},
|
||||
102,
|
||||
"",
|
||||
marks=pytest.mark.xfail(reason="issue#5923"),
|
||||
),
|
||||
pytest.param(
|
||||
{"create_time": 1741671443322},
|
||||
102,
|
||||
"",
|
||||
marks=pytest.mark.xfail(reason="issue#5923"),
|
||||
),
|
||||
pytest.param(
|
||||
{"created_by": "aa"},
|
||||
102,
|
||||
"",
|
||||
marks=pytest.mark.xfail(reason="issue#5923"),
|
||||
),
|
||||
({"document_count": 1}, 102, "Can't change `document_count`."),
|
||||
({"id": "id"}, 102, "The input parameters are invalid."),
|
||||
pytest.param(
|
||||
{"status": "1"}, 102, "", marks=pytest.mark.xfail(reason="issue#5923")
|
||||
),
|
||||
(
|
||||
{"tenant_id": "e57c1966f99211efb41e9e45646e0111"},
|
||||
102,
|
||||
"Can't change `tenant_id`.",
|
||||
),
|
||||
pytest.param(
|
||||
{"token_num": 1}, 102, "", marks=pytest.mark.xfail(reason="issue#5923")
|
||||
),
|
||||
pytest.param(
|
||||
{"update_date": "Tue, 11 Mar 2025 13:37:23 GMT"},
|
||||
102,
|
||||
"",
|
||||
marks=pytest.mark.xfail(reason="issue#5923"),
|
||||
),
|
||||
pytest.param(
|
||||
{"update_time": 1741671443339},
|
||||
102,
|
||||
"",
|
||||
marks=pytest.mark.xfail(reason="issue#5923"),
|
||||
),
|
||||
pytest.param(
|
||||
{"unknown_field": 0},
|
||||
100,
|
||||
"",
|
||||
marks=pytest.mark.xfail(reason="issue#5923"),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_modify_unsupported_field(
|
||||
self, get_http_api_auth, payload, expected_code, expected_message
|
||||
):
|
||||
ids = create_datasets(get_http_api_auth, 1)
|
||||
res = update_dataset(get_http_api_auth, ids[0], payload)
|
||||
assert res["code"] == expected_code
|
||||
assert res["message"] == expected_message
|
||||
|
||||
def test_concurrent_update(self, get_http_api_auth):
|
||||
ids = create_datasets(get_http_api_auth, 1)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = [
|
||||
executor.submit(
|
||||
update_dataset, get_http_api_auth, ids[0], {"name": f"dataset_{i}"}
|
||||
)
|
||||
for i in range(100)
|
||||
]
|
||||
responses = [f.result() for f in futures]
|
||||
assert all(r["code"] == 0 for r in responses)
|
||||
|
||||
res = list_dataset(get_http_api_auth, {"id": ids[0]})
|
||||
assert res["data"][0]["name"] == "dataset_99"
|
Loading…
x
Reference in New Issue
Block a user