Test: Update test cases to reduce execution time (#6470)

### What problem does this PR solve?

_Briefly describe what this PR aims to solve. Include background context
that will help reviewers understand the purpose of the PR._

### Type of change

- [x] update test cases
This commit is contained in:
liu an 2025-03-25 09:17:05 +08:00 committed by GitHub
parent 390086c6ab
commit b6f3242c6c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 704 additions and 695 deletions

View File

@ -74,7 +74,7 @@ def delete_dataset(auth, payload=None):
return res.json()
def create_datasets(auth, num):
def batch_create_datasets(auth, num):
ids = []
for i in range(num):
res = create_dataset(auth, {"name": f"dataset_{i}"})
@ -111,18 +111,6 @@ def upload_documnets(auth, dataset_id, files_path=None):
f.close()
def batch_upload_documents(auth, dataset_id, num, tmp_path):
fps = []
for i in range(num):
fp = create_txt_file(tmp_path / f"ragflow_test_upload_{i}.txt")
fps.append(fp)
res = upload_documnets(auth, dataset_id, fps)
document_ids = []
for document in res["data"]:
document_ids.append(document["id"])
return document_ids
def download_document(auth, dataset_id, document_id, save_path):
url = f"{HOST_ADDRESS}{FILE_API_URL}/{document_id}".format(dataset_id=dataset_id)
res = requests.get(url=url, auth=auth, stream=True)
@ -172,8 +160,39 @@ def stop_parse_documnet(auth, dataset_id, payload=None):
return res.json()
def bulk_upload_documents(auth, dataset_id, num, tmp_path):
fps = []
for i in range(num):
fp = create_txt_file(tmp_path / f"ragflow_test_upload_{i}.txt")
fps.append(fp)
res = upload_documnets(auth, dataset_id, fps)
document_ids = []
for document in res["data"]:
document_ids.append(document["id"])
return document_ids
# CHUNK MANAGEMENT WITHIN DATASET
def add_chunk(auth, dataset_id, document_id, payload=None):
url = f"{HOST_ADDRESS}{CHUNK_API_URL}".format(dataset_id=dataset_id, document_id=document_id)
res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload)
return res.json()
def list_chunks(auth, dataset_id, document_id, params=None):
url = f"{HOST_ADDRESS}{CHUNK_API_URL}".format(dataset_id=dataset_id, document_id=document_id)
res = requests.get(
url=url,
headers=HEADERS,
auth=auth,
params=params,
)
return res.json()
def batch_add_chunks(auth, dataset_id, document_id, num):
ids = []
for i in range(num):
res = add_chunk(auth, dataset_id, document_id, {"content": f"ragflow test {i}"})
ids.append(res["data"]["chunk"]["id"])
return ids

View File

@ -31,7 +31,7 @@ from libs.utils.file_utils import (
)
@pytest.fixture(scope="function", autouse=True)
@pytest.fixture(scope="function")
def clear_datasets(get_http_api_auth):
yield
delete_dataset(get_http_api_auth)

View File

@ -0,0 +1,45 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pytest
from common import batch_create_datasets, bulk_upload_documents, delete_dataset, list_documnet, parse_documnet
from libs.utils import wait_for
@wait_for(10, 1, "Document parsing timeout")
def condition(_auth, _dataset_id):
res = list_documnet(_auth, _dataset_id)
for doc in res["data"]["docs"]:
if doc["run"] != "DONE":
return False
return True
@pytest.fixture(scope="class")
def chunk_management_tmp_dir(tmp_path_factory):
return tmp_path_factory.mktemp("chunk_management")
@pytest.fixture(scope="class")
def get_dataset_id_and_document_id(get_http_api_auth, chunk_management_tmp_dir):
dataset_ids = batch_create_datasets(get_http_api_auth, 1)
document_ids = bulk_upload_documents(get_http_api_auth, dataset_ids[0], 1, chunk_management_tmp_dir)
parse_documnet(get_http_api_auth, dataset_ids[0], {"document_ids": document_ids})
condition(get_http_api_auth, dataset_ids[0])
yield dataset_ids[0], document_ids[0]
delete_dataset(get_http_api_auth)

View File

@ -16,7 +16,7 @@
from concurrent.futures import ThreadPoolExecutor
import pytest
from common import INVALID_API_TOKEN, add_chunk, batch_upload_documents, create_datasets, delete_documnet
from common import INVALID_API_TOKEN, add_chunk, delete_documnet
from libs.auth import RAGFlowHttpApiAuth
@ -43,11 +43,9 @@ class TestAuthorization:
),
],
)
def test_invalid_auth(self, get_http_api_auth, tmp_path, auth, expected_code, expected_message):
ids = create_datasets(get_http_api_auth, 1)
dataset_id = ids[0]
document_ids = batch_upload_documents(get_http_api_auth, dataset_id, 1, tmp_path)
res = add_chunk(auth, dataset_id, document_ids[0], {})
def test_invalid_auth(self, get_dataset_id_and_document_id, auth, expected_code, expected_message):
dataset_id, document_id = get_dataset_id_and_document_id
res = add_chunk(auth, dataset_id, document_id, {})
assert res["code"] == expected_code
assert res["message"] == expected_message
@ -69,14 +67,12 @@ class TestAddChunk:
({"content": "\n!?。;!?\"'"}, 0, ""),
],
)
def test_content(self, get_http_api_auth, tmp_path, payload, expected_code, expected_message):
ids = create_datasets(get_http_api_auth, 1)
dataset_id = ids[0]
document_ids = batch_upload_documents(get_http_api_auth, dataset_id, 1, tmp_path)
res = add_chunk(get_http_api_auth, dataset_id, document_ids[0], payload)
def test_content(self, get_http_api_auth, get_dataset_id_and_document_id, payload, expected_code, expected_message):
dataset_id, document_id = get_dataset_id_and_document_id
res = add_chunk(get_http_api_auth, dataset_id, document_id, payload)
assert res["code"] == expected_code
if expected_code == 0:
validate_chunk_details(dataset_id, document_ids[0], payload, res)
validate_chunk_details(dataset_id, document_id, payload, res)
else:
assert res["message"] == expected_message
@ -95,14 +91,12 @@ class TestAddChunk:
({"content": "a", "important_keywords": 123}, 102, "`important_keywords` is required to be a list"),
],
)
def test_important_keywords(self, get_http_api_auth, tmp_path, payload, expected_code, expected_message):
ids = create_datasets(get_http_api_auth, 1)
dataset_id = ids[0]
document_ids = batch_upload_documents(get_http_api_auth, dataset_id, 1, tmp_path)
res = add_chunk(get_http_api_auth, dataset_id, document_ids[0], payload)
def test_important_keywords(self, get_http_api_auth, get_dataset_id_and_document_id, payload, expected_code, expected_message):
dataset_id, document_id = get_dataset_id_and_document_id
res = add_chunk(get_http_api_auth, dataset_id, document_id, payload)
assert res["code"] == expected_code
if expected_code == 0:
validate_chunk_details(dataset_id, document_ids[0], payload, res)
validate_chunk_details(dataset_id, document_id, payload, res)
else:
assert res["message"] == expected_message
@ -122,14 +116,12 @@ class TestAddChunk:
({"content": "a", "questions": 123}, 102, "`questions` is required to be a list"),
],
)
def test_questions(self, get_http_api_auth, tmp_path, payload, expected_code, expected_message):
ids = create_datasets(get_http_api_auth, 1)
dataset_id = ids[0]
document_ids = batch_upload_documents(get_http_api_auth, dataset_id, 1, tmp_path)
res = add_chunk(get_http_api_auth, dataset_id, document_ids[0], payload)
def test_questions(self, get_http_api_auth, get_dataset_id_and_document_id, payload, expected_code, expected_message):
dataset_id, document_id = get_dataset_id_and_document_id
res = add_chunk(get_http_api_auth, dataset_id, document_id, payload)
assert res["code"] == expected_code
if expected_code == 0:
validate_chunk_details(dataset_id, document_ids[0], payload, res)
validate_chunk_details(dataset_id, document_id, payload, res)
else:
assert res["message"] == expected_message
@ -147,14 +139,13 @@ class TestAddChunk:
def test_invalid_dataset_id(
self,
get_http_api_auth,
tmp_path,
get_dataset_id_and_document_id,
dataset_id,
expected_code,
expected_message,
):
ids = create_datasets(get_http_api_auth, 1)
document_ids = batch_upload_documents(get_http_api_auth, ids[0], 1, tmp_path)
res = add_chunk(get_http_api_auth, dataset_id, document_ids[0], {"content": "a"})
_, document_id = get_dataset_id_and_document_id
res = add_chunk(get_http_api_auth, dataset_id, document_id, {"content": "a"})
assert res["code"] == expected_code
assert res["message"] == expected_message
@ -169,49 +160,42 @@ class TestAddChunk:
),
],
)
def test_invalid_document_id(self, get_http_api_auth, document_id, expected_code, expected_message):
ids = create_datasets(get_http_api_auth, 1)
res = add_chunk(get_http_api_auth, ids[0], document_id, {"content": "a"})
def test_invalid_document_id(self, get_http_api_auth, get_dataset_id_and_document_id, document_id, expected_code, expected_message):
dataset_id, _ = get_dataset_id_and_document_id
res = add_chunk(get_http_api_auth, dataset_id, document_id, {"content": "a"})
assert res["code"] == expected_code
assert res["message"] == expected_message
def test_repeated_add_chunk(self, get_http_api_auth, tmp_path):
def test_repeated_add_chunk(self, get_http_api_auth, get_dataset_id_and_document_id):
payload = {"content": "a"}
ids = create_datasets(get_http_api_auth, 1)
dataset_id = ids[0]
document_ids = batch_upload_documents(get_http_api_auth, dataset_id, 1, tmp_path)
res = add_chunk(get_http_api_auth, dataset_id, document_ids[0], payload)
dataset_id, document_id = get_dataset_id_and_document_id
res = add_chunk(get_http_api_auth, dataset_id, document_id, payload)
assert res["code"] == 0
validate_chunk_details(dataset_id, document_ids[0], payload, res)
validate_chunk_details(dataset_id, document_id, payload, res)
res = add_chunk(get_http_api_auth, dataset_id, document_ids[0], payload)
res = add_chunk(get_http_api_auth, dataset_id, document_id, payload)
assert res["code"] == 0
validate_chunk_details(dataset_id, document_ids[0], payload, res)
validate_chunk_details(dataset_id, document_id, payload, res)
def test_add_chunk_to_deleted_document(self, get_http_api_auth, tmp_path):
ids = create_datasets(get_http_api_auth, 1)
dataset_id = ids[0]
document_ids = batch_upload_documents(get_http_api_auth, dataset_id, 1, tmp_path)
delete_documnet(get_http_api_auth, ids[0], {"ids": document_ids})
res = add_chunk(get_http_api_auth, dataset_id, document_ids[0], {"content": "a"})
def test_add_chunk_to_deleted_document(self, get_http_api_auth, get_dataset_id_and_document_id):
dataset_id, document_id = get_dataset_id_and_document_id
delete_documnet(get_http_api_auth, dataset_id, {"ids": [document_id]})
res = add_chunk(get_http_api_auth, dataset_id, document_id, {"content": "a"})
assert res["code"] == 102
assert res["message"] == f"You don't own the document {document_ids[0]}."
assert res["message"] == f"You don't own the document {document_id}."
@pytest.mark.skip(reason="issues/6411")
def test_concurrent_add_chunk(self, get_http_api_auth, tmp_path):
def test_concurrent_add_chunk(self, get_http_api_auth, get_dataset_id_and_document_id):
chunk_num = 50
ids = create_datasets(get_http_api_auth, 1)
dataset_id = ids[0]
document_ids = batch_upload_documents(get_http_api_auth, dataset_id, 1, tmp_path)
dataset_id, document_id = get_dataset_id_and_document_id
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [
executor.submit(
add_chunk,
get_http_api_auth,
ids[0],
document_ids[0],
dataset_id,
document_id,
{"content": "a"},
)
for i in range(chunk_num)

View File

@ -0,0 +1,26 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pytest
from common import batch_create_datasets, delete_dataset
@pytest.fixture(scope="class")
def get_dataset_ids(get_http_api_auth):
ids = batch_create_datasets(get_http_api_auth, 5)
yield ids
delete_dataset(get_http_api_auth)

View File

@ -21,6 +21,7 @@ from libs.utils import encode_avatar
from libs.utils.file_utils import create_image_file
@pytest.mark.usefixtures("clear_datasets")
class TestAuthorization:
@pytest.mark.parametrize(
"auth, expected_code, expected_message",
@ -39,6 +40,7 @@ class TestAuthorization:
assert res["message"] == expected_message
@pytest.mark.usefixtures("clear_datasets")
class TestDatasetCreation:
@pytest.mark.parametrize(
"payload, expected_code",
@ -74,6 +76,7 @@ class TestDatasetCreation:
assert res["code"] == 0, f"Failed to create dataset {i}"
@pytest.mark.usefixtures("clear_datasets")
class TestAdvancedConfigurations:
def test_avatar(self, get_http_api_auth, tmp_path):
fn = create_image_file(tmp_path / "ragflow_test.png")
@ -172,9 +175,7 @@ class TestAdvancedConfigurations:
("other_embedding_model", "other_embedding_model", 102),
],
)
def test_embedding_model(
self, get_http_api_auth, name, embedding_model, expected_code
):
def test_embedding_model(self, get_http_api_auth, name, embedding_model, expected_code):
payload = {"name": name, "embedding_model": embedding_model}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == expected_code

View File

@ -19,13 +19,14 @@ from concurrent.futures import ThreadPoolExecutor
import pytest
from common import (
INVALID_API_TOKEN,
create_datasets,
batch_create_datasets,
delete_dataset,
list_dataset,
)
from libs.auth import RAGFlowHttpApiAuth
@pytest.mark.usefixtures("clear_datasets")
class TestAuthorization:
@pytest.mark.parametrize(
"auth, expected_code, expected_message",
@ -39,7 +40,7 @@ class TestAuthorization:
],
)
def test_invalid_auth(self, get_http_api_auth, auth, expected_code, expected_message):
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
res = delete_dataset(auth, {"ids": ids})
assert res["code"] == expected_code
assert res["message"] == expected_message
@ -48,6 +49,7 @@ class TestAuthorization:
assert len(res["data"]) == 1
@pytest.mark.usefixtures("clear_datasets")
class TestDatasetDeletion:
@pytest.mark.parametrize(
"payload, expected_code, expected_message, remaining",
@ -72,7 +74,7 @@ class TestDatasetDeletion:
],
)
def test_basic_scenarios(self, get_http_api_auth, payload, expected_code, expected_message, remaining):
ids = create_datasets(get_http_api_auth, 3)
ids = batch_create_datasets(get_http_api_auth, 3)
if callable(payload):
payload = payload(ids)
res = delete_dataset(get_http_api_auth, payload)
@ -92,7 +94,7 @@ class TestDatasetDeletion:
],
)
def test_delete_partial_invalid_id(self, get_http_api_auth, payload):
ids = create_datasets(get_http_api_auth, 3)
ids = batch_create_datasets(get_http_api_auth, 3)
if callable(payload):
payload = payload(ids)
res = delete_dataset(get_http_api_auth, payload)
@ -104,7 +106,7 @@ class TestDatasetDeletion:
assert len(res["data"]) == 0
def test_repeated_deletion(self, get_http_api_auth):
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
res = delete_dataset(get_http_api_auth, {"ids": ids})
assert res["code"] == 0
@ -113,7 +115,7 @@ class TestDatasetDeletion:
assert res["message"] == f"You don't own the dataset {ids[0]}"
def test_duplicate_deletion(self, get_http_api_auth):
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
res = delete_dataset(get_http_api_auth, {"ids": ids + ids})
assert res["code"] == 0
assert res["data"]["errors"][0] == f"Duplicate dataset ids: {ids[0]}"
@ -123,7 +125,7 @@ class TestDatasetDeletion:
assert len(res["data"]) == 0
def test_concurrent_deletion(self, get_http_api_auth):
ids = create_datasets(get_http_api_auth, 100)
ids = batch_create_datasets(get_http_api_auth, 100)
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(delete_dataset, get_http_api_auth, {"ids": ids[i : i + 1]}) for i in range(100)]
@ -132,7 +134,7 @@ class TestDatasetDeletion:
@pytest.mark.slow
def test_delete_10k(self, get_http_api_auth):
ids = create_datasets(get_http_api_auth, 10_000)
ids = batch_create_datasets(get_http_api_auth, 10_000)
res = delete_dataset(get_http_api_auth, {"ids": ids})
assert res["code"] == 0

