mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-04-19 12:39:59 +08:00
Test: Refactor test fixtures and test cases (#6709)
### What problem does this PR solve? Refactor test fixtures and test cases ### Type of change - [ ] Refactoring test cases
This commit is contained in:
parent
20b8ccd1e9
commit
58e6e7b668
@ -15,26 +15,27 @@
|
||||
#
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from libs.auth import RAGFlowHttpApiAuth
|
||||
|
||||
HOST_ADDRESS = os.getenv('HOST_ADDRESS', 'http://127.0.0.1:9380')
|
||||
HOST_ADDRESS = os.getenv("HOST_ADDRESS", "http://127.0.0.1:9380")
|
||||
|
||||
|
||||
# def generate_random_email():
|
||||
# return 'user_' + ''.join(random.choices(string.ascii_lowercase + string.digits, k=8))+'@1.com'
|
||||
|
||||
|
||||
def generate_email():
|
||||
return 'user_123@1.com'
|
||||
return "user_123@1.com"
|
||||
|
||||
|
||||
EMAIL = generate_email()
|
||||
# password is "123"
|
||||
PASSWORD = '''ctAseGvejiaSWWZ88T/m4FQVOpQyUvP+x7sXtdv3feqZACiQleuewkUi35E16wSd5C5QcnkkcV9cYc8TKPTRZlxappDuirxghxoOvFcJxFU4ixLsD
|
||||
PASSWORD = """ctAseGvejiaSWWZ88T/m4FQVOpQyUvP+x7sXtdv3feqZACiQleuewkUi35E16wSd5C5QcnkkcV9cYc8TKPTRZlxappDuirxghxoOvFcJxFU4ixLsD
|
||||
fN33jCHRoDUW81IH9zjij/vaw8IbVyb6vuwg6MX6inOEBRRzVbRYxXOu1wkWY6SsI8X70oF9aeLFp/PzQpjoe/YbSqpTq8qqrmHzn9vO+yvyYyvmDsphXe
|
||||
X8f7fp9c7vUsfOCkM+gHY3PadG+QHa7KI7mzTKgUTZImK6BZtfRBATDTthEUbbaTewY4H0MnWiCeeDhcbeQao6cFy1To8pE3RpmxnGnS8BsBn8w=='''
|
||||
X8f7fp9c7vUsfOCkM+gHY3PadG+QHa7KI7mzTKgUTZImK6BZtfRBATDTthEUbbaTewY4H0MnWiCeeDhcbeQao6cFy1To8pE3RpmxnGnS8BsBn8w=="""
|
||||
|
||||
|
||||
def register():
|
||||
@ -92,3 +93,64 @@ def get_email():
|
||||
@pytest.fixture(scope="session")
|
||||
def get_http_api_auth(get_api_key_fixture):
|
||||
return RAGFlowHttpApiAuth(get_api_key_fixture)
|
||||
|
||||
|
||||
def get_my_llms(auth, name):
|
||||
url = HOST_ADDRESS + "/v1/llm/my_llms"
|
||||
authorization = {"Authorization": auth}
|
||||
response = requests.get(url=url, headers=authorization)
|
||||
res = response.json()
|
||||
if res.get("code") != 0:
|
||||
raise Exception(res.get("message"))
|
||||
if name in res.get("data"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def add_models(auth):
|
||||
url = HOST_ADDRESS + "/v1/llm/set_api_key"
|
||||
authorization = {"Authorization": auth}
|
||||
models_info = {
|
||||
"ZHIPU-AI": {"llm_factory": "ZHIPU-AI", "api_key": "d06253dacd404180aa8afb096fcb6c30.KatwBIUpvCSml9sU"},
|
||||
}
|
||||
|
||||
for name, model_info in models_info.items():
|
||||
if not get_my_llms(auth, name):
|
||||
response = requests.post(url=url, headers=authorization, json=model_info)
|
||||
res = response.json()
|
||||
if res.get("code") != 0:
|
||||
raise Exception(res.get("message"))
|
||||
|
||||
|
||||
def get_tenant_info(auth):
|
||||
url = HOST_ADDRESS + "/v1/user/tenant_info"
|
||||
authorization = {"Authorization": auth}
|
||||
response = requests.get(url=url, headers=authorization)
|
||||
res = response.json()
|
||||
if res.get("code") != 0:
|
||||
raise Exception(res.get("message"))
|
||||
return res["data"].get("tenant_id")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def set_tenant_info(get_auth):
|
||||
auth = get_auth
|
||||
try:
|
||||
add_models(auth)
|
||||
tenant_id = get_tenant_info(auth)
|
||||
except Exception as e:
|
||||
raise Exception(e)
|
||||
url = HOST_ADDRESS + "/v1/user/set_tenant_info"
|
||||
authorization = {"Authorization": get_auth}
|
||||
tenant_info = {
|
||||
"tenant_id": tenant_id,
|
||||
"llm_id": "glm-4-flash@ZHIPU-AI",
|
||||
"embd_id": "embedding-3@ZHIPU-AI",
|
||||
"img2txt_id": "glm-4v@ZHIPU-AI",
|
||||
"asr_id": "",
|
||||
"tts_id": None,
|
||||
}
|
||||
response = requests.post(url=url, headers=authorization, json=tenant_info)
|
||||
res = response.json()
|
||||
if res.get("code") != 0:
|
||||
raise Exception(res.get("message"))
|
||||
|
@ -27,6 +27,7 @@ DATASETS_API_URL = "/api/v1/datasets"
|
||||
FILE_API_URL = "/api/v1/datasets/{dataset_id}/documents"
|
||||
FILE_CHUNK_API_URL = "/api/v1/datasets/{dataset_id}/chunks"
|
||||
CHUNK_API_URL = "/api/v1/datasets/{dataset_id}/documents/{document_id}/chunks"
|
||||
CHAT_ASSISTANT_API_URL = "/api/v1/chats"
|
||||
|
||||
INVALID_API_TOKEN = "invalid_key_123"
|
||||
DATASET_NAME_LIMIT = 128
|
||||
@ -39,7 +40,7 @@ def create_dataset(auth, payload=None):
|
||||
return res.json()
|
||||
|
||||
|
||||
def list_dataset(auth, params=None):
|
||||
def list_datasets(auth, params=None):
|
||||
res = requests.get(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=HEADERS, auth=auth, params=params)
|
||||
return res.json()
|
||||
|
||||
@ -49,7 +50,7 @@ def update_dataset(auth, dataset_id, payload=None):
|
||||
return res.json()
|
||||
|
||||
|
||||
def delete_dataset(auth, payload=None):
|
||||
def delete_datasets(auth, payload=None):
|
||||
res = requests.delete(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=HEADERS, auth=auth, json=payload)
|
||||
return res.json()
|
||||
|
||||
@ -105,7 +106,7 @@ def download_document(auth, dataset_id, document_id, save_path):
|
||||
return res
|
||||
|
||||
|
||||
def list_documnet(auth, dataset_id, params=None):
|
||||
def list_documnets(auth, dataset_id, params=None):
|
||||
url = f"{HOST_ADDRESS}{FILE_API_URL}".format(dataset_id=dataset_id)
|
||||
res = requests.get(url=url, headers=HEADERS, auth=auth, params=params)
|
||||
return res.json()
|
||||
@ -117,19 +118,19 @@ def update_documnet(auth, dataset_id, document_id, payload=None):
|
||||
return res.json()
|
||||
|
||||
|
||||
def delete_documnet(auth, dataset_id, payload=None):
|
||||
def delete_documnets(auth, dataset_id, payload=None):
|
||||
url = f"{HOST_ADDRESS}{FILE_API_URL}".format(dataset_id=dataset_id)
|
||||
res = requests.delete(url=url, headers=HEADERS, auth=auth, json=payload)
|
||||
return res.json()
|
||||
|
||||
|
||||
def parse_documnet(auth, dataset_id, payload=None):
|
||||
def parse_documnets(auth, dataset_id, payload=None):
|
||||
url = f"{HOST_ADDRESS}{FILE_CHUNK_API_URL}".format(dataset_id=dataset_id)
|
||||
res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload)
|
||||
return res.json()
|
||||
|
||||
|
||||
def stop_parse_documnet(auth, dataset_id, payload=None):
|
||||
def stop_parse_documnets(auth, dataset_id, payload=None):
|
||||
url = f"{HOST_ADDRESS}{FILE_CHUNK_API_URL}".format(dataset_id=dataset_id)
|
||||
res = requests.delete(url=url, headers=HEADERS, auth=auth, json=payload)
|
||||
return res.json()
|
||||
@ -184,3 +185,36 @@ def batch_add_chunks(auth, dataset_id, document_id, num):
|
||||
res = add_chunk(auth, dataset_id, document_id, {"content": f"chunk test {i}"})
|
||||
chunk_ids.append(res["data"]["chunk"]["id"])
|
||||
return chunk_ids
|
||||
|
||||
|
||||
# CHAT ASSISTANT MANAGEMENT
|
||||
def create_chat_assistant(auth, payload=None):
|
||||
url = f"{HOST_ADDRESS}{CHAT_ASSISTANT_API_URL}"
|
||||
res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload)
|
||||
return res.json()
|
||||
|
||||
|
||||
def list_chat_assistants(auth, params=None):
|
||||
url = f"{HOST_ADDRESS}{CHAT_ASSISTANT_API_URL}"
|
||||
res = requests.get(url=url, headers=HEADERS, auth=auth, params=params)
|
||||
return res.json()
|
||||
|
||||
|
||||
def update_chat_assistant(auth, chat_assistant_id, payload=None):
|
||||
url = f"{HOST_ADDRESS}{CHAT_ASSISTANT_API_URL}/{chat_assistant_id}"
|
||||
res = requests.put(url=url, headers=HEADERS, auth=auth, json=payload)
|
||||
return res.json()
|
||||
|
||||
|
||||
def delete_chat_assistants(auth, payload=None):
|
||||
url = f"{HOST_ADDRESS}{CHAT_ASSISTANT_API_URL}"
|
||||
res = requests.delete(url=url, headers=HEADERS, auth=auth, json=payload)
|
||||
return res.json()
|
||||
|
||||
|
||||
def batch_create_chat_assistants(auth, num):
|
||||
chat_assistant_ids = []
|
||||
for i in range(num):
|
||||
res = create_chat_assistant(auth, {"name": f"test_chat_assistant_{i}"})
|
||||
chat_assistant_ids.append(res["data"]["id"])
|
||||
return chat_assistant_ids
|
||||
|
@ -14,9 +14,8 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
import pytest
|
||||
from common import delete_dataset
|
||||
from common import batch_create_datasets, bulk_upload_documents, delete_datasets
|
||||
from libs.utils.file_utils import (
|
||||
create_docx_file,
|
||||
create_eml_file,
|
||||
@ -34,7 +33,7 @@ from libs.utils.file_utils import (
|
||||
@pytest.fixture(scope="function")
|
||||
def clear_datasets(get_http_api_auth):
|
||||
yield
|
||||
delete_dataset(get_http_api_auth)
|
||||
delete_datasets(get_http_api_auth)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -58,3 +57,38 @@ def generate_test_files(request, tmp_path):
|
||||
creator_func(file_path)
|
||||
files[file_type] = file_path
|
||||
return files
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def ragflow_tmp_dir(request, tmp_path_factory):
|
||||
class_name = request.cls.__name__
|
||||
return tmp_path_factory.mktemp(class_name)
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def add_dataset(request, get_http_api_auth):
|
||||
def cleanup():
|
||||
delete_datasets(get_http_api_auth)
|
||||
|
||||
request.addfinalizer(cleanup)
|
||||
|
||||
dataset_ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
return dataset_ids[0]
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def add_dataset_func(request, get_http_api_auth):
|
||||
def cleanup():
|
||||
delete_datasets(get_http_api_auth)
|
||||
|
||||
request.addfinalizer(cleanup)
|
||||
|
||||
dataset_ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
return dataset_ids[0]
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def add_document(get_http_api_auth, add_dataset, ragflow_tmp_dir):
|
||||
dataset_id = add_dataset
|
||||
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 1, ragflow_tmp_dir)
|
||||
return dataset_id, document_ids[0]
|
||||
|
@ -16,13 +16,13 @@
|
||||
|
||||
|
||||
import pytest
|
||||
from common import add_chunk, batch_create_datasets, bulk_upload_documents, delete_chunks, delete_dataset, list_documnet, parse_documnet
|
||||
from common import add_chunk, delete_chunks, list_documnets, parse_documnets
|
||||
from libs.utils import wait_for
|
||||
|
||||
|
||||
@wait_for(10, 1, "Document parsing timeout")
|
||||
def condition(_auth, _dataset_id):
|
||||
res = list_documnet(_auth, _dataset_id)
|
||||
res = list_documnets(_auth, _dataset_id)
|
||||
for doc in res["data"]["docs"]:
|
||||
if doc["run"] != "DONE":
|
||||
return False
|
||||
@ -30,29 +30,11 @@ def condition(_auth, _dataset_id):
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def chunk_management_tmp_dir(tmp_path_factory):
|
||||
return tmp_path_factory.mktemp("chunk_management")
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def get_dataset_id_and_document_id(get_http_api_auth, chunk_management_tmp_dir, request):
|
||||
def cleanup():
|
||||
delete_dataset(get_http_api_auth)
|
||||
|
||||
request.addfinalizer(cleanup)
|
||||
|
||||
dataset_ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
dataset_id = dataset_ids[0]
|
||||
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 1, chunk_management_tmp_dir)
|
||||
parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids})
|
||||
def add_chunks(get_http_api_auth, add_document):
|
||||
dataset_id, document_id = add_document
|
||||
parse_documnets(get_http_api_auth, dataset_id, {"document_ids": [document_id]})
|
||||
condition(get_http_api_auth, dataset_id)
|
||||
|
||||
return dataset_id, document_ids[0]
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def add_chunks(get_http_api_auth, get_dataset_id_and_document_id):
|
||||
dataset_id, document_id = get_dataset_id_and_document_id
|
||||
chunk_ids = []
|
||||
for i in range(4):
|
||||
res = add_chunk(get_http_api_auth, dataset_id, document_id, {"content": f"chunk test {i}"})
|
||||
@ -66,8 +48,10 @@ def add_chunks(get_http_api_auth, get_dataset_id_and_document_id):
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def add_chunks_func(get_http_api_auth, get_dataset_id_and_document_id, request):
|
||||
dataset_id, document_id = get_dataset_id_and_document_id
|
||||
def add_chunks_func(request, get_http_api_auth, add_document):
|
||||
dataset_id, document_id = add_document
|
||||
parse_documnets(get_http_api_auth, dataset_id, {"document_ids": [document_id]})
|
||||
condition(get_http_api_auth, dataset_id)
|
||||
|
||||
chunk_ids = []
|
||||
for i in range(4):
|
||||
|
@ -16,7 +16,7 @@
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import pytest
|
||||
from common import INVALID_API_TOKEN, add_chunk, delete_documnet, list_chunks
|
||||
from common import INVALID_API_TOKEN, add_chunk, delete_documnets, list_chunks
|
||||
from libs.auth import RAGFlowHttpApiAuth
|
||||
|
||||
|
||||
@ -44,7 +44,7 @@ class TestAuthorization:
|
||||
],
|
||||
)
|
||||
def test_invalid_auth(self, auth, expected_code, expected_message):
|
||||
res = add_chunk(auth, "dataset_id", "document_id", {})
|
||||
res = add_chunk(auth, "dataset_id", "document_id")
|
||||
assert res["code"] == expected_code
|
||||
assert res["message"] == expected_message
|
||||
|
||||
@ -66,8 +66,8 @@ class TestAddChunk:
|
||||
({"content": "\n!?。;!?\"'"}, 0, ""),
|
||||
],
|
||||
)
|
||||
def test_content(self, get_http_api_auth, get_dataset_id_and_document_id, payload, expected_code, expected_message):
|
||||
dataset_id, document_id = get_dataset_id_and_document_id
|
||||
def test_content(self, get_http_api_auth, add_document, payload, expected_code, expected_message):
|
||||
dataset_id, document_id = add_document
|
||||
res = list_chunks(get_http_api_auth, dataset_id, document_id)
|
||||
if res["code"] != 0:
|
||||
assert False, res
|
||||
@ -98,8 +98,8 @@ class TestAddChunk:
|
||||
({"content": "chunk test", "important_keywords": 123}, 102, "`important_keywords` is required to be a list"),
|
||||
],
|
||||
)
|
||||
def test_important_keywords(self, get_http_api_auth, get_dataset_id_and_document_id, payload, expected_code, expected_message):
|
||||
dataset_id, document_id = get_dataset_id_and_document_id
|
||||
def test_important_keywords(self, get_http_api_auth, add_document, payload, expected_code, expected_message):
|
||||
dataset_id, document_id = add_document
|
||||
res = list_chunks(get_http_api_auth, dataset_id, document_id)
|
||||
if res["code"] != 0:
|
||||
assert False, res
|
||||
@ -126,8 +126,8 @@ class TestAddChunk:
|
||||
({"content": "chunk test", "questions": 123}, 102, "`questions` is required to be a list"),
|
||||
],
|
||||
)
|
||||
def test_questions(self, get_http_api_auth, get_dataset_id_and_document_id, payload, expected_code, expected_message):
|
||||
dataset_id, document_id = get_dataset_id_and_document_id
|
||||
def test_questions(self, get_http_api_auth, add_document, payload, expected_code, expected_message):
|
||||
dataset_id, document_id = add_document
|
||||
res = list_chunks(get_http_api_auth, dataset_id, document_id)
|
||||
if res["code"] != 0:
|
||||
assert False, res
|
||||
@ -157,12 +157,12 @@ class TestAddChunk:
|
||||
def test_invalid_dataset_id(
|
||||
self,
|
||||
get_http_api_auth,
|
||||
get_dataset_id_and_document_id,
|
||||
add_document,
|
||||
dataset_id,
|
||||
expected_code,
|
||||
expected_message,
|
||||
):
|
||||
_, document_id = get_dataset_id_and_document_id
|
||||
_, document_id = add_document
|
||||
res = add_chunk(get_http_api_auth, dataset_id, document_id, {"content": "a"})
|
||||
assert res["code"] == expected_code
|
||||
assert res["message"] == expected_message
|
||||
@ -178,15 +178,15 @@ class TestAddChunk:
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_invalid_document_id(self, get_http_api_auth, get_dataset_id_and_document_id, document_id, expected_code, expected_message):
|
||||
dataset_id, _ = get_dataset_id_and_document_id
|
||||
def test_invalid_document_id(self, get_http_api_auth, add_document, document_id, expected_code, expected_message):
|
||||
dataset_id, _ = add_document
|
||||
res = add_chunk(get_http_api_auth, dataset_id, document_id, {"content": "chunk test"})
|
||||
assert res["code"] == expected_code
|
||||
assert res["message"] == expected_message
|
||||
|
||||
def test_repeated_add_chunk(self, get_http_api_auth, get_dataset_id_and_document_id):
|
||||
def test_repeated_add_chunk(self, get_http_api_auth, add_document):
|
||||
payload = {"content": "chunk test"}
|
||||
dataset_id, document_id = get_dataset_id_and_document_id
|
||||
dataset_id, document_id = add_document
|
||||
res = list_chunks(get_http_api_auth, dataset_id, document_id)
|
||||
if res["code"] != 0:
|
||||
assert False, res
|
||||
@ -207,17 +207,17 @@ class TestAddChunk:
|
||||
assert False, res
|
||||
assert res["data"]["doc"]["chunk_count"] == chunks_count + 2
|
||||
|
||||
def test_add_chunk_to_deleted_document(self, get_http_api_auth, get_dataset_id_and_document_id):
|
||||
dataset_id, document_id = get_dataset_id_and_document_id
|
||||
delete_documnet(get_http_api_auth, dataset_id, {"ids": [document_id]})
|
||||
def test_add_chunk_to_deleted_document(self, get_http_api_auth, add_document):
|
||||
dataset_id, document_id = add_document
|
||||
delete_documnets(get_http_api_auth, dataset_id, {"ids": [document_id]})
|
||||
res = add_chunk(get_http_api_auth, dataset_id, document_id, {"content": "chunk test"})
|
||||
assert res["code"] == 102
|
||||
assert res["message"] == f"You don't own the document {document_id}."
|
||||
|
||||
@pytest.mark.skip(reason="issues/6411")
|
||||
def test_concurrent_add_chunk(self, get_http_api_auth, get_dataset_id_and_document_id):
|
||||
def test_concurrent_add_chunk(self, get_http_api_auth, add_document):
|
||||
chunk_num = 50
|
||||
dataset_id, document_id = get_dataset_id_and_document_id
|
||||
dataset_id, document_id = add_document
|
||||
res = list_chunks(get_http_api_auth, dataset_id, document_id)
|
||||
if res["code"] != 0:
|
||||
assert False, res
|
||||
|
@ -39,7 +39,7 @@ class TestAuthorization:
|
||||
assert res["message"] == expected_message
|
||||
|
||||
|
||||
class TestChunkstDeletion:
|
||||
class TestChunksDeletion:
|
||||
@pytest.mark.parametrize(
|
||||
"dataset_id, expected_code, expected_message",
|
||||
[
|
||||
@ -61,25 +61,14 @@ class TestChunkstDeletion:
|
||||
"document_id, expected_code, expected_message",
|
||||
[
|
||||
("", 100, "<MethodNotAllowed '405: Method Not Allowed'>"),
|
||||
pytest.param(
|
||||
"invalid_document_id",
|
||||
100,
|
||||
"LookupError('Document not found which is supposed to be there')",
|
||||
marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="issues/6611"),
|
||||
),
|
||||
pytest.param(
|
||||
"invalid_document_id",
|
||||
100,
|
||||
"rm_chunk deleted chunks 0, expect 4",
|
||||
marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "elasticsearch"], reason="issues/6611"),
|
||||
),
|
||||
("invalid_document_id", 100, """LookupError("Can't find the document with ID invalid_document_id!")"""),
|
||||
],
|
||||
)
|
||||
def test_invalid_document_id(self, get_http_api_auth, add_chunks_func, document_id, expected_code, expected_message):
|
||||
dataset_id, _, chunk_ids = add_chunks_func
|
||||
res = delete_chunks(get_http_api_auth, dataset_id, document_id, {"chunk_ids": chunk_ids})
|
||||
assert res["code"] == expected_code
|
||||
#assert res["message"] == expected_message
|
||||
assert res["message"] == expected_message
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"payload",
|
||||
|
@ -17,11 +17,7 @@ import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import pytest
|
||||
from common import (
|
||||
INVALID_API_TOKEN,
|
||||
batch_add_chunks,
|
||||
list_chunks,
|
||||
)
|
||||
from common import INVALID_API_TOKEN, batch_add_chunks, list_chunks
|
||||
from libs.auth import RAGFlowHttpApiAuth
|
||||
|
||||
|
||||
@ -153,8 +149,9 @@ class TestChunksList:
|
||||
assert all(r["code"] == 0 for r in responses)
|
||||
assert all(len(r["data"]["chunks"]) == 5 for r in responses)
|
||||
|
||||
def test_default(self, get_http_api_auth, get_dataset_id_and_document_id):
|
||||
dataset_id, document_id = get_dataset_id_and_document_id
|
||||
def test_default(self, get_http_api_auth, add_document):
|
||||
dataset_id, document_id = add_document
|
||||
|
||||
res = list_chunks(get_http_api_auth, dataset_id, document_id)
|
||||
chunks_count = res["data"]["doc"]["chunk_count"]
|
||||
batch_add_chunks(get_http_api_auth, dataset_id, document_id, 31)
|
||||
|
@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
@ -52,9 +51,7 @@ class TestChunksRetrieval:
|
||||
({"question": "chunk"}, 102, 0, "`dataset_ids` is required."),
|
||||
],
|
||||
)
|
||||
def test_basic_scenarios(
|
||||
self, get_http_api_auth, add_chunks, payload, expected_code, expected_page_size, expected_message
|
||||
):
|
||||
def test_basic_scenarios(self, get_http_api_auth, add_chunks, payload, expected_code, expected_page_size, expected_message):
|
||||
dataset_id, document_id, _ = add_chunks
|
||||
if "dataset_ids" in payload:
|
||||
payload["dataset_ids"] = [dataset_id]
|
||||
@ -137,9 +134,7 @@ class TestChunksRetrieval:
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_page_size(
|
||||
self, get_http_api_auth, add_chunks, payload, expected_code, expected_page_size, expected_message
|
||||
):
|
||||
def test_page_size(self, get_http_api_auth, add_chunks, payload, expected_code, expected_page_size, expected_message):
|
||||
dataset_id, _, _ = add_chunks
|
||||
payload.update({"question": "chunk", "dataset_ids": [dataset_id]})
|
||||
|
||||
@ -165,9 +160,7 @@ class TestChunksRetrieval:
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_vector_similarity_weight(
|
||||
self, get_http_api_auth, add_chunks, payload, expected_code, expected_page_size, expected_message
|
||||
):
|
||||
def test_vector_similarity_weight(self, get_http_api_auth, add_chunks, payload, expected_code, expected_page_size, expected_message):
|
||||
dataset_id, _, _ = add_chunks
|
||||
payload.update({"question": "chunk", "dataset_ids": [dataset_id]})
|
||||
res = retrieval_chunks(get_http_api_auth, payload)
|
||||
@ -233,9 +226,7 @@ class TestChunksRetrieval:
|
||||
"payload, expected_code, expected_message",
|
||||
[
|
||||
({"rerank_id": "BAAI/bge-reranker-v2-m3"}, 0, ""),
|
||||
pytest.param(
|
||||
{"rerank_id": "unknown"}, 100, "LookupError('Model(unknown) not authorized')", marks=pytest.mark.skip
|
||||
),
|
||||
pytest.param({"rerank_id": "unknown"}, 100, "LookupError('Model(unknown) not authorized')", marks=pytest.mark.skip),
|
||||
],
|
||||
)
|
||||
def test_rerank_id(self, get_http_api_auth, add_chunks, payload, expected_code, expected_message):
|
||||
@ -248,7 +239,6 @@ class TestChunksRetrieval:
|
||||
else:
|
||||
assert expected_message in res["message"]
|
||||
|
||||
@pytest.mark.skip(reason="chat model is not set")
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_code, expected_page_size, expected_message",
|
||||
[
|
||||
@ -279,9 +269,7 @@ class TestChunksRetrieval:
|
||||
pytest.param({"highlight": None}, 0, False, "", marks=pytest.mark.skip(reason="issues/6648")),
|
||||
],
|
||||
)
|
||||
def test_highlight(
|
||||
self, get_http_api_auth, add_chunks, payload, expected_code, expected_highlight, expected_message
|
||||
):
|
||||
def test_highlight(self, get_http_api_auth, add_chunks, payload, expected_code, expected_highlight, expected_message):
|
||||
dataset_id, _, _ = add_chunks
|
||||
payload.update({"question": "chunk", "dataset_ids": [dataset_id]})
|
||||
res = retrieval_chunks(get_http_api_auth, payload)
|
||||
@ -302,3 +290,14 @@ class TestChunksRetrieval:
|
||||
res = retrieval_chunks(get_http_api_auth, payload)
|
||||
assert res["code"] == 0
|
||||
assert len(res["data"]["chunks"]) == 4
|
||||
|
||||
def test_concurrent_retrieval(self, get_http_api_auth, add_chunks):
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
dataset_id, _, _ = add_chunks
|
||||
payload = {"question": "chunk", "dataset_ids": [dataset_id]}
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = [executor.submit(retrieval_chunks, get_http_api_auth, payload) for i in range(100)]
|
||||
responses = [f.result() for f in futures]
|
||||
assert all(r["code"] == 0 for r in responses)
|
||||
|
@ -18,7 +18,7 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
from random import randint
|
||||
|
||||
import pytest
|
||||
from common import INVALID_API_TOKEN, delete_documnet, update_chunk
|
||||
from common import INVALID_API_TOKEN, delete_documnets, update_chunk
|
||||
from libs.auth import RAGFlowHttpApiAuth
|
||||
|
||||
|
||||
@ -233,7 +233,7 @@ class TestUpdatedChunk:
|
||||
|
||||
def test_update_chunk_to_deleted_document(self, get_http_api_auth, add_chunks):
|
||||
dataset_id, document_id, chunk_ids = add_chunks
|
||||
delete_documnet(get_http_api_auth, dataset_id, {"ids": [document_id]})
|
||||
delete_documnets(get_http_api_auth, dataset_id, {"ids": [document_id]})
|
||||
res = update_chunk(get_http_api_auth, dataset_id, document_id, chunk_ids[0])
|
||||
assert res["code"] == 102
|
||||
assert res["message"] == f"Can't find this chunk {chunk_ids[0]}"
|
||||
|
@ -16,14 +16,24 @@
|
||||
|
||||
|
||||
import pytest
|
||||
from common import batch_create_datasets, delete_dataset
|
||||
from common import batch_create_datasets, delete_datasets
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def get_dataset_ids(get_http_api_auth, request):
|
||||
def add_datasets(get_http_api_auth, request):
|
||||
def cleanup():
|
||||
delete_dataset(get_http_api_auth)
|
||||
delete_datasets(get_http_api_auth)
|
||||
|
||||
request.addfinalizer(cleanup)
|
||||
|
||||
return batch_create_datasets(get_http_api_auth, 5)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def add_datasets_func(get_http_api_auth, request):
|
||||
def cleanup():
|
||||
delete_datasets(get_http_api_auth)
|
||||
|
||||
request.addfinalizer(cleanup)
|
||||
|
||||
return batch_create_datasets(get_http_api_auth, 3)
|
||||
|
@ -75,9 +75,6 @@ class TestDatasetCreation:
|
||||
res = create_dataset(get_http_api_auth, payload)
|
||||
assert res["code"] == 0, f"Failed to create dataset {i}"
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("clear_datasets")
|
||||
class TestAdvancedConfigurations:
|
||||
def test_avatar(self, get_http_api_auth, tmp_path):
|
||||
fn = create_image_file(tmp_path / "ragflow_test.png")
|
||||
payload = {
|
||||
|
@ -20,13 +20,12 @@ import pytest
|
||||
from common import (
|
||||
INVALID_API_TOKEN,
|
||||
batch_create_datasets,
|
||||
delete_dataset,
|
||||
list_dataset,
|
||||
delete_datasets,
|
||||
list_datasets,
|
||||
)
|
||||
from libs.auth import RAGFlowHttpApiAuth
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("clear_datasets")
|
||||
class TestAuthorization:
|
||||
@pytest.mark.parametrize(
|
||||
"auth, expected_code, expected_message",
|
||||
@ -39,18 +38,13 @@ class TestAuthorization:
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_invalid_auth(self, get_http_api_auth, auth, expected_code, expected_message):
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
res = delete_dataset(auth, {"ids": ids})
|
||||
def test_invalid_auth(self, auth, expected_code, expected_message):
|
||||
res = delete_datasets(auth)
|
||||
assert res["code"] == expected_code
|
||||
assert res["message"] == expected_message
|
||||
|
||||
res = list_dataset(get_http_api_auth)
|
||||
assert len(res["data"]) == 1
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("clear_datasets")
|
||||
class TestDatasetDeletion:
|
||||
class TestDatasetsDeletion:
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_code, expected_message, remaining",
|
||||
[
|
||||
@ -73,16 +67,16 @@ class TestDatasetDeletion:
|
||||
(lambda r: {"ids": r}, 0, "", 0),
|
||||
],
|
||||
)
|
||||
def test_basic_scenarios(self, get_http_api_auth, payload, expected_code, expected_message, remaining):
|
||||
ids = batch_create_datasets(get_http_api_auth, 3)
|
||||
def test_basic_scenarios(self, get_http_api_auth, add_datasets_func, payload, expected_code, expected_message, remaining):
|
||||
dataset_ids = add_datasets_func
|
||||
if callable(payload):
|
||||
payload = payload(ids)
|
||||
res = delete_dataset(get_http_api_auth, payload)
|
||||
payload = payload(dataset_ids)
|
||||
res = delete_datasets(get_http_api_auth, payload)
|
||||
assert res["code"] == expected_code
|
||||
if res["code"] != 0:
|
||||
assert res["message"] == expected_message
|
||||
|
||||
res = list_dataset(get_http_api_auth)
|
||||
res = list_datasets(get_http_api_auth)
|
||||
assert len(res["data"]) == remaining
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -93,50 +87,50 @@ class TestDatasetDeletion:
|
||||
lambda r: {"ids": r + ["invalid_id"]},
|
||||
],
|
||||
)
|
||||
def test_delete_partial_invalid_id(self, get_http_api_auth, payload):
|
||||
ids = batch_create_datasets(get_http_api_auth, 3)
|
||||
def test_delete_partial_invalid_id(self, get_http_api_auth, add_datasets_func, payload):
|
||||
dataset_ids = add_datasets_func
|
||||
if callable(payload):
|
||||
payload = payload(ids)
|
||||
res = delete_dataset(get_http_api_auth, payload)
|
||||
payload = payload(dataset_ids)
|
||||
res = delete_datasets(get_http_api_auth, payload)
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["errors"][0] == "You don't own the dataset invalid_id"
|
||||
assert res["data"]["success_count"] == 3
|
||||
|
||||
res = list_dataset(get_http_api_auth)
|
||||
res = list_datasets(get_http_api_auth)
|
||||
assert len(res["data"]) == 0
|
||||
|
||||
def test_repeated_deletion(self, get_http_api_auth):
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
res = delete_dataset(get_http_api_auth, {"ids": ids})
|
||||
def test_repeated_deletion(self, get_http_api_auth, add_datasets_func):
|
||||
dataset_ids = add_datasets_func
|
||||
res = delete_datasets(get_http_api_auth, {"ids": dataset_ids})
|
||||
assert res["code"] == 0
|
||||
|
||||
res = delete_dataset(get_http_api_auth, {"ids": ids})
|
||||
res = delete_datasets(get_http_api_auth, {"ids": dataset_ids})
|
||||
assert res["code"] == 102
|
||||
assert res["message"] == f"You don't own the dataset {ids[0]}"
|
||||
assert "You don't own the dataset" in res["message"]
|
||||
|
||||
def test_duplicate_deletion(self, get_http_api_auth):
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
res = delete_dataset(get_http_api_auth, {"ids": ids + ids})
|
||||
def test_duplicate_deletion(self, get_http_api_auth, add_datasets_func):
|
||||
dataset_ids = add_datasets_func
|
||||
res = delete_datasets(get_http_api_auth, {"ids": dataset_ids + dataset_ids})
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["errors"][0] == f"Duplicate dataset ids: {ids[0]}"
|
||||
assert res["data"]["success_count"] == 1
|
||||
assert "Duplicate dataset ids" in res["data"]["errors"][0]
|
||||
assert res["data"]["success_count"] == 3
|
||||
|
||||
res = list_dataset(get_http_api_auth)
|
||||
res = list_datasets(get_http_api_auth)
|
||||
assert len(res["data"]) == 0
|
||||
|
||||
def test_concurrent_deletion(self, get_http_api_auth):
|
||||
ids = batch_create_datasets(get_http_api_auth, 100)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = [executor.submit(delete_dataset, get_http_api_auth, {"ids": ids[i : i + 1]}) for i in range(100)]
|
||||
futures = [executor.submit(delete_datasets, get_http_api_auth, {"ids": ids[i : i + 1]}) for i in range(100)]
|
||||
responses = [f.result() for f in futures]
|
||||
assert all(r["code"] == 0 for r in responses)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_delete_10k(self, get_http_api_auth):
|
||||
ids = batch_create_datasets(get_http_api_auth, 10_000)
|
||||
res = delete_dataset(get_http_api_auth, {"ids": ids})
|
||||
res = delete_datasets(get_http_api_auth, {"ids": ids})
|
||||
assert res["code"] == 0
|
||||
|
||||
res = list_dataset(get_http_api_auth)
|
||||
res = list_datasets(get_http_api_auth)
|
||||
assert len(res["data"]) == 0
|
||||
|
@ -16,7 +16,7 @@
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import pytest
|
||||
from common import INVALID_API_TOKEN, list_dataset
|
||||
from common import INVALID_API_TOKEN, list_datasets
|
||||
from libs.auth import RAGFlowHttpApiAuth
|
||||
|
||||
|
||||
@ -25,7 +25,6 @@ def is_sorted(data, field, descending=True):
|
||||
return all(a >= b for a, b in zip(timestamps, timestamps[1:])) if descending else all(a <= b for a, b in zip(timestamps, timestamps[1:]))
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("clear_datasets")
|
||||
class TestAuthorization:
|
||||
@pytest.mark.parametrize(
|
||||
"auth, expected_code, expected_message",
|
||||
@ -39,15 +38,15 @@ class TestAuthorization:
|
||||
],
|
||||
)
|
||||
def test_invalid_auth(self, auth, expected_code, expected_message):
|
||||
res = list_dataset(auth)
|
||||
res = list_datasets(auth)
|
||||
assert res["code"] == expected_code
|
||||
assert res["message"] == expected_message
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("get_dataset_ids")
|
||||
class TestDatasetList:
|
||||
@pytest.mark.usefixtures("add_datasets")
|
||||
class TestDatasetsList:
|
||||
def test_default(self, get_http_api_auth):
|
||||
res = list_dataset(get_http_api_auth, params={})
|
||||
res = list_datasets(get_http_api_auth, params={})
|
||||
|
||||
assert res["code"] == 0
|
||||
assert len(res["data"]) == 5
|
||||
@ -77,7 +76,7 @@ class TestDatasetList:
|
||||
],
|
||||
)
|
||||
def test_page(self, get_http_api_auth, params, expected_code, expected_page_size, expected_message):
|
||||
res = list_dataset(get_http_api_auth, params=params)
|
||||
res = list_datasets(get_http_api_auth, params=params)
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
assert len(res["data"]) == expected_page_size
|
||||
@ -116,7 +115,7 @@ class TestDatasetList:
|
||||
expected_page_size,
|
||||
expected_message,
|
||||
):
|
||||
res = list_dataset(get_http_api_auth, params=params)
|
||||
res = list_datasets(get_http_api_auth, params=params)
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
assert len(res["data"]) == expected_page_size
|
||||
@ -168,7 +167,7 @@ class TestDatasetList:
|
||||
assertions,
|
||||
expected_message,
|
||||
):
|
||||
res = list_dataset(get_http_api_auth, params=params)
|
||||
res = list_datasets(get_http_api_auth, params=params)
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
if callable(assertions):
|
||||
@ -244,7 +243,7 @@ class TestDatasetList:
|
||||
assertions,
|
||||
expected_message,
|
||||
):
|
||||
res = list_dataset(get_http_api_auth, params=params)
|
||||
res = list_datasets(get_http_api_auth, params=params)
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
if callable(assertions):
|
||||
@ -262,7 +261,7 @@ class TestDatasetList:
|
||||
],
|
||||
)
|
||||
def test_name(self, get_http_api_auth, params, expected_code, expected_num, expected_message):
|
||||
res = list_dataset(get_http_api_auth, params=params)
|
||||
res = list_datasets(get_http_api_auth, params=params)
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
if params["name"] in [None, ""]:
|
||||
@ -284,19 +283,19 @@ class TestDatasetList:
|
||||
def test_id(
|
||||
self,
|
||||
get_http_api_auth,
|
||||
get_dataset_ids,
|
||||
add_datasets,
|
||||
dataset_id,
|
||||
expected_code,
|
||||
expected_num,
|
||||
expected_message,
|
||||
):
|
||||
dataset_ids = get_dataset_ids
|
||||
dataset_ids = add_datasets
|
||||
if callable(dataset_id):
|
||||
params = {"id": dataset_id(dataset_ids)}
|
||||
else:
|
||||
params = {"id": dataset_id}
|
||||
|
||||
res = list_dataset(get_http_api_auth, params=params)
|
||||
res = list_datasets(get_http_api_auth, params=params)
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
if params["id"] in [None, ""]:
|
||||
@ -318,20 +317,20 @@ class TestDatasetList:
|
||||
def test_name_and_id(
|
||||
self,
|
||||
get_http_api_auth,
|
||||
get_dataset_ids,
|
||||
add_datasets,
|
||||
dataset_id,
|
||||
name,
|
||||
expected_code,
|
||||
expected_num,
|
||||
expected_message,
|
||||
):
|
||||
dataset_ids = get_dataset_ids
|
||||
dataset_ids = add_datasets
|
||||
if callable(dataset_id):
|
||||
params = {"id": dataset_id(dataset_ids), "name": name}
|
||||
else:
|
||||
params = {"id": dataset_id, "name": name}
|
||||
|
||||
res = list_dataset(get_http_api_auth, params=params)
|
||||
res = list_datasets(get_http_api_auth, params=params)
|
||||
if expected_code == 0:
|
||||
assert len(res["data"]) == expected_num
|
||||
else:
|
||||
@ -339,12 +338,12 @@ class TestDatasetList:
|
||||
|
||||
def test_concurrent_list(self, get_http_api_auth):
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = [executor.submit(list_dataset, get_http_api_auth) for i in range(100)]
|
||||
futures = [executor.submit(list_datasets, get_http_api_auth) for i in range(100)]
|
||||
responses = [f.result() for f in futures]
|
||||
assert all(r["code"] == 0 for r in responses)
|
||||
|
||||
def test_invalid_params(self, get_http_api_auth):
|
||||
params = {"a": "b"}
|
||||
res = list_dataset(get_http_api_auth, params=params)
|
||||
res = list_datasets(get_http_api_auth, params=params)
|
||||
assert res["code"] == 0
|
||||
assert len(res["data"]) == 5
|
||||
|
@ -19,8 +19,7 @@ import pytest
|
||||
from common import (
|
||||
DATASET_NAME_LIMIT,
|
||||
INVALID_API_TOKEN,
|
||||
batch_create_datasets,
|
||||
list_dataset,
|
||||
list_datasets,
|
||||
update_dataset,
|
||||
)
|
||||
from libs.auth import RAGFlowHttpApiAuth
|
||||
@ -30,7 +29,6 @@ from libs.utils.file_utils import create_image_file
|
||||
# TODO: Missing scenario for updating embedding_model with chunk_count != 0
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("clear_datasets")
|
||||
class TestAuthorization:
|
||||
@pytest.mark.parametrize(
|
||||
"auth, expected_code, expected_message",
|
||||
@ -43,14 +41,12 @@ class TestAuthorization:
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_invalid_auth(self, get_http_api_auth, auth, expected_code, expected_message):
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
res = update_dataset(auth, ids[0], {"name": "new_name"})
|
||||
def test_invalid_auth(self, auth, expected_code, expected_message):
|
||||
res = update_dataset(auth, "dataset_id")
|
||||
assert res["code"] == expected_code
|
||||
assert res["message"] == expected_message
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("clear_datasets")
|
||||
class TestDatasetUpdate:
|
||||
@pytest.mark.parametrize(
|
||||
"name, expected_code, expected_message",
|
||||
@ -72,12 +68,12 @@ class TestDatasetUpdate:
|
||||
("DATASET_1", 102, "Duplicated dataset name in updating dataset."),
|
||||
],
|
||||
)
|
||||
def test_name(self, get_http_api_auth, name, expected_code, expected_message):
|
||||
ids = batch_create_datasets(get_http_api_auth, 2)
|
||||
res = update_dataset(get_http_api_auth, ids[0], {"name": name})
|
||||
def test_name(self, get_http_api_auth, add_datasets_func, name, expected_code, expected_message):
|
||||
dataset_ids = add_datasets_func
|
||||
res = update_dataset(get_http_api_auth, dataset_ids[0], {"name": name})
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
res = list_dataset(get_http_api_auth, {"id": ids[0]})
|
||||
res = list_datasets(get_http_api_auth, {"id": dataset_ids[0]})
|
||||
assert res["data"][0]["name"] == name
|
||||
else:
|
||||
assert res["message"] == expected_message
|
||||
@ -95,12 +91,12 @@ class TestDatasetUpdate:
|
||||
(None, 102, "`embedding_model` can't be empty"),
|
||||
],
|
||||
)
|
||||
def test_embedding_model(self, get_http_api_auth, embedding_model, expected_code, expected_message):
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
res = update_dataset(get_http_api_auth, ids[0], {"embedding_model": embedding_model})
|
||||
def test_embedding_model(self, get_http_api_auth, add_dataset_func, embedding_model, expected_code, expected_message):
|
||||
dataset_id = add_dataset_func
|
||||
res = update_dataset(get_http_api_auth, dataset_id, {"embedding_model": embedding_model})
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
res = list_dataset(get_http_api_auth, {"id": ids[0]})
|
||||
res = list_datasets(get_http_api_auth, {"id": dataset_id})
|
||||
assert res["data"][0]["embedding_model"] == embedding_model
|
||||
else:
|
||||
assert res["message"] == expected_message
|
||||
@ -129,12 +125,12 @@ class TestDatasetUpdate:
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_chunk_method(self, get_http_api_auth, chunk_method, expected_code, expected_message):
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
res = update_dataset(get_http_api_auth, ids[0], {"chunk_method": chunk_method})
|
||||
def test_chunk_method(self, get_http_api_auth, add_dataset_func, chunk_method, expected_code, expected_message):
|
||||
dataset_id = add_dataset_func
|
||||
res = update_dataset(get_http_api_auth, dataset_id, {"chunk_method": chunk_method})
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
res = list_dataset(get_http_api_auth, {"id": ids[0]})
|
||||
res = list_datasets(get_http_api_auth, {"id": dataset_id})
|
||||
if chunk_method != "":
|
||||
assert res["data"][0]["chunk_method"] == chunk_method
|
||||
else:
|
||||
@ -142,38 +138,38 @@ class TestDatasetUpdate:
|
||||
else:
|
||||
assert res["message"] == expected_message
|
||||
|
||||
def test_avatar(self, get_http_api_auth, tmp_path):
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
def test_avatar(self, get_http_api_auth, add_dataset_func, tmp_path):
|
||||
dataset_id = add_dataset_func
|
||||
fn = create_image_file(tmp_path / "ragflow_test.png")
|
||||
payload = {"avatar": encode_avatar(fn)}
|
||||
res = update_dataset(get_http_api_auth, ids[0], payload)
|
||||
res = update_dataset(get_http_api_auth, dataset_id, payload)
|
||||
assert res["code"] == 0
|
||||
|
||||
def test_description(self, get_http_api_auth):
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
def test_description(self, get_http_api_auth, add_dataset_func):
|
||||
dataset_id = add_dataset_func
|
||||
payload = {"description": "description"}
|
||||
res = update_dataset(get_http_api_auth, ids[0], payload)
|
||||
res = update_dataset(get_http_api_auth, dataset_id, payload)
|
||||
assert res["code"] == 0
|
||||
|
||||
res = list_dataset(get_http_api_auth, {"id": ids[0]})
|
||||
res = list_datasets(get_http_api_auth, {"id": dataset_id})
|
||||
assert res["data"][0]["description"] == "description"
|
||||
|
||||
def test_pagerank(self, get_http_api_auth):
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
def test_pagerank(self, get_http_api_auth, add_dataset_func):
|
||||
dataset_id = add_dataset_func
|
||||
payload = {"pagerank": 1}
|
||||
res = update_dataset(get_http_api_auth, ids[0], payload)
|
||||
res = update_dataset(get_http_api_auth, dataset_id, payload)
|
||||
assert res["code"] == 0
|
||||
|
||||
res = list_dataset(get_http_api_auth, {"id": ids[0]})
|
||||
res = list_datasets(get_http_api_auth, {"id": dataset_id})
|
||||
assert res["data"][0]["pagerank"] == 1
|
||||
|
||||
def test_similarity_threshold(self, get_http_api_auth):
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
def test_similarity_threshold(self, get_http_api_auth, add_dataset_func):
|
||||
dataset_id = add_dataset_func
|
||||
payload = {"similarity_threshold": 1}
|
||||
res = update_dataset(get_http_api_auth, ids[0], payload)
|
||||
res = update_dataset(get_http_api_auth, dataset_id, payload)
|
||||
assert res["code"] == 0
|
||||
|
||||
res = list_dataset(get_http_api_auth, {"id": ids[0]})
|
||||
res = list_datasets(get_http_api_auth, {"id": dataset_id})
|
||||
assert res["data"][0]["similarity_threshold"] == 1
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -187,29 +183,28 @@ class TestDatasetUpdate:
|
||||
("other_permission", 102),
|
||||
],
|
||||
)
|
||||
def test_permission(self, get_http_api_auth, permission, expected_code):
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
def test_permission(self, get_http_api_auth, add_dataset_func, permission, expected_code):
|
||||
dataset_id = add_dataset_func
|
||||
payload = {"permission": permission}
|
||||
res = update_dataset(get_http_api_auth, ids[0], payload)
|
||||
res = update_dataset(get_http_api_auth, dataset_id, payload)
|
||||
assert res["code"] == expected_code
|
||||
|
||||
res = list_dataset(get_http_api_auth, {"id": ids[0]})
|
||||
res = list_datasets(get_http_api_auth, {"id": dataset_id})
|
||||
if expected_code == 0 and permission != "":
|
||||
assert res["data"][0]["permission"] == permission
|
||||
if permission == "":
|
||||
assert res["data"][0]["permission"] == "me"
|
||||
|
||||
def test_vector_similarity_weight(self, get_http_api_auth):
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
def test_vector_similarity_weight(self, get_http_api_auth, add_dataset_func):
|
||||
dataset_id = add_dataset_func
|
||||
payload = {"vector_similarity_weight": 1}
|
||||
res = update_dataset(get_http_api_auth, ids[0], payload)
|
||||
res = update_dataset(get_http_api_auth, dataset_id, payload)
|
||||
assert res["code"] == 0
|
||||
|
||||
res = list_dataset(get_http_api_auth, {"id": ids[0]})
|
||||
res = list_datasets(get_http_api_auth, {"id": dataset_id})
|
||||
assert res["data"][0]["vector_similarity_weight"] == 1
|
||||
|
||||
def test_invalid_dataset_id(self, get_http_api_auth):
|
||||
batch_create_datasets(get_http_api_auth, 1)
|
||||
res = update_dataset(get_http_api_auth, "invalid_dataset_id", {"name": "invalid_dataset_id"})
|
||||
assert res["code"] == 102
|
||||
assert res["message"] == "You don't own the dataset"
|
||||
@ -230,21 +225,21 @@ class TestDatasetUpdate:
|
||||
{"update_time": 1741671443339},
|
||||
],
|
||||
)
|
||||
def test_modify_read_only_field(self, get_http_api_auth, payload):
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
res = update_dataset(get_http_api_auth, ids[0], payload)
|
||||
def test_modify_read_only_field(self, get_http_api_auth, add_dataset_func, payload):
|
||||
dataset_id = add_dataset_func
|
||||
res = update_dataset(get_http_api_auth, dataset_id, payload)
|
||||
assert res["code"] == 101
|
||||
assert "is readonly" in res["message"]
|
||||
|
||||
def test_modify_unknown_field(self, get_http_api_auth):
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
res = update_dataset(get_http_api_auth, ids[0], {"unknown_field": 0})
|
||||
def test_modify_unknown_field(self, get_http_api_auth, add_dataset_func):
|
||||
dataset_id = add_dataset_func
|
||||
res = update_dataset(get_http_api_auth, dataset_id, {"unknown_field": 0})
|
||||
assert res["code"] == 100
|
||||
|
||||
def test_concurrent_update(self, get_http_api_auth):
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
def test_concurrent_update(self, get_http_api_auth, add_dataset_func):
|
||||
dataset_id = add_dataset_func
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = [executor.submit(update_dataset, get_http_api_auth, ids[0], {"name": f"dataset_{i}"}) for i in range(100)]
|
||||
futures = [executor.submit(update_dataset, get_http_api_auth, dataset_id, {"name": f"dataset_{i}"}) for i in range(100)]
|
||||
responses = [f.result() for f in futures]
|
||||
assert all(r["code"] == 0 for r in responses)
|
||||
|
@ -16,22 +16,36 @@
|
||||
|
||||
|
||||
import pytest
|
||||
from common import batch_create_datasets, bulk_upload_documents, delete_dataset
|
||||
from common import bulk_upload_documents, delete_documnets
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def file_management_tmp_dir(tmp_path_factory):
|
||||
return tmp_path_factory.mktemp("file_management")
|
||||
@pytest.fixture(scope="function")
|
||||
def add_document_func(request, get_http_api_auth, add_dataset, ragflow_tmp_dir):
|
||||
dataset_id = add_dataset
|
||||
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 1, ragflow_tmp_dir)
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def get_dataset_id_and_document_ids(get_http_api_auth, file_management_tmp_dir, request):
|
||||
def cleanup():
|
||||
delete_dataset(get_http_api_auth)
|
||||
delete_documnets(get_http_api_auth, dataset_id, {"ids": document_ids})
|
||||
|
||||
request.addfinalizer(cleanup)
|
||||
return dataset_id, document_ids[0]
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def add_documents(request, get_http_api_auth, add_dataset, ragflow_tmp_dir):
|
||||
dataset_id = add_dataset
|
||||
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 5, ragflow_tmp_dir)
|
||||
|
||||
def cleanup():
|
||||
delete_documnets(get_http_api_auth, dataset_id, {"ids": document_ids})
|
||||
|
||||
request.addfinalizer(cleanup)
|
||||
return dataset_id, document_ids
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def add_documents_func(get_http_api_auth, add_dataset_func, ragflow_tmp_dir):
|
||||
dataset_id = add_dataset_func
|
||||
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 3, ragflow_tmp_dir)
|
||||
|
||||
dataset_ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
dataset_id = dataset_ids[0]
|
||||
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 5, file_management_tmp_dir)
|
||||
return dataset_id, document_ids
|
||||
|
@ -16,13 +16,7 @@
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import pytest
|
||||
from common import (
|
||||
INVALID_API_TOKEN,
|
||||
batch_create_datasets,
|
||||
bulk_upload_documents,
|
||||
delete_documnet,
|
||||
list_documnet,
|
||||
)
|
||||
from common import INVALID_API_TOKEN, bulk_upload_documents, delete_documnets, list_documnets
|
||||
from libs.auth import RAGFlowHttpApiAuth
|
||||
|
||||
|
||||
@ -38,15 +32,13 @@ class TestAuthorization:
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_invalid_auth(self, get_dataset_id_and_document_ids, auth, expected_code, expected_message):
|
||||
dataset_id, document_ids = get_dataset_id_and_document_ids
|
||||
res = delete_documnet(auth, dataset_id, {"ids": document_ids})
|
||||
def test_invalid_auth(self, auth, expected_code, expected_message):
|
||||
res = delete_documnets(auth, "dataset_id")
|
||||
assert res["code"] == expected_code
|
||||
assert res["message"] == expected_message
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("clear_datasets")
|
||||
class TestDocumentDeletion:
|
||||
class TestDocumentsDeletion:
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_code, expected_message, remaining",
|
||||
[
|
||||
@ -72,22 +64,21 @@ class TestDocumentDeletion:
|
||||
def test_basic_scenarios(
|
||||
self,
|
||||
get_http_api_auth,
|
||||
tmp_path,
|
||||
add_documents_func,
|
||||
payload,
|
||||
expected_code,
|
||||
expected_message,
|
||||
remaining,
|
||||
):
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
document_ids = bulk_upload_documents(get_http_api_auth, ids[0], 3, tmp_path)
|
||||
dataset_id, document_ids = add_documents_func
|
||||
if callable(payload):
|
||||
payload = payload(document_ids)
|
||||
res = delete_documnet(get_http_api_auth, ids[0], payload)
|
||||
res = delete_documnets(get_http_api_auth, dataset_id, payload)
|
||||
assert res["code"] == expected_code
|
||||
if res["code"] != 0:
|
||||
assert res["message"] == expected_message
|
||||
|
||||
res = list_documnet(get_http_api_auth, ids[0])
|
||||
res = list_documnets(get_http_api_auth, dataset_id)
|
||||
assert len(res["data"]["docs"]) == remaining
|
||||
assert res["data"]["total"] == remaining
|
||||
|
||||
@ -102,10 +93,9 @@ class TestDocumentDeletion:
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_invalid_dataset_id(self, get_http_api_auth, tmp_path, dataset_id, expected_code, expected_message):
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
document_ids = bulk_upload_documents(get_http_api_auth, ids[0], 3, tmp_path)
|
||||
res = delete_documnet(get_http_api_auth, dataset_id, {"ids": document_ids[:1]})
|
||||
def test_invalid_dataset_id(self, get_http_api_auth, add_documents_func, dataset_id, expected_code, expected_message):
|
||||
_, document_ids = add_documents_func
|
||||
res = delete_documnets(get_http_api_auth, dataset_id, {"ids": document_ids[:1]})
|
||||
assert res["code"] == expected_code
|
||||
assert res["message"] == expected_message
|
||||
|
||||
@ -117,69 +107,68 @@ class TestDocumentDeletion:
|
||||
lambda r: {"ids": r + ["invalid_id"]},
|
||||
],
|
||||
)
|
||||
def test_delete_partial_invalid_id(self, get_http_api_auth, tmp_path, payload):
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
document_ids = bulk_upload_documents(get_http_api_auth, ids[0], 3, tmp_path)
|
||||
def test_delete_partial_invalid_id(self, get_http_api_auth, add_documents_func, payload):
|
||||
dataset_id, document_ids = add_documents_func
|
||||
if callable(payload):
|
||||
payload = payload(document_ids)
|
||||
res = delete_documnet(get_http_api_auth, ids[0], payload)
|
||||
res = delete_documnets(get_http_api_auth, dataset_id, payload)
|
||||
assert res["code"] == 102
|
||||
assert res["message"] == "Documents not found: ['invalid_id']"
|
||||
|
||||
res = list_documnet(get_http_api_auth, ids[0])
|
||||
res = list_documnets(get_http_api_auth, dataset_id)
|
||||
assert len(res["data"]["docs"]) == 0
|
||||
assert res["data"]["total"] == 0
|
||||
|
||||
def test_repeated_deletion(self, get_http_api_auth, tmp_path):
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
document_ids = bulk_upload_documents(get_http_api_auth, ids[0], 1, tmp_path)
|
||||
res = delete_documnet(get_http_api_auth, ids[0], {"ids": document_ids})
|
||||
def test_repeated_deletion(self, get_http_api_auth, add_documents_func):
|
||||
dataset_id, document_ids = add_documents_func
|
||||
res = delete_documnets(get_http_api_auth, dataset_id, {"ids": document_ids})
|
||||
assert res["code"] == 0
|
||||
|
||||
res = delete_documnet(get_http_api_auth, ids[0], {"ids": document_ids})
|
||||
res = delete_documnets(get_http_api_auth, dataset_id, {"ids": document_ids})
|
||||
assert res["code"] == 102
|
||||
assert res["message"] == f"Documents not found: {document_ids}"
|
||||
assert "Documents not found" in res["message"]
|
||||
|
||||
def test_duplicate_deletion(self, get_http_api_auth, tmp_path):
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
document_ids = bulk_upload_documents(get_http_api_auth, ids[0], 1, tmp_path)
|
||||
res = delete_documnet(get_http_api_auth, ids[0], {"ids": document_ids + document_ids})
|
||||
def test_duplicate_deletion(self, get_http_api_auth, add_documents_func):
|
||||
dataset_id, document_ids = add_documents_func
|
||||
res = delete_documnets(get_http_api_auth, dataset_id, {"ids": document_ids + document_ids})
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["errors"][0] == f"Duplicate document ids: {document_ids[0]}"
|
||||
assert res["data"]["success_count"] == 1
|
||||
assert "Duplicate document ids" in res["data"]["errors"][0]
|
||||
assert res["data"]["success_count"] == 3
|
||||
|
||||
res = list_documnet(get_http_api_auth, ids[0])
|
||||
res = list_documnets(get_http_api_auth, dataset_id)
|
||||
assert len(res["data"]["docs"]) == 0
|
||||
assert res["data"]["total"] == 0
|
||||
|
||||
def test_concurrent_deletion(self, get_http_api_auth, tmp_path):
|
||||
documnets_num = 100
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
document_ids = bulk_upload_documents(get_http_api_auth, ids[0], documnets_num, tmp_path)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = [
|
||||
executor.submit(
|
||||
delete_documnet,
|
||||
get_http_api_auth,
|
||||
ids[0],
|
||||
{"ids": document_ids[i : i + 1]},
|
||||
)
|
||||
for i in range(documnets_num)
|
||||
]
|
||||
responses = [f.result() for f in futures]
|
||||
assert all(r["code"] == 0 for r in responses)
|
||||
def test_concurrent_deletion(get_http_api_auth, add_dataset, tmp_path):
|
||||
documnets_num = 100
|
||||
dataset_id = add_dataset
|
||||
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, documnets_num, tmp_path)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_delete_1k(self, get_http_api_auth, tmp_path):
|
||||
documnets_num = 1_000
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
document_ids = bulk_upload_documents(get_http_api_auth, ids[0], documnets_num, tmp_path)
|
||||
res = list_documnet(get_http_api_auth, ids[0])
|
||||
assert res["data"]["total"] == documnets_num
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = [
|
||||
executor.submit(
|
||||
delete_documnets,
|
||||
get_http_api_auth,
|
||||
dataset_id,
|
||||
{"ids": document_ids[i : i + 1]},
|
||||
)
|
||||
for i in range(documnets_num)
|
||||
]
|
||||
responses = [f.result() for f in futures]
|
||||
assert all(r["code"] == 0 for r in responses)
|
||||
|
||||
res = delete_documnet(get_http_api_auth, ids[0], {"ids": document_ids})
|
||||
assert res["code"] == 0
|
||||
|
||||
res = list_documnet(get_http_api_auth, ids[0])
|
||||
assert res["data"]["total"] == 0
|
||||
@pytest.mark.slow
|
||||
def test_delete_1k(get_http_api_auth, add_dataset, tmp_path):
|
||||
documnets_num = 1_000
|
||||
dataset_id = add_dataset
|
||||
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, documnets_num, tmp_path)
|
||||
res = list_documnets(get_http_api_auth, dataset_id)
|
||||
assert res["data"]["total"] == documnets_num
|
||||
|
||||
res = delete_documnets(get_http_api_auth, dataset_id, {"ids": document_ids})
|
||||
assert res["code"] == 0
|
||||
|
||||
res = list_documnets(get_http_api_auth, dataset_id)
|
||||
assert res["data"]["total"] == 0
|
||||
|
@ -18,7 +18,7 @@ import json
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import pytest
|
||||
from common import INVALID_API_TOKEN, batch_create_datasets, bulk_upload_documents, download_document, upload_documnets
|
||||
from common import INVALID_API_TOKEN, bulk_upload_documents, download_document, upload_documnets
|
||||
from libs.auth import RAGFlowHttpApiAuth
|
||||
from libs.utils import compare_by_hash
|
||||
from requests import codes
|
||||
@ -36,9 +36,8 @@ class TestAuthorization:
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_invalid_auth(self, get_dataset_id_and_document_ids, tmp_path, auth, expected_code, expected_message):
|
||||
dataset_id, document_ids = get_dataset_id_and_document_ids
|
||||
res = download_document(auth, dataset_id, document_ids[0], tmp_path / "ragflow_tes.txt")
|
||||
def test_invalid_auth(self, tmp_path, auth, expected_code, expected_message):
|
||||
res = download_document(auth, "dataset_id", "document_id", tmp_path / "ragflow_tes.txt")
|
||||
assert res.status_code == codes.ok
|
||||
with (tmp_path / "ragflow_tes.txt").open("r") as f:
|
||||
response_json = json.load(f)
|
||||
@ -46,7 +45,6 @@ class TestAuthorization:
|
||||
assert response_json["message"] == expected_message
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("clear_datasets")
|
||||
@pytest.mark.parametrize(
|
||||
"generate_test_files",
|
||||
[
|
||||
@ -63,15 +61,15 @@ class TestAuthorization:
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_file_type_validation(get_http_api_auth, generate_test_files, request):
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
def test_file_type_validation(get_http_api_auth, add_dataset, generate_test_files, request):
|
||||
dataset_id = add_dataset
|
||||
fp = generate_test_files[request.node.callspec.params["generate_test_files"]]
|
||||
res = upload_documnets(get_http_api_auth, ids[0], [fp])
|
||||
res = upload_documnets(get_http_api_auth, dataset_id, [fp])
|
||||
document_id = res["data"][0]["id"]
|
||||
|
||||
res = download_document(
|
||||
get_http_api_auth,
|
||||
ids[0],
|
||||
dataset_id,
|
||||
document_id,
|
||||
fp.with_stem("ragflow_test_download"),
|
||||
)
|
||||
@ -93,8 +91,8 @@ class TestDocumentDownload:
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_invalid_document_id(self, get_http_api_auth, get_dataset_id_and_document_ids, tmp_path, document_id, expected_code, expected_message):
|
||||
dataset_id, _ = get_dataset_id_and_document_ids
|
||||
def test_invalid_document_id(self, get_http_api_auth, add_documents, tmp_path, document_id, expected_code, expected_message):
|
||||
dataset_id, _ = add_documents
|
||||
res = download_document(
|
||||
get_http_api_auth,
|
||||
dataset_id,
|
||||
@ -118,8 +116,8 @@ class TestDocumentDownload:
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_invalid_dataset_id(self, get_http_api_auth, get_dataset_id_and_document_ids, tmp_path, dataset_id, expected_code, expected_message):
|
||||
_, document_ids = get_dataset_id_and_document_ids
|
||||
def test_invalid_dataset_id(self, get_http_api_auth, add_documents, tmp_path, dataset_id, expected_code, expected_message):
|
||||
_, document_ids = add_documents
|
||||
res = download_document(
|
||||
get_http_api_auth,
|
||||
dataset_id,
|
||||
@ -132,9 +130,9 @@ class TestDocumentDownload:
|
||||
assert response_json["code"] == expected_code
|
||||
assert response_json["message"] == expected_message
|
||||
|
||||
def test_same_file_repeat(self, get_http_api_auth, get_dataset_id_and_document_ids, tmp_path, file_management_tmp_dir):
|
||||
def test_same_file_repeat(self, get_http_api_auth, add_documents, tmp_path, ragflow_tmp_dir):
|
||||
num = 5
|
||||
dataset_id, document_ids = get_dataset_id_and_document_ids
|
||||
dataset_id, document_ids = add_documents
|
||||
for i in range(num):
|
||||
res = download_document(
|
||||
get_http_api_auth,
|
||||
@ -144,23 +142,22 @@ class TestDocumentDownload:
|
||||
)
|
||||
assert res.status_code == codes.ok
|
||||
assert compare_by_hash(
|
||||
file_management_tmp_dir / "ragflow_test_upload_0.txt",
|
||||
ragflow_tmp_dir / "ragflow_test_upload_0.txt",
|
||||
tmp_path / f"ragflow_test_download_{i}.txt",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("clear_datasets")
|
||||
def test_concurrent_download(get_http_api_auth, tmp_path):
|
||||
def test_concurrent_download(get_http_api_auth, add_dataset, tmp_path):
|
||||
document_count = 20
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
document_ids = bulk_upload_documents(get_http_api_auth, ids[0], document_count, tmp_path)
|
||||
dataset_id = add_dataset
|
||||
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_count, tmp_path)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = [
|
||||
executor.submit(
|
||||
download_document,
|
||||
get_http_api_auth,
|
||||
ids[0],
|
||||
dataset_id,
|
||||
document_ids[i],
|
||||
tmp_path / f"ragflow_test_download_{i}.txt",
|
||||
)
|
||||
|
@ -16,10 +16,7 @@
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import pytest
|
||||
from common import (
|
||||
INVALID_API_TOKEN,
|
||||
list_documnet,
|
||||
)
|
||||
from common import INVALID_API_TOKEN, list_documnets
|
||||
from libs.auth import RAGFlowHttpApiAuth
|
||||
|
||||
|
||||
@ -40,17 +37,16 @@ class TestAuthorization:
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_invalid_auth(self, get_dataset_id_and_document_ids, auth, expected_code, expected_message):
|
||||
dataset_id, _ = get_dataset_id_and_document_ids
|
||||
res = list_documnet(auth, dataset_id)
|
||||
def test_invalid_auth(self, auth, expected_code, expected_message):
|
||||
res = list_documnets(auth, "dataset_id")
|
||||
assert res["code"] == expected_code
|
||||
assert res["message"] == expected_message
|
||||
|
||||
|
||||
class TestDocumentList:
|
||||
def test_default(self, get_http_api_auth, get_dataset_id_and_document_ids):
|
||||
dataset_id, _ = get_dataset_id_and_document_ids
|
||||
res = list_documnet(get_http_api_auth, dataset_id)
|
||||
class TestDocumentsList:
|
||||
def test_default(self, get_http_api_auth, add_documents):
|
||||
dataset_id, _ = add_documents
|
||||
res = list_documnets(get_http_api_auth, dataset_id)
|
||||
assert res["code"] == 0
|
||||
assert len(res["data"]["docs"]) == 5
|
||||
assert res["data"]["total"] == 5
|
||||
@ -66,8 +62,8 @@ class TestDocumentList:
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_invalid_dataset_id(self, get_http_api_auth, get_dataset_id_and_document_ids, dataset_id, expected_code, expected_message):
|
||||
res = list_documnet(get_http_api_auth, dataset_id)
|
||||
def test_invalid_dataset_id(self, get_http_api_auth, dataset_id, expected_code, expected_message):
|
||||
res = list_documnets(get_http_api_auth, dataset_id)
|
||||
assert res["code"] == expected_code
|
||||
assert res["message"] == expected_message
|
||||
|
||||
@ -98,14 +94,14 @@ class TestDocumentList:
|
||||
def test_page(
|
||||
self,
|
||||
get_http_api_auth,
|
||||
get_dataset_id_and_document_ids,
|
||||
add_documents,
|
||||
params,
|
||||
expected_code,
|
||||
expected_page_size,
|
||||
expected_message,
|
||||
):
|
||||
dataset_id, _ = get_dataset_id_and_document_ids
|
||||
res = list_documnet(get_http_api_auth, dataset_id, params=params)
|
||||
dataset_id, _ = add_documents
|
||||
res = list_documnets(get_http_api_auth, dataset_id, params=params)
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
assert len(res["data"]["docs"]) == expected_page_size
|
||||
@ -140,14 +136,14 @@ class TestDocumentList:
|
||||
def test_page_size(
|
||||
self,
|
||||
get_http_api_auth,
|
||||
get_dataset_id_and_document_ids,
|
||||
add_documents,
|
||||
params,
|
||||
expected_code,
|
||||
expected_page_size,
|
||||
expected_message,
|
||||
):
|
||||
dataset_id, _ = get_dataset_id_and_document_ids
|
||||
res = list_documnet(get_http_api_auth, dataset_id, params=params)
|
||||
dataset_id, _ = add_documents
|
||||
res = list_documnets(get_http_api_auth, dataset_id, params=params)
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
assert len(res["data"]["docs"]) == expected_page_size
|
||||
@ -194,14 +190,14 @@ class TestDocumentList:
|
||||
def test_orderby(
|
||||
self,
|
||||
get_http_api_auth,
|
||||
get_dataset_id_and_document_ids,
|
||||
add_documents,
|
||||
params,
|
||||
expected_code,
|
||||
assertions,
|
||||
expected_message,
|
||||
):
|
||||
dataset_id, _ = get_dataset_id_and_document_ids
|
||||
res = list_documnet(get_http_api_auth, dataset_id, params=params)
|
||||
dataset_id, _ = add_documents
|
||||
res = list_documnets(get_http_api_auth, dataset_id, params=params)
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
if callable(assertions):
|
||||
@ -273,14 +269,14 @@ class TestDocumentList:
|
||||
def test_desc(
|
||||
self,
|
||||
get_http_api_auth,
|
||||
get_dataset_id_and_document_ids,
|
||||
add_documents,
|
||||
params,
|
||||
expected_code,
|
||||
assertions,
|
||||
expected_message,
|
||||
):
|
||||
dataset_id, _ = get_dataset_id_and_document_ids
|
||||
res = list_documnet(get_http_api_auth, dataset_id, params=params)
|
||||
dataset_id, _ = add_documents
|
||||
res = list_documnets(get_http_api_auth, dataset_id, params=params)
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
if callable(assertions):
|
||||
@ -298,9 +294,9 @@ class TestDocumentList:
|
||||
({"keywords": "unknown"}, 0),
|
||||
],
|
||||
)
|
||||
def test_keywords(self, get_http_api_auth, get_dataset_id_and_document_ids, params, expected_num):
|
||||
dataset_id, _ = get_dataset_id_and_document_ids
|
||||
res = list_documnet(get_http_api_auth, dataset_id, params=params)
|
||||
def test_keywords(self, get_http_api_auth, add_documents, params, expected_num):
|
||||
dataset_id, _ = add_documents
|
||||
res = list_documnets(get_http_api_auth, dataset_id, params=params)
|
||||
assert res["code"] == 0
|
||||
assert len(res["data"]["docs"]) == expected_num
|
||||
assert res["data"]["total"] == expected_num
|
||||
@ -322,14 +318,14 @@ class TestDocumentList:
|
||||
def test_name(
|
||||
self,
|
||||
get_http_api_auth,
|
||||
get_dataset_id_and_document_ids,
|
||||
add_documents,
|
||||
params,
|
||||
expected_code,
|
||||
expected_num,
|
||||
expected_message,
|
||||
):
|
||||
dataset_id, _ = get_dataset_id_and_document_ids
|
||||
res = list_documnet(get_http_api_auth, dataset_id, params=params)
|
||||
dataset_id, _ = add_documents
|
||||
res = list_documnets(get_http_api_auth, dataset_id, params=params)
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
if params["name"] in [None, ""]:
|
||||
@ -351,18 +347,18 @@ class TestDocumentList:
|
||||
def test_id(
|
||||
self,
|
||||
get_http_api_auth,
|
||||
get_dataset_id_and_document_ids,
|
||||
add_documents,
|
||||
document_id,
|
||||
expected_code,
|
||||
expected_num,
|
||||
expected_message,
|
||||
):
|
||||
dataset_id, document_ids = get_dataset_id_and_document_ids
|
||||
dataset_id, document_ids = add_documents
|
||||
if callable(document_id):
|
||||
params = {"id": document_id(document_ids)}
|
||||
else:
|
||||
params = {"id": document_id}
|
||||
res = list_documnet(get_http_api_auth, dataset_id, params=params)
|
||||
res = list_documnets(get_http_api_auth, dataset_id, params=params)
|
||||
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
@ -391,36 +387,36 @@ class TestDocumentList:
|
||||
def test_name_and_id(
|
||||
self,
|
||||
get_http_api_auth,
|
||||
get_dataset_id_and_document_ids,
|
||||
add_documents,
|
||||
document_id,
|
||||
name,
|
||||
expected_code,
|
||||
expected_num,
|
||||
expected_message,
|
||||
):
|
||||
dataset_id, document_ids = get_dataset_id_and_document_ids
|
||||
dataset_id, document_ids = add_documents
|
||||
if callable(document_id):
|
||||
params = {"id": document_id(document_ids), "name": name}
|
||||
else:
|
||||
params = {"id": document_id, "name": name}
|
||||
|
||||
res = list_documnet(get_http_api_auth, dataset_id, params=params)
|
||||
res = list_documnets(get_http_api_auth, dataset_id, params=params)
|
||||
if expected_code == 0:
|
||||
assert len(res["data"]["docs"]) == expected_num
|
||||
else:
|
||||
assert res["message"] == expected_message
|
||||
|
||||
def test_concurrent_list(self, get_http_api_auth, get_dataset_id_and_document_ids):
|
||||
dataset_id, _ = get_dataset_id_and_document_ids
|
||||
def test_concurrent_list(self, get_http_api_auth, add_documents):
|
||||
dataset_id, _ = add_documents
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = [executor.submit(list_documnet, get_http_api_auth, dataset_id) for i in range(100)]
|
||||
futures = [executor.submit(list_documnets, get_http_api_auth, dataset_id) for i in range(100)]
|
||||
responses = [f.result() for f in futures]
|
||||
assert all(r["code"] == 0 for r in responses)
|
||||
|
||||
def test_invalid_params(self, get_http_api_auth, get_dataset_id_and_document_ids):
|
||||
dataset_id, _ = get_dataset_id_and_document_ids
|
||||
def test_invalid_params(self, get_http_api_auth, add_documents):
|
||||
dataset_id, _ = add_documents
|
||||
params = {"a": "b"}
|
||||
res = list_documnet(get_http_api_auth, dataset_id, params=params)
|
||||
res = list_documnets(get_http_api_auth, dataset_id, params=params)
|
||||
assert res["code"] == 0
|
||||
assert len(res["data"]["docs"]) == 5
|
||||
|
@ -16,20 +16,14 @@
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import pytest
|
||||
from common import (
|
||||
INVALID_API_TOKEN,
|
||||
batch_create_datasets,
|
||||
bulk_upload_documents,
|
||||
list_documnet,
|
||||
parse_documnet,
|
||||
)
|
||||
from common import INVALID_API_TOKEN, bulk_upload_documents, list_documnets, parse_documnets
|
||||
from libs.auth import RAGFlowHttpApiAuth
|
||||
from libs.utils import wait_for
|
||||
|
||||
|
||||
def validate_document_details(auth, dataset_id, document_ids):
|
||||
for document_id in document_ids:
|
||||
res = list_documnet(auth, dataset_id, params={"id": document_id})
|
||||
res = list_documnets(auth, dataset_id, params={"id": document_id})
|
||||
doc = res["data"]["docs"][0]
|
||||
assert doc["run"] == "DONE"
|
||||
assert len(doc["process_begin_at"]) > 0
|
||||
@ -50,14 +44,12 @@ class TestAuthorization:
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_invalid_auth(self, get_dataset_id_and_document_ids, auth, expected_code, expected_message):
|
||||
dataset_id, document_ids = get_dataset_id_and_document_ids
|
||||
res = parse_documnet(auth, dataset_id, {"document_ids": document_ids})
|
||||
def test_invalid_auth(self, auth, expected_code, expected_message):
|
||||
res = parse_documnets(auth, "dataset_id")
|
||||
assert res["code"] == expected_code
|
||||
assert res["message"] == expected_message
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("clear_datasets")
|
||||
class TestDocumentsParse:
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_code, expected_message",
|
||||
@ -89,21 +81,19 @@ class TestDocumentsParse:
|
||||
(lambda r: {"document_ids": r}, 0, ""),
|
||||
],
|
||||
)
|
||||
def test_basic_scenarios(self, get_http_api_auth, tmp_path, payload, expected_code, expected_message):
|
||||
def test_basic_scenarios(self, get_http_api_auth, add_documents_func, payload, expected_code, expected_message):
|
||||
@wait_for(10, 1, "Document parsing timeout")
|
||||
def condition(_auth, _dataset_id, _document_ids):
|
||||
for _document_id in _document_ids:
|
||||
res = list_documnet(_auth, _dataset_id, {"id": _document_id})
|
||||
res = list_documnets(_auth, _dataset_id, {"id": _document_id})
|
||||
if res["data"]["docs"][0]["run"] != "DONE":
|
||||
return False
|
||||
return True
|
||||
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
dataset_id = ids[0]
|
||||
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 3, tmp_path)
|
||||
dataset_id, document_ids = add_documents_func
|
||||
if callable(payload):
|
||||
payload = payload(document_ids)
|
||||
res = parse_documnet(get_http_api_auth, dataset_id, payload)
|
||||
res = parse_documnets(get_http_api_auth, dataset_id, payload)
|
||||
assert res["code"] == expected_code
|
||||
if expected_code != 0:
|
||||
assert res["message"] == expected_message
|
||||
@ -125,14 +115,13 @@ class TestDocumentsParse:
|
||||
def test_invalid_dataset_id(
|
||||
self,
|
||||
get_http_api_auth,
|
||||
tmp_path,
|
||||
add_documents_func,
|
||||
dataset_id,
|
||||
expected_code,
|
||||
expected_message,
|
||||
):
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
document_ids = bulk_upload_documents(get_http_api_auth, ids[0], 1, tmp_path)
|
||||
res = parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids})
|
||||
_, document_ids = add_documents_func
|
||||
res = parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids})
|
||||
assert res["code"] == expected_code
|
||||
assert res["message"] == expected_message
|
||||
|
||||
@ -144,21 +133,19 @@ class TestDocumentsParse:
|
||||
lambda r: {"document_ids": r + ["invalid_id"]},
|
||||
],
|
||||
)
|
||||
def test_parse_partial_invalid_document_id(self, get_http_api_auth, tmp_path, payload):
|
||||
def test_parse_partial_invalid_document_id(self, get_http_api_auth, add_documents_func, payload):
|
||||
@wait_for(10, 1, "Document parsing timeout")
|
||||
def condition(_auth, _dataset_id):
|
||||
res = list_documnet(_auth, _dataset_id)
|
||||
res = list_documnets(_auth, _dataset_id)
|
||||
for doc in res["data"]["docs"]:
|
||||
if doc["run"] != "DONE":
|
||||
return False
|
||||
return True
|
||||
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
dataset_id = ids[0]
|
||||
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 3, tmp_path)
|
||||
dataset_id, document_ids = add_documents_func
|
||||
if callable(payload):
|
||||
payload = payload(document_ids)
|
||||
res = parse_documnet(get_http_api_auth, dataset_id, payload)
|
||||
res = parse_documnets(get_http_api_auth, dataset_id, payload)
|
||||
assert res["code"] == 102
|
||||
assert res["message"] == "Documents not found: ['invalid_id']"
|
||||
|
||||
@ -166,96 +153,92 @@ class TestDocumentsParse:
|
||||
|
||||
validate_document_details(get_http_api_auth, dataset_id, document_ids)
|
||||
|
||||
def test_repeated_parse(self, get_http_api_auth, tmp_path):
|
||||
def test_repeated_parse(self, get_http_api_auth, add_documents_func):
|
||||
@wait_for(10, 1, "Document parsing timeout")
|
||||
def condition(_auth, _dataset_id):
|
||||
res = list_documnet(_auth, _dataset_id)
|
||||
res = list_documnets(_auth, _dataset_id)
|
||||
for doc in res["data"]["docs"]:
|
||||
if doc["run"] != "DONE":
|
||||
return False
|
||||
return True
|
||||
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
dataset_id = ids[0]
|
||||
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 1, tmp_path)
|
||||
res = parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids})
|
||||
dataset_id, document_ids = add_documents_func
|
||||
res = parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids})
|
||||
assert res["code"] == 0
|
||||
|
||||
condition(get_http_api_auth, dataset_id)
|
||||
|
||||
res = parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids})
|
||||
res = parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids})
|
||||
assert res["code"] == 0
|
||||
|
||||
def test_duplicate_parse(self, get_http_api_auth, tmp_path):
|
||||
def test_duplicate_parse(self, get_http_api_auth, add_documents_func):
|
||||
@wait_for(10, 1, "Document parsing timeout")
|
||||
def condition(_auth, _dataset_id):
|
||||
res = list_documnet(_auth, _dataset_id)
|
||||
res = list_documnets(_auth, _dataset_id)
|
||||
for doc in res["data"]["docs"]:
|
||||
if doc["run"] != "DONE":
|
||||
return False
|
||||
return True
|
||||
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
dataset_id = ids[0]
|
||||
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 1, tmp_path)
|
||||
res = parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids + document_ids})
|
||||
dataset_id, document_ids = add_documents_func
|
||||
res = parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids + document_ids})
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["errors"][0] == f"Duplicate document ids: {document_ids[0]}"
|
||||
assert res["data"]["success_count"] == 1
|
||||
assert "Duplicate document ids" in res["data"]["errors"][0]
|
||||
assert res["data"]["success_count"] == 3
|
||||
|
||||
condition(get_http_api_auth, dataset_id)
|
||||
|
||||
validate_document_details(get_http_api_auth, dataset_id, document_ids)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_parse_100_files(self, get_http_api_auth, tmp_path):
|
||||
@wait_for(100, 1, "Document parsing timeout")
|
||||
def condition(_auth, _dataset_id, _document_num):
|
||||
res = list_documnet(_auth, _dataset_id, {"page_size": _document_num})
|
||||
for doc in res["data"]["docs"]:
|
||||
if doc["run"] != "DONE":
|
||||
return False
|
||||
return True
|
||||
|
||||
document_num = 100
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
dataset_id = ids[0]
|
||||
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path)
|
||||
res = parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids})
|
||||
assert res["code"] == 0
|
||||
@pytest.mark.slow
|
||||
def test_parse_100_files(get_http_api_auth, add_datase_func, tmp_path):
|
||||
@wait_for(100, 1, "Document parsing timeout")
|
||||
def condition(_auth, _dataset_id, _document_num):
|
||||
res = list_documnets(_auth, _dataset_id, {"page_size": _document_num})
|
||||
for doc in res["data"]["docs"]:
|
||||
if doc["run"] != "DONE":
|
||||
return False
|
||||
return True
|
||||
|
||||
condition(get_http_api_auth, dataset_id, document_num)
|
||||
document_num = 100
|
||||
dataset_id = add_datase_func
|
||||
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path)
|
||||
res = parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids})
|
||||
assert res["code"] == 0
|
||||
|
||||
validate_document_details(get_http_api_auth, dataset_id, document_ids)
|
||||
condition(get_http_api_auth, dataset_id, document_num)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_concurrent_parse(self, get_http_api_auth, tmp_path):
|
||||
@wait_for(120, 1, "Document parsing timeout")
|
||||
def condition(_auth, _dataset_id, _document_num):
|
||||
res = list_documnet(_auth, _dataset_id, {"page_size": _document_num})
|
||||
for doc in res["data"]["docs"]:
|
||||
if doc["run"] != "DONE":
|
||||
return False
|
||||
return True
|
||||
validate_document_details(get_http_api_auth, dataset_id, document_ids)
|
||||
|
||||
document_num = 100
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
dataset_id = ids[0]
|
||||
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = [
|
||||
executor.submit(
|
||||
parse_documnet,
|
||||
get_http_api_auth,
|
||||
dataset_id,
|
||||
{"document_ids": document_ids[i : i + 1]},
|
||||
)
|
||||
for i in range(document_num)
|
||||
]
|
||||
responses = [f.result() for f in futures]
|
||||
assert all(r["code"] == 0 for r in responses)
|
||||
@pytest.mark.slow
|
||||
def test_concurrent_parse(get_http_api_auth, add_datase_func, tmp_path):
|
||||
@wait_for(120, 1, "Document parsing timeout")
|
||||
def condition(_auth, _dataset_id, _document_num):
|
||||
res = list_documnets(_auth, _dataset_id, {"page_size": _document_num})
|
||||
for doc in res["data"]["docs"]:
|
||||
if doc["run"] != "DONE":
|
||||
return False
|
||||
return True
|
||||
|
||||
condition(get_http_api_auth, dataset_id, document_num)
|
||||
document_num = 100
|
||||
dataset_id = add_datase_func
|
||||
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path)
|
||||
|
||||
validate_document_details(get_http_api_auth, dataset_id, document_ids)
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = [
|
||||
executor.submit(
|
||||
parse_documnets,
|
||||
get_http_api_auth,
|
||||
dataset_id,
|
||||
{"document_ids": document_ids[i : i + 1]},
|
||||
)
|
||||
for i in range(document_num)
|
||||
]
|
||||
responses = [f.result() for f in futures]
|
||||
assert all(r["code"] == 0 for r in responses)
|
||||
|
||||
condition(get_http_api_auth, dataset_id, document_num)
|
||||
|
||||
validate_document_details(get_http_api_auth, dataset_id, document_ids)
|
||||
|
@ -16,21 +16,14 @@
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import pytest
|
||||
from common import (
|
||||
INVALID_API_TOKEN,
|
||||
batch_create_datasets,
|
||||
bulk_upload_documents,
|
||||
list_documnet,
|
||||
parse_documnet,
|
||||
stop_parse_documnet,
|
||||
)
|
||||
from common import INVALID_API_TOKEN, bulk_upload_documents, list_documnets, parse_documnets, stop_parse_documnets
|
||||
from libs.auth import RAGFlowHttpApiAuth
|
||||
from libs.utils import wait_for
|
||||
|
||||
|
||||
def validate_document_parse_done(auth, dataset_id, document_ids):
|
||||
for document_id in document_ids:
|
||||
res = list_documnet(auth, dataset_id, params={"id": document_id})
|
||||
res = list_documnets(auth, dataset_id, params={"id": document_id})
|
||||
doc = res["data"]["docs"][0]
|
||||
assert doc["run"] == "DONE"
|
||||
assert len(doc["process_begin_at"]) > 0
|
||||
@ -41,14 +34,13 @@ def validate_document_parse_done(auth, dataset_id, document_ids):
|
||||
|
||||
def validate_document_parse_cancel(auth, dataset_id, document_ids):
|
||||
for document_id in document_ids:
|
||||
res = list_documnet(auth, dataset_id, params={"id": document_id})
|
||||
res = list_documnets(auth, dataset_id, params={"id": document_id})
|
||||
doc = res["data"]["docs"][0]
|
||||
assert doc["run"] == "CANCEL"
|
||||
assert len(doc["process_begin_at"]) > 0
|
||||
assert doc["progress"] == 0.0
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("clear_datasets")
|
||||
class TestAuthorization:
|
||||
@pytest.mark.parametrize(
|
||||
"auth, expected_code, expected_message",
|
||||
@ -61,15 +53,13 @@ class TestAuthorization:
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_invalid_auth(self, get_http_api_auth, auth, expected_code, expected_message):
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
res = stop_parse_documnet(auth, ids[0])
|
||||
def test_invalid_auth(self, auth, expected_code, expected_message):
|
||||
res = stop_parse_documnets(auth, "dataset_id")
|
||||
assert res["code"] == expected_code
|
||||
assert res["message"] == expected_message
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.usefixtures("clear_datasets")
|
||||
class TestDocumentsParseStop:
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_code, expected_message",
|
||||
@ -101,24 +91,22 @@ class TestDocumentsParseStop:
|
||||
(lambda r: {"document_ids": r}, 0, ""),
|
||||
],
|
||||
)
|
||||
def test_basic_scenarios(self, get_http_api_auth, tmp_path, payload, expected_code, expected_message):
|
||||
def test_basic_scenarios(self, get_http_api_auth, add_documents_func, payload, expected_code, expected_message):
|
||||
@wait_for(10, 1, "Document parsing timeout")
|
||||
def condition(_auth, _dataset_id, _document_ids):
|
||||
for _document_id in _document_ids:
|
||||
res = list_documnet(_auth, _dataset_id, {"id": _document_id})
|
||||
res = list_documnets(_auth, _dataset_id, {"id": _document_id})
|
||||
if res["data"]["docs"][0]["run"] != "DONE":
|
||||
return False
|
||||
return True
|
||||
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
dataset_id = ids[0]
|
||||
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 3, tmp_path)
|
||||
parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids})
|
||||
dataset_id, document_ids = add_documents_func
|
||||
parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids})
|
||||
|
||||
if callable(payload):
|
||||
payload = payload(document_ids)
|
||||
|
||||
res = stop_parse_documnet(get_http_api_auth, dataset_id, payload)
|
||||
res = stop_parse_documnets(get_http_api_auth, dataset_id, payload)
|
||||
assert res["code"] == expected_code
|
||||
if expected_code != 0:
|
||||
assert res["message"] == expected_message
|
||||
@ -129,7 +117,7 @@ class TestDocumentsParseStop:
|
||||
validate_document_parse_done(get_http_api_auth, dataset_id, completed_document_ids)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dataset_id, expected_code, expected_message",
|
||||
"invalid_dataset_id, expected_code, expected_message",
|
||||
[
|
||||
("", 100, "<MethodNotAllowed '405: Method Not Allowed'>"),
|
||||
(
|
||||
@ -142,14 +130,14 @@ class TestDocumentsParseStop:
|
||||
def test_invalid_dataset_id(
|
||||
self,
|
||||
get_http_api_auth,
|
||||
tmp_path,
|
||||
dataset_id,
|
||||
add_documents_func,
|
||||
invalid_dataset_id,
|
||||
expected_code,
|
||||
expected_message,
|
||||
):
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
document_ids = bulk_upload_documents(get_http_api_auth, ids[0], 1, tmp_path)
|
||||
res = stop_parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids})
|
||||
dataset_id, document_ids = add_documents_func
|
||||
parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids})
|
||||
res = stop_parse_documnets(get_http_api_auth, invalid_dataset_id, {"document_ids": document_ids})
|
||||
assert res["code"] == expected_code
|
||||
assert res["message"] == expected_message
|
||||
|
||||
@ -162,71 +150,65 @@ class TestDocumentsParseStop:
|
||||
lambda r: {"document_ids": r + ["invalid_id"]},
|
||||
],
|
||||
)
|
||||
def test_stop_parse_partial_invalid_document_id(self, get_http_api_auth, tmp_path, payload):
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
dataset_id = ids[0]
|
||||
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 3, tmp_path)
|
||||
parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids})
|
||||
def test_stop_parse_partial_invalid_document_id(self, get_http_api_auth, add_documents_func, payload):
|
||||
dataset_id, document_ids = add_documents_func
|
||||
parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids})
|
||||
|
||||
if callable(payload):
|
||||
payload = payload(document_ids)
|
||||
res = stop_parse_documnet(get_http_api_auth, dataset_id, payload)
|
||||
res = stop_parse_documnets(get_http_api_auth, dataset_id, payload)
|
||||
assert res["code"] == 102
|
||||
assert res["message"] == "You don't own the document invalid_id."
|
||||
|
||||
validate_document_parse_cancel(get_http_api_auth, dataset_id, document_ids)
|
||||
|
||||
def test_repeated_stop_parse(self, get_http_api_auth, tmp_path):
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
dataset_id = ids[0]
|
||||
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 1, tmp_path)
|
||||
parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids})
|
||||
res = stop_parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids})
|
||||
def test_repeated_stop_parse(self, get_http_api_auth, add_documents_func):
|
||||
dataset_id, document_ids = add_documents_func
|
||||
parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids})
|
||||
res = stop_parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids})
|
||||
assert res["code"] == 0
|
||||
|
||||
res = stop_parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids})
|
||||
res = stop_parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids})
|
||||
assert res["code"] == 102
|
||||
assert res["message"] == "Can't stop parsing document with progress at 0 or 1"
|
||||
|
||||
def test_duplicate_stop_parse(self, get_http_api_auth, tmp_path):
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
dataset_id = ids[0]
|
||||
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 1, tmp_path)
|
||||
parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids})
|
||||
res = stop_parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids + document_ids})
|
||||
def test_duplicate_stop_parse(self, get_http_api_auth, add_documents_func):
|
||||
dataset_id, document_ids = add_documents_func
|
||||
parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids})
|
||||
res = stop_parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids + document_ids})
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["success_count"] == 1
|
||||
assert res["data"]["success_count"] == 3
|
||||
assert f"Duplicate document ids: {document_ids[0]}" in res["data"]["errors"]
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_stop_parse_100_files(self, get_http_api_auth, tmp_path):
|
||||
document_num = 100
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
dataset_id = ids[0]
|
||||
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path)
|
||||
parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids})
|
||||
res = stop_parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids})
|
||||
assert res["code"] == 0
|
||||
validate_document_parse_cancel(get_http_api_auth, dataset_id, document_ids)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_concurrent_parse(self, get_http_api_auth, tmp_path):
|
||||
document_num = 50
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
dataset_id = ids[0]
|
||||
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path)
|
||||
parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids})
|
||||
@pytest.mark.slow
|
||||
def test_stop_parse_100_files(get_http_api_auth, add_datase_func, tmp_path):
|
||||
document_num = 100
|
||||
dataset_id = add_datase_func
|
||||
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path)
|
||||
parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids})
|
||||
res = stop_parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids})
|
||||
assert res["code"] == 0
|
||||
validate_document_parse_cancel(get_http_api_auth, dataset_id, document_ids)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = [
|
||||
executor.submit(
|
||||
stop_parse_documnet,
|
||||
get_http_api_auth,
|
||||
dataset_id,
|
||||
{"document_ids": document_ids[i : i + 1]},
|
||||
)
|
||||
for i in range(document_num)
|
||||
]
|
||||
responses = [f.result() for f in futures]
|
||||
assert all(r["code"] == 0 for r in responses)
|
||||
validate_document_parse_cancel(get_http_api_auth, dataset_id, document_ids)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_concurrent_parse(get_http_api_auth, add_datase_func, tmp_path):
|
||||
document_num = 50
|
||||
dataset_id = add_datase_func
|
||||
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path)
|
||||
parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids})
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = [
|
||||
executor.submit(
|
||||
stop_parse_documnets,
|
||||
get_http_api_auth,
|
||||
dataset_id,
|
||||
{"document_ids": document_ids[i : i + 1]},
|
||||
)
|
||||
for i in range(document_num)
|
||||
]
|
||||
responses = [f.result() for f in futures]
|
||||
assert all(r["code"] == 0 for r in responses)
|
||||
validate_document_parse_cancel(get_http_api_auth, dataset_id, document_ids)
|
||||
|
@ -16,7 +16,7 @@
|
||||
|
||||
|
||||
import pytest
|
||||
from common import DOCUMENT_NAME_LIMIT, INVALID_API_TOKEN, batch_create_datasets, bulk_upload_documents, list_documnet, update_documnet
|
||||
from common import DOCUMENT_NAME_LIMIT, INVALID_API_TOKEN, list_documnets, update_documnet
|
||||
from libs.auth import RAGFlowHttpApiAuth
|
||||
|
||||
|
||||
@ -32,14 +32,13 @@ class TestAuthorization:
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_invalid_auth(self, get_dataset_id_and_document_ids, auth, expected_code, expected_message):
|
||||
dataset_id, document_ids = get_dataset_id_and_document_ids
|
||||
res = update_documnet(auth, dataset_id, document_ids[0], {"name": "auth_test.txt"})
|
||||
def test_invalid_auth(self, auth, expected_code, expected_message):
|
||||
res = update_documnet(auth, "dataset_id", "document_id")
|
||||
assert res["code"] == expected_code
|
||||
assert res["message"] == expected_message
|
||||
|
||||
|
||||
class TestUpdatedDocument:
|
||||
class TestDocumentsUpdated:
|
||||
@pytest.mark.parametrize(
|
||||
"name, expected_code, expected_message",
|
||||
[
|
||||
@ -81,12 +80,12 @@ class TestUpdatedDocument:
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_name(self, get_http_api_auth, get_dataset_id_and_document_ids, name, expected_code, expected_message):
|
||||
dataset_id, document_ids = get_dataset_id_and_document_ids
|
||||
def test_name(self, get_http_api_auth, add_documents, name, expected_code, expected_message):
|
||||
dataset_id, document_ids = add_documents
|
||||
res = update_documnet(get_http_api_auth, dataset_id, document_ids[0], {"name": name})
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
res = list_documnet(get_http_api_auth, dataset_id, {"id": document_ids[0]})
|
||||
res = list_documnets(get_http_api_auth, dataset_id, {"id": document_ids[0]})
|
||||
assert res["data"]["docs"][0]["name"] == name
|
||||
else:
|
||||
assert res["message"] == expected_message
|
||||
@ -102,8 +101,8 @@ class TestUpdatedDocument:
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_invalid_document_id(self, get_http_api_auth, get_dataset_id_and_document_ids, document_id, expected_code, expected_message):
|
||||
dataset_id, _ = get_dataset_id_and_document_ids
|
||||
def test_invalid_document_id(self, get_http_api_auth, add_documents, document_id, expected_code, expected_message):
|
||||
dataset_id, _ = add_documents
|
||||
res = update_documnet(get_http_api_auth, dataset_id, document_id, {"name": "new_name.txt"})
|
||||
assert res["code"] == expected_code
|
||||
assert res["message"] == expected_message
|
||||
@ -119,8 +118,8 @@ class TestUpdatedDocument:
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_invalid_dataset_id(self, get_http_api_auth, get_dataset_id_and_document_ids, dataset_id, expected_code, expected_message):
|
||||
_, document_ids = get_dataset_id_and_document_ids
|
||||
def test_invalid_dataset_id(self, get_http_api_auth, add_documents, dataset_id, expected_code, expected_message):
|
||||
_, document_ids = add_documents
|
||||
res = update_documnet(get_http_api_auth, dataset_id, document_ids[0], {"name": "new_name.txt"})
|
||||
assert res["code"] == expected_code
|
||||
assert res["message"] == expected_message
|
||||
@ -129,11 +128,11 @@ class TestUpdatedDocument:
|
||||
"meta_fields, expected_code, expected_message",
|
||||
[({"test": "test"}, 0, ""), ("test", 102, "meta_fields must be a dictionary")],
|
||||
)
|
||||
def test_meta_fields(self, get_http_api_auth, get_dataset_id_and_document_ids, meta_fields, expected_code, expected_message):
|
||||
dataset_id, document_ids = get_dataset_id_and_document_ids
|
||||
def test_meta_fields(self, get_http_api_auth, add_documents, meta_fields, expected_code, expected_message):
|
||||
dataset_id, document_ids = add_documents
|
||||
res = update_documnet(get_http_api_auth, dataset_id, document_ids[0], {"meta_fields": meta_fields})
|
||||
if expected_code == 0:
|
||||
res = list_documnet(get_http_api_auth, dataset_id, {"id": document_ids[0]})
|
||||
res = list_documnets(get_http_api_auth, dataset_id, {"id": document_ids[0]})
|
||||
assert res["data"]["docs"][0]["meta_fields"] == meta_fields
|
||||
else:
|
||||
assert res["message"] == expected_message
|
||||
@ -162,12 +161,12 @@ class TestUpdatedDocument:
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_chunk_method(self, get_http_api_auth, get_dataset_id_and_document_ids, chunk_method, expected_code, expected_message):
|
||||
dataset_id, document_ids = get_dataset_id_and_document_ids
|
||||
def test_chunk_method(self, get_http_api_auth, add_documents, chunk_method, expected_code, expected_message):
|
||||
dataset_id, document_ids = add_documents
|
||||
res = update_documnet(get_http_api_auth, dataset_id, document_ids[0], {"chunk_method": chunk_method})
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
res = list_documnet(get_http_api_auth, dataset_id, {"id": document_ids[0]})
|
||||
res = list_documnets(get_http_api_auth, dataset_id, {"id": document_ids[0]})
|
||||
if chunk_method != "":
|
||||
assert res["data"]["docs"][0]["chunk_method"] == chunk_method
|
||||
else:
|
||||
@ -282,259 +281,259 @@ class TestUpdatedDocument:
|
||||
def test_invalid_field(
|
||||
self,
|
||||
get_http_api_auth,
|
||||
get_dataset_id_and_document_ids,
|
||||
add_documents,
|
||||
payload,
|
||||
expected_code,
|
||||
expected_message,
|
||||
):
|
||||
dataset_id, document_ids = get_dataset_id_and_document_ids
|
||||
dataset_id, document_ids = add_documents
|
||||
res = update_documnet(get_http_api_auth, dataset_id, document_ids[0], payload)
|
||||
assert res["code"] == expected_code
|
||||
assert res["message"] == expected_message
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("clear_datasets")
|
||||
@pytest.mark.parametrize(
|
||||
"chunk_method, parser_config, expected_code, expected_message",
|
||||
[
|
||||
("naive", {}, 0, ""),
|
||||
(
|
||||
"naive",
|
||||
{
|
||||
"chunk_token_num": 128,
|
||||
"layout_recognize": "DeepDOC",
|
||||
"html4excel": False,
|
||||
"delimiter": "\\n!?;。;!?",
|
||||
"task_page_size": 12,
|
||||
"raptor": {"use_raptor": False},
|
||||
},
|
||||
0,
|
||||
"",
|
||||
),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"chunk_token_num": -1},
|
||||
100,
|
||||
"AssertionError('chunk_token_num should be in range from 1 to 100000000')",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"chunk_token_num": 0},
|
||||
100,
|
||||
"AssertionError('chunk_token_num should be in range from 1 to 100000000')",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"chunk_token_num": 100000000},
|
||||
100,
|
||||
"AssertionError('chunk_token_num should be in range from 1 to 100000000')",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"chunk_token_num": 3.14},
|
||||
102,
|
||||
"",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"chunk_token_num": "1024"},
|
||||
100,
|
||||
"",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
(
|
||||
"naive",
|
||||
{"layout_recognize": "DeepDOC"},
|
||||
0,
|
||||
"",
|
||||
),
|
||||
(
|
||||
"naive",
|
||||
{"layout_recognize": "Naive"},
|
||||
0,
|
||||
"",
|
||||
),
|
||||
("naive", {"html4excel": True}, 0, ""),
|
||||
("naive", {"html4excel": False}, 0, ""),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"html4excel": 1},
|
||||
100,
|
||||
"AssertionError('html4excel should be True or False')",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
("naive", {"delimiter": ""}, 0, ""),
|
||||
("naive", {"delimiter": "`##`"}, 0, ""),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"delimiter": 1},
|
||||
100,
|
||||
"",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"task_page_size": -1},
|
||||
100,
|
||||
"AssertionError('task_page_size should be in range from 1 to 100000000')",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"task_page_size": 0},
|
||||
100,
|
||||
"AssertionError('task_page_size should be in range from 1 to 100000000')",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"task_page_size": 100000000},
|
||||
100,
|
||||
"AssertionError('task_page_size should be in range from 1 to 100000000')",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"task_page_size": 3.14},
|
||||
100,
|
||||
"",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"task_page_size": "1024"},
|
||||
100,
|
||||
"",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
("naive", {"raptor": {"use_raptor": True}}, 0, ""),
|
||||
("naive", {"raptor": {"use_raptor": False}}, 0, ""),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"invalid_key": "invalid_value"},
|
||||
100,
|
||||
"""AssertionError("Abnormal \'parser_config\'. Invalid key: invalid_key")""",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"auto_keywords": -1},
|
||||
100,
|
||||
"AssertionError('auto_keywords should be in range from 0 to 32')",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"auto_keywords": 32},
|
||||
100,
|
||||
"AssertionError('auto_keywords should be in range from 0 to 32')",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"auto_questions": 3.14},
|
||||
100,
|
||||
"",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"auto_keywords": "1024"},
|
||||
100,
|
||||
"",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"auto_questions": -1},
|
||||
100,
|
||||
"AssertionError('auto_questions should be in range from 0 to 10')",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"auto_questions": 10},
|
||||
100,
|
||||
"AssertionError('auto_questions should be in range from 0 to 10')",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"auto_questions": 3.14},
|
||||
100,
|
||||
"",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"auto_questions": "1024"},
|
||||
100,
|
||||
"",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"topn_tags": -1},
|
||||
100,
|
||||
"AssertionError('topn_tags should be in range from 0 to 10')",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"topn_tags": 10},
|
||||
100,
|
||||
"AssertionError('topn_tags should be in range from 0 to 10')",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"topn_tags": 3.14},
|
||||
100,
|
||||
"",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"topn_tags": "1024"},
|
||||
100,
|
||||
"",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_parser_config(
|
||||
get_http_api_auth,
|
||||
tmp_path,
|
||||
chunk_method,
|
||||
parser_config,
|
||||
expected_code,
|
||||
expected_message,
|
||||
):
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
document_ids = bulk_upload_documents(get_http_api_auth, ids[0], 1, tmp_path)
|
||||
res = update_documnet(
|
||||
get_http_api_auth,
|
||||
ids[0],
|
||||
document_ids[0],
|
||||
{"chunk_method": chunk_method, "parser_config": parser_config},
|
||||
class TestUpdateDocumentParserConfig:
|
||||
@pytest.mark.parametrize(
|
||||
"chunk_method, parser_config, expected_code, expected_message",
|
||||
[
|
||||
("naive", {}, 0, ""),
|
||||
(
|
||||
"naive",
|
||||
{
|
||||
"chunk_token_num": 128,
|
||||
"layout_recognize": "DeepDOC",
|
||||
"html4excel": False,
|
||||
"delimiter": "\\n!?;。;!?",
|
||||
"task_page_size": 12,
|
||||
"raptor": {"use_raptor": False},
|
||||
},
|
||||
0,
|
||||
"",
|
||||
),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"chunk_token_num": -1},
|
||||
100,
|
||||
"AssertionError('chunk_token_num should be in range from 1 to 100000000')",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"chunk_token_num": 0},
|
||||
100,
|
||||
"AssertionError('chunk_token_num should be in range from 1 to 100000000')",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"chunk_token_num": 100000000},
|
||||
100,
|
||||
"AssertionError('chunk_token_num should be in range from 1 to 100000000')",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"chunk_token_num": 3.14},
|
||||
102,
|
||||
"",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"chunk_token_num": "1024"},
|
||||
100,
|
||||
"",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
(
|
||||
"naive",
|
||||
{"layout_recognize": "DeepDOC"},
|
||||
0,
|
||||
"",
|
||||
),
|
||||
(
|
||||
"naive",
|
||||
{"layout_recognize": "Naive"},
|
||||
0,
|
||||
"",
|
||||
),
|
||||
("naive", {"html4excel": True}, 0, ""),
|
||||
("naive", {"html4excel": False}, 0, ""),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"html4excel": 1},
|
||||
100,
|
||||
"AssertionError('html4excel should be True or False')",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
("naive", {"delimiter": ""}, 0, ""),
|
||||
("naive", {"delimiter": "`##`"}, 0, ""),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"delimiter": 1},
|
||||
100,
|
||||
"",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"task_page_size": -1},
|
||||
100,
|
||||
"AssertionError('task_page_size should be in range from 1 to 100000000')",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"task_page_size": 0},
|
||||
100,
|
||||
"AssertionError('task_page_size should be in range from 1 to 100000000')",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"task_page_size": 100000000},
|
||||
100,
|
||||
"AssertionError('task_page_size should be in range from 1 to 100000000')",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"task_page_size": 3.14},
|
||||
100,
|
||||
"",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"task_page_size": "1024"},
|
||||
100,
|
||||
"",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
("naive", {"raptor": {"use_raptor": True}}, 0, ""),
|
||||
("naive", {"raptor": {"use_raptor": False}}, 0, ""),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"invalid_key": "invalid_value"},
|
||||
100,
|
||||
"""AssertionError("Abnormal \'parser_config\'. Invalid key: invalid_key")""",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"auto_keywords": -1},
|
||||
100,
|
||||
"AssertionError('auto_keywords should be in range from 0 to 32')",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"auto_keywords": 32},
|
||||
100,
|
||||
"AssertionError('auto_keywords should be in range from 0 to 32')",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"auto_questions": 3.14},
|
||||
100,
|
||||
"",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"auto_keywords": "1024"},
|
||||
100,
|
||||
"",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"auto_questions": -1},
|
||||
100,
|
||||
"AssertionError('auto_questions should be in range from 0 to 10')",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"auto_questions": 10},
|
||||
100,
|
||||
"AssertionError('auto_questions should be in range from 0 to 10')",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"auto_questions": 3.14},
|
||||
100,
|
||||
"",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"auto_questions": "1024"},
|
||||
100,
|
||||
"",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"topn_tags": -1},
|
||||
100,
|
||||
"AssertionError('topn_tags should be in range from 0 to 10')",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"topn_tags": 10},
|
||||
100,
|
||||
"AssertionError('topn_tags should be in range from 0 to 10')",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"topn_tags": 3.14},
|
||||
100,
|
||||
"",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
pytest.param(
|
||||
"naive",
|
||||
{"topn_tags": "1024"},
|
||||
100,
|
||||
"",
|
||||
marks=pytest.mark.skip(reason="issues/6098"),
|
||||
),
|
||||
],
|
||||
)
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
res = list_documnet(get_http_api_auth, ids[0], {"id": document_ids[0]})
|
||||
if parser_config != {}:
|
||||
for k, v in parser_config.items():
|
||||
assert res["data"]["docs"][0]["parser_config"][k] == v
|
||||
else:
|
||||
assert res["data"]["docs"][0]["parser_config"] == {
|
||||
"chunk_token_num": 128,
|
||||
"delimiter": "\\n!?;。;!?",
|
||||
"html4excel": False,
|
||||
"layout_recognize": "DeepDOC",
|
||||
"raptor": {"use_raptor": False},
|
||||
}
|
||||
if expected_code != 0 or expected_message:
|
||||
assert res["message"] == expected_message
|
||||
def test_parser_config(
|
||||
self,
|
||||
get_http_api_auth,
|
||||
add_documents,
|
||||
chunk_method,
|
||||
parser_config,
|
||||
expected_code,
|
||||
expected_message,
|
||||
):
|
||||
dataset_id, document_ids = add_documents
|
||||
res = update_documnet(
|
||||
get_http_api_auth,
|
||||
dataset_id,
|
||||
document_ids[0],
|
||||
{"chunk_method": chunk_method, "parser_config": parser_config},
|
||||
)
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
res = list_documnets(get_http_api_auth, dataset_id, {"id": document_ids[0]})
|
||||
if parser_config != {}:
|
||||
for k, v in parser_config.items():
|
||||
assert res["data"]["docs"][0]["parser_config"][k] == v
|
||||
else:
|
||||
assert res["data"]["docs"][0]["parser_config"] == {
|
||||
"chunk_token_num": 128,
|
||||
"delimiter": "\\n!?;。;!?",
|
||||
"html4excel": False,
|
||||
"layout_recognize": "DeepDOC",
|
||||
"raptor": {"use_raptor": False},
|
||||
}
|
||||
if expected_code != 0 or expected_message:
|
||||
assert res["message"] == expected_message
|
||||
|
@ -19,15 +19,7 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from common import (
|
||||
DOCUMENT_NAME_LIMIT,
|
||||
FILE_API_URL,
|
||||
HOST_ADDRESS,
|
||||
INVALID_API_TOKEN,
|
||||
batch_create_datasets,
|
||||
list_dataset,
|
||||
upload_documnets,
|
||||
)
|
||||
from common import DOCUMENT_NAME_LIMIT, FILE_API_URL, HOST_ADDRESS, INVALID_API_TOKEN, list_datasets, upload_documnets
|
||||
from libs.auth import RAGFlowHttpApiAuth
|
||||
from libs.utils.file_utils import create_txt_file
|
||||
from requests_toolbelt import MultipartEncoder
|
||||
@ -46,21 +38,19 @@ class TestAuthorization:
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_invalid_auth(self, get_http_api_auth, auth, expected_code, expected_message):
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
res = upload_documnets(auth, ids[0])
|
||||
def test_invalid_auth(self, auth, expected_code, expected_message):
|
||||
res = upload_documnets(auth, "dataset_id")
|
||||
assert res["code"] == expected_code
|
||||
assert res["message"] == expected_message
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("clear_datasets")
|
||||
class TestUploadDocuments:
|
||||
def test_valid_single_upload(self, get_http_api_auth, tmp_path):
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
class TestDocumentsUpload:
|
||||
def test_valid_single_upload(self, get_http_api_auth, add_dataset_func, tmp_path):
|
||||
dataset_id = add_dataset_func
|
||||
fp = create_txt_file(tmp_path / "ragflow_test.txt")
|
||||
res = upload_documnets(get_http_api_auth, ids[0], [fp])
|
||||
res = upload_documnets(get_http_api_auth, dataset_id, [fp])
|
||||
assert res["code"] == 0
|
||||
assert res["data"][0]["dataset_id"] == ids[0]
|
||||
assert res["data"][0]["dataset_id"] == dataset_id
|
||||
assert res["data"][0]["name"] == fp.name
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -79,45 +69,45 @@ class TestUploadDocuments:
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_file_type_validation(self, get_http_api_auth, generate_test_files, request):
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
def test_file_type_validation(self, get_http_api_auth, add_dataset_func, generate_test_files, request):
|
||||
dataset_id = add_dataset_func
|
||||
fp = generate_test_files[request.node.callspec.params["generate_test_files"]]
|
||||
res = upload_documnets(get_http_api_auth, ids[0], [fp])
|
||||
res = upload_documnets(get_http_api_auth, dataset_id, [fp])
|
||||
assert res["code"] == 0
|
||||
assert res["data"][0]["dataset_id"] == ids[0]
|
||||
assert res["data"][0]["dataset_id"] == dataset_id
|
||||
assert res["data"][0]["name"] == fp.name
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"file_type",
|
||||
["exe", "unknown"],
|
||||
)
|
||||
def test_unsupported_file_type(self, get_http_api_auth, tmp_path, file_type):
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
def test_unsupported_file_type(self, get_http_api_auth, add_dataset_func, tmp_path, file_type):
|
||||
dataset_id = add_dataset_func
|
||||
fp = tmp_path / f"ragflow_test.{file_type}"
|
||||
fp.touch()
|
||||
res = upload_documnets(get_http_api_auth, ids[0], [fp])
|
||||
res = upload_documnets(get_http_api_auth, dataset_id, [fp])
|
||||
assert res["code"] == 500
|
||||
assert res["message"] == f"ragflow_test.{file_type}: This type of file has not been supported yet!"
|
||||
|
||||
def test_missing_file(self, get_http_api_auth):
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
res = upload_documnets(get_http_api_auth, ids[0])
|
||||
def test_missing_file(self, get_http_api_auth, add_dataset_func):
|
||||
dataset_id = add_dataset_func
|
||||
res = upload_documnets(get_http_api_auth, dataset_id)
|
||||
assert res["code"] == 101
|
||||
assert res["message"] == "No file part!"
|
||||
|
||||
def test_empty_file(self, get_http_api_auth, tmp_path):
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
def test_empty_file(self, get_http_api_auth, add_dataset_func, tmp_path):
|
||||
dataset_id = add_dataset_func
|
||||
fp = tmp_path / "empty.txt"
|
||||
fp.touch()
|
||||
|
||||
res = upload_documnets(get_http_api_auth, ids[0], [fp])
|
||||
res = upload_documnets(get_http_api_auth, dataset_id, [fp])
|
||||
assert res["code"] == 0
|
||||
assert res["data"][0]["size"] == 0
|
||||
|
||||
def test_filename_empty(self, get_http_api_auth, tmp_path):
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
def test_filename_empty(self, get_http_api_auth, add_dataset_func, tmp_path):
|
||||
dataset_id = add_dataset_func
|
||||
fp = create_txt_file(tmp_path / "ragflow_test.txt")
|
||||
url = f"{HOST_ADDRESS}{FILE_API_URL}".format(dataset_id=ids[0])
|
||||
url = f"{HOST_ADDRESS}{FILE_API_URL}".format(dataset_id=dataset_id)
|
||||
fields = (("file", ("", fp.open("rb"))),)
|
||||
m = MultipartEncoder(fields=fields)
|
||||
res = requests.post(
|
||||
@ -129,11 +119,11 @@ class TestUploadDocuments:
|
||||
assert res.json()["code"] == 101
|
||||
assert res.json()["message"] == "No file selected!"
|
||||
|
||||
def test_filename_exceeds_max_length(self, get_http_api_auth, tmp_path):
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
def test_filename_exceeds_max_length(self, get_http_api_auth, add_dataset_func, tmp_path):
|
||||
dataset_id = add_dataset_func
|
||||
# filename_length = 129
|
||||
fp = create_txt_file(tmp_path / f"{'a' * (DOCUMENT_NAME_LIMIT - 3)}.txt")
|
||||
res = upload_documnets(get_http_api_auth, ids[0], [fp])
|
||||
res = upload_documnets(get_http_api_auth, dataset_id, [fp])
|
||||
assert res["code"] == 101
|
||||
assert res["message"] == "File name should be less than 128 bytes."
|
||||
|
||||
@ -143,61 +133,61 @@ class TestUploadDocuments:
|
||||
assert res["code"] == 100
|
||||
assert res["message"] == """LookupError("Can\'t find the dataset with ID invalid_dataset_id!")"""
|
||||
|
||||
def test_duplicate_files(self, get_http_api_auth, tmp_path):
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
def test_duplicate_files(self, get_http_api_auth, add_dataset_func, tmp_path):
|
||||
dataset_id = add_dataset_func
|
||||
fp = create_txt_file(tmp_path / "ragflow_test.txt")
|
||||
res = upload_documnets(get_http_api_auth, ids[0], [fp, fp])
|
||||
res = upload_documnets(get_http_api_auth, dataset_id, [fp, fp])
|
||||
assert res["code"] == 0
|
||||
assert len(res["data"]) == 2
|
||||
for i in range(len(res["data"])):
|
||||
assert res["data"][i]["dataset_id"] == ids[0]
|
||||
assert res["data"][i]["dataset_id"] == dataset_id
|
||||
expected_name = fp.name
|
||||
if i != 0:
|
||||
expected_name = f"{fp.stem}({i}){fp.suffix}"
|
||||
assert res["data"][i]["name"] == expected_name
|
||||
|
||||
def test_same_file_repeat(self, get_http_api_auth, tmp_path):
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
def test_same_file_repeat(self, get_http_api_auth, add_dataset_func, tmp_path):
|
||||
dataset_id = add_dataset_func
|
||||
fp = create_txt_file(tmp_path / "ragflow_test.txt")
|
||||
for i in range(10):
|
||||
res = upload_documnets(get_http_api_auth, ids[0], [fp])
|
||||
res = upload_documnets(get_http_api_auth, dataset_id, [fp])
|
||||
assert res["code"] == 0
|
||||
assert len(res["data"]) == 1
|
||||
assert res["data"][0]["dataset_id"] == ids[0]
|
||||
assert res["data"][0]["dataset_id"] == dataset_id
|
||||
expected_name = fp.name
|
||||
if i != 0:
|
||||
expected_name = f"{fp.stem}({i}){fp.suffix}"
|
||||
assert res["data"][0]["name"] == expected_name
|
||||
|
||||
def test_filename_special_characters(self, get_http_api_auth, tmp_path):
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
def test_filename_special_characters(self, get_http_api_auth, add_dataset_func, tmp_path):
|
||||
dataset_id = add_dataset_func
|
||||
illegal_chars = '<>:"/\\|?*'
|
||||
translation_table = str.maketrans({char: "_" for char in illegal_chars})
|
||||
safe_filename = string.punctuation.translate(translation_table)
|
||||
fp = tmp_path / f"{safe_filename}.txt"
|
||||
fp.write_text("Sample text content")
|
||||
|
||||
res = upload_documnets(get_http_api_auth, ids[0], [fp])
|
||||
res = upload_documnets(get_http_api_auth, dataset_id, [fp])
|
||||
assert res["code"] == 0
|
||||
assert len(res["data"]) == 1
|
||||
assert res["data"][0]["dataset_id"] == ids[0]
|
||||
assert res["data"][0]["dataset_id"] == dataset_id
|
||||
assert res["data"][0]["name"] == fp.name
|
||||
|
||||
def test_multiple_files(self, get_http_api_auth, tmp_path):
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
def test_multiple_files(self, get_http_api_auth, add_dataset_func, tmp_path):
|
||||
dataset_id = add_dataset_func
|
||||
expected_document_count = 20
|
||||
fps = []
|
||||
for i in range(expected_document_count):
|
||||
fp = create_txt_file(tmp_path / f"ragflow_test_{i}.txt")
|
||||
fps.append(fp)
|
||||
res = upload_documnets(get_http_api_auth, ids[0], fps)
|
||||
res = upload_documnets(get_http_api_auth, dataset_id, fps)
|
||||
assert res["code"] == 0
|
||||
|
||||
res = list_dataset(get_http_api_auth, {"id": ids[0]})
|
||||
res = list_datasets(get_http_api_auth, {"id": dataset_id})
|
||||
assert res["data"][0]["document_count"] == expected_document_count
|
||||
|
||||
def test_concurrent_upload(self, get_http_api_auth, tmp_path):
|
||||
ids = batch_create_datasets(get_http_api_auth, 1)
|
||||
def test_concurrent_upload(self, get_http_api_auth, add_dataset_func, tmp_path):
|
||||
dataset_id = add_dataset_func
|
||||
|
||||
expected_document_count = 20
|
||||
fps = []
|
||||
@ -206,9 +196,9 @@ class TestUploadDocuments:
|
||||
fps.append(fp)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = [executor.submit(upload_documnets, get_http_api_auth, ids[0], fps[i : i + 1]) for i in range(expected_document_count)]
|
||||
futures = [executor.submit(upload_documnets, get_http_api_auth, dataset_id, fps[i : i + 1]) for i in range(expected_document_count)]
|
||||
responses = [f.result() for f in futures]
|
||||
assert all(r["code"] == 0 for r in responses)
|
||||
|
||||
res = list_dataset(get_http_api_auth, {"id": ids[0]})
|
||||
res = list_datasets(get_http_api_auth, {"id": dataset_id})
|
||||
assert res["data"][0]["document_count"] == expected_document_count
|
||||
|
Loading…
x
Reference in New Issue
Block a user