Test: Refactor test fixtures and test cases (#6709)

### What problem does this PR solve?

 Refactor test fixtures and test cases

### Type of change

- [ ] Refactoring test cases
This commit is contained in:
liu an 2025-04-01 13:39:07 +08:00 committed by GitHub
parent 20b8ccd1e9
commit 58e6e7b668
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 881 additions and 837 deletions

View File

@ -15,26 +15,27 @@
#
import os
import pytest
import requests
from libs.auth import RAGFlowHttpApiAuth
HOST_ADDRESS = os.getenv('HOST_ADDRESS', 'http://127.0.0.1:9380')
HOST_ADDRESS = os.getenv("HOST_ADDRESS", "http://127.0.0.1:9380")
# def generate_random_email():
# return 'user_' + ''.join(random.choices(string.ascii_lowercase + string.digits, k=8))+'@1.com'
def generate_email():
return 'user_123@1.com'
return "user_123@1.com"
EMAIL = generate_email()
# password is "123"
PASSWORD = '''ctAseGvejiaSWWZ88T/m4FQVOpQyUvP+x7sXtdv3feqZACiQleuewkUi35E16wSd5C5QcnkkcV9cYc8TKPTRZlxappDuirxghxoOvFcJxFU4ixLsD
PASSWORD = """ctAseGvejiaSWWZ88T/m4FQVOpQyUvP+x7sXtdv3feqZACiQleuewkUi35E16wSd5C5QcnkkcV9cYc8TKPTRZlxappDuirxghxoOvFcJxFU4ixLsD
fN33jCHRoDUW81IH9zjij/vaw8IbVyb6vuwg6MX6inOEBRRzVbRYxXOu1wkWY6SsI8X70oF9aeLFp/PzQpjoe/YbSqpTq8qqrmHzn9vO+yvyYyvmDsphXe
X8f7fp9c7vUsfOCkM+gHY3PadG+QHa7KI7mzTKgUTZImK6BZtfRBATDTthEUbbaTewY4H0MnWiCeeDhcbeQao6cFy1To8pE3RpmxnGnS8BsBn8w=='''
X8f7fp9c7vUsfOCkM+gHY3PadG+QHa7KI7mzTKgUTZImK6BZtfRBATDTthEUbbaTewY4H0MnWiCeeDhcbeQao6cFy1To8pE3RpmxnGnS8BsBn8w=="""
def register():
@ -92,3 +93,64 @@ def get_email():
@pytest.fixture(scope="session")
def get_http_api_auth(get_api_key_fixture):
return RAGFlowHttpApiAuth(get_api_key_fixture)
def get_my_llms(auth, name):
url = HOST_ADDRESS + "/v1/llm/my_llms"
authorization = {"Authorization": auth}
response = requests.get(url=url, headers=authorization)
res = response.json()
if res.get("code") != 0:
raise Exception(res.get("message"))
if name in res.get("data"):
return True
return False
def add_models(auth):
url = HOST_ADDRESS + "/v1/llm/set_api_key"
authorization = {"Authorization": auth}
models_info = {
"ZHIPU-AI": {"llm_factory": "ZHIPU-AI", "api_key": "d06253dacd404180aa8afb096fcb6c30.KatwBIUpvCSml9sU"},
}
for name, model_info in models_info.items():
if not get_my_llms(auth, name):
response = requests.post(url=url, headers=authorization, json=model_info)
res = response.json()
if res.get("code") != 0:
raise Exception(res.get("message"))
def get_tenant_info(auth):
url = HOST_ADDRESS + "/v1/user/tenant_info"
authorization = {"Authorization": auth}
response = requests.get(url=url, headers=authorization)
res = response.json()
if res.get("code") != 0:
raise Exception(res.get("message"))
return res["data"].get("tenant_id")
@pytest.fixture(scope="session", autouse=True)
def set_tenant_info(get_auth):
auth = get_auth
try:
add_models(auth)
tenant_id = get_tenant_info(auth)
except Exception as e:
raise Exception(e)
url = HOST_ADDRESS + "/v1/user/set_tenant_info"
authorization = {"Authorization": get_auth}
tenant_info = {
"tenant_id": tenant_id,
"llm_id": "glm-4-flash@ZHIPU-AI",
"embd_id": "embedding-3@ZHIPU-AI",
"img2txt_id": "glm-4v@ZHIPU-AI",
"asr_id": "",
"tts_id": None,
}
response = requests.post(url=url, headers=authorization, json=tenant_info)
res = response.json()
if res.get("code") != 0:
raise Exception(res.get("message"))

View File

@ -27,6 +27,7 @@ DATASETS_API_URL = "/api/v1/datasets"
FILE_API_URL = "/api/v1/datasets/{dataset_id}/documents"
FILE_CHUNK_API_URL = "/api/v1/datasets/{dataset_id}/chunks"
CHUNK_API_URL = "/api/v1/datasets/{dataset_id}/documents/{document_id}/chunks"
CHAT_ASSISTANT_API_URL = "/api/v1/chats"
INVALID_API_TOKEN = "invalid_key_123"
DATASET_NAME_LIMIT = 128
@ -39,7 +40,7 @@ def create_dataset(auth, payload=None):
return res.json()
def list_dataset(auth, params=None):
def list_datasets(auth, params=None):
res = requests.get(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=HEADERS, auth=auth, params=params)
return res.json()
@ -49,7 +50,7 @@ def update_dataset(auth, dataset_id, payload=None):
return res.json()
def delete_dataset(auth, payload=None):
def delete_datasets(auth, payload=None):
res = requests.delete(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=HEADERS, auth=auth, json=payload)
return res.json()
@ -105,7 +106,7 @@ def download_document(auth, dataset_id, document_id, save_path):
return res
def list_documnet(auth, dataset_id, params=None):
def list_documnets(auth, dataset_id, params=None):
url = f"{HOST_ADDRESS}{FILE_API_URL}".format(dataset_id=dataset_id)
res = requests.get(url=url, headers=HEADERS, auth=auth, params=params)
return res.json()
@ -117,19 +118,19 @@ def update_documnet(auth, dataset_id, document_id, payload=None):
return res.json()
def delete_documnet(auth, dataset_id, payload=None):
def delete_documnets(auth, dataset_id, payload=None):
url = f"{HOST_ADDRESS}{FILE_API_URL}".format(dataset_id=dataset_id)
res = requests.delete(url=url, headers=HEADERS, auth=auth, json=payload)
return res.json()
def parse_documnet(auth, dataset_id, payload=None):
def parse_documnets(auth, dataset_id, payload=None):
url = f"{HOST_ADDRESS}{FILE_CHUNK_API_URL}".format(dataset_id=dataset_id)
res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload)
return res.json()
def stop_parse_documnet(auth, dataset_id, payload=None):
def stop_parse_documnets(auth, dataset_id, payload=None):
url = f"{HOST_ADDRESS}{FILE_CHUNK_API_URL}".format(dataset_id=dataset_id)
res = requests.delete(url=url, headers=HEADERS, auth=auth, json=payload)
return res.json()
@ -184,3 +185,36 @@ def batch_add_chunks(auth, dataset_id, document_id, num):
res = add_chunk(auth, dataset_id, document_id, {"content": f"chunk test {i}"})
chunk_ids.append(res["data"]["chunk"]["id"])
return chunk_ids
# CHAT ASSISTANT MANAGEMENT
def create_chat_assistant(auth, payload=None):
url = f"{HOST_ADDRESS}{CHAT_ASSISTANT_API_URL}"
res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload)
return res.json()
def list_chat_assistants(auth, params=None):
url = f"{HOST_ADDRESS}{CHAT_ASSISTANT_API_URL}"
res = requests.get(url=url, headers=HEADERS, auth=auth, params=params)
return res.json()
def update_chat_assistant(auth, chat_assistant_id, payload=None):
url = f"{HOST_ADDRESS}{CHAT_ASSISTANT_API_URL}/{chat_assistant_id}"
res = requests.put(url=url, headers=HEADERS, auth=auth, json=payload)
return res.json()
def delete_chat_assistants(auth, payload=None):
url = f"{HOST_ADDRESS}{CHAT_ASSISTANT_API_URL}"
res = requests.delete(url=url, headers=HEADERS, auth=auth, json=payload)
return res.json()
def batch_create_chat_assistants(auth, num):
chat_assistant_ids = []
for i in range(num):
res = create_chat_assistant(auth, {"name": f"test_chat_assistant_{i}"})
chat_assistant_ids.append(res["data"]["id"])
return chat_assistant_ids

View File

@ -14,9 +14,8 @@
# limitations under the License.
#
import pytest
from common import delete_dataset
from common import batch_create_datasets, bulk_upload_documents, delete_datasets
from libs.utils.file_utils import (
create_docx_file,
create_eml_file,
@ -34,7 +33,7 @@ from libs.utils.file_utils import (
@pytest.fixture(scope="function")
def clear_datasets(get_http_api_auth):
yield
delete_dataset(get_http_api_auth)
delete_datasets(get_http_api_auth)
@pytest.fixture
@ -58,3 +57,38 @@ def generate_test_files(request, tmp_path):
creator_func(file_path)
files[file_type] = file_path
return files
@pytest.fixture(scope="class")
def ragflow_tmp_dir(request, tmp_path_factory):
class_name = request.cls.__name__
return tmp_path_factory.mktemp(class_name)
@pytest.fixture(scope="class")
def add_dataset(request, get_http_api_auth):
def cleanup():
delete_datasets(get_http_api_auth)
request.addfinalizer(cleanup)
dataset_ids = batch_create_datasets(get_http_api_auth, 1)
return dataset_ids[0]
@pytest.fixture(scope="function")
def add_dataset_func(request, get_http_api_auth):
def cleanup():
delete_datasets(get_http_api_auth)
request.addfinalizer(cleanup)
dataset_ids = batch_create_datasets(get_http_api_auth, 1)
return dataset_ids[0]
@pytest.fixture(scope="class")
def add_document(get_http_api_auth, add_dataset, ragflow_tmp_dir):
dataset_id = add_dataset
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 1, ragflow_tmp_dir)
return dataset_id, document_ids[0]

View File

@ -16,13 +16,13 @@
import pytest
from common import add_chunk, batch_create_datasets, bulk_upload_documents, delete_chunks, delete_dataset, list_documnet, parse_documnet
from common import add_chunk, delete_chunks, list_documnets, parse_documnets
from libs.utils import wait_for
@wait_for(10, 1, "Document parsing timeout")
def condition(_auth, _dataset_id):
res = list_documnet(_auth, _dataset_id)
res = list_documnets(_auth, _dataset_id)
for doc in res["data"]["docs"]:
if doc["run"] != "DONE":
return False
@ -30,29 +30,11 @@ def condition(_auth, _dataset_id):
@pytest.fixture(scope="class")
def chunk_management_tmp_dir(tmp_path_factory):
return tmp_path_factory.mktemp("chunk_management")
@pytest.fixture(scope="class")
def get_dataset_id_and_document_id(get_http_api_auth, chunk_management_tmp_dir, request):
def cleanup():
delete_dataset(get_http_api_auth)
request.addfinalizer(cleanup)
dataset_ids = batch_create_datasets(get_http_api_auth, 1)
dataset_id = dataset_ids[0]
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 1, chunk_management_tmp_dir)
parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids})
def add_chunks(get_http_api_auth, add_document):
dataset_id, document_id = add_document
parse_documnets(get_http_api_auth, dataset_id, {"document_ids": [document_id]})
condition(get_http_api_auth, dataset_id)
return dataset_id, document_ids[0]
@pytest.fixture(scope="class")
def add_chunks(get_http_api_auth, get_dataset_id_and_document_id):
dataset_id, document_id = get_dataset_id_and_document_id
chunk_ids = []
for i in range(4):
res = add_chunk(get_http_api_auth, dataset_id, document_id, {"content": f"chunk test {i}"})
@ -66,8 +48,10 @@ def add_chunks(get_http_api_auth, get_dataset_id_and_document_id):
@pytest.fixture(scope="function")
def add_chunks_func(get_http_api_auth, get_dataset_id_and_document_id, request):
dataset_id, document_id = get_dataset_id_and_document_id
def add_chunks_func(request, get_http_api_auth, add_document):
dataset_id, document_id = add_document
parse_documnets(get_http_api_auth, dataset_id, {"document_ids": [document_id]})
condition(get_http_api_auth, dataset_id)
chunk_ids = []
for i in range(4):

View File

@ -16,7 +16,7 @@
from concurrent.futures import ThreadPoolExecutor
import pytest
from common import INVALID_API_TOKEN, add_chunk, delete_documnet, list_chunks
from common import INVALID_API_TOKEN, add_chunk, delete_documnets, list_chunks
from libs.auth import RAGFlowHttpApiAuth
@ -44,7 +44,7 @@ class TestAuthorization:
],
)
def test_invalid_auth(self, auth, expected_code, expected_message):
res = add_chunk(auth, "dataset_id", "document_id", {})
res = add_chunk(auth, "dataset_id", "document_id")
assert res["code"] == expected_code
assert res["message"] == expected_message
@ -66,8 +66,8 @@ class TestAddChunk:
({"content": "\n!?。;!?\"'"}, 0, ""),
],
)
def test_content(self, get_http_api_auth, get_dataset_id_and_document_id, payload, expected_code, expected_message):
dataset_id, document_id = get_dataset_id_and_document_id
def test_content(self, get_http_api_auth, add_document, payload, expected_code, expected_message):
dataset_id, document_id = add_document
res = list_chunks(get_http_api_auth, dataset_id, document_id)
if res["code"] != 0:
assert False, res
@ -98,8 +98,8 @@ class TestAddChunk:
({"content": "chunk test", "important_keywords": 123}, 102, "`important_keywords` is required to be a list"),
],
)
def test_important_keywords(self, get_http_api_auth, get_dataset_id_and_document_id, payload, expected_code, expected_message):
dataset_id, document_id = get_dataset_id_and_document_id
def test_important_keywords(self, get_http_api_auth, add_document, payload, expected_code, expected_message):
dataset_id, document_id = add_document
res = list_chunks(get_http_api_auth, dataset_id, document_id)
if res["code"] != 0:
assert False, res
@ -126,8 +126,8 @@ class TestAddChunk:
({"content": "chunk test", "questions": 123}, 102, "`questions` is required to be a list"),
],
)
def test_questions(self, get_http_api_auth, get_dataset_id_and_document_id, payload, expected_code, expected_message):
dataset_id, document_id = get_dataset_id_and_document_id
def test_questions(self, get_http_api_auth, add_document, payload, expected_code, expected_message):
dataset_id, document_id = add_document
res = list_chunks(get_http_api_auth, dataset_id, document_id)
if res["code"] != 0:
assert False, res
@ -157,12 +157,12 @@ class TestAddChunk:
def test_invalid_dataset_id(
self,
get_http_api_auth,
get_dataset_id_and_document_id,
add_document,
dataset_id,
expected_code,
expected_message,
):
_, document_id = get_dataset_id_and_document_id
_, document_id = add_document
res = add_chunk(get_http_api_auth, dataset_id, document_id, {"content": "a"})
assert res["code"] == expected_code
assert res["message"] == expected_message
@ -178,15 +178,15 @@ class TestAddChunk:
),
],
)
def test_invalid_document_id(self, get_http_api_auth, get_dataset_id_and_document_id, document_id, expected_code, expected_message):
dataset_id, _ = get_dataset_id_and_document_id
def test_invalid_document_id(self, get_http_api_auth, add_document, document_id, expected_code, expected_message):
dataset_id, _ = add_document
res = add_chunk(get_http_api_auth, dataset_id, document_id, {"content": "chunk test"})
assert res["code"] == expected_code
assert res["message"] == expected_message
def test_repeated_add_chunk(self, get_http_api_auth, get_dataset_id_and_document_id):
def test_repeated_add_chunk(self, get_http_api_auth, add_document):
payload = {"content": "chunk test"}
dataset_id, document_id = get_dataset_id_and_document_id
dataset_id, document_id = add_document
res = list_chunks(get_http_api_auth, dataset_id, document_id)
if res["code"] != 0:
assert False, res
@ -207,17 +207,17 @@ class TestAddChunk:
assert False, res
assert res["data"]["doc"]["chunk_count"] == chunks_count + 2
def test_add_chunk_to_deleted_document(self, get_http_api_auth, get_dataset_id_and_document_id):
dataset_id, document_id = get_dataset_id_and_document_id
delete_documnet(get_http_api_auth, dataset_id, {"ids": [document_id]})
def test_add_chunk_to_deleted_document(self, get_http_api_auth, add_document):
dataset_id, document_id = add_document
delete_documnets(get_http_api_auth, dataset_id, {"ids": [document_id]})
res = add_chunk(get_http_api_auth, dataset_id, document_id, {"content": "chunk test"})
assert res["code"] == 102
assert res["message"] == f"You don't own the document {document_id}."
@pytest.mark.skip(reason="issues/6411")
def test_concurrent_add_chunk(self, get_http_api_auth, get_dataset_id_and_document_id):
def test_concurrent_add_chunk(self, get_http_api_auth, add_document):
chunk_num = 50
dataset_id, document_id = get_dataset_id_and_document_id
dataset_id, document_id = add_document
res = list_chunks(get_http_api_auth, dataset_id, document_id)
if res["code"] != 0:
assert False, res

View File

@ -39,7 +39,7 @@ class TestAuthorization:
assert res["message"] == expected_message
class TestChunkstDeletion:
class TestChunksDeletion:
@pytest.mark.parametrize(
"dataset_id, expected_code, expected_message",
[
@ -61,25 +61,14 @@ class TestChunkstDeletion:
"document_id, expected_code, expected_message",
[
("", 100, "<MethodNotAllowed '405: Method Not Allowed'>"),
pytest.param(
"invalid_document_id",
100,
"LookupError('Document not found which is supposed to be there')",
marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="issues/6611"),
),
pytest.param(
"invalid_document_id",
100,
"rm_chunk deleted chunks 0, expect 4",
marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "elasticsearch"], reason="issues/6611"),
),
("invalid_document_id", 100, """LookupError("Can't find the document with ID invalid_document_id!")"""),
],
)
def test_invalid_document_id(self, get_http_api_auth, add_chunks_func, document_id, expected_code, expected_message):
dataset_id, _, chunk_ids = add_chunks_func
res = delete_chunks(get_http_api_auth, dataset_id, document_id, {"chunk_ids": chunk_ids})
assert res["code"] == expected_code
#assert res["message"] == expected_message
assert res["message"] == expected_message
@pytest.mark.parametrize(
"payload",

View File

@ -17,11 +17,7 @@ import os
from concurrent.futures import ThreadPoolExecutor
import pytest
from common import (
INVALID_API_TOKEN,
batch_add_chunks,
list_chunks,
)
from common import INVALID_API_TOKEN, batch_add_chunks, list_chunks
from libs.auth import RAGFlowHttpApiAuth
@ -153,8 +149,9 @@ class TestChunksList:
assert all(r["code"] == 0 for r in responses)
assert all(len(r["data"]["chunks"]) == 5 for r in responses)
def test_default(self, get_http_api_auth, get_dataset_id_and_document_id):
dataset_id, document_id = get_dataset_id_and_document_id
def test_default(self, get_http_api_auth, add_document):
dataset_id, document_id = add_document
res = list_chunks(get_http_api_auth, dataset_id, document_id)
chunks_count = res["data"]["doc"]["chunk_count"]
batch_add_chunks(get_http_api_auth, dataset_id, document_id, 31)

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import pytest
@ -52,9 +51,7 @@ class TestChunksRetrieval:
({"question": "chunk"}, 102, 0, "`dataset_ids` is required."),
],
)
def test_basic_scenarios(
self, get_http_api_auth, add_chunks, payload, expected_code, expected_page_size, expected_message
):
def test_basic_scenarios(self, get_http_api_auth, add_chunks, payload, expected_code, expected_page_size, expected_message):
dataset_id, document_id, _ = add_chunks
if "dataset_ids" in payload:
payload["dataset_ids"] = [dataset_id]
@ -137,9 +134,7 @@ class TestChunksRetrieval:
),
],
)
def test_page_size(
self, get_http_api_auth, add_chunks, payload, expected_code, expected_page_size, expected_message
):
def test_page_size(self, get_http_api_auth, add_chunks, payload, expected_code, expected_page_size, expected_message):
dataset_id, _, _ = add_chunks
payload.update({"question": "chunk", "dataset_ids": [dataset_id]})
@ -165,9 +160,7 @@ class TestChunksRetrieval:
),
],
)
def test_vector_similarity_weight(
self, get_http_api_auth, add_chunks, payload, expected_code, expected_page_size, expected_message
):
def test_vector_similarity_weight(self, get_http_api_auth, add_chunks, payload, expected_code, expected_page_size, expected_message):
dataset_id, _, _ = add_chunks
payload.update({"question": "chunk", "dataset_ids": [dataset_id]})
res = retrieval_chunks(get_http_api_auth, payload)
@ -233,9 +226,7 @@ class TestChunksRetrieval:
"payload, expected_code, expected_message",
[
({"rerank_id": "BAAI/bge-reranker-v2-m3"}, 0, ""),
pytest.param(
{"rerank_id": "unknown"}, 100, "LookupError('Model(unknown) not authorized')", marks=pytest.mark.skip
),
pytest.param({"rerank_id": "unknown"}, 100, "LookupError('Model(unknown) not authorized')", marks=pytest.mark.skip),
],
)
def test_rerank_id(self, get_http_api_auth, add_chunks, payload, expected_code, expected_message):
@ -248,7 +239,6 @@ class TestChunksRetrieval:
else:
assert expected_message in res["message"]
@pytest.mark.skip(reason="chat model is not set")
@pytest.mark.parametrize(
"payload, expected_code, expected_page_size, expected_message",
[
@ -279,9 +269,7 @@ class TestChunksRetrieval:
pytest.param({"highlight": None}, 0, False, "", marks=pytest.mark.skip(reason="issues/6648")),
],
)
def test_highlight(
self, get_http_api_auth, add_chunks, payload, expected_code, expected_highlight, expected_message
):
def test_highlight(self, get_http_api_auth, add_chunks, payload, expected_code, expected_highlight, expected_message):
dataset_id, _, _ = add_chunks
payload.update({"question": "chunk", "dataset_ids": [dataset_id]})
res = retrieval_chunks(get_http_api_auth, payload)
@ -302,3 +290,14 @@ class TestChunksRetrieval:
res = retrieval_chunks(get_http_api_auth, payload)
assert res["code"] == 0
assert len(res["data"]["chunks"]) == 4
def test_concurrent_retrieval(self, get_http_api_auth, add_chunks):
from concurrent.futures import ThreadPoolExecutor
dataset_id, _, _ = add_chunks
payload = {"question": "chunk", "dataset_ids": [dataset_id]}
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(retrieval_chunks, get_http_api_auth, payload) for i in range(100)]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses)

View File

@ -18,7 +18,7 @@ from concurrent.futures import ThreadPoolExecutor
from random import randint
import pytest
from common import INVALID_API_TOKEN, delete_documnet, update_chunk
from common import INVALID_API_TOKEN, delete_documnets, update_chunk
from libs.auth import RAGFlowHttpApiAuth
@ -233,7 +233,7 @@ class TestUpdatedChunk:
def test_update_chunk_to_deleted_document(self, get_http_api_auth, add_chunks):
dataset_id, document_id, chunk_ids = add_chunks
delete_documnet(get_http_api_auth, dataset_id, {"ids": [document_id]})
delete_documnets(get_http_api_auth, dataset_id, {"ids": [document_id]})
res = update_chunk(get_http_api_auth, dataset_id, document_id, chunk_ids[0])
assert res["code"] == 102
assert res["message"] == f"Can't find this chunk {chunk_ids[0]}"

View File

@ -16,14 +16,24 @@
import pytest
from common import batch_create_datasets, delete_dataset
from common import batch_create_datasets, delete_datasets
@pytest.fixture(scope="class")
def get_dataset_ids(get_http_api_auth, request):
def add_datasets(get_http_api_auth, request):
def cleanup():
delete_dataset(get_http_api_auth)
delete_datasets(get_http_api_auth)
request.addfinalizer(cleanup)
return batch_create_datasets(get_http_api_auth, 5)
@pytest.fixture(scope="function")
def add_datasets_func(get_http_api_auth, request):
def cleanup():
delete_datasets(get_http_api_auth)
request.addfinalizer(cleanup)
return batch_create_datasets(get_http_api_auth, 3)

View File

@ -75,9 +75,6 @@ class TestDatasetCreation:
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, f"Failed to create dataset {i}"
@pytest.mark.usefixtures("clear_datasets")
class TestAdvancedConfigurations:
def test_avatar(self, get_http_api_auth, tmp_path):
fn = create_image_file(tmp_path / "ragflow_test.png")
payload = {

View File

@ -20,13 +20,12 @@ import pytest
from common import (
INVALID_API_TOKEN,
batch_create_datasets,
delete_dataset,
list_dataset,
delete_datasets,
list_datasets,
)
from libs.auth import RAGFlowHttpApiAuth
@pytest.mark.usefixtures("clear_datasets")
class TestAuthorization:
@pytest.mark.parametrize(
"auth, expected_code, expected_message",
@ -39,18 +38,13 @@ class TestAuthorization:
),
],
)
def test_invalid_auth(self, get_http_api_auth, auth, expected_code, expected_message):
ids = batch_create_datasets(get_http_api_auth, 1)
res = delete_dataset(auth, {"ids": ids})
def test_invalid_auth(self, auth, expected_code, expected_message):
res = delete_datasets(auth)
assert res["code"] == expected_code
assert res["message"] == expected_message
res = list_dataset(get_http_api_auth)
assert len(res["data"]) == 1
@pytest.mark.usefixtures("clear_datasets")
class TestDatasetDeletion:
class TestDatasetsDeletion:
@pytest.mark.parametrize(
"payload, expected_code, expected_message, remaining",
[
@ -73,16 +67,16 @@ class TestDatasetDeletion:
(lambda r: {"ids": r}, 0, "", 0),
],
)
def test_basic_scenarios(self, get_http_api_auth, payload, expected_code, expected_message, remaining):
ids = batch_create_datasets(get_http_api_auth, 3)
def test_basic_scenarios(self, get_http_api_auth, add_datasets_func, payload, expected_code, expected_message, remaining):
dataset_ids = add_datasets_func
if callable(payload):
payload = payload(ids)
res = delete_dataset(get_http_api_auth, payload)
payload = payload(dataset_ids)
res = delete_datasets(get_http_api_auth, payload)
assert res["code"] == expected_code
if res["code"] != 0:
assert res["message"] == expected_message
res = list_dataset(get_http_api_auth)
res = list_datasets(get_http_api_auth)
assert len(res["data"]) == remaining
@pytest.mark.parametrize(
@ -93,50 +87,50 @@ class TestDatasetDeletion:
lambda r: {"ids": r + ["invalid_id"]},
],
)
def test_delete_partial_invalid_id(self, get_http_api_auth, payload):
ids = batch_create_datasets(get_http_api_auth, 3)
def test_delete_partial_invalid_id(self, get_http_api_auth, add_datasets_func, payload):
dataset_ids = add_datasets_func
if callable(payload):
payload = payload(ids)
res = delete_dataset(get_http_api_auth, payload)
payload = payload(dataset_ids)
res = delete_datasets(get_http_api_auth, payload)
assert res["code"] == 0
assert res["data"]["errors"][0] == "You don't own the dataset invalid_id"
assert res["data"]["success_count"] == 3
res = list_dataset(get_http_api_auth)
res = list_datasets(get_http_api_auth)
assert len(res["data"]) == 0
def test_repeated_deletion(self, get_http_api_auth):
ids = batch_create_datasets(get_http_api_auth, 1)
res = delete_dataset(get_http_api_auth, {"ids": ids})
def test_repeated_deletion(self, get_http_api_auth, add_datasets_func):
dataset_ids = add_datasets_func
res = delete_datasets(get_http_api_auth, {"ids": dataset_ids})
assert res["code"] == 0
res = delete_dataset(get_http_api_auth, {"ids": ids})
res = delete_datasets(get_http_api_auth, {"ids": dataset_ids})
assert res["code"] == 102
assert res["message"] == f"You don't own the dataset {ids[0]}"
assert "You don't own the dataset" in res["message"]
def test_duplicate_deletion(self, get_http_api_auth):
ids = batch_create_datasets(get_http_api_auth, 1)
res = delete_dataset(get_http_api_auth, {"ids": ids + ids})
def test_duplicate_deletion(self, get_http_api_auth, add_datasets_func):
dataset_ids = add_datasets_func
res = delete_datasets(get_http_api_auth, {"ids": dataset_ids + dataset_ids})
assert res["code"] == 0
assert res["data"]["errors"][0] == f"Duplicate dataset ids: {ids[0]}"
assert res["data"]["success_count"] == 1
assert "Duplicate dataset ids" in res["data"]["errors"][0]
assert res["data"]["success_count"] == 3
res = list_dataset(get_http_api_auth)
res = list_datasets(get_http_api_auth)
assert len(res["data"]) == 0
def test_concurrent_deletion(self, get_http_api_auth):
ids = batch_create_datasets(get_http_api_auth, 100)
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(delete_dataset, get_http_api_auth, {"ids": ids[i : i + 1]}) for i in range(100)]
futures = [executor.submit(delete_datasets, get_http_api_auth, {"ids": ids[i : i + 1]}) for i in range(100)]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses)
@pytest.mark.slow
def test_delete_10k(self, get_http_api_auth):
ids = batch_create_datasets(get_http_api_auth, 10_000)
res = delete_dataset(get_http_api_auth, {"ids": ids})
res = delete_datasets(get_http_api_auth, {"ids": ids})
assert res["code"] == 0
res = list_dataset(get_http_api_auth)
res = list_datasets(get_http_api_auth)
assert len(res["data"]) == 0

View File

@ -16,7 +16,7 @@
from concurrent.futures import ThreadPoolExecutor
import pytest
from common import INVALID_API_TOKEN, list_dataset
from common import INVALID_API_TOKEN, list_datasets
from libs.auth import RAGFlowHttpApiAuth
@ -25,7 +25,6 @@ def is_sorted(data, field, descending=True):
return all(a >= b for a, b in zip(timestamps, timestamps[1:])) if descending else all(a <= b for a, b in zip(timestamps, timestamps[1:]))
@pytest.mark.usefixtures("clear_datasets")
class TestAuthorization:
@pytest.mark.parametrize(
"auth, expected_code, expected_message",
@ -39,15 +38,15 @@ class TestAuthorization:
],
)
def test_invalid_auth(self, auth, expected_code, expected_message):
res = list_dataset(auth)
res = list_datasets(auth)
assert res["code"] == expected_code
assert res["message"] == expected_message
@pytest.mark.usefixtures("get_dataset_ids")
class TestDatasetList:
@pytest.mark.usefixtures("add_datasets")
class TestDatasetsList:
def test_default(self, get_http_api_auth):
res = list_dataset(get_http_api_auth, params={})
res = list_datasets(get_http_api_auth, params={})
assert res["code"] == 0
assert len(res["data"]) == 5
@ -77,7 +76,7 @@ class TestDatasetList:
],
)
def test_page(self, get_http_api_auth, params, expected_code, expected_page_size, expected_message):
res = list_dataset(get_http_api_auth, params=params)
res = list_datasets(get_http_api_auth, params=params)
assert res["code"] == expected_code
if expected_code == 0:
assert len(res["data"]) == expected_page_size
@ -116,7 +115,7 @@ class TestDatasetList:
expected_page_size,
expected_message,
):
res = list_dataset(get_http_api_auth, params=params)
res = list_datasets(get_http_api_auth, params=params)
assert res["code"] == expected_code
if expected_code == 0:
assert len(res["data"]) == expected_page_size
@ -168,7 +167,7 @@ class TestDatasetList:
assertions,
expected_message,
):
res = list_dataset(get_http_api_auth, params=params)
res = list_datasets(get_http_api_auth, params=params)
assert res["code"] == expected_code
if expected_code == 0:
if callable(assertions):
@ -244,7 +243,7 @@ class TestDatasetList:
assertions,
expected_message,
):
res = list_dataset(get_http_api_auth, params=params)
res = list_datasets(get_http_api_auth, params=params)
assert res["code"] == expected_code
if expected_code == 0:
if callable(assertions):
@ -262,7 +261,7 @@ class TestDatasetList:
],
)
def test_name(self, get_http_api_auth, params, expected_code, expected_num, expected_message):
res = list_dataset(get_http_api_auth, params=params)
res = list_datasets(get_http_api_auth, params=params)
assert res["code"] == expected_code
if expected_code == 0:
if params["name"] in [None, ""]:
@ -284,19 +283,19 @@ class TestDatasetList:
def test_id(
self,
get_http_api_auth,
get_dataset_ids,
add_datasets,
dataset_id,
expected_code,
expected_num,
expected_message,
):
dataset_ids = get_dataset_ids
dataset_ids = add_datasets
if callable(dataset_id):
params = {"id": dataset_id(dataset_ids)}
else:
params = {"id": dataset_id}
res = list_dataset(get_http_api_auth, params=params)
res = list_datasets(get_http_api_auth, params=params)
assert res["code"] == expected_code
if expected_code == 0:
if params["id"] in [None, ""]:
@ -318,20 +317,20 @@ class TestDatasetList:
def test_name_and_id(
self,
get_http_api_auth,
get_dataset_ids,
add_datasets,
dataset_id,
name,
expected_code,
expected_num,
expected_message,
):
dataset_ids = get_dataset_ids
dataset_ids = add_datasets
if callable(dataset_id):
params = {"id": dataset_id(dataset_ids), "name": name}
else:
params = {"id": dataset_id, "name": name}
res = list_dataset(get_http_api_auth, params=params)
res = list_datasets(get_http_api_auth, params=params)
if expected_code == 0:
assert len(res["data"]) == expected_num
else:
@ -339,12 +338,12 @@ class TestDatasetList:
def test_concurrent_list(self, get_http_api_auth):
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(list_dataset, get_http_api_auth) for i in range(100)]
futures = [executor.submit(list_datasets, get_http_api_auth) for i in range(100)]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses)
def test_invalid_params(self, get_http_api_auth):
params = {"a": "b"}
res = list_dataset(get_http_api_auth, params=params)
res = list_datasets(get_http_api_auth, params=params)
assert res["code"] == 0
assert len(res["data"]) == 5

View File

@ -19,8 +19,7 @@ import pytest
from common import (
DATASET_NAME_LIMIT,
INVALID_API_TOKEN,
batch_create_datasets,
list_dataset,
list_datasets,
update_dataset,
)
from libs.auth import RAGFlowHttpApiAuth
@ -30,7 +29,6 @@ from libs.utils.file_utils import create_image_file
# TODO: Missing scenario for updating embedding_model with chunk_count != 0
@pytest.mark.usefixtures("clear_datasets")
class TestAuthorization:
@pytest.mark.parametrize(
"auth, expected_code, expected_message",
@ -43,14 +41,12 @@ class TestAuthorization:
),
],
)
def test_invalid_auth(self, get_http_api_auth, auth, expected_code, expected_message):
ids = batch_create_datasets(get_http_api_auth, 1)
res = update_dataset(auth, ids[0], {"name": "new_name"})
def test_invalid_auth(self, auth, expected_code, expected_message):
res = update_dataset(auth, "dataset_id")
assert res["code"] == expected_code
assert res["message"] == expected_message
@pytest.mark.usefixtures("clear_datasets")
class TestDatasetUpdate:
@pytest.mark.parametrize(
"name, expected_code, expected_message",
@ -72,12 +68,12 @@ class TestDatasetUpdate:
("DATASET_1", 102, "Duplicated dataset name in updating dataset."),
],
)
def test_name(self, get_http_api_auth, name, expected_code, expected_message):
ids = batch_create_datasets(get_http_api_auth, 2)
res = update_dataset(get_http_api_auth, ids[0], {"name": name})
def test_name(self, get_http_api_auth, add_datasets_func, name, expected_code, expected_message):
dataset_ids = add_datasets_func
res = update_dataset(get_http_api_auth, dataset_ids[0], {"name": name})
assert res["code"] == expected_code
if expected_code == 0:
res = list_dataset(get_http_api_auth, {"id": ids[0]})
res = list_datasets(get_http_api_auth, {"id": dataset_ids[0]})
assert res["data"][0]["name"] == name
else:
assert res["message"] == expected_message
@ -95,12 +91,12 @@ class TestDatasetUpdate:
(None, 102, "`embedding_model` can't be empty"),
],
)
def test_embedding_model(self, get_http_api_auth, embedding_model, expected_code, expected_message):
ids = batch_create_datasets(get_http_api_auth, 1)
res = update_dataset(get_http_api_auth, ids[0], {"embedding_model": embedding_model})
def test_embedding_model(self, get_http_api_auth, add_dataset_func, embedding_model, expected_code, expected_message):
dataset_id = add_dataset_func
res = update_dataset(get_http_api_auth, dataset_id, {"embedding_model": embedding_model})
assert res["code"] == expected_code
if expected_code == 0:
res = list_dataset(get_http_api_auth, {"id": ids[0]})
res = list_datasets(get_http_api_auth, {"id": dataset_id})
assert res["data"][0]["embedding_model"] == embedding_model
else:
assert res["message"] == expected_message
@ -129,12 +125,12 @@ class TestDatasetUpdate:
),
],
)
def test_chunk_method(self, get_http_api_auth, chunk_method, expected_code, expected_message):
ids = batch_create_datasets(get_http_api_auth, 1)
res = update_dataset(get_http_api_auth, ids[0], {"chunk_method": chunk_method})
def test_chunk_method(self, get_http_api_auth, add_dataset_func, chunk_method, expected_code, expected_message):
dataset_id = add_dataset_func
res = update_dataset(get_http_api_auth, dataset_id, {"chunk_method": chunk_method})
assert res["code"] == expected_code
if expected_code == 0:
res = list_dataset(get_http_api_auth, {"id": ids[0]})
res = list_datasets(get_http_api_auth, {"id": dataset_id})
if chunk_method != "":
assert res["data"][0]["chunk_method"] == chunk_method
else:
@ -142,38 +138,38 @@ class TestDatasetUpdate:
else:
assert res["message"] == expected_message
def test_avatar(self, get_http_api_auth, tmp_path):
ids = batch_create_datasets(get_http_api_auth, 1)
def test_avatar(self, get_http_api_auth, add_dataset_func, tmp_path):
dataset_id = add_dataset_func
fn = create_image_file(tmp_path / "ragflow_test.png")
payload = {"avatar": encode_avatar(fn)}
res = update_dataset(get_http_api_auth, ids[0], payload)
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 0
def test_description(self, get_http_api_auth):
ids = batch_create_datasets(get_http_api_auth, 1)
def test_description(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func
payload = {"description": "description"}
res = update_dataset(get_http_api_auth, ids[0], payload)
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 0
res = list_dataset(get_http_api_auth, {"id": ids[0]})
res = list_datasets(get_http_api_auth, {"id": dataset_id})
assert res["data"][0]["description"] == "description"
def test_pagerank(self, get_http_api_auth):
ids = batch_create_datasets(get_http_api_auth, 1)
def test_pagerank(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func
payload = {"pagerank": 1}
res = update_dataset(get_http_api_auth, ids[0], payload)
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 0
res = list_dataset(get_http_api_auth, {"id": ids[0]})
res = list_datasets(get_http_api_auth, {"id": dataset_id})
assert res["data"][0]["pagerank"] == 1
def test_similarity_threshold(self, get_http_api_auth):
ids = batch_create_datasets(get_http_api_auth, 1)
def test_similarity_threshold(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func
payload = {"similarity_threshold": 1}
res = update_dataset(get_http_api_auth, ids[0], payload)
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 0
res = list_dataset(get_http_api_auth, {"id": ids[0]})
res = list_datasets(get_http_api_auth, {"id": dataset_id})
assert res["data"][0]["similarity_threshold"] == 1
@pytest.mark.parametrize(
@ -187,29 +183,28 @@ class TestDatasetUpdate:
("other_permission", 102),
],
)
def test_permission(self, get_http_api_auth, permission, expected_code):
ids = batch_create_datasets(get_http_api_auth, 1)
def test_permission(self, get_http_api_auth, add_dataset_func, permission, expected_code):
dataset_id = add_dataset_func
payload = {"permission": permission}
res = update_dataset(get_http_api_auth, ids[0], payload)
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == expected_code
res = list_dataset(get_http_api_auth, {"id": ids[0]})
res = list_datasets(get_http_api_auth, {"id": dataset_id})
if expected_code == 0 and permission != "":
assert res["data"][0]["permission"] == permission
if permission == "":
assert res["data"][0]["permission"] == "me"
def test_vector_similarity_weight(self, get_http_api_auth):
ids = batch_create_datasets(get_http_api_auth, 1)
def test_vector_similarity_weight(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func
payload = {"vector_similarity_weight": 1}
res = update_dataset(get_http_api_auth, ids[0], payload)
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 0
res = list_dataset(get_http_api_auth, {"id": ids[0]})
res = list_datasets(get_http_api_auth, {"id": dataset_id})
assert res["data"][0]["vector_similarity_weight"] == 1
def test_invalid_dataset_id(self, get_http_api_auth):
batch_create_datasets(get_http_api_auth, 1)
res = update_dataset(get_http_api_auth, "invalid_dataset_id", {"name": "invalid_dataset_id"})
assert res["code"] == 102
assert res["message"] == "You don't own the dataset"
@ -230,21 +225,21 @@ class TestDatasetUpdate:
{"update_time": 1741671443339},
],
)
def test_modify_read_only_field(self, get_http_api_auth, payload):
ids = batch_create_datasets(get_http_api_auth, 1)
res = update_dataset(get_http_api_auth, ids[0], payload)
def test_modify_read_only_field(self, get_http_api_auth, add_dataset_func, payload):
dataset_id = add_dataset_func
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 101
assert "is readonly" in res["message"]
def test_modify_unknown_field(self, get_http_api_auth):
ids = batch_create_datasets(get_http_api_auth, 1)
res = update_dataset(get_http_api_auth, ids[0], {"unknown_field": 0})
def test_modify_unknown_field(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func
res = update_dataset(get_http_api_auth, dataset_id, {"unknown_field": 0})
assert res["code"] == 100
def test_concurrent_update(self, get_http_api_auth):
ids = batch_create_datasets(get_http_api_auth, 1)
def test_concurrent_update(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(update_dataset, get_http_api_auth, ids[0], {"name": f"dataset_{i}"}) for i in range(100)]
futures = [executor.submit(update_dataset, get_http_api_auth, dataset_id, {"name": f"dataset_{i}"}) for i in range(100)]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses)

View File

@ -16,22 +16,36 @@
import pytest
from common import batch_create_datasets, bulk_upload_documents, delete_dataset
from common import bulk_upload_documents, delete_documnets
@pytest.fixture(scope="class")
def file_management_tmp_dir(tmp_path_factory):
return tmp_path_factory.mktemp("file_management")
@pytest.fixture(scope="function")
def add_document_func(request, get_http_api_auth, add_dataset, ragflow_tmp_dir):
dataset_id = add_dataset
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 1, ragflow_tmp_dir)
@pytest.fixture(scope="class")
def get_dataset_id_and_document_ids(get_http_api_auth, file_management_tmp_dir, request):
def cleanup():
delete_dataset(get_http_api_auth)
delete_documnets(get_http_api_auth, dataset_id, {"ids": document_ids})
request.addfinalizer(cleanup)
return dataset_id, document_ids[0]
@pytest.fixture(scope="class")
def add_documents(request, get_http_api_auth, add_dataset, ragflow_tmp_dir):
dataset_id = add_dataset
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 5, ragflow_tmp_dir)
def cleanup():
delete_documnets(get_http_api_auth, dataset_id, {"ids": document_ids})
request.addfinalizer(cleanup)
return dataset_id, document_ids
@pytest.fixture(scope="function")
def add_documents_func(get_http_api_auth, add_dataset_func, ragflow_tmp_dir):
dataset_id = add_dataset_func
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 3, ragflow_tmp_dir)
dataset_ids = batch_create_datasets(get_http_api_auth, 1)
dataset_id = dataset_ids[0]
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 5, file_management_tmp_dir)
return dataset_id, document_ids

View File

@ -16,13 +16,7 @@
from concurrent.futures import ThreadPoolExecutor
import pytest
from common import (
INVALID_API_TOKEN,
batch_create_datasets,
bulk_upload_documents,
delete_documnet,
list_documnet,
)
from common import INVALID_API_TOKEN, bulk_upload_documents, delete_documnets, list_documnets
from libs.auth import RAGFlowHttpApiAuth
@ -38,15 +32,13 @@ class TestAuthorization:
),
],
)
def test_invalid_auth(self, get_dataset_id_and_document_ids, auth, expected_code, expected_message):
dataset_id, document_ids = get_dataset_id_and_document_ids
res = delete_documnet(auth, dataset_id, {"ids": document_ids})
def test_invalid_auth(self, auth, expected_code, expected_message):
res = delete_documnets(auth, "dataset_id")
assert res["code"] == expected_code
assert res["message"] == expected_message
@pytest.mark.usefixtures("clear_datasets")
class TestDocumentDeletion:
class TestDocumentsDeletion:
@pytest.mark.parametrize(
"payload, expected_code, expected_message, remaining",
[
@ -72,22 +64,21 @@ class TestDocumentDeletion:
def test_basic_scenarios(
self,
get_http_api_auth,
tmp_path,
add_documents_func,
payload,
expected_code,
expected_message,
remaining,
):
ids = batch_create_datasets(get_http_api_auth, 1)
document_ids = bulk_upload_documents(get_http_api_auth, ids[0], 3, tmp_path)
dataset_id, document_ids = add_documents_func
if callable(payload):
payload = payload(document_ids)
res = delete_documnet(get_http_api_auth, ids[0], payload)
res = delete_documnets(get_http_api_auth, dataset_id, payload)
assert res["code"] == expected_code
if res["code"] != 0:
assert res["message"] == expected_message
res = list_documnet(get_http_api_auth, ids[0])
res = list_documnets(get_http_api_auth, dataset_id)
assert len(res["data"]["docs"]) == remaining
assert res["data"]["total"] == remaining
@ -102,10 +93,9 @@ class TestDocumentDeletion:
),
],
)
def test_invalid_dataset_id(self, get_http_api_auth, tmp_path, dataset_id, expected_code, expected_message):
ids = batch_create_datasets(get_http_api_auth, 1)
document_ids = bulk_upload_documents(get_http_api_auth, ids[0], 3, tmp_path)
res = delete_documnet(get_http_api_auth, dataset_id, {"ids": document_ids[:1]})
def test_invalid_dataset_id(self, get_http_api_auth, add_documents_func, dataset_id, expected_code, expected_message):
_, document_ids = add_documents_func
res = delete_documnets(get_http_api_auth, dataset_id, {"ids": document_ids[:1]})
assert res["code"] == expected_code
assert res["message"] == expected_message
@ -117,69 +107,68 @@ class TestDocumentDeletion:
lambda r: {"ids": r + ["invalid_id"]},
],
)
def test_delete_partial_invalid_id(self, get_http_api_auth, tmp_path, payload):
ids = batch_create_datasets(get_http_api_auth, 1)
document_ids = bulk_upload_documents(get_http_api_auth, ids[0], 3, tmp_path)
def test_delete_partial_invalid_id(self, get_http_api_auth, add_documents_func, payload):
dataset_id, document_ids = add_documents_func
if callable(payload):
payload = payload(document_ids)
res = delete_documnet(get_http_api_auth, ids[0], payload)
res = delete_documnets(get_http_api_auth, dataset_id, payload)
assert res["code"] == 102
assert res["message"] == "Documents not found: ['invalid_id']"
res = list_documnet(get_http_api_auth, ids[0])
res = list_documnets(get_http_api_auth, dataset_id)
assert len(res["data"]["docs"]) == 0
assert res["data"]["total"] == 0
def test_repeated_deletion(self, get_http_api_auth, tmp_path):
ids = batch_create_datasets(get_http_api_auth, 1)
document_ids = bulk_upload_documents(get_http_api_auth, ids[0], 1, tmp_path)
res = delete_documnet(get_http_api_auth, ids[0], {"ids": document_ids})
def test_repeated_deletion(self, get_http_api_auth, add_documents_func):
dataset_id, document_ids = add_documents_func
res = delete_documnets(get_http_api_auth, dataset_id, {"ids": document_ids})
assert res["code"] == 0
res = delete_documnet(get_http_api_auth, ids[0], {"ids": document_ids})
res = delete_documnets(get_http_api_auth, dataset_id, {"ids": document_ids})
assert res["code"] == 102
assert res["message"] == f"Documents not found: {document_ids}"
assert "Documents not found" in res["message"]
def test_duplicate_deletion(self, get_http_api_auth, tmp_path):
ids = batch_create_datasets(get_http_api_auth, 1)
document_ids = bulk_upload_documents(get_http_api_auth, ids[0], 1, tmp_path)
res = delete_documnet(get_http_api_auth, ids[0], {"ids": document_ids + document_ids})
def test_duplicate_deletion(self, get_http_api_auth, add_documents_func):
dataset_id, document_ids = add_documents_func
res = delete_documnets(get_http_api_auth, dataset_id, {"ids": document_ids + document_ids})
assert res["code"] == 0
assert res["data"]["errors"][0] == f"Duplicate document ids: {document_ids[0]}"
assert res["data"]["success_count"] == 1
assert "Duplicate document ids" in res["data"]["errors"][0]
assert res["data"]["success_count"] == 3
res = list_documnet(get_http_api_auth, ids[0])
res = list_documnets(get_http_api_auth, dataset_id)
assert len(res["data"]["docs"]) == 0
assert res["data"]["total"] == 0
def test_concurrent_deletion(self, get_http_api_auth, tmp_path):
documnets_num = 100
ids = batch_create_datasets(get_http_api_auth, 1)
document_ids = bulk_upload_documents(get_http_api_auth, ids[0], documnets_num, tmp_path)
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [
executor.submit(
delete_documnet,
get_http_api_auth,
ids[0],
{"ids": document_ids[i : i + 1]},
)
for i in range(documnets_num)
]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses)
def test_concurrent_deletion(get_http_api_auth, add_dataset, tmp_path):
documnets_num = 100
dataset_id = add_dataset
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, documnets_num, tmp_path)
@pytest.mark.slow
def test_delete_1k(self, get_http_api_auth, tmp_path):
documnets_num = 1_000
ids = batch_create_datasets(get_http_api_auth, 1)
document_ids = bulk_upload_documents(get_http_api_auth, ids[0], documnets_num, tmp_path)
res = list_documnet(get_http_api_auth, ids[0])
assert res["data"]["total"] == documnets_num
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [
executor.submit(
delete_documnets,
get_http_api_auth,
dataset_id,
{"ids": document_ids[i : i + 1]},
)
for i in range(documnets_num)
]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses)
res = delete_documnet(get_http_api_auth, ids[0], {"ids": document_ids})
assert res["code"] == 0
res = list_documnet(get_http_api_auth, ids[0])
assert res["data"]["total"] == 0
@pytest.mark.slow
def test_delete_1k(get_http_api_auth, add_dataset, tmp_path):
documnets_num = 1_000
dataset_id = add_dataset
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, documnets_num, tmp_path)
res = list_documnets(get_http_api_auth, dataset_id)
assert res["data"]["total"] == documnets_num
res = delete_documnets(get_http_api_auth, dataset_id, {"ids": document_ids})
assert res["code"] == 0
res = list_documnets(get_http_api_auth, dataset_id)
assert res["data"]["total"] == 0

View File

@ -18,7 +18,7 @@ import json
from concurrent.futures import ThreadPoolExecutor
import pytest
from common import INVALID_API_TOKEN, batch_create_datasets, bulk_upload_documents, download_document, upload_documnets
from common import INVALID_API_TOKEN, bulk_upload_documents, download_document, upload_documnets
from libs.auth import RAGFlowHttpApiAuth
from libs.utils import compare_by_hash
from requests import codes
@ -36,9 +36,8 @@ class TestAuthorization:
),
],
)
def test_invalid_auth(self, get_dataset_id_and_document_ids, tmp_path, auth, expected_code, expected_message):
dataset_id, document_ids = get_dataset_id_and_document_ids
res = download_document(auth, dataset_id, document_ids[0], tmp_path / "ragflow_tes.txt")
def test_invalid_auth(self, tmp_path, auth, expected_code, expected_message):
res = download_document(auth, "dataset_id", "document_id", tmp_path / "ragflow_tes.txt")
assert res.status_code == codes.ok
with (tmp_path / "ragflow_tes.txt").open("r") as f:
response_json = json.load(f)
@ -46,7 +45,6 @@ class TestAuthorization:
assert response_json["message"] == expected_message
@pytest.mark.usefixtures("clear_datasets")
@pytest.mark.parametrize(
"generate_test_files",
[
@ -63,15 +61,15 @@ class TestAuthorization:
],
indirect=True,
)
def test_file_type_validation(get_http_api_auth, generate_test_files, request):
ids = batch_create_datasets(get_http_api_auth, 1)
def test_file_type_validation(get_http_api_auth, add_dataset, generate_test_files, request):
dataset_id = add_dataset
fp = generate_test_files[request.node.callspec.params["generate_test_files"]]
res = upload_documnets(get_http_api_auth, ids[0], [fp])
res = upload_documnets(get_http_api_auth, dataset_id, [fp])
document_id = res["data"][0]["id"]
res = download_document(
get_http_api_auth,
ids[0],
dataset_id,
document_id,
fp.with_stem("ragflow_test_download"),
)
@ -93,8 +91,8 @@ class TestDocumentDownload:
),
],
)
def test_invalid_document_id(self, get_http_api_auth, get_dataset_id_and_document_ids, tmp_path, document_id, expected_code, expected_message):
dataset_id, _ = get_dataset_id_and_document_ids
def test_invalid_document_id(self, get_http_api_auth, add_documents, tmp_path, document_id, expected_code, expected_message):
dataset_id, _ = add_documents
res = download_document(
get_http_api_auth,
dataset_id,
@ -118,8 +116,8 @@ class TestDocumentDownload:
),
],
)
def test_invalid_dataset_id(self, get_http_api_auth, get_dataset_id_and_document_ids, tmp_path, dataset_id, expected_code, expected_message):
_, document_ids = get_dataset_id_and_document_ids
def test_invalid_dataset_id(self, get_http_api_auth, add_documents, tmp_path, dataset_id, expected_code, expected_message):
_, document_ids = add_documents
res = download_document(
get_http_api_auth,
dataset_id,
@ -132,9 +130,9 @@ class TestDocumentDownload:
assert response_json["code"] == expected_code
assert response_json["message"] == expected_message
def test_same_file_repeat(self, get_http_api_auth, get_dataset_id_and_document_ids, tmp_path, file_management_tmp_dir):
def test_same_file_repeat(self, get_http_api_auth, add_documents, tmp_path, ragflow_tmp_dir):
num = 5
dataset_id, document_ids = get_dataset_id_and_document_ids
dataset_id, document_ids = add_documents
for i in range(num):
res = download_document(
get_http_api_auth,
@ -144,23 +142,22 @@ class TestDocumentDownload:
)
assert res.status_code == codes.ok
assert compare_by_hash(
file_management_tmp_dir / "ragflow_test_upload_0.txt",
ragflow_tmp_dir / "ragflow_test_upload_0.txt",
tmp_path / f"ragflow_test_download_{i}.txt",
)
@pytest.mark.usefixtures("clear_datasets")
def test_concurrent_download(get_http_api_auth, tmp_path):
def test_concurrent_download(get_http_api_auth, add_dataset, tmp_path):
document_count = 20
ids = batch_create_datasets(get_http_api_auth, 1)
document_ids = bulk_upload_documents(get_http_api_auth, ids[0], document_count, tmp_path)
dataset_id = add_dataset
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_count, tmp_path)
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [
executor.submit(
download_document,
get_http_api_auth,
ids[0],
dataset_id,
document_ids[i],
tmp_path / f"ragflow_test_download_{i}.txt",
)

View File

@ -16,10 +16,7 @@
from concurrent.futures import ThreadPoolExecutor
import pytest
from common import (
INVALID_API_TOKEN,
list_documnet,
)
from common import INVALID_API_TOKEN, list_documnets
from libs.auth import RAGFlowHttpApiAuth
@ -40,17 +37,16 @@ class TestAuthorization:
),
],
)
def test_invalid_auth(self, get_dataset_id_and_document_ids, auth, expected_code, expected_message):
dataset_id, _ = get_dataset_id_and_document_ids
res = list_documnet(auth, dataset_id)
def test_invalid_auth(self, auth, expected_code, expected_message):
res = list_documnets(auth, "dataset_id")
assert res["code"] == expected_code
assert res["message"] == expected_message
class TestDocumentList:
def test_default(self, get_http_api_auth, get_dataset_id_and_document_ids):
dataset_id, _ = get_dataset_id_and_document_ids
res = list_documnet(get_http_api_auth, dataset_id)
class TestDocumentsList:
def test_default(self, get_http_api_auth, add_documents):
dataset_id, _ = add_documents
res = list_documnets(get_http_api_auth, dataset_id)
assert res["code"] == 0
assert len(res["data"]["docs"]) == 5
assert res["data"]["total"] == 5
@ -66,8 +62,8 @@ class TestDocumentList:
),
],
)
def test_invalid_dataset_id(self, get_http_api_auth, get_dataset_id_and_document_ids, dataset_id, expected_code, expected_message):
res = list_documnet(get_http_api_auth, dataset_id)
def test_invalid_dataset_id(self, get_http_api_auth, dataset_id, expected_code, expected_message):
res = list_documnets(get_http_api_auth, dataset_id)
assert res["code"] == expected_code
assert res["message"] == expected_message
@ -98,14 +94,14 @@ class TestDocumentList:
def test_page(
self,
get_http_api_auth,
get_dataset_id_and_document_ids,
add_documents,
params,
expected_code,
expected_page_size,
expected_message,
):
dataset_id, _ = get_dataset_id_and_document_ids
res = list_documnet(get_http_api_auth, dataset_id, params=params)
dataset_id, _ = add_documents
res = list_documnets(get_http_api_auth, dataset_id, params=params)
assert res["code"] == expected_code
if expected_code == 0:
assert len(res["data"]["docs"]) == expected_page_size
@ -140,14 +136,14 @@ class TestDocumentList:
def test_page_size(
self,
get_http_api_auth,
get_dataset_id_and_document_ids,
add_documents,
params,
expected_code,
expected_page_size,
expected_message,
):
dataset_id, _ = get_dataset_id_and_document_ids
res = list_documnet(get_http_api_auth, dataset_id, params=params)
dataset_id, _ = add_documents
res = list_documnets(get_http_api_auth, dataset_id, params=params)
assert res["code"] == expected_code
if expected_code == 0:
assert len(res["data"]["docs"]) == expected_page_size
@ -194,14 +190,14 @@ class TestDocumentList:
def test_orderby(
self,
get_http_api_auth,
get_dataset_id_and_document_ids,
add_documents,
params,
expected_code,
assertions,
expected_message,
):
dataset_id, _ = get_dataset_id_and_document_ids
res = list_documnet(get_http_api_auth, dataset_id, params=params)
dataset_id, _ = add_documents
res = list_documnets(get_http_api_auth, dataset_id, params=params)
assert res["code"] == expected_code
if expected_code == 0:
if callable(assertions):
@ -273,14 +269,14 @@ class TestDocumentList:
def test_desc(
self,
get_http_api_auth,
get_dataset_id_and_document_ids,
add_documents,
params,
expected_code,
assertions,
expected_message,
):
dataset_id, _ = get_dataset_id_and_document_ids
res = list_documnet(get_http_api_auth, dataset_id, params=params)
dataset_id, _ = add_documents
res = list_documnets(get_http_api_auth, dataset_id, params=params)
assert res["code"] == expected_code
if expected_code == 0:
if callable(assertions):
@ -298,9 +294,9 @@ class TestDocumentList:
({"keywords": "unknown"}, 0),
],
)
def test_keywords(self, get_http_api_auth, get_dataset_id_and_document_ids, params, expected_num):
dataset_id, _ = get_dataset_id_and_document_ids
res = list_documnet(get_http_api_auth, dataset_id, params=params)
def test_keywords(self, get_http_api_auth, add_documents, params, expected_num):
dataset_id, _ = add_documents
res = list_documnets(get_http_api_auth, dataset_id, params=params)
assert res["code"] == 0
assert len(res["data"]["docs"]) == expected_num
assert res["data"]["total"] == expected_num
@ -322,14 +318,14 @@ class TestDocumentList:
def test_name(
self,
get_http_api_auth,
get_dataset_id_and_document_ids,
add_documents,
params,
expected_code,
expected_num,
expected_message,
):
dataset_id, _ = get_dataset_id_and_document_ids
res = list_documnet(get_http_api_auth, dataset_id, params=params)
dataset_id, _ = add_documents
res = list_documnets(get_http_api_auth, dataset_id, params=params)
assert res["code"] == expected_code
if expected_code == 0:
if params["name"] in [None, ""]:
@ -351,18 +347,18 @@ class TestDocumentList:
def test_id(
self,
get_http_api_auth,
get_dataset_id_and_document_ids,
add_documents,
document_id,
expected_code,
expected_num,
expected_message,
):
dataset_id, document_ids = get_dataset_id_and_document_ids
dataset_id, document_ids = add_documents
if callable(document_id):
params = {"id": document_id(document_ids)}
else:
params = {"id": document_id}
res = list_documnet(get_http_api_auth, dataset_id, params=params)
res = list_documnets(get_http_api_auth, dataset_id, params=params)
assert res["code"] == expected_code
if expected_code == 0:
@ -391,36 +387,36 @@ class TestDocumentList:
def test_name_and_id(
self,
get_http_api_auth,
get_dataset_id_and_document_ids,
add_documents,
document_id,
name,
expected_code,
expected_num,
expected_message,
):
dataset_id, document_ids = get_dataset_id_and_document_ids
dataset_id, document_ids = add_documents
if callable(document_id):
params = {"id": document_id(document_ids), "name": name}
else:
params = {"id": document_id, "name": name}
res = list_documnet(get_http_api_auth, dataset_id, params=params)
res = list_documnets(get_http_api_auth, dataset_id, params=params)
if expected_code == 0:
assert len(res["data"]["docs"]) == expected_num
else:
assert res["message"] == expected_message
def test_concurrent_list(self, get_http_api_auth, get_dataset_id_and_document_ids):
dataset_id, _ = get_dataset_id_and_document_ids
def test_concurrent_list(self, get_http_api_auth, add_documents):
dataset_id, _ = add_documents
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(list_documnet, get_http_api_auth, dataset_id) for i in range(100)]
futures = [executor.submit(list_documnets, get_http_api_auth, dataset_id) for i in range(100)]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses)
def test_invalid_params(self, get_http_api_auth, get_dataset_id_and_document_ids):
dataset_id, _ = get_dataset_id_and_document_ids
def test_invalid_params(self, get_http_api_auth, add_documents):
dataset_id, _ = add_documents
params = {"a": "b"}
res = list_documnet(get_http_api_auth, dataset_id, params=params)
res = list_documnets(get_http_api_auth, dataset_id, params=params)
assert res["code"] == 0
assert len(res["data"]["docs"]) == 5

View File

@ -16,20 +16,14 @@
from concurrent.futures import ThreadPoolExecutor
import pytest
from common import (
INVALID_API_TOKEN,
batch_create_datasets,
bulk_upload_documents,
list_documnet,
parse_documnet,
)
from common import INVALID_API_TOKEN, bulk_upload_documents, list_documnets, parse_documnets
from libs.auth import RAGFlowHttpApiAuth
from libs.utils import wait_for
def validate_document_details(auth, dataset_id, document_ids):
for document_id in document_ids:
res = list_documnet(auth, dataset_id, params={"id": document_id})
res = list_documnets(auth, dataset_id, params={"id": document_id})
doc = res["data"]["docs"][0]
assert doc["run"] == "DONE"
assert len(doc["process_begin_at"]) > 0
@ -50,14 +44,12 @@ class TestAuthorization:
),
],
)
def test_invalid_auth(self, get_dataset_id_and_document_ids, auth, expected_code, expected_message):
dataset_id, document_ids = get_dataset_id_and_document_ids
res = parse_documnet(auth, dataset_id, {"document_ids": document_ids})
def test_invalid_auth(self, auth, expected_code, expected_message):
res = parse_documnets(auth, "dataset_id")
assert res["code"] == expected_code
assert res["message"] == expected_message
@pytest.mark.usefixtures("clear_datasets")
class TestDocumentsParse:
@pytest.mark.parametrize(
"payload, expected_code, expected_message",
@ -89,21 +81,19 @@ class TestDocumentsParse:
(lambda r: {"document_ids": r}, 0, ""),
],
)
def test_basic_scenarios(self, get_http_api_auth, tmp_path, payload, expected_code, expected_message):
def test_basic_scenarios(self, get_http_api_auth, add_documents_func, payload, expected_code, expected_message):
@wait_for(10, 1, "Document parsing timeout")
def condition(_auth, _dataset_id, _document_ids):
for _document_id in _document_ids:
res = list_documnet(_auth, _dataset_id, {"id": _document_id})
res = list_documnets(_auth, _dataset_id, {"id": _document_id})
if res["data"]["docs"][0]["run"] != "DONE":
return False
return True
ids = batch_create_datasets(get_http_api_auth, 1)
dataset_id = ids[0]
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 3, tmp_path)
dataset_id, document_ids = add_documents_func
if callable(payload):
payload = payload(document_ids)
res = parse_documnet(get_http_api_auth, dataset_id, payload)
res = parse_documnets(get_http_api_auth, dataset_id, payload)
assert res["code"] == expected_code
if expected_code != 0:
assert res["message"] == expected_message
@ -125,14 +115,13 @@ class TestDocumentsParse:
def test_invalid_dataset_id(
self,
get_http_api_auth,
tmp_path,
add_documents_func,
dataset_id,
expected_code,
expected_message,
):
ids = batch_create_datasets(get_http_api_auth, 1)
document_ids = bulk_upload_documents(get_http_api_auth, ids[0], 1, tmp_path)
res = parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids})
_, document_ids = add_documents_func
res = parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids})
assert res["code"] == expected_code
assert res["message"] == expected_message
@ -144,21 +133,19 @@ class TestDocumentsParse:
lambda r: {"document_ids": r + ["invalid_id"]},
],
)
def test_parse_partial_invalid_document_id(self, get_http_api_auth, tmp_path, payload):
def test_parse_partial_invalid_document_id(self, get_http_api_auth, add_documents_func, payload):
@wait_for(10, 1, "Document parsing timeout")
def condition(_auth, _dataset_id):
res = list_documnet(_auth, _dataset_id)
res = list_documnets(_auth, _dataset_id)
for doc in res["data"]["docs"]:
if doc["run"] != "DONE":
return False
return True
ids = batch_create_datasets(get_http_api_auth, 1)
dataset_id = ids[0]
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 3, tmp_path)
dataset_id, document_ids = add_documents_func
if callable(payload):
payload = payload(document_ids)
res = parse_documnet(get_http_api_auth, dataset_id, payload)
res = parse_documnets(get_http_api_auth, dataset_id, payload)
assert res["code"] == 102
assert res["message"] == "Documents not found: ['invalid_id']"
@ -166,96 +153,92 @@ class TestDocumentsParse:
validate_document_details(get_http_api_auth, dataset_id, document_ids)
def test_repeated_parse(self, get_http_api_auth, tmp_path):
def test_repeated_parse(self, get_http_api_auth, add_documents_func):
@wait_for(10, 1, "Document parsing timeout")
def condition(_auth, _dataset_id):
res = list_documnet(_auth, _dataset_id)
res = list_documnets(_auth, _dataset_id)
for doc in res["data"]["docs"]:
if doc["run"] != "DONE":
return False
return True
ids = batch_create_datasets(get_http_api_auth, 1)
dataset_id = ids[0]
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 1, tmp_path)
res = parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids})
dataset_id, document_ids = add_documents_func
res = parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids})
assert res["code"] == 0
condition(get_http_api_auth, dataset_id)
res = parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids})
res = parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids})
assert res["code"] == 0
def test_duplicate_parse(self, get_http_api_auth, tmp_path):
def test_duplicate_parse(self, get_http_api_auth, add_documents_func):
@wait_for(10, 1, "Document parsing timeout")
def condition(_auth, _dataset_id):
res = list_documnet(_auth, _dataset_id)
res = list_documnets(_auth, _dataset_id)
for doc in res["data"]["docs"]:
if doc["run"] != "DONE":
return False
return True
ids = batch_create_datasets(get_http_api_auth, 1)
dataset_id = ids[0]
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 1, tmp_path)
res = parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids + document_ids})
dataset_id, document_ids = add_documents_func
res = parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids + document_ids})
assert res["code"] == 0
assert res["data"]["errors"][0] == f"Duplicate document ids: {document_ids[0]}"
assert res["data"]["success_count"] == 1
assert "Duplicate document ids" in res["data"]["errors"][0]
assert res["data"]["success_count"] == 3
condition(get_http_api_auth, dataset_id)
validate_document_details(get_http_api_auth, dataset_id, document_ids)
@pytest.mark.slow
def test_parse_100_files(self, get_http_api_auth, tmp_path):
@wait_for(100, 1, "Document parsing timeout")
def condition(_auth, _dataset_id, _document_num):
res = list_documnet(_auth, _dataset_id, {"page_size": _document_num})
for doc in res["data"]["docs"]:
if doc["run"] != "DONE":
return False
return True
document_num = 100
ids = batch_create_datasets(get_http_api_auth, 1)
dataset_id = ids[0]
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path)
res = parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids})
assert res["code"] == 0
@pytest.mark.slow
def test_parse_100_files(get_http_api_auth, add_datase_func, tmp_path):
@wait_for(100, 1, "Document parsing timeout")
def condition(_auth, _dataset_id, _document_num):
res = list_documnets(_auth, _dataset_id, {"page_size": _document_num})
for doc in res["data"]["docs"]:
if doc["run"] != "DONE":
return False
return True
condition(get_http_api_auth, dataset_id, document_num)
document_num = 100
dataset_id = add_datase_func
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path)
res = parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids})
assert res["code"] == 0
validate_document_details(get_http_api_auth, dataset_id, document_ids)
condition(get_http_api_auth, dataset_id, document_num)
@pytest.mark.slow
def test_concurrent_parse(self, get_http_api_auth, tmp_path):
@wait_for(120, 1, "Document parsing timeout")
def condition(_auth, _dataset_id, _document_num):
res = list_documnet(_auth, _dataset_id, {"page_size": _document_num})
for doc in res["data"]["docs"]:
if doc["run"] != "DONE":
return False
return True
validate_document_details(get_http_api_auth, dataset_id, document_ids)
document_num = 100
ids = batch_create_datasets(get_http_api_auth, 1)
dataset_id = ids[0]
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path)
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [
executor.submit(
parse_documnet,
get_http_api_auth,
dataset_id,
{"document_ids": document_ids[i : i + 1]},
)
for i in range(document_num)
]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses)
@pytest.mark.slow
def test_concurrent_parse(get_http_api_auth, add_datase_func, tmp_path):
@wait_for(120, 1, "Document parsing timeout")
def condition(_auth, _dataset_id, _document_num):
res = list_documnets(_auth, _dataset_id, {"page_size": _document_num})
for doc in res["data"]["docs"]:
if doc["run"] != "DONE":
return False
return True
condition(get_http_api_auth, dataset_id, document_num)
document_num = 100
dataset_id = add_datase_func
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path)
validate_document_details(get_http_api_auth, dataset_id, document_ids)
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [
executor.submit(
parse_documnets,
get_http_api_auth,
dataset_id,
{"document_ids": document_ids[i : i + 1]},
)
for i in range(document_num)
]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses)
condition(get_http_api_auth, dataset_id, document_num)
validate_document_details(get_http_api_auth, dataset_id, document_ids)

View File

@ -16,21 +16,14 @@
from concurrent.futures import ThreadPoolExecutor
import pytest
from common import (
INVALID_API_TOKEN,
batch_create_datasets,
bulk_upload_documents,
list_documnet,
parse_documnet,
stop_parse_documnet,
)
from common import INVALID_API_TOKEN, bulk_upload_documents, list_documnets, parse_documnets, stop_parse_documnets
from libs.auth import RAGFlowHttpApiAuth
from libs.utils import wait_for
def validate_document_parse_done(auth, dataset_id, document_ids):
for document_id in document_ids:
res = list_documnet(auth, dataset_id, params={"id": document_id})
res = list_documnets(auth, dataset_id, params={"id": document_id})
doc = res["data"]["docs"][0]
assert doc["run"] == "DONE"
assert len(doc["process_begin_at"]) > 0
@ -41,14 +34,13 @@ def validate_document_parse_done(auth, dataset_id, document_ids):
def validate_document_parse_cancel(auth, dataset_id, document_ids):
for document_id in document_ids:
res = list_documnet(auth, dataset_id, params={"id": document_id})
res = list_documnets(auth, dataset_id, params={"id": document_id})
doc = res["data"]["docs"][0]
assert doc["run"] == "CANCEL"
assert len(doc["process_begin_at"]) > 0
assert doc["progress"] == 0.0
@pytest.mark.usefixtures("clear_datasets")
class TestAuthorization:
@pytest.mark.parametrize(
"auth, expected_code, expected_message",
@ -61,15 +53,13 @@ class TestAuthorization:
),
],
)
def test_invalid_auth(self, get_http_api_auth, auth, expected_code, expected_message):
ids = batch_create_datasets(get_http_api_auth, 1)
res = stop_parse_documnet(auth, ids[0])
def test_invalid_auth(self, auth, expected_code, expected_message):
res = stop_parse_documnets(auth, "dataset_id")
assert res["code"] == expected_code
assert res["message"] == expected_message
@pytest.mark.skip
@pytest.mark.usefixtures("clear_datasets")
class TestDocumentsParseStop:
@pytest.mark.parametrize(
"payload, expected_code, expected_message",
@ -101,24 +91,22 @@ class TestDocumentsParseStop:
(lambda r: {"document_ids": r}, 0, ""),
],
)
def test_basic_scenarios(self, get_http_api_auth, tmp_path, payload, expected_code, expected_message):
def test_basic_scenarios(self, get_http_api_auth, add_documents_func, payload, expected_code, expected_message):
@wait_for(10, 1, "Document parsing timeout")
def condition(_auth, _dataset_id, _document_ids):
for _document_id in _document_ids:
res = list_documnet(_auth, _dataset_id, {"id": _document_id})
res = list_documnets(_auth, _dataset_id, {"id": _document_id})
if res["data"]["docs"][0]["run"] != "DONE":
return False
return True
ids = batch_create_datasets(get_http_api_auth, 1)
dataset_id = ids[0]
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 3, tmp_path)
parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids})
dataset_id, document_ids = add_documents_func
parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids})
if callable(payload):
payload = payload(document_ids)
res = stop_parse_documnet(get_http_api_auth, dataset_id, payload)
res = stop_parse_documnets(get_http_api_auth, dataset_id, payload)
assert res["code"] == expected_code
if expected_code != 0:
assert res["message"] == expected_message
@ -129,7 +117,7 @@ class TestDocumentsParseStop:
validate_document_parse_done(get_http_api_auth, dataset_id, completed_document_ids)
@pytest.mark.parametrize(
"dataset_id, expected_code, expected_message",
"invalid_dataset_id, expected_code, expected_message",
[
("", 100, "<MethodNotAllowed '405: Method Not Allowed'>"),
(
@ -142,14 +130,14 @@ class TestDocumentsParseStop:
def test_invalid_dataset_id(
self,
get_http_api_auth,
tmp_path,
dataset_id,
add_documents_func,
invalid_dataset_id,
expected_code,
expected_message,
):
ids = batch_create_datasets(get_http_api_auth, 1)
document_ids = bulk_upload_documents(get_http_api_auth, ids[0], 1, tmp_path)
res = stop_parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids})
dataset_id, document_ids = add_documents_func
parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids})
res = stop_parse_documnets(get_http_api_auth, invalid_dataset_id, {"document_ids": document_ids})
assert res["code"] == expected_code
assert res["message"] == expected_message
@ -162,71 +150,65 @@ class TestDocumentsParseStop:
lambda r: {"document_ids": r + ["invalid_id"]},
],
)
def test_stop_parse_partial_invalid_document_id(self, get_http_api_auth, tmp_path, payload):
ids = batch_create_datasets(get_http_api_auth, 1)
dataset_id = ids[0]
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 3, tmp_path)
parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids})
def test_stop_parse_partial_invalid_document_id(self, get_http_api_auth, add_documents_func, payload):
dataset_id, document_ids = add_documents_func
parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids})
if callable(payload):
payload = payload(document_ids)
res = stop_parse_documnet(get_http_api_auth, dataset_id, payload)
res = stop_parse_documnets(get_http_api_auth, dataset_id, payload)
assert res["code"] == 102
assert res["message"] == "You don't own the document invalid_id."
validate_document_parse_cancel(get_http_api_auth, dataset_id, document_ids)
def test_repeated_stop_parse(self, get_http_api_auth, tmp_path):
ids = batch_create_datasets(get_http_api_auth, 1)
dataset_id = ids[0]
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 1, tmp_path)
parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids})
res = stop_parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids})
def test_repeated_stop_parse(self, get_http_api_auth, add_documents_func):
dataset_id, document_ids = add_documents_func
parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids})
res = stop_parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids})
assert res["code"] == 0
res = stop_parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids})
res = stop_parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids})
assert res["code"] == 102
assert res["message"] == "Can't stop parsing document with progress at 0 or 1"
def test_duplicate_stop_parse(self, get_http_api_auth, tmp_path):
ids = batch_create_datasets(get_http_api_auth, 1)
dataset_id = ids[0]
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 1, tmp_path)
parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids})
res = stop_parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids + document_ids})
def test_duplicate_stop_parse(self, get_http_api_auth, add_documents_func):
dataset_id, document_ids = add_documents_func
parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids})
res = stop_parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids + document_ids})
assert res["code"] == 0
assert res["data"]["success_count"] == 1
assert res["data"]["success_count"] == 3
assert f"Duplicate document ids: {document_ids[0]}" in res["data"]["errors"]
@pytest.mark.slow
def test_stop_parse_100_files(self, get_http_api_auth, tmp_path):
document_num = 100
ids = batch_create_datasets(get_http_api_auth, 1)
dataset_id = ids[0]
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path)
parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids})
res = stop_parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids})
assert res["code"] == 0
validate_document_parse_cancel(get_http_api_auth, dataset_id, document_ids)
@pytest.mark.slow
def test_concurrent_parse(self, get_http_api_auth, tmp_path):
document_num = 50
ids = batch_create_datasets(get_http_api_auth, 1)
dataset_id = ids[0]
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path)
parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids})
@pytest.mark.slow
def test_stop_parse_100_files(get_http_api_auth, add_datase_func, tmp_path):
document_num = 100
dataset_id = add_datase_func
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path)
parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids})
res = stop_parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids})
assert res["code"] == 0
validate_document_parse_cancel(get_http_api_auth, dataset_id, document_ids)
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [
executor.submit(
stop_parse_documnet,
get_http_api_auth,
dataset_id,
{"document_ids": document_ids[i : i + 1]},
)
for i in range(document_num)
]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses)
validate_document_parse_cancel(get_http_api_auth, dataset_id, document_ids)
@pytest.mark.slow
def test_concurrent_parse(get_http_api_auth, add_datase_func, tmp_path):
document_num = 50
dataset_id = add_datase_func
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path)
parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids})
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [
executor.submit(
stop_parse_documnets,
get_http_api_auth,
dataset_id,
{"document_ids": document_ids[i : i + 1]},
)
for i in range(document_num)
]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses)
validate_document_parse_cancel(get_http_api_auth, dataset_id, document_ids)

View File

@ -16,7 +16,7 @@
import pytest
from common import DOCUMENT_NAME_LIMIT, INVALID_API_TOKEN, batch_create_datasets, bulk_upload_documents, list_documnet, update_documnet
from common import DOCUMENT_NAME_LIMIT, INVALID_API_TOKEN, list_documnets, update_documnet
from libs.auth import RAGFlowHttpApiAuth
@ -32,14 +32,13 @@ class TestAuthorization:
),
],
)
def test_invalid_auth(self, get_dataset_id_and_document_ids, auth, expected_code, expected_message):
dataset_id, document_ids = get_dataset_id_and_document_ids
res = update_documnet(auth, dataset_id, document_ids[0], {"name": "auth_test.txt"})
def test_invalid_auth(self, auth, expected_code, expected_message):
res = update_documnet(auth, "dataset_id", "document_id")
assert res["code"] == expected_code
assert res["message"] == expected_message
class TestUpdatedDocument:
class TestDocumentsUpdated:
@pytest.mark.parametrize(
"name, expected_code, expected_message",
[
@ -81,12 +80,12 @@ class TestUpdatedDocument:
),
],
)
def test_name(self, get_http_api_auth, get_dataset_id_and_document_ids, name, expected_code, expected_message):
dataset_id, document_ids = get_dataset_id_and_document_ids
def test_name(self, get_http_api_auth, add_documents, name, expected_code, expected_message):
dataset_id, document_ids = add_documents
res = update_documnet(get_http_api_auth, dataset_id, document_ids[0], {"name": name})
assert res["code"] == expected_code
if expected_code == 0:
res = list_documnet(get_http_api_auth, dataset_id, {"id": document_ids[0]})
res = list_documnets(get_http_api_auth, dataset_id, {"id": document_ids[0]})
assert res["data"]["docs"][0]["name"] == name
else:
assert res["message"] == expected_message
@ -102,8 +101,8 @@ class TestUpdatedDocument:
),
],
)
def test_invalid_document_id(self, get_http_api_auth, get_dataset_id_and_document_ids, document_id, expected_code, expected_message):
dataset_id, _ = get_dataset_id_and_document_ids
def test_invalid_document_id(self, get_http_api_auth, add_documents, document_id, expected_code, expected_message):
dataset_id, _ = add_documents
res = update_documnet(get_http_api_auth, dataset_id, document_id, {"name": "new_name.txt"})
assert res["code"] == expected_code
assert res["message"] == expected_message
@ -119,8 +118,8 @@ class TestUpdatedDocument:
),
],
)
def test_invalid_dataset_id(self, get_http_api_auth, get_dataset_id_and_document_ids, dataset_id, expected_code, expected_message):
_, document_ids = get_dataset_id_and_document_ids
def test_invalid_dataset_id(self, get_http_api_auth, add_documents, dataset_id, expected_code, expected_message):
_, document_ids = add_documents
res = update_documnet(get_http_api_auth, dataset_id, document_ids[0], {"name": "new_name.txt"})
assert res["code"] == expected_code
assert res["message"] == expected_message
@ -129,11 +128,11 @@ class TestUpdatedDocument:
"meta_fields, expected_code, expected_message",
[({"test": "test"}, 0, ""), ("test", 102, "meta_fields must be a dictionary")],
)
def test_meta_fields(self, get_http_api_auth, get_dataset_id_and_document_ids, meta_fields, expected_code, expected_message):
dataset_id, document_ids = get_dataset_id_and_document_ids
def test_meta_fields(self, get_http_api_auth, add_documents, meta_fields, expected_code, expected_message):
dataset_id, document_ids = add_documents
res = update_documnet(get_http_api_auth, dataset_id, document_ids[0], {"meta_fields": meta_fields})
if expected_code == 0:
res = list_documnet(get_http_api_auth, dataset_id, {"id": document_ids[0]})
res = list_documnets(get_http_api_auth, dataset_id, {"id": document_ids[0]})
assert res["data"]["docs"][0]["meta_fields"] == meta_fields
else:
assert res["message"] == expected_message
@ -162,12 +161,12 @@ class TestUpdatedDocument:
),
],
)
def test_chunk_method(self, get_http_api_auth, get_dataset_id_and_document_ids, chunk_method, expected_code, expected_message):
dataset_id, document_ids = get_dataset_id_and_document_ids
def test_chunk_method(self, get_http_api_auth, add_documents, chunk_method, expected_code, expected_message):
dataset_id, document_ids = add_documents
res = update_documnet(get_http_api_auth, dataset_id, document_ids[0], {"chunk_method": chunk_method})
assert res["code"] == expected_code
if expected_code == 0:
res = list_documnet(get_http_api_auth, dataset_id, {"id": document_ids[0]})
res = list_documnets(get_http_api_auth, dataset_id, {"id": document_ids[0]})
if chunk_method != "":
assert res["data"]["docs"][0]["chunk_method"] == chunk_method
else:
@ -282,259 +281,259 @@ class TestUpdatedDocument:
def test_invalid_field(
self,
get_http_api_auth,
get_dataset_id_and_document_ids,
add_documents,
payload,
expected_code,
expected_message,
):
dataset_id, document_ids = get_dataset_id_and_document_ids
dataset_id, document_ids = add_documents
res = update_documnet(get_http_api_auth, dataset_id, document_ids[0], payload)
assert res["code"] == expected_code
assert res["message"] == expected_message
@pytest.mark.usefixtures("clear_datasets")
@pytest.mark.parametrize(
"chunk_method, parser_config, expected_code, expected_message",
[
("naive", {}, 0, ""),
(
"naive",
{
"chunk_token_num": 128,
"layout_recognize": "DeepDOC",
"html4excel": False,
"delimiter": "\\n!?;。;!?",
"task_page_size": 12,
"raptor": {"use_raptor": False},
},
0,
"",
),
pytest.param(
"naive",
{"chunk_token_num": -1},
100,
"AssertionError('chunk_token_num should be in range from 1 to 100000000')",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"chunk_token_num": 0},
100,
"AssertionError('chunk_token_num should be in range from 1 to 100000000')",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"chunk_token_num": 100000000},
100,
"AssertionError('chunk_token_num should be in range from 1 to 100000000')",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"chunk_token_num": 3.14},
102,
"",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"chunk_token_num": "1024"},
100,
"",
marks=pytest.mark.skip(reason="issues/6098"),
),
(
"naive",
{"layout_recognize": "DeepDOC"},
0,
"",
),
(
"naive",
{"layout_recognize": "Naive"},
0,
"",
),
("naive", {"html4excel": True}, 0, ""),
("naive", {"html4excel": False}, 0, ""),
pytest.param(
"naive",
{"html4excel": 1},
100,
"AssertionError('html4excel should be True or False')",
marks=pytest.mark.skip(reason="issues/6098"),
),
("naive", {"delimiter": ""}, 0, ""),
("naive", {"delimiter": "`##`"}, 0, ""),
pytest.param(
"naive",
{"delimiter": 1},
100,
"",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"task_page_size": -1},
100,
"AssertionError('task_page_size should be in range from 1 to 100000000')",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"task_page_size": 0},
100,
"AssertionError('task_page_size should be in range from 1 to 100000000')",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"task_page_size": 100000000},
100,
"AssertionError('task_page_size should be in range from 1 to 100000000')",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"task_page_size": 3.14},
100,
"",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"task_page_size": "1024"},
100,
"",
marks=pytest.mark.skip(reason="issues/6098"),
),
("naive", {"raptor": {"use_raptor": True}}, 0, ""),
("naive", {"raptor": {"use_raptor": False}}, 0, ""),
pytest.param(
"naive",
{"invalid_key": "invalid_value"},
100,
"""AssertionError("Abnormal \'parser_config\'. Invalid key: invalid_key")""",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"auto_keywords": -1},
100,
"AssertionError('auto_keywords should be in range from 0 to 32')",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"auto_keywords": 32},
100,
"AssertionError('auto_keywords should be in range from 0 to 32')",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"auto_questions": 3.14},
100,
"",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"auto_keywords": "1024"},
100,
"",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"auto_questions": -1},
100,
"AssertionError('auto_questions should be in range from 0 to 10')",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"auto_questions": 10},
100,
"AssertionError('auto_questions should be in range from 0 to 10')",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"auto_questions": 3.14},
100,
"",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"auto_questions": "1024"},
100,
"",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"topn_tags": -1},
100,
"AssertionError('topn_tags should be in range from 0 to 10')",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"topn_tags": 10},
100,
"AssertionError('topn_tags should be in range from 0 to 10')",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"topn_tags": 3.14},
100,
"",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"topn_tags": "1024"},
100,
"",
marks=pytest.mark.skip(reason="issues/6098"),
),
],
)
def test_parser_config(
get_http_api_auth,
tmp_path,
chunk_method,
parser_config,
expected_code,
expected_message,
):
ids = batch_create_datasets(get_http_api_auth, 1)
document_ids = bulk_upload_documents(get_http_api_auth, ids[0], 1, tmp_path)
res = update_documnet(
get_http_api_auth,
ids[0],
document_ids[0],
{"chunk_method": chunk_method, "parser_config": parser_config},
class TestUpdateDocumentParserConfig:
@pytest.mark.parametrize(
"chunk_method, parser_config, expected_code, expected_message",
[
("naive", {}, 0, ""),
(
"naive",
{
"chunk_token_num": 128,
"layout_recognize": "DeepDOC",
"html4excel": False,
"delimiter": "\\n!?;。;!?",
"task_page_size": 12,
"raptor": {"use_raptor": False},
},
0,
"",
),
pytest.param(
"naive",
{"chunk_token_num": -1},
100,
"AssertionError('chunk_token_num should be in range from 1 to 100000000')",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"chunk_token_num": 0},
100,
"AssertionError('chunk_token_num should be in range from 1 to 100000000')",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"chunk_token_num": 100000000},
100,
"AssertionError('chunk_token_num should be in range from 1 to 100000000')",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"chunk_token_num": 3.14},
102,
"",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"chunk_token_num": "1024"},
100,
"",
marks=pytest.mark.skip(reason="issues/6098"),
),
(
"naive",
{"layout_recognize": "DeepDOC"},
0,
"",
),
(
"naive",
{"layout_recognize": "Naive"},
0,
"",
),
("naive", {"html4excel": True}, 0, ""),
("naive", {"html4excel": False}, 0, ""),
pytest.param(
"naive",
{"html4excel": 1},
100,
"AssertionError('html4excel should be True or False')",
marks=pytest.mark.skip(reason="issues/6098"),
),
("naive", {"delimiter": ""}, 0, ""),
("naive", {"delimiter": "`##`"}, 0, ""),
pytest.param(
"naive",
{"delimiter": 1},
100,
"",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"task_page_size": -1},
100,
"AssertionError('task_page_size should be in range from 1 to 100000000')",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"task_page_size": 0},
100,
"AssertionError('task_page_size should be in range from 1 to 100000000')",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"task_page_size": 100000000},
100,
"AssertionError('task_page_size should be in range from 1 to 100000000')",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"task_page_size": 3.14},
100,
"",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"task_page_size": "1024"},
100,
"",
marks=pytest.mark.skip(reason="issues/6098"),
),
("naive", {"raptor": {"use_raptor": True}}, 0, ""),
("naive", {"raptor": {"use_raptor": False}}, 0, ""),
pytest.param(
"naive",
{"invalid_key": "invalid_value"},
100,
"""AssertionError("Abnormal \'parser_config\'. Invalid key: invalid_key")""",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"auto_keywords": -1},
100,
"AssertionError('auto_keywords should be in range from 0 to 32')",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"auto_keywords": 32},
100,
"AssertionError('auto_keywords should be in range from 0 to 32')",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"auto_questions": 3.14},
100,
"",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"auto_keywords": "1024"},
100,
"",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"auto_questions": -1},
100,
"AssertionError('auto_questions should be in range from 0 to 10')",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"auto_questions": 10},
100,
"AssertionError('auto_questions should be in range from 0 to 10')",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"auto_questions": 3.14},
100,
"",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"auto_questions": "1024"},
100,
"",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"topn_tags": -1},
100,
"AssertionError('topn_tags should be in range from 0 to 10')",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"topn_tags": 10},
100,
"AssertionError('topn_tags should be in range from 0 to 10')",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"topn_tags": 3.14},
100,
"",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"topn_tags": "1024"},
100,
"",
marks=pytest.mark.skip(reason="issues/6098"),
),
],
)
assert res["code"] == expected_code
if expected_code == 0:
res = list_documnet(get_http_api_auth, ids[0], {"id": document_ids[0]})
if parser_config != {}:
for k, v in parser_config.items():
assert res["data"]["docs"][0]["parser_config"][k] == v
else:
assert res["data"]["docs"][0]["parser_config"] == {
"chunk_token_num": 128,
"delimiter": "\\n!?;。;!?",
"html4excel": False,
"layout_recognize": "DeepDOC",
"raptor": {"use_raptor": False},
}
if expected_code != 0 or expected_message:
assert res["message"] == expected_message
def test_parser_config(
self,
get_http_api_auth,
add_documents,
chunk_method,
parser_config,
expected_code,
expected_message,
):
dataset_id, document_ids = add_documents
res = update_documnet(
get_http_api_auth,
dataset_id,
document_ids[0],
{"chunk_method": chunk_method, "parser_config": parser_config},
)
assert res["code"] == expected_code
if expected_code == 0:
res = list_documnets(get_http_api_auth, dataset_id, {"id": document_ids[0]})
if parser_config != {}:
for k, v in parser_config.items():
assert res["data"]["docs"][0]["parser_config"][k] == v
else:
assert res["data"]["docs"][0]["parser_config"] == {
"chunk_token_num": 128,
"delimiter": "\\n!?;。;!?",
"html4excel": False,
"layout_recognize": "DeepDOC",
"raptor": {"use_raptor": False},
}
if expected_code != 0 or expected_message:
assert res["message"] == expected_message

View File

@ -19,15 +19,7 @@ from concurrent.futures import ThreadPoolExecutor
import pytest
import requests
from common import (
DOCUMENT_NAME_LIMIT,
FILE_API_URL,
HOST_ADDRESS,
INVALID_API_TOKEN,
batch_create_datasets,
list_dataset,
upload_documnets,
)
from common import DOCUMENT_NAME_LIMIT, FILE_API_URL, HOST_ADDRESS, INVALID_API_TOKEN, list_datasets, upload_documnets
from libs.auth import RAGFlowHttpApiAuth
from libs.utils.file_utils import create_txt_file
from requests_toolbelt import MultipartEncoder
@ -46,21 +38,19 @@ class TestAuthorization:
),
],
)
def test_invalid_auth(self, get_http_api_auth, auth, expected_code, expected_message):
ids = batch_create_datasets(get_http_api_auth, 1)
res = upload_documnets(auth, ids[0])
def test_invalid_auth(self, auth, expected_code, expected_message):
res = upload_documnets(auth, "dataset_id")
assert res["code"] == expected_code
assert res["message"] == expected_message
@pytest.mark.usefixtures("clear_datasets")
class TestUploadDocuments:
def test_valid_single_upload(self, get_http_api_auth, tmp_path):
ids = batch_create_datasets(get_http_api_auth, 1)
class TestDocumentsUpload:
def test_valid_single_upload(self, get_http_api_auth, add_dataset_func, tmp_path):
dataset_id = add_dataset_func
fp = create_txt_file(tmp_path / "ragflow_test.txt")
res = upload_documnets(get_http_api_auth, ids[0], [fp])
res = upload_documnets(get_http_api_auth, dataset_id, [fp])
assert res["code"] == 0
assert res["data"][0]["dataset_id"] == ids[0]
assert res["data"][0]["dataset_id"] == dataset_id
assert res["data"][0]["name"] == fp.name
@pytest.mark.parametrize(
@ -79,45 +69,45 @@ class TestUploadDocuments:
],
indirect=True,
)
def test_file_type_validation(self, get_http_api_auth, generate_test_files, request):
ids = batch_create_datasets(get_http_api_auth, 1)
def test_file_type_validation(self, get_http_api_auth, add_dataset_func, generate_test_files, request):
dataset_id = add_dataset_func
fp = generate_test_files[request.node.callspec.params["generate_test_files"]]
res = upload_documnets(get_http_api_auth, ids[0], [fp])
res = upload_documnets(get_http_api_auth, dataset_id, [fp])
assert res["code"] == 0
assert res["data"][0]["dataset_id"] == ids[0]
assert res["data"][0]["dataset_id"] == dataset_id
assert res["data"][0]["name"] == fp.name
@pytest.mark.parametrize(
"file_type",
["exe", "unknown"],
)
def test_unsupported_file_type(self, get_http_api_auth, tmp_path, file_type):
ids = batch_create_datasets(get_http_api_auth, 1)
def test_unsupported_file_type(self, get_http_api_auth, add_dataset_func, tmp_path, file_type):
dataset_id = add_dataset_func
fp = tmp_path / f"ragflow_test.{file_type}"
fp.touch()
res = upload_documnets(get_http_api_auth, ids[0], [fp])
res = upload_documnets(get_http_api_auth, dataset_id, [fp])
assert res["code"] == 500
assert res["message"] == f"ragflow_test.{file_type}: This type of file has not been supported yet!"
def test_missing_file(self, get_http_api_auth):
ids = batch_create_datasets(get_http_api_auth, 1)
res = upload_documnets(get_http_api_auth, ids[0])
def test_missing_file(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func
res = upload_documnets(get_http_api_auth, dataset_id)
assert res["code"] == 101
assert res["message"] == "No file part!"
def test_empty_file(self, get_http_api_auth, tmp_path):
ids = batch_create_datasets(get_http_api_auth, 1)
def test_empty_file(self, get_http_api_auth, add_dataset_func, tmp_path):
dataset_id = add_dataset_func
fp = tmp_path / "empty.txt"
fp.touch()
res = upload_documnets(get_http_api_auth, ids[0], [fp])
res = upload_documnets(get_http_api_auth, dataset_id, [fp])
assert res["code"] == 0
assert res["data"][0]["size"] == 0
def test_filename_empty(self, get_http_api_auth, tmp_path):
ids = batch_create_datasets(get_http_api_auth, 1)
def test_filename_empty(self, get_http_api_auth, add_dataset_func, tmp_path):
dataset_id = add_dataset_func
fp = create_txt_file(tmp_path / "ragflow_test.txt")
url = f"{HOST_ADDRESS}{FILE_API_URL}".format(dataset_id=ids[0])
url = f"{HOST_ADDRESS}{FILE_API_URL}".format(dataset_id=dataset_id)
fields = (("file", ("", fp.open("rb"))),)
m = MultipartEncoder(fields=fields)
res = requests.post(
@ -129,11 +119,11 @@ class TestUploadDocuments:
assert res.json()["code"] == 101
assert res.json()["message"] == "No file selected!"
def test_filename_exceeds_max_length(self, get_http_api_auth, tmp_path):
ids = batch_create_datasets(get_http_api_auth, 1)
def test_filename_exceeds_max_length(self, get_http_api_auth, add_dataset_func, tmp_path):
dataset_id = add_dataset_func
# filename_length = 129
fp = create_txt_file(tmp_path / f"{'a' * (DOCUMENT_NAME_LIMIT - 3)}.txt")
res = upload_documnets(get_http_api_auth, ids[0], [fp])
res = upload_documnets(get_http_api_auth, dataset_id, [fp])
assert res["code"] == 101
assert res["message"] == "File name should be less than 128 bytes."
@ -143,61 +133,61 @@ class TestUploadDocuments:
assert res["code"] == 100
assert res["message"] == """LookupError("Can\'t find the dataset with ID invalid_dataset_id!")"""
def test_duplicate_files(self, get_http_api_auth, tmp_path):
ids = batch_create_datasets(get_http_api_auth, 1)
def test_duplicate_files(self, get_http_api_auth, add_dataset_func, tmp_path):
dataset_id = add_dataset_func
fp = create_txt_file(tmp_path / "ragflow_test.txt")
res = upload_documnets(get_http_api_auth, ids[0], [fp, fp])
res = upload_documnets(get_http_api_auth, dataset_id, [fp, fp])
assert res["code"] == 0
assert len(res["data"]) == 2
for i in range(len(res["data"])):
assert res["data"][i]["dataset_id"] == ids[0]
assert res["data"][i]["dataset_id"] == dataset_id
expected_name = fp.name
if i != 0:
expected_name = f"{fp.stem}({i}){fp.suffix}"
assert res["data"][i]["name"] == expected_name
def test_same_file_repeat(self, get_http_api_auth, tmp_path):
ids = batch_create_datasets(get_http_api_auth, 1)
def test_same_file_repeat(self, get_http_api_auth, add_dataset_func, tmp_path):
dataset_id = add_dataset_func
fp = create_txt_file(tmp_path / "ragflow_test.txt")
for i in range(10):
res = upload_documnets(get_http_api_auth, ids[0], [fp])
res = upload_documnets(get_http_api_auth, dataset_id, [fp])
assert res["code"] == 0
assert len(res["data"]) == 1
assert res["data"][0]["dataset_id"] == ids[0]
assert res["data"][0]["dataset_id"] == dataset_id
expected_name = fp.name
if i != 0:
expected_name = f"{fp.stem}({i}){fp.suffix}"
assert res["data"][0]["name"] == expected_name
def test_filename_special_characters(self, get_http_api_auth, tmp_path):
ids = batch_create_datasets(get_http_api_auth, 1)
def test_filename_special_characters(self, get_http_api_auth, add_dataset_func, tmp_path):
dataset_id = add_dataset_func
illegal_chars = '<>:"/\\|?*'
translation_table = str.maketrans({char: "_" for char in illegal_chars})
safe_filename = string.punctuation.translate(translation_table)
fp = tmp_path / f"{safe_filename}.txt"
fp.write_text("Sample text content")
res = upload_documnets(get_http_api_auth, ids[0], [fp])
res = upload_documnets(get_http_api_auth, dataset_id, [fp])
assert res["code"] == 0
assert len(res["data"]) == 1
assert res["data"][0]["dataset_id"] == ids[0]
assert res["data"][0]["dataset_id"] == dataset_id
assert res["data"][0]["name"] == fp.name
def test_multiple_files(self, get_http_api_auth, tmp_path):
ids = batch_create_datasets(get_http_api_auth, 1)
def test_multiple_files(self, get_http_api_auth, add_dataset_func, tmp_path):
dataset_id = add_dataset_func
expected_document_count = 20
fps = []
for i in range(expected_document_count):
fp = create_txt_file(tmp_path / f"ragflow_test_{i}.txt")
fps.append(fp)
res = upload_documnets(get_http_api_auth, ids[0], fps)
res = upload_documnets(get_http_api_auth, dataset_id, fps)
assert res["code"] == 0
res = list_dataset(get_http_api_auth, {"id": ids[0]})
res = list_datasets(get_http_api_auth, {"id": dataset_id})
assert res["data"][0]["document_count"] == expected_document_count
def test_concurrent_upload(self, get_http_api_auth, tmp_path):
ids = batch_create_datasets(get_http_api_auth, 1)
def test_concurrent_upload(self, get_http_api_auth, add_dataset_func, tmp_path):
dataset_id = add_dataset_func
expected_document_count = 20
fps = []
@ -206,9 +196,9 @@ class TestUploadDocuments:
fps.append(fp)
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(upload_documnets, get_http_api_auth, ids[0], fps[i : i + 1]) for i in range(expected_document_count)]
futures = [executor.submit(upload_documnets, get_http_api_auth, dataset_id, fps[i : i + 1]) for i in range(expected_document_count)]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses)
res = list_dataset(get_http_api_auth, {"id": ids[0]})
res = list_datasets(get_http_api_auth, {"id": dataset_id})
assert res["data"][0]["document_count"] == expected_document_count