View File

@ -16,19 +16,16 @@
from concurrent.futures import ThreadPoolExecutor
import pytest
from common import INVALID_API_TOKEN, create_datasets, list_dataset
from common import INVALID_API_TOKEN, list_dataset
from libs.auth import RAGFlowHttpApiAuth
def is_sorted(data, field, descending=True):
timestamps = [ds[field] for ds in data]
return (
all(a >= b for a, b in zip(timestamps, timestamps[1:]))
if descending
else all(a <= b for a, b in zip(timestamps, timestamps[1:]))
)
return all(a >= b for a, b in zip(timestamps, timestamps[1:])) if descending else all(a <= b for a, b in zip(timestamps, timestamps[1:]))
@pytest.mark.usefixtures("clear_datasets")
class TestAuthorization:
@pytest.mark.parametrize(
"auth, expected_code, expected_message",
@ -47,13 +44,13 @@ class TestAuthorization:
assert res["message"] == expected_message
@pytest.mark.usefixtures("get_dataset_ids")
class TestDatasetList:
def test_default(self, get_http_api_auth):
create_datasets(get_http_api_auth, 31)
res = list_dataset(get_http_api_auth, params={})
assert res["code"] == 0
assert len(res["data"]) == 30
assert len(res["data"]) == 5
@pytest.mark.parametrize(
"params, expected_code, expected_page_size, expected_message",
@ -79,15 +76,7 @@ class TestDatasetList:
),
],
)
def test_page(
self,
get_http_api_auth,
params,
expected_code,
expected_page_size,
expected_message,
):
create_datasets(get_http_api_auth, 5)
def test_page(self, get_http_api_auth, params, expected_code, expected_page_size, expected_message):
res = list_dataset(get_http_api_auth, params=params)
assert res["code"] == expected_code
if expected_code == 0:
@ -98,10 +87,10 @@ class TestDatasetList:
@pytest.mark.parametrize(
"params, expected_code, expected_page_size, expected_message",
[
({"page_size": None}, 0, 30, ""),
({"page_size": None}, 0, 5, ""),
({"page_size": 0}, 0, 0, ""),
({"page_size": 1}, 0, 1, ""),
({"page_size": 32}, 0, 31, ""),
({"page_size": 6}, 0, 5, ""),
({"page_size": "1"}, 0, 1, ""),
pytest.param(
{"page_size": -1},
@ -127,7 +116,6 @@ class TestDatasetList:
expected_page_size,
expected_message,
):
create_datasets(get_http_api_auth, 31)
res = list_dataset(get_http_api_auth, params=params)
assert res["code"] == expected_code
if expected_code == 0:
@ -180,7 +168,6 @@ class TestDatasetList:
assertions,
expected_message,
):
create_datasets(get_http_api_auth, 3)
res = list_dataset(get_http_api_auth, params=params)
assert res["code"] == expected_code
if expected_code == 0:
@ -257,7 +244,6 @@ class TestDatasetList:
assertions,
expected_message,
):
create_datasets(get_http_api_auth, 3)
res = list_dataset(get_http_api_auth, params=params)
assert res["code"] == expected_code
if expected_code == 0:
@ -269,16 +255,13 @@ class TestDatasetList:
@pytest.mark.parametrize(
"params, expected_code, expected_num, expected_message",
[
({"name": None}, 0, 3, ""),
({"name": ""}, 0, 3, ""),
({"name": None}, 0, 5, ""),
({"name": ""}, 0, 5, ""),
({"name": "dataset_1"}, 0, 1, ""),
({"name": "unknown"}, 102, 0, "You don't own the dataset unknown"),
],
)
def test_name(
self, get_http_api_auth, params, expected_code, expected_num, expected_message
):
create_datasets(get_http_api_auth, 3)
def test_name(self, get_http_api_auth, params, expected_code, expected_num, expected_message):
res = list_dataset(get_http_api_auth, params=params)
assert res["code"] == expected_code
if expected_code == 0:
@ -292,8 +275,8 @@ class TestDatasetList:
@pytest.mark.parametrize(
"dataset_id, expected_code, expected_num, expected_message",
[
(None, 0, 3, ""),
("", 0, 3, ""),
(None, 0, 5, ""),
("", 0, 5, ""),
(lambda r: r[0], 0, 1, ""),
("unknown", 102, 0, "You don't own the dataset unknown"),
],
@ -301,14 +284,15 @@ class TestDatasetList:
def test_id(
self,
get_http_api_auth,
get_dataset_ids,
dataset_id,
expected_code,
expected_num,
expected_message,
):
ids = create_datasets(get_http_api_auth, 3)
dataset_ids = get_dataset_ids
if callable(dataset_id):
params = {"id": dataset_id(ids)}
params = {"id": dataset_id(dataset_ids)}
else:
params = {"id": dataset_id}
@ -334,15 +318,16 @@ class TestDatasetList:
def test_name_and_id(
self,
get_http_api_auth,
get_dataset_ids,
dataset_id,
name,
expected_code,
expected_num,
expected_message,
):
ids = create_datasets(get_http_api_auth, 3)
dataset_ids = get_dataset_ids
if callable(dataset_id):
params = {"id": dataset_id(ids), "name": name}
params = {"id": dataset_id(dataset_ids), "name": name}
else:
params = {"id": dataset_id, "name": name}
@ -353,12 +338,8 @@ class TestDatasetList:
assert res["message"] == expected_message
def test_concurrent_list(self, get_http_api_auth):
create_datasets(get_http_api_auth, 3)
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [
executor.submit(list_dataset, get_http_api_auth) for i in range(100)
]
futures = [executor.submit(list_dataset, get_http_api_auth) for i in range(100)]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses)
@ -366,4 +347,4 @@ class TestDatasetList:
params = {"a": "b"}
res = list_dataset(get_http_api_auth, params=params)
assert res["code"] == 0
assert len(res["data"]) == 0
assert len(res["data"]) == 5

View File

@ -19,7 +19,7 @@ import pytest
from common import (
DATASET_NAME_LIMIT,
INVALID_API_TOKEN,
create_datasets,
batch_create_datasets,
list_dataset,
update_dataset,
)
@ -30,6 +30,7 @@ from libs.utils.file_utils import create_image_file
# TODO: Missing scenario for updating embedding_model with chunk_count != 0
@pytest.mark.usefixtures("clear_datasets")
class TestAuthorization:
@pytest.mark.parametrize(
"auth, expected_code, expected_message",
@ -42,15 +43,14 @@ class TestAuthorization:
),
],
)
def test_invalid_auth(
self, get_http_api_auth, auth, expected_code, expected_message
):
ids = create_datasets(get_http_api_auth, 1)
def test_invalid_auth(self, get_http_api_auth, auth, expected_code, expected_message):
ids = batch_create_datasets(get_http_api_auth, 1)
res = update_dataset(auth, ids[0], {"name": "new_name"})
assert res["code"] == expected_code
assert res["message"] == expected_message
@pytest.mark.usefixtures("clear_datasets")
class TestDatasetUpdate:
@pytest.mark.parametrize(
"name, expected_code, expected_message",
@ -73,7 +73,7 @@ class TestDatasetUpdate:
],
)
def test_name(self, get_http_api_auth, name, expected_code, expected_message):
ids = create_datasets(get_http_api_auth, 2)
ids = batch_create_datasets(get_http_api_auth, 2)
res = update_dataset(get_http_api_auth, ids[0], {"name": name})
assert res["code"] == expected_code
if expected_code == 0:
@ -105,13 +105,9 @@ class TestDatasetUpdate:
(None, 102, "`embedding_model` can't be empty"),
],
)
def test_embedding_model(
self, get_http_api_auth, embedding_model, expected_code, expected_message
):
ids = create_datasets(get_http_api_auth, 1)
res = update_dataset(
get_http_api_auth, ids[0], {"embedding_model": embedding_model}
)
def test_embedding_model(self, get_http_api_auth, embedding_model, expected_code, expected_message):
ids = batch_create_datasets(get_http_api_auth, 1)
res = update_dataset(get_http_api_auth, ids[0], {"embedding_model": embedding_model})
assert res["code"] == expected_code
if expected_code == 0:
res = list_dataset(get_http_api_auth, {"id": ids[0]})
@ -139,16 +135,12 @@ class TestDatasetUpdate:
(
"other_chunk_method",
102,
"'other_chunk_method' is not in ['naive', 'manual', 'qa', 'table',"
" 'paper', 'book', 'laws', 'presentation', 'picture', 'one', "
"'knowledge_graph', 'email', 'tag']",
"'other_chunk_method' is not in ['naive', 'manual', 'qa', 'table', 'paper', 'book', 'laws', 'presentation', 'picture', 'one', 'knowledge_graph', 'email', 'tag']",
),
],
)
def test_chunk_method(
self, get_http_api_auth, chunk_method, expected_code, expected_message
):
ids = create_datasets(get_http_api_auth, 1)
def test_chunk_method(self, get_http_api_auth, chunk_method, expected_code, expected_message):
ids = batch_create_datasets(get_http_api_auth, 1)
res = update_dataset(get_http_api_auth, ids[0], {"chunk_method": chunk_method})
assert res["code"] == expected_code
if expected_code == 0:
@ -161,14 +153,14 @@ class TestDatasetUpdate:
assert res["message"] == expected_message
def test_avatar(self, get_http_api_auth, tmp_path):
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
fn = create_image_file(tmp_path / "ragflow_test.png")
payload = {"avatar": encode_avatar(fn)}
res = update_dataset(get_http_api_auth, ids[0], payload)
assert res["code"] == 0
def test_description(self, get_http_api_auth):
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
payload = {"description": "description"}
res = update_dataset(get_http_api_auth, ids[0], payload)
assert res["code"] == 0
@ -177,7 +169,7 @@ class TestDatasetUpdate:
assert res["data"][0]["description"] == "description"
def test_pagerank(self, get_http_api_auth):
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
payload = {"pagerank": 1}
res = update_dataset(get_http_api_auth, ids[0], payload)
assert res["code"] == 0
@ -186,7 +178,7 @@ class TestDatasetUpdate:
assert res["data"][0]["pagerank"] == 1
def test_similarity_threshold(self, get_http_api_auth):
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
payload = {"similarity_threshold": 1}
res = update_dataset(get_http_api_auth, ids[0], payload)
assert res["code"] == 0
@ -206,7 +198,7 @@ class TestDatasetUpdate:
],
)
def test_permission(self, get_http_api_auth, permission, expected_code):
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
payload = {"permission": permission}
res = update_dataset(get_http_api_auth, ids[0], payload)
assert res["code"] == expected_code
@ -218,7 +210,7 @@ class TestDatasetUpdate:
assert res["data"][0]["permission"] == "me"
def test_vector_similarity_weight(self, get_http_api_auth):
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
payload = {"vector_similarity_weight": 1}
res = update_dataset(get_http_api_auth, ids[0], payload)
assert res["code"] == 0
@ -227,10 +219,8 @@ class TestDatasetUpdate:
assert res["data"][0]["vector_similarity_weight"] == 1
def test_invalid_dataset_id(self, get_http_api_auth):
create_datasets(get_http_api_auth, 1)
res = update_dataset(
get_http_api_auth, "invalid_dataset_id", {"name": "invalid_dataset_id"}
)
batch_create_datasets(get_http_api_auth, 1)
res = update_dataset(get_http_api_auth, "invalid_dataset_id", {"name": "invalid_dataset_id"})
assert res["code"] == 102
assert res["message"] == "You don't own the dataset"
@ -251,25 +241,20 @@ class TestDatasetUpdate:
],
)
def test_modify_read_only_field(self, get_http_api_auth, payload):
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
res = update_dataset(get_http_api_auth, ids[0], payload)
assert res["code"] == 101
assert "is readonly" in res["message"]
def test_modify_unknown_field(self, get_http_api_auth):
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
res = update_dataset(get_http_api_auth, ids[0], {"unknown_field": 0})
assert res["code"] == 100
def test_concurrent_update(self, get_http_api_auth):
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [
executor.submit(
update_dataset, get_http_api_auth, ids[0], {"name": f"dataset_{i}"}
)
for i in range(100)
]
futures = [executor.submit(update_dataset, get_http_api_auth, ids[0], {"name": f"dataset_{i}"}) for i in range(100)]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses)

View File

@ -0,0 +1,32 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pytest
from common import batch_create_datasets, bulk_upload_documents, delete_dataset
@pytest.fixture(scope="class")
def file_management_tmp_dir(tmp_path_factory):
return tmp_path_factory.mktemp("file_management")
@pytest.fixture(scope="class")
def get_dataset_id_and_document_ids(get_http_api_auth, file_management_tmp_dir):
dataset_ids = batch_create_datasets(get_http_api_auth, 1)
document_ids = bulk_upload_documents(get_http_api_auth, dataset_ids[0], 5, file_management_tmp_dir)
yield dataset_ids[0], document_ids
delete_dataset(get_http_api_auth)

View File

@ -18,8 +18,8 @@ from concurrent.futures import ThreadPoolExecutor
import pytest
from common import (
INVALID_API_TOKEN,
batch_upload_documents,
create_datasets,
batch_create_datasets,
bulk_upload_documents,
delete_documnet,
list_documnet,
)
@ -38,14 +38,14 @@ class TestAuthorization:
),
],
)
def test_invalid_auth(self, get_http_api_auth, tmp_path, auth, expected_code, expected_message):
ids = create_datasets(get_http_api_auth, 1)
document_ids = batch_upload_documents(get_http_api_auth, ids[0], 1, tmp_path)
res = delete_documnet(auth, ids[0], {"ids": document_ids[0]})
def test_invalid_auth(self, get_dataset_id_and_document_ids, auth, expected_code, expected_message):
dataset_id, document_ids = get_dataset_id_and_document_ids
res = delete_documnet(auth, dataset_id, {"ids": document_ids})
assert res["code"] == expected_code
assert res["message"] == expected_message
@pytest.mark.usefixtures("clear_datasets")
class TestDocumentDeletion:
@pytest.mark.parametrize(
"payload, expected_code, expected_message, remaining",
@ -78,8 +78,8 @@ class TestDocumentDeletion:
expected_message,
remaining,
):
ids = create_datasets(get_http_api_auth, 1)
document_ids = batch_upload_documents(get_http_api_auth, ids[0], 3, tmp_path)
ids = batch_create_datasets(get_http_api_auth, 1)
document_ids = bulk_upload_documents(get_http_api_auth, ids[0], 3, tmp_path)
if callable(payload):
payload = payload(document_ids)
res = delete_documnet(get_http_api_auth, ids[0], payload)
@ -103,13 +103,12 @@ class TestDocumentDeletion:
],
)
def test_invalid_dataset_id(self, get_http_api_auth, tmp_path, dataset_id, expected_code, expected_message):
ids = create_datasets(get_http_api_auth, 1)
document_ids = batch_upload_documents(get_http_api_auth, ids[0], 3, tmp_path)
ids = batch_create_datasets(get_http_api_auth, 1)
document_ids = bulk_upload_documents(get_http_api_auth, ids[0], 3, tmp_path)
res = delete_documnet(get_http_api_auth, dataset_id, {"ids": document_ids[:1]})
assert res["code"] == expected_code
assert res["message"] == expected_message
# @pytest.mark.xfail(reason="issues/6174")
@pytest.mark.parametrize(
"payload",
[
@ -119,8 +118,8 @@ class TestDocumentDeletion:
],
)
def test_delete_partial_invalid_id(self, get_http_api_auth, tmp_path, payload):
ids = create_datasets(get_http_api_auth, 1)
document_ids = batch_upload_documents(get_http_api_auth, ids[0], 3, tmp_path)
ids = batch_create_datasets(get_http_api_auth, 1)
document_ids = bulk_upload_documents(get_http_api_auth, ids[0], 3, tmp_path)
if callable(payload):
payload = payload(document_ids)
res = delete_documnet(get_http_api_auth, ids[0], payload)
@ -132,8 +131,8 @@ class TestDocumentDeletion:
assert res["data"]["total"] == 0
def test_repeated_deletion(self, get_http_api_auth, tmp_path):
ids = create_datasets(get_http_api_auth, 1)
document_ids = batch_upload_documents(get_http_api_auth, ids[0], 1, tmp_path)
ids = batch_create_datasets(get_http_api_auth, 1)
document_ids = bulk_upload_documents(get_http_api_auth, ids[0], 1, tmp_path)
res = delete_documnet(get_http_api_auth, ids[0], {"ids": document_ids})
assert res["code"] == 0
@ -142,8 +141,8 @@ class TestDocumentDeletion:
assert res["message"] == f"Documents not found: {document_ids}"
def test_duplicate_deletion(self, get_http_api_auth, tmp_path):
ids = create_datasets(get_http_api_auth, 1)
document_ids = batch_upload_documents(get_http_api_auth, ids[0], 1, tmp_path)
ids = batch_create_datasets(get_http_api_auth, 1)
document_ids = bulk_upload_documents(get_http_api_auth, ids[0], 1, tmp_path)
res = delete_documnet(get_http_api_auth, ids[0], {"ids": document_ids + document_ids})
assert res["code"] == 0
assert res["data"]["errors"][0] == f"Duplicate document ids: {document_ids[0]}"
@ -155,8 +154,8 @@ class TestDocumentDeletion:
def test_concurrent_deletion(self, get_http_api_auth, tmp_path):
documnets_num = 100
ids = create_datasets(get_http_api_auth, 1)
document_ids = batch_upload_documents(get_http_api_auth, ids[0], documnets_num, tmp_path)
ids = batch_create_datasets(get_http_api_auth, 1)
document_ids = bulk_upload_documents(get_http_api_auth, ids[0], documnets_num, tmp_path)
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [
@ -174,8 +173,8 @@ class TestDocumentDeletion:
@pytest.mark.slow
def test_delete_1k(self, get_http_api_auth, tmp_path):
documnets_num = 1_000
ids = create_datasets(get_http_api_auth, 1)
document_ids = batch_upload_documents(get_http_api_auth, ids[0], documnets_num, tmp_path)
ids = batch_create_datasets(get_http_api_auth, 1)
document_ids = bulk_upload_documents(get_http_api_auth, ids[0], documnets_num, tmp_path)
res = list_documnet(get_http_api_auth, ids[0])
assert res["data"]["total"] == documnets_num

View File

@ -18,13 +18,7 @@ import json
from concurrent.futures import ThreadPoolExecutor
import pytest
from common import (
INVALID_API_TOKEN,
batch_upload_documents,
create_datasets,
download_document,
upload_documnets,
)
from common import INVALID_API_TOKEN, batch_create_datasets, bulk_upload_documents, download_document, upload_documnets
from libs.auth import RAGFlowHttpApiAuth
from libs.utils import compare_by_hash
from requests import codes
@ -42,14 +36,9 @@ class TestAuthorization:
),
],
)
def test_invalid_auth(
self, get_http_api_auth, tmp_path, auth, expected_code, expected_message
):
ids = create_datasets(get_http_api_auth, 1)
document_ids = batch_upload_documents(get_http_api_auth, ids[0], 1, tmp_path)
res = download_document(
auth, ids[0], document_ids[0], tmp_path / "ragflow_tes.txt"
)
def test_invalid_auth(self, get_dataset_id_and_document_ids, tmp_path, auth, expected_code, expected_message):
dataset_id, document_ids = get_dataset_id_and_document_ids
res = download_document(auth, dataset_id, document_ids[0], tmp_path / "ragflow_tes.txt")
assert res.status_code == codes.ok
with (tmp_path / "ragflow_tes.txt").open("r") as f:
response_json = json.load(f)
@ -57,43 +46,43 @@ class TestAuthorization:
assert response_json["message"] == expected_message
class TestDocumentDownload:
@pytest.mark.parametrize(
"generate_test_files",
[
"docx",
"excel",
"ppt",
"image",
"pdf",
"txt",
"md",
"json",
"eml",
"html",
],
indirect=True,
@pytest.mark.usefixtures("clear_datasets")
@pytest.mark.parametrize(
"generate_test_files",
[
"docx",
"excel",
"ppt",
"image",
"pdf",
"txt",
"md",
"json",
"eml",
"html",
],
indirect=True,
)
def test_file_type_validation(get_http_api_auth, generate_test_files, request):
ids = batch_create_datasets(get_http_api_auth, 1)
fp = generate_test_files[request.node.callspec.params["generate_test_files"]]
res = upload_documnets(get_http_api_auth, ids[0], [fp])
document_id = res["data"][0]["id"]
res = download_document(
get_http_api_auth,
ids[0],
document_id,
fp.with_stem("ragflow_test_download"),
)
assert res.status_code == codes.ok
assert compare_by_hash(
fp,
fp.with_stem("ragflow_test_download"),
)
def test_file_type_validation(
self, get_http_api_auth, generate_test_files, request
):
ids = create_datasets(get_http_api_auth, 1)
fp = generate_test_files[request.node.callspec.params["generate_test_files"]]
res = upload_documnets(get_http_api_auth, ids[0], [fp])
document_id = res["data"][0]["id"]
res = download_document(
get_http_api_auth,
ids[0],
document_id,
fp.with_stem("ragflow_test_download"),
)
assert res.status_code == codes.ok
assert compare_by_hash(
fp,
fp.with_stem("ragflow_test_download"),
)
class TestDocumentDownload:
@pytest.mark.parametrize(
"document_id, expected_code, expected_message",
[
@ -104,13 +93,11 @@ class TestDocumentDownload:
),
],
)
def test_invalid_document_id(
self, get_http_api_auth, tmp_path, document_id, expected_code, expected_message
):
ids = create_datasets(get_http_api_auth, 1)
def test_invalid_document_id(self, get_http_api_auth, get_dataset_id_and_document_ids, tmp_path, document_id, expected_code, expected_message):
dataset_id, _ = get_dataset_id_and_document_ids
res = download_document(
get_http_api_auth,
ids[0],
dataset_id,
document_id,
tmp_path / "ragflow_test_download_1.txt",
)
@ -131,11 +118,8 @@ class TestDocumentDownload:
),
],
)
def test_invalid_dataset_id(
self, get_http_api_auth, tmp_path, dataset_id, expected_code, expected_message
):
ids = create_datasets(get_http_api_auth, 1)
document_ids = batch_upload_documents(get_http_api_auth, ids[0], 1, tmp_path)
def test_invalid_dataset_id(self, get_http_api_auth, get_dataset_id_and_document_ids, tmp_path, dataset_id, expected_code, expected_message):
_, document_ids = get_dataset_id_and_document_ids
res = download_document(
get_http_api_auth,
dataset_id,
@ -148,45 +132,44 @@ class TestDocumentDownload:
assert response_json["code"] == expected_code
assert response_json["message"] == expected_message
def test_same_file_repeat(self, get_http_api_auth, tmp_path):
def test_same_file_repeat(self, get_http_api_auth, get_dataset_id_and_document_ids, tmp_path, file_management_tmp_dir):
num = 5
ids = create_datasets(get_http_api_auth, 1)
document_ids = batch_upload_documents(get_http_api_auth, ids[0], 1, tmp_path)
dataset_id, document_ids = get_dataset_id_and_document_ids
for i in range(num):
res = download_document(
get_http_api_auth,
ids[0],
dataset_id,
document_ids[0],
tmp_path / f"ragflow_test_download_{i}.txt",
)
assert res.status_code == codes.ok
assert compare_by_hash(
tmp_path / "ragflow_test_upload_0.txt",
file_management_tmp_dir / "ragflow_test_upload_0.txt",
tmp_path / f"ragflow_test_download_{i}.txt",
)
def test_concurrent_download(self, get_http_api_auth, tmp_path):
document_count = 20
ids = create_datasets(get_http_api_auth, 1)
document_ids = batch_upload_documents(
get_http_api_auth, ids[0], document_count, tmp_path
@pytest.mark.usefixtures("clear_datasets")
def test_concurrent_download(get_http_api_auth, tmp_path):
document_count = 20
ids = batch_create_datasets(get_http_api_auth, 1)
document_ids = bulk_upload_documents(get_http_api_auth, ids[0], document_count, tmp_path)
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [
executor.submit(
download_document,
get_http_api_auth,
ids[0],
document_ids[i],
tmp_path / f"ragflow_test_download_{i}.txt",
)
for i in range(document_count)
]
responses = [f.result() for f in futures]
assert all(r.status_code == codes.ok for r in responses)
for i in range(document_count):
assert compare_by_hash(
tmp_path / f"ragflow_test_upload_{i}.txt",
tmp_path / f"ragflow_test_download_{i}.txt",
)
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [
executor.submit(
download_document,
get_http_api_auth,
ids[0],
document_ids[i],
tmp_path / f"ragflow_test_download_{i}.txt",
)
for i in range(document_count)
]
responses = [f.result() for f in futures]
assert all(r.status_code == codes.ok for r in responses)
for i in range(document_count):
assert compare_by_hash(
tmp_path / f"ragflow_test_upload_{i}.txt",
tmp_path / f"ragflow_test_download_{i}.txt",
)

View File

@ -18,8 +18,6 @@ from concurrent.futures import ThreadPoolExecutor
import pytest
from common import (
INVALID_API_TOKEN,
batch_upload_documents,
create_datasets,
list_documnet,
)
from libs.auth import RAGFlowHttpApiAuth
@ -27,11 +25,7 @@ from libs.auth import RAGFlowHttpApiAuth
def is_sorted(data, field, descending=True):
timestamps = [ds[field] for ds in data]
return (
all(a >= b for a, b in zip(timestamps, timestamps[1:]))
if descending
else all(a <= b for a, b in zip(timestamps, timestamps[1:]))
)
return all(a >= b for a, b in zip(timestamps, timestamps[1:])) if descending else all(a <= b for a, b in zip(timestamps, timestamps[1:]))
class TestAuthorization:
@ -46,23 +40,20 @@ class TestAuthorization:
),
],
)
def test_invalid_auth(
self, get_http_api_auth, auth, expected_code, expected_message
):
ids = create_datasets(get_http_api_auth, 1)
res = list_documnet(auth, ids[0])
def test_invalid_auth(self, get_dataset_id_and_document_ids, auth, expected_code, expected_message):
dataset_id, _ = get_dataset_id_and_document_ids
res = list_documnet(auth, dataset_id)
assert res["code"] == expected_code
assert res["message"] == expected_message
class TestDocumentList:
def test_default(self, get_http_api_auth, tmp_path):
ids = create_datasets(get_http_api_auth, 1)
batch_upload_documents(get_http_api_auth, ids[0], 31, tmp_path)
res = list_documnet(get_http_api_auth, ids[0])
def test_default(self, get_http_api_auth, get_dataset_id_and_document_ids):
dataset_id, _ = get_dataset_id_and_document_ids
res = list_documnet(get_http_api_auth, dataset_id)
assert res["code"] == 0
assert len(res["data"]["docs"]) == 30
assert res["data"]["total"] == 31
assert len(res["data"]["docs"]) == 5
assert res["data"]["total"] == 5
@pytest.mark.parametrize(
"dataset_id, expected_code, expected_message",
@ -75,10 +66,7 @@ class TestDocumentList:
),
],
)
def test_invalid_dataset_id(
self, get_http_api_auth, dataset_id, expected_code, expected_message
):
create_datasets(get_http_api_auth, 1)
def test_invalid_dataset_id(self, get_http_api_auth, get_dataset_id_and_document_ids, dataset_id, expected_code, expected_message):
res = list_documnet(get_http_api_auth, dataset_id)
assert res["code"] == expected_code
assert res["message"] == expected_message
@ -110,15 +98,14 @@ class TestDocumentList:
def test_page(
self,
get_http_api_auth,
tmp_path,
get_dataset_id_and_document_ids,
params,
expected_code,
expected_page_size,
expected_message,
):
ids = create_datasets(get_http_api_auth, 1)
batch_upload_documents(get_http_api_auth, ids[0], 5, tmp_path)
res = list_documnet(get_http_api_auth, ids[0], params=params)
dataset_id, _ = get_dataset_id_and_document_ids
res = list_documnet(get_http_api_auth, dataset_id, params=params)
assert res["code"] == expected_code
if expected_code == 0:
assert len(res["data"]["docs"]) == expected_page_size
@ -129,10 +116,10 @@ class TestDocumentList:
@pytest.mark.parametrize(
"params, expected_code, expected_page_size, expected_message",
[
({"page_size": None}, 0, 30, ""),
({"page_size": None}, 0, 5, ""),
({"page_size": 0}, 0, 0, ""),
({"page_size": 1}, 0, 1, ""),
({"page_size": 32}, 0, 31, ""),
({"page_size": 6}, 0, 5, ""),
({"page_size": "1"}, 0, 1, ""),
pytest.param(
{"page_size": -1},
@ -153,15 +140,14 @@ class TestDocumentList:
def test_page_size(
self,
get_http_api_auth,
tmp_path,
get_dataset_id_and_document_ids,
params,
expected_code,
expected_page_size,
expected_message,
):
ids = create_datasets(get_http_api_auth, 1)
batch_upload_documents(get_http_api_auth, ids[0], 31, tmp_path)
res = list_documnet(get_http_api_auth, ids[0], params=params)
dataset_id, _ = get_dataset_id_and_document_ids
res = list_documnet(get_http_api_auth, dataset_id, params=params)
assert res["code"] == expected_code
if expected_code == 0:
assert len(res["data"]["docs"]) == expected_page_size
@ -208,15 +194,14 @@ class TestDocumentList:
def test_orderby(
self,
get_http_api_auth,
tmp_path,
get_dataset_id_and_document_ids,
params,
expected_code,
assertions,
expected_message,
):
ids = create_datasets(get_http_api_auth, 1)
batch_upload_documents(get_http_api_auth, ids[0], 3, tmp_path)
res = list_documnet(get_http_api_auth, ids[0], params=params)
dataset_id, _ = get_dataset_id_and_document_ids
res = list_documnet(get_http_api_auth, dataset_id, params=params)
assert res["code"] == expected_code
if expected_code == 0:
if callable(assertions):
@ -288,15 +273,14 @@ class TestDocumentList:
def test_desc(
self,
get_http_api_auth,
tmp_path,
get_dataset_id_and_document_ids,
params,
expected_code,
assertions,
expected_message,
):
ids = create_datasets(get_http_api_auth, 1)
batch_upload_documents(get_http_api_auth, ids[0], 3, tmp_path)
res = list_documnet(get_http_api_auth, ids[0], params=params)
dataset_id, _ = get_dataset_id_and_document_ids
res = list_documnet(get_http_api_auth, dataset_id, params=params)
assert res["code"] == expected_code
if expected_code == 0:
if callable(assertions):
@ -307,17 +291,16 @@ class TestDocumentList:
@pytest.mark.parametrize(
"params, expected_num",
[
({"keywords": None}, 3),
({"keywords": ""}, 3),
({"keywords": None}, 5),
({"keywords": ""}, 5),
({"keywords": "0"}, 1),
({"keywords": "ragflow_test_upload"}, 3),
({"keywords": "ragflow_test_upload"}, 5),
({"keywords": "unknown"}, 0),
],
)
def test_keywords(self, get_http_api_auth, tmp_path, params, expected_num):
ids = create_datasets(get_http_api_auth, 1)
batch_upload_documents(get_http_api_auth, ids[0], 3, tmp_path)
res = list_documnet(get_http_api_auth, ids[0], params=params)
def test_keywords(self, get_http_api_auth, get_dataset_id_and_document_ids, params, expected_num):
dataset_id, _ = get_dataset_id_and_document_ids
res = list_documnet(get_http_api_auth, dataset_id, params=params)
assert res["code"] == 0
assert len(res["data"]["docs"]) == expected_num
assert res["data"]["total"] == expected_num
@ -325,8 +308,8 @@ class TestDocumentList:
@pytest.mark.parametrize(
"params, expected_code, expected_num, expected_message",
[
({"name": None}, 0, 3, ""),
({"name": ""}, 0, 3, ""),
({"name": None}, 0, 5, ""),
({"name": ""}, 0, 5, ""),
({"name": "ragflow_test_upload_0.txt"}, 0, 1, ""),
(
{"name": "unknown.txt"},
@ -339,15 +322,14 @@ class TestDocumentList:
def test_name(
self,
get_http_api_auth,
tmp_path,
get_dataset_id_and_document_ids,
params,
expected_code,
expected_num,
expected_message,
):
ids = create_datasets(get_http_api_auth, 1)
batch_upload_documents(get_http_api_auth, ids[0], 3, tmp_path)
res = list_documnet(get_http_api_auth, ids[0], params=params)
dataset_id, _ = get_dataset_id_and_document_ids
res = list_documnet(get_http_api_auth, dataset_id, params=params)
assert res["code"] == expected_code
if expected_code == 0:
if params["name"] in [None, ""]:
@ -360,8 +342,8 @@ class TestDocumentList:
@pytest.mark.parametrize(
"document_id, expected_code, expected_num, expected_message",
[
(None, 0, 3, ""),
("", 0, 3, ""),
(None, 0, 5, ""),
("", 0, 5, ""),
(lambda r: r[0], 0, 1, ""),
("unknown.txt", 102, 0, "You don't own the document unknown.txt."),
],
@ -369,19 +351,18 @@ class TestDocumentList:
def test_id(
self,
get_http_api_auth,
tmp_path,
get_dataset_id_and_document_ids,
document_id,
expected_code,
expected_num,
expected_message,
):
ids = create_datasets(get_http_api_auth, 1)
document_ids = batch_upload_documents(get_http_api_auth, ids[0], 3, tmp_path)
dataset_id, document_ids = get_dataset_id_and_document_ids
if callable(document_id):
params = {"id": document_id(document_ids)}
else:
params = {"id": document_id}
res = list_documnet(get_http_api_auth, ids[0], params=params)
res = list_documnet(get_http_api_auth, dataset_id, params=params)
assert res["code"] == expected_code
if expected_code == 0:
@ -410,41 +391,36 @@ class TestDocumentList:
def test_name_and_id(
self,
get_http_api_auth,
tmp_path,
get_dataset_id_and_document_ids,
document_id,
name,
expected_code,
expected_num,
expected_message,
):
ids = create_datasets(get_http_api_auth, 1)
document_ids = batch_upload_documents(get_http_api_auth, ids[0], 3, tmp_path)
dataset_id, document_ids = get_dataset_id_and_document_ids
if callable(document_id):
params = {"id": document_id(document_ids), "name": name}
else:
params = {"id": document_id, "name": name}
res = list_documnet(get_http_api_auth, ids[0], params=params)
res = list_documnet(get_http_api_auth, dataset_id, params=params)
if expected_code == 0:
assert len(res["data"]["docs"]) == expected_num
else:
assert res["message"] == expected_message
def test_concurrent_list(self, get_http_api_auth, tmp_path):
ids = create_datasets(get_http_api_auth, 1)
batch_upload_documents(get_http_api_auth, ids[0], 3, tmp_path)
def test_concurrent_list(self, get_http_api_auth, get_dataset_id_and_document_ids):
dataset_id, _ = get_dataset_id_and_document_ids
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [
executor.submit(list_documnet, get_http_api_auth, ids[0])
for i in range(100)
]
futures = [executor.submit(list_documnet, get_http_api_auth, dataset_id) for i in range(100)]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses)
def test_invalid_params(self, get_http_api_auth):
ids = create_datasets(get_http_api_auth, 1)
def test_invalid_params(self, get_http_api_auth, get_dataset_id_and_document_ids):
dataset_id, _ = get_dataset_id_and_document_ids
params = {"a": "b"}
res = list_documnet(get_http_api_auth, ids[0], params=params)
res = list_documnet(get_http_api_auth, dataset_id, params=params)
assert res["code"] == 0
assert len(res["data"]["docs"]) == 0
assert len(res["data"]["docs"]) == 5

View File

@ -18,8 +18,8 @@ from concurrent.futures import ThreadPoolExecutor
import pytest
from common import (
INVALID_API_TOKEN,
batch_upload_documents,
create_datasets,
batch_create_datasets,
bulk_upload_documents,
list_documnet,
parse_documnet,
)
@ -50,14 +50,14 @@ class TestAuthorization:
),
],
)
def test_invalid_auth(self, get_http_api_auth, tmp_path, auth, expected_code, expected_message):
ids = create_datasets(get_http_api_auth, 1)
document_ids = batch_upload_documents(get_http_api_auth, ids[0], 1, tmp_path)
res = parse_documnet(auth, ids[0], {"document_ids": document_ids[0]})
def test_invalid_auth(self, get_dataset_id_and_document_ids, auth, expected_code, expected_message):
dataset_id, document_ids = get_dataset_id_and_document_ids
res = parse_documnet(auth, dataset_id, {"document_ids": document_ids})
assert res["code"] == expected_code
assert res["message"] == expected_message
@pytest.mark.usefixtures("clear_datasets")
class TestDocumentsParse:
@pytest.mark.parametrize(
"payload, expected_code, expected_message",
@ -98,9 +98,9 @@ class TestDocumentsParse:
return False
return True
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
dataset_id = ids[0]
document_ids = batch_upload_documents(get_http_api_auth, dataset_id, 3, tmp_path)
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 3, tmp_path)
if callable(payload):
payload = payload(document_ids)
res = parse_documnet(get_http_api_auth, dataset_id, payload)
@ -130,8 +130,8 @@ class TestDocumentsParse:
expected_code,
expected_message,
):
ids = create_datasets(get_http_api_auth, 1)
document_ids = batch_upload_documents(get_http_api_auth, ids[0], 1, tmp_path)
ids = batch_create_datasets(get_http_api_auth, 1)
document_ids = bulk_upload_documents(get_http_api_auth, ids[0], 1, tmp_path)
res = parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids})
assert res["code"] == expected_code
assert res["message"] == expected_message
@ -153,9 +153,9 @@ class TestDocumentsParse:
return False
return True
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
dataset_id = ids[0]
document_ids = batch_upload_documents(get_http_api_auth, dataset_id, 3, tmp_path)
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 3, tmp_path)
if callable(payload):
payload = payload(document_ids)
res = parse_documnet(get_http_api_auth, dataset_id, payload)
@ -175,9 +175,9 @@ class TestDocumentsParse:
return False
return True
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
dataset_id = ids[0]
document_ids = batch_upload_documents(get_http_api_auth, dataset_id, 1, tmp_path)
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 1, tmp_path)
res = parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids})
assert res["code"] == 0
@ -195,9 +195,9 @@ class TestDocumentsParse:
return False
return True
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
dataset_id = ids[0]
document_ids = batch_upload_documents(get_http_api_auth, dataset_id, 1, tmp_path)
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 1, tmp_path)
res = parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids + document_ids})
assert res["code"] == 0
assert res["data"]["errors"][0] == f"Duplicate document ids: {document_ids[0]}"
@ -218,9 +218,9 @@ class TestDocumentsParse:
return True
document_num = 100
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
dataset_id = ids[0]
document_ids = batch_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path)
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path)
res = parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids})
assert res["code"] == 0
@ -239,9 +239,9 @@ class TestDocumentsParse:
return True
document_num = 100
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
dataset_id = ids[0]
document_ids = batch_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path)
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path)
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [

View File

@ -18,8 +18,8 @@ from concurrent.futures import ThreadPoolExecutor
import pytest
from common import (
INVALID_API_TOKEN,
batch_upload_documents,
create_datasets,
batch_create_datasets,
bulk_upload_documents,
list_documnet,
parse_documnet,
stop_parse_documnet,
@ -48,6 +48,7 @@ def validate_document_parse_cancel(auth, dataset_id, document_ids):
assert doc["progress"] == 0.0
@pytest.mark.usefixtures("clear_datasets")
class TestAuthorization:
@pytest.mark.parametrize(
"auth, expected_code, expected_message",
@ -61,13 +62,14 @@ class TestAuthorization:
],
)
def test_invalid_auth(self, get_http_api_auth, auth, expected_code, expected_message):
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
res = stop_parse_documnet(auth, ids[0])
assert res["code"] == expected_code
assert res["message"] == expected_message
@pytest.mark.skip
@pytest.mark.usefixtures("clear_datasets")
class TestDocumentsParseStop:
@pytest.mark.parametrize(
"payload, expected_code, expected_message",
@ -108,9 +110,9 @@ class TestDocumentsParseStop:
return False
return True
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
dataset_id = ids[0]
document_ids = batch_upload_documents(get_http_api_auth, dataset_id, 3, tmp_path)
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 3, tmp_path)
parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids})
if callable(payload):
@ -145,8 +147,8 @@ class TestDocumentsParseStop:
expected_code,
expected_message,
):
ids = create_datasets(get_http_api_auth, 1)
document_ids = batch_upload_documents(get_http_api_auth, ids[0], 1, tmp_path)
ids = batch_create_datasets(get_http_api_auth, 1)
document_ids = bulk_upload_documents(get_http_api_auth, ids[0], 1, tmp_path)
res = stop_parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids})
assert res["code"] == expected_code
assert res["message"] == expected_message
@ -161,9 +163,9 @@ class TestDocumentsParseStop:
],
)
def test_stop_parse_partial_invalid_document_id(self, get_http_api_auth, tmp_path, payload):
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
dataset_id = ids[0]
document_ids = batch_upload_documents(get_http_api_auth, dataset_id, 3, tmp_path)
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 3, tmp_path)
parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids})
if callable(payload):
@ -175,9 +177,9 @@ class TestDocumentsParseStop:
validate_document_parse_cancel(get_http_api_auth, dataset_id, document_ids)
def test_repeated_stop_parse(self, get_http_api_auth, tmp_path):
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
dataset_id = ids[0]
document_ids = batch_upload_documents(get_http_api_auth, dataset_id, 1, tmp_path)
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 1, tmp_path)
parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids})
res = stop_parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids})
assert res["code"] == 0
@ -187,9 +189,9 @@ class TestDocumentsParseStop:
assert res["message"] == "Can't stop parsing document with progress at 0 or 1"
def test_duplicate_stop_parse(self, get_http_api_auth, tmp_path):
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
dataset_id = ids[0]
document_ids = batch_upload_documents(get_http_api_auth, dataset_id, 1, tmp_path)
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 1, tmp_path)
parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids})
res = stop_parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids + document_ids})
assert res["code"] == 0
@ -199,9 +201,9 @@ class TestDocumentsParseStop:
@pytest.mark.slow
def test_stop_parse_100_files(self, get_http_api_auth, tmp_path):
document_num = 100
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
dataset_id = ids[0]
document_ids = batch_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path)
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path)
parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids})
res = stop_parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids})
assert res["code"] == 0
@ -210,9 +212,9 @@ class TestDocumentsParseStop:
@pytest.mark.slow
def test_concurrent_parse(self, get_http_api_auth, tmp_path):
document_num = 50
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
dataset_id = ids[0]
document_ids = batch_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path)
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path)
parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids})
with ThreadPoolExecutor(max_workers=5) as executor:

View File

@ -16,14 +16,7 @@
import pytest
from common import (
DOCUMENT_NAME_LIMIT,
INVALID_API_TOKEN,
batch_upload_documents,
create_datasets,
list_documnet,
update_documnet,
)
from common import DOCUMENT_NAME_LIMIT, INVALID_API_TOKEN, batch_create_datasets, bulk_upload_documents, list_documnet, update_documnet
from libs.auth import RAGFlowHttpApiAuth
@ -39,10 +32,9 @@ class TestAuthorization:
),
],
)
def test_invalid_auth(self, get_http_api_auth, tmp_path, auth, expected_code, expected_message):
ids = create_datasets(get_http_api_auth, 1)
document_ids = batch_upload_documents(get_http_api_auth, ids[0], 1, tmp_path)
res = update_documnet(auth, ids[0], document_ids[0], {"name": "auth_test.txt"})
def test_invalid_auth(self, get_dataset_id_and_document_ids, auth, expected_code, expected_message):
dataset_id, document_ids = get_dataset_id_and_document_ids
res = update_documnet(auth, dataset_id, document_ids[0], {"name": "auth_test.txt"})
assert res["code"] == expected_code
assert res["message"] == expected_message
@ -89,13 +81,12 @@ class TestUpdatedDocument:
),
],
)
def test_name(self, get_http_api_auth, tmp_path, name, expected_code, expected_message):
ids = create_datasets(get_http_api_auth, 1)
document_ids = batch_upload_documents(get_http_api_auth, ids[0], 2, tmp_path)
res = update_documnet(get_http_api_auth, ids[0], document_ids[0], {"name": name})
def test_name(self, get_http_api_auth, get_dataset_id_and_document_ids, name, expected_code, expected_message):
dataset_id, document_ids = get_dataset_id_and_document_ids
res = update_documnet(get_http_api_auth, dataset_id, document_ids[0], {"name": name})
assert res["code"] == expected_code
if expected_code == 0:
res = list_documnet(get_http_api_auth, ids[0], {"id": document_ids[0]})
res = list_documnet(get_http_api_auth, dataset_id, {"id": document_ids[0]})
assert res["data"]["docs"][0]["name"] == name
else:
assert res["message"] == expected_message
@ -111,9 +102,9 @@ class TestUpdatedDocument:
),
],
)
def test_invalid_document_id(self, get_http_api_auth, document_id, expected_code, expected_message):
ids = create_datasets(get_http_api_auth, 1)
res = update_documnet(get_http_api_auth, ids[0], document_id, {"name": "new_name.txt"})
def test_invalid_document_id(self, get_http_api_auth, get_dataset_id_and_document_ids, document_id, expected_code, expected_message):
dataset_id, _ = get_dataset_id_and_document_ids
res = update_documnet(get_http_api_auth, dataset_id, document_id, {"name": "new_name.txt"})
assert res["code"] == expected_code
assert res["message"] == expected_message
@ -128,9 +119,8 @@ class TestUpdatedDocument:
),
],
)
def test_invalid_dataset_id(self, get_http_api_auth, tmp_path, dataset_id, expected_code, expected_message):
ids = create_datasets(get_http_api_auth, 1)
document_ids = batch_upload_documents(get_http_api_auth, ids[0], 1, tmp_path)
def test_invalid_dataset_id(self, get_http_api_auth, get_dataset_id_and_document_ids, dataset_id, expected_code, expected_message):
_, document_ids = get_dataset_id_and_document_ids
res = update_documnet(get_http_api_auth, dataset_id, document_ids[0], {"name": "new_name.txt"})
assert res["code"] == expected_code
assert res["message"] == expected_message
@ -139,12 +129,11 @@ class TestUpdatedDocument:
"meta_fields, expected_code, expected_message",
[({"test": "test"}, 0, ""), ("test", 102, "meta_fields must be a dictionary")],
)
def test_meta_fields(self, get_http_api_auth, tmp_path, meta_fields, expected_code, expected_message):
ids = create_datasets(get_http_api_auth, 1)
document_ids = batch_upload_documents(get_http_api_auth, ids[0], 1, tmp_path)
res = update_documnet(get_http_api_auth, ids[0], document_ids[0], {"meta_fields": meta_fields})
def test_meta_fields(self, get_http_api_auth, get_dataset_id_and_document_ids, meta_fields, expected_code, expected_message):
dataset_id, document_ids = get_dataset_id_and_document_ids
res = update_documnet(get_http_api_auth, dataset_id, document_ids[0], {"meta_fields": meta_fields})
if expected_code == 0:
res = list_documnet(get_http_api_auth, ids[0], {"id": document_ids[0]})
res = list_documnet(get_http_api_auth, dataset_id, {"id": document_ids[0]})
assert res["data"]["docs"][0]["meta_fields"] == meta_fields
else:
assert res["message"] == expected_message
@ -173,14 +162,12 @@ class TestUpdatedDocument:
),
],
)
def test_chunk_method(self, get_http_api_auth, tmp_path, chunk_method, expected_code, expected_message):
ids = create_datasets(get_http_api_auth, 1)
document_ids = batch_upload_documents(get_http_api_auth, ids[0], 1, tmp_path)
res = update_documnet(get_http_api_auth, ids[0], document_ids[0], {"chunk_method": chunk_method})
print(res)
def test_chunk_method(self, get_http_api_auth, get_dataset_id_and_document_ids, chunk_method, expected_code, expected_message):
dataset_id, document_ids = get_dataset_id_and_document_ids
res = update_documnet(get_http_api_auth, dataset_id, document_ids[0], {"chunk_method": chunk_method})
assert res["code"] == expected_code
if expected_code == 0:
res = list_documnet(get_http_api_auth, ids[0], {"id": document_ids[0]})
res = list_documnet(get_http_api_auth, dataset_id, {"id": document_ids[0]})
if chunk_method != "":
assert res["data"]["docs"][0]["chunk_method"] == chunk_method
else:
@ -188,252 +175,6 @@ class TestUpdatedDocument:
else:
assert res["message"] == expected_message
@pytest.mark.parametrize(
"chunk_method, parser_config, expected_code, expected_message",
[
(
"naive",
{
"chunk_token_num": 128,
"layout_recognize": "DeepDOC",
"html4excel": False,
"delimiter": "\n!?。;!?",
"task_page_size": 12,
"raptor": {"use_raptor": False},
},
0,
"",
),
("naive", {}, 0, ""),
pytest.param(
"naive",
{"chunk_token_num": -1},
100,
"AssertionError('chunk_token_num should be in range from 1 to 100000000')",
marks=pytest.mark.xfail(reason="issues/6098"),
),
pytest.param(
"naive",
{"chunk_token_num": 0},
100,
"AssertionError('chunk_token_num should be in range from 1 to 100000000')",
marks=pytest.mark.xfail(reason="issues/6098"),
),
pytest.param(
"naive",
{"chunk_token_num": 100000000},
100,
"AssertionError('chunk_token_num should be in range from 1 to 100000000')",
marks=pytest.mark.xfail(reason="issues/6098"),
),
pytest.param(
"naive",
{"chunk_token_num": 3.14},
102,
"",
marks=pytest.mark.xfail(reason="issues/6098"),
),
pytest.param(
"naive",
{"chunk_token_num": "1024"},
100,
"",
marks=pytest.mark.xfail(reason="issues/6098"),
),
(
"naive",
{"layout_recognize": "DeepDOC"},
0,
"",
),
(
"naive",
{"layout_recognize": "Naive"},
0,
"",
),
("naive", {"html4excel": True}, 0, ""),
("naive", {"html4excel": False}, 0, ""),
pytest.param(
"naive",
{"html4excel": 1},
100,
"AssertionError('html4excel should be True or False')",
marks=pytest.mark.xfail(reason="issues/6098"),
),
("naive", {"delimiter": ""}, 0, ""),
("naive", {"delimiter": "`##`"}, 0, ""),
pytest.param(
"naive",
{"delimiter": 1},
100,
"",
marks=pytest.mark.xfail(reason="issues/6098"),
),
pytest.param(
"naive",
{"task_page_size": -1},
100,
"AssertionError('task_page_size should be in range from 1 to 100000000')",
marks=pytest.mark.xfail(reason="issues/6098"),
),
pytest.param(
"naive",
{"task_page_size": 0},
100,
"AssertionError('task_page_size should be in range from 1 to 100000000')",
marks=pytest.mark.xfail(reason="issues/6098"),
),
pytest.param(
"naive",
{"task_page_size": 100000000},
100,
"AssertionError('task_page_size should be in range from 1 to 100000000')",
marks=pytest.mark.xfail(reason="issues/6098"),
),
pytest.param(
"naive",
{"task_page_size": 3.14},
100,
"",
marks=pytest.mark.xfail(reason="issues/6098"),
),
pytest.param(
"naive",
{"task_page_size": "1024"},
100,
"",
marks=pytest.mark.xfail(reason="issues/6098"),
),
("naive", {"raptor": {"use_raptor": True}}, 0, ""),
("naive", {"raptor": {"use_raptor": False}}, 0, ""),
pytest.param(
"naive",
{"invalid_key": "invalid_value"},
100,
"""AssertionError("Abnormal \'parser_config\'. Invalid key: invalid_key")""",
marks=pytest.mark.xfail(reason="issues/6098"),
),
pytest.param(
"naive",
{"auto_keywords": -1},
100,
"AssertionError('auto_keywords should be in range from 0 to 32')",
marks=pytest.mark.xfail(reason="issues/6098"),
),
pytest.param(
"naive",
{"auto_keywords": 32},
100,
"AssertionError('auto_keywords should be in range from 0 to 32')",
marks=pytest.mark.xfail(reason="issues/6098"),
),
pytest.param(
"naive",
{"auto_questions": 3.14},
100,
"",
marks=pytest.mark.xfail(reason="issues/6098"),
),
pytest.param(
"naive",
{"auto_keywords": "1024"},
100,
"",
marks=pytest.mark.xfail(reason="issues/6098"),
),
pytest.param(
"naive",
{"auto_questions": -1},
100,
"AssertionError('auto_questions should be in range from 0 to 10')",
marks=pytest.mark.xfail(reason="issues/6098"),
),
pytest.param(
"naive",
{"auto_questions": 10},
100,
"AssertionError('auto_questions should be in range from 0 to 10')",
marks=pytest.mark.xfail(reason="issues/6098"),
),
pytest.param(
"naive",
{"auto_questions": 3.14},
100,
"",
marks=pytest.mark.xfail(reason="issues/6098"),
),
pytest.param(
"naive",
{"auto_questions": "1024"},
100,
"",
marks=pytest.mark.xfail(reason="issues/6098"),
),
pytest.param(
"naive",
{"topn_tags": -1},
100,
"AssertionError('topn_tags should be in range from 0 to 10')",
marks=pytest.mark.xfail(reason="issues/6098"),
),
pytest.param(
"naive",
{"topn_tags": 10},
100,
"AssertionError('topn_tags should be in range from 0 to 10')",
marks=pytest.mark.xfail(reason="issues/6098"),
),
pytest.param(
"naive",
{"topn_tags": 3.14},
100,
"",
marks=pytest.mark.xfail(reason="issues/6098"),
),
pytest.param(
"naive",
{"topn_tags": "1024"},
100,
"",
marks=pytest.mark.xfail(reason="issues/6098"),
),
],
)
def test_parser_config(
self,
get_http_api_auth,
tmp_path,
chunk_method,
parser_config,
expected_code,
expected_message,
):
ids = create_datasets(get_http_api_auth, 1)
document_ids = batch_upload_documents(get_http_api_auth, ids[0], 1, tmp_path)
res = update_documnet(
get_http_api_auth,
ids[0],
document_ids[0],
{"chunk_method": chunk_method, "parser_config": parser_config},
)
assert res["code"] == expected_code
if expected_code == 0:
res = list_documnet(get_http_api_auth, ids[0], {"id": document_ids[0]})
if parser_config != {}:
for k, v in parser_config.items():
assert res["data"]["docs"][0]["parser_config"][k] == v
else:
assert res["data"]["docs"][0]["parser_config"] == {
"chunk_token_num": 128,
"delimiter": "\\n!?;。;!?",
"html4excel": False,
"layout_recognize": "DeepDOC",
"raptor": {"use_raptor": False},
}
if expected_code != 0 or expected_message:
assert res["message"] == expected_message
@pytest.mark.parametrize(
"payload, expected_code, expected_message",
[
@ -541,13 +282,259 @@ class TestUpdatedDocument:
def test_invalid_field(
self,
get_http_api_auth,
tmp_path,
get_dataset_id_and_document_ids,
payload,
expected_code,
expected_message,
):
ids = create_datasets(get_http_api_auth, 1)
document_ids = batch_upload_documents(get_http_api_auth, ids[0], 2, tmp_path)
res = update_documnet(get_http_api_auth, ids[0], document_ids[0], payload)
dataset_id, document_ids = get_dataset_id_and_document_ids
res = update_documnet(get_http_api_auth, dataset_id, document_ids[0], payload)
assert res["code"] == expected_code
assert res["message"] == expected_message
@pytest.mark.usefixtures("clear_datasets")
@pytest.mark.parametrize(
"chunk_method, parser_config, expected_code, expected_message",
[
("naive", {}, 0, ""),
(
"naive",
{
"chunk_token_num": 128,
"layout_recognize": "DeepDOC",
"html4excel": False,
"delimiter": "\\n!?;。;!?",
"task_page_size": 12,
"raptor": {"use_raptor": False},
},
0,
"",
),
pytest.param(
"naive",
{"chunk_token_num": -1},
100,
"AssertionError('chunk_token_num should be in range from 1 to 100000000')",
marks=pytest.mark.xfail(reason="issues/6098"),
),
pytest.param(
"naive",
{"chunk_token_num": 0},
100,
"AssertionError('chunk_token_num should be in range from 1 to 100000000')",
marks=pytest.mark.xfail(reason="issues/6098"),
),
pytest.param(
"naive",
{"chunk_token_num": 100000000},
100,
"AssertionError('chunk_token_num should be in range from 1 to 100000000')",
marks=pytest.mark.xfail(reason="issues/6098"),
),
pytest.param(
"naive",
{"chunk_token_num": 3.14},
102,
"",
marks=pytest.mark.xfail(reason="issues/6098"),
),
pytest.param(
"naive",
{"chunk_token_num": "1024"},
100,
"",
marks=pytest.mark.xfail(reason="issues/6098"),
),
(
"naive",
{"layout_recognize": "DeepDOC"},
0,
"",
),
(
"naive",
{"layout_recognize": "Naive"},
0,
"",
),
("naive", {"html4excel": True}, 0, ""),
("naive", {"html4excel": False}, 0, ""),
pytest.param(
"naive",
{"html4excel": 1},
100,
"AssertionError('html4excel should be True or False')",
marks=pytest.mark.xfail(reason="issues/6098"),
),
("naive", {"delimiter": ""}, 0, ""),
("naive", {"delimiter": "`##`"}, 0, ""),
pytest.param(
"naive",
{"delimiter": 1},
100,
"",
marks=pytest.mark.xfail(reason="issues/6098"),
),
pytest.param(
"naive",
{"task_page_size": -1},
100,
"AssertionError('task_page_size should be in range from 1 to 100000000')",
marks=pytest.mark.xfail(reason="issues/6098"),
),
pytest.param(
"naive",
{"task_page_size": 0},
100,
"AssertionError('task_page_size should be in range from 1 to 100000000')",
marks=pytest.mark.xfail(reason="issues/6098"),
),
pytest.param(
"naive",
{"task_page_size": 100000000},
100,
"AssertionError('task_page_size should be in range from 1 to 100000000')",
marks=pytest.mark.xfail(reason="issues/6098"),
),
pytest.param(
"naive",
{"task_page_size": 3.14},
100,
"",
marks=pytest.mark.xfail(reason="issues/6098"),
),
pytest.param(
"naive",
{"task_page_size": "1024"},
100,
"",
marks=pytest.mark.xfail(reason="issues/6098"),
),
("naive", {"raptor": {"use_raptor": True}}, 0, ""),
("naive", {"raptor": {"use_raptor": False}}, 0, ""),
pytest.param(
"naive",
{"invalid_key": "invalid_value"},
100,
"""AssertionError("Abnormal \'parser_config\'. Invalid key: invalid_key")""",
marks=pytest.mark.xfail(reason="issues/6098"),
),
pytest.param(
"naive",
{"auto_keywords": -1},
100,
"AssertionError('auto_keywords should be in range from 0 to 32')",
marks=pytest.mark.xfail(reason="issues/6098"),
),
pytest.param(
"naive",
{"auto_keywords": 32},
100,
"AssertionError('auto_keywords should be in range from 0 to 32')",
marks=pytest.mark.xfail(reason="issues/6098"),
),
pytest.param(
"naive",
{"auto_questions": 3.14},
100,
"",
marks=pytest.mark.xfail(reason="issues/6098"),
),
pytest.param(
"naive",
{"auto_keywords": "1024"},
100,
"",
marks=pytest.mark.xfail(reason="issues/6098"),
),
pytest.param(
"naive",
{"auto_questions": -1},
100,
"AssertionError('auto_questions should be in range from 0 to 10')",
marks=pytest.mark.xfail(reason="issues/6098"),
),
pytest.param(
"naive",
{"auto_questions": 10},
100,
"AssertionError('auto_questions should be in range from 0 to 10')",
marks=pytest.mark.xfail(reason="issues/6098"),
),
pytest.param(
"naive",
{"auto_questions": 3.14},
100,
"",
marks=pytest.mark.xfail(reason="issues/6098"),
),
pytest.param(
"naive",
{"auto_questions": "1024"},
100,
"",
marks=pytest.mark.xfail(reason="issues/6098"),
),
pytest.param(
"naive",
{"topn_tags": -1},
100,
"AssertionError('topn_tags should be in range from 0 to 10')",
marks=pytest.mark.xfail(reason="issues/6098"),
),
pytest.param(
"naive",
{"topn_tags": 10},
100,
"AssertionError('topn_tags should be in range from 0 to 10')",
marks=pytest.mark.xfail(reason="issues/6098"),
),
pytest.param(
"naive",
{"topn_tags": 3.14},
100,
"",
marks=pytest.mark.xfail(reason="issues/6098"),
),
pytest.param(
"naive",
{"topn_tags": "1024"},
100,
"",
marks=pytest.mark.xfail(reason="issues/6098"),
),
],
)
def test_parser_config(
get_http_api_auth,
tmp_path,
chunk_method,
parser_config,
expected_code,
expected_message,
):
ids = batch_create_datasets(get_http_api_auth, 1)
document_ids = bulk_upload_documents(get_http_api_auth, ids[0], 1, tmp_path)
res = update_documnet(
get_http_api_auth,
ids[0],
document_ids[0],
{"chunk_method": chunk_method, "parser_config": parser_config},
)
assert res["code"] == expected_code
if expected_code == 0:
res = list_documnet(get_http_api_auth, ids[0], {"id": document_ids[0]})
if parser_config != {}:
for k, v in parser_config.items():
assert res["data"]["docs"][0]["parser_config"][k] == v
else:
assert res["data"]["docs"][0]["parser_config"] == {
"chunk_token_num": 128,
"delimiter": "\\n!?;。;!?",
"html4excel": False,
"layout_recognize": "DeepDOC",
"raptor": {"use_raptor": False},
}
if expected_code != 0 or expected_message:
assert res["message"] == expected_message

View File

@ -24,7 +24,7 @@ from common import (
FILE_API_URL,
HOST_ADDRESS,
INVALID_API_TOKEN,
create_datasets,
batch_create_datasets,
list_dataset,
upload_documnets,
)
@ -33,6 +33,7 @@ from libs.utils.file_utils import create_txt_file
from requests_toolbelt import MultipartEncoder
@pytest.mark.usefixtures("clear_datasets")
class TestAuthorization:
@pytest.mark.parametrize(
"auth, expected_code, expected_message",
@ -45,18 +46,17 @@ class TestAuthorization:
),
],
)
def test_invalid_auth(
self, get_http_api_auth, auth, expected_code, expected_message
):
ids = create_datasets(get_http_api_auth, 1)
def test_invalid_auth(self, get_http_api_auth, auth, expected_code, expected_message):
ids = batch_create_datasets(get_http_api_auth, 1)
res = upload_documnets(auth, ids[0])
assert res["code"] == expected_code
assert res["message"] == expected_message
@pytest.mark.usefixtures("clear_datasets")
class TestUploadDocuments:
def test_valid_single_upload(self, get_http_api_auth, tmp_path):
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
fp = create_txt_file(tmp_path / "ragflow_test.txt")
res = upload_documnets(get_http_api_auth, ids[0], [fp])
assert res["code"] == 0
@ -79,10 +79,8 @@ class TestUploadDocuments:
],
indirect=True,
)
def test_file_type_validation(
self, get_http_api_auth, generate_test_files, request
):
ids = create_datasets(get_http_api_auth, 1)
def test_file_type_validation(self, get_http_api_auth, generate_test_files, request):
ids = batch_create_datasets(get_http_api_auth, 1)
fp = generate_test_files[request.node.callspec.params["generate_test_files"]]
res = upload_documnets(get_http_api_auth, ids[0], [fp])
assert res["code"] == 0
@ -94,24 +92,21 @@ class TestUploadDocuments:
["exe", "unknown"],
)
def test_unsupported_file_type(self, get_http_api_auth, tmp_path, file_type):
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
fp = tmp_path / f"ragflow_test.{file_type}"
fp.touch()
res = upload_documnets(get_http_api_auth, ids[0], [fp])
assert res["code"] == 500
assert (
res["message"]
== f"ragflow_test.{file_type}: This type of file has not been supported yet!"
)
assert res["message"] == f"ragflow_test.{file_type}: This type of file has not been supported yet!"
def test_missing_file(self, get_http_api_auth):
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
res = upload_documnets(get_http_api_auth, ids[0])
assert res["code"] == 101
assert res["message"] == "No file part!"
def test_empty_file(self, get_http_api_auth, tmp_path):
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
fp = tmp_path / "empty.txt"
fp.touch()
@ -120,7 +115,7 @@ class TestUploadDocuments:
assert res["data"][0]["size"] == 0
def test_filename_empty(self, get_http_api_auth, tmp_path):
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
fp = create_txt_file(tmp_path / "ragflow_test.txt")
url = f"{HOST_ADDRESS}{FILE_API_URL}".format(dataset_id=ids[0])
fields = (("file", ("", fp.open("rb"))),)
@ -135,7 +130,7 @@ class TestUploadDocuments:
assert res.json()["message"] == "No file selected!"
def test_filename_exceeds_max_length(self, get_http_api_auth, tmp_path):
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
# filename_length = 129
fp = create_txt_file(tmp_path / f"{'a' * (DOCUMENT_NAME_LIMIT - 3)}.txt")
res = upload_documnets(get_http_api_auth, ids[0], [fp])
@ -146,13 +141,10 @@ class TestUploadDocuments:
fp = create_txt_file(tmp_path / "ragflow_test.txt")
res = upload_documnets(get_http_api_auth, "invalid_dataset_id", [fp])
assert res["code"] == 100
assert (
res["message"]
== """LookupError("Can\'t find the dataset with ID invalid_dataset_id!")"""
)
assert res["message"] == """LookupError("Can\'t find the dataset with ID invalid_dataset_id!")"""
def test_duplicate_files(self, get_http_api_auth, tmp_path):
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
fp = create_txt_file(tmp_path / "ragflow_test.txt")
res = upload_documnets(get_http_api_auth, ids[0], [fp, fp])
assert res["code"] == 0
@ -165,7 +157,7 @@ class TestUploadDocuments:
assert res["data"][i]["name"] == expected_name
def test_same_file_repeat(self, get_http_api_auth, tmp_path):
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
fp = create_txt_file(tmp_path / "ragflow_test.txt")
for i in range(10):
res = upload_documnets(get_http_api_auth, ids[0], [fp])
@ -178,7 +170,7 @@ class TestUploadDocuments:
assert res["data"][0]["name"] == expected_name
def test_filename_special_characters(self, get_http_api_auth, tmp_path):
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
illegal_chars = '<>:"/\\|?*'
translation_table = str.maketrans({char: "_" for char in illegal_chars})
safe_filename = string.punctuation.translate(translation_table)
@ -192,7 +184,7 @@ class TestUploadDocuments:
assert res["data"][0]["name"] == fp.name
def test_multiple_files(self, get_http_api_auth, tmp_path):
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
expected_document_count = 20
fps = []
for i in range(expected_document_count):
@ -205,7 +197,7 @@ class TestUploadDocuments:
assert res["data"][0]["document_count"] == expected_document_count
def test_concurrent_upload(self, get_http_api_auth, tmp_path):
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
expected_document_count = 20
fps = []
@ -214,12 +206,7 @@ class TestUploadDocuments:
fps.append(fp)
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [
executor.submit(
upload_documnets, get_http_api_auth, ids[0], fps[i : i + 1]
)
for i in range(expected_document_count)
]
futures = [executor.submit(upload_documnets, get_http_api_auth, ids[0], fps[i : i + 1]) for i in range(expected_document_count)]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses)