From b213e88cca59eb99573e6cd338c7782e40b8150a Mon Sep 17 00:00:00 2001 From: liu an Date: Thu, 3 Apr 2025 17:22:23 +0800 Subject: [PATCH] Test: Added test cases for List Chat Assistants HTTP API (#6792) ### What problem does this PR solve? cover [list chat assistant](https://ragflow.io/docs/v0.17.2/http_api_reference#list-chat-assistants) endpoints ### Type of change - [x] add test cases --- .../conftest.py | 33 ++ .../test_list_chat_assistants.py | 358 ++++++++++++++++++ .../conftest.py | 2 +- .../test_list_datasets.py | 1 + .../test_parse_documents.py | 51 +-- 5 files changed, 412 insertions(+), 33 deletions(-) create mode 100644 sdk/python/test/test_http_api/test_chat_assistant_management/conftest.py create mode 100644 sdk/python/test/test_http_api/test_chat_assistant_management/test_list_chat_assistants.py diff --git a/sdk/python/test/test_http_api/test_chat_assistant_management/conftest.py b/sdk/python/test/test_http_api/test_chat_assistant_management/conftest.py new file mode 100644 index 000000000..ebc0cbff1 --- /dev/null +++ b/sdk/python/test/test_http_api/test_chat_assistant_management/conftest.py @@ -0,0 +1,33 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest +from common import create_chat_assistant, delete_chat_assistants + + +@pytest.fixture(scope="class") +def add_chat_assistants(request, get_http_api_auth, add_chunks): + def cleanup(): + delete_chat_assistants(get_http_api_auth) + + request.addfinalizer(cleanup) + + dataset_id, document_id, chunk_ids = add_chunks + chat_assistant_ids = [] + for i in range(5): + res = create_chat_assistant(get_http_api_auth, {"name": f"test_chat_assistant_{i}", "dataset_ids": [dataset_id]}) + chat_assistant_ids.append(res["data"]["id"]) + + return dataset_id, document_id, chunk_ids, chat_assistant_ids diff --git a/sdk/python/test/test_http_api/test_chat_assistant_management/test_list_chat_assistants.py b/sdk/python/test/test_http_api/test_chat_assistant_management/test_list_chat_assistants.py new file mode 100644 index 000000000..7b663db1f --- /dev/null +++ b/sdk/python/test/test_http_api/test_chat_assistant_management/test_list_chat_assistants.py @@ -0,0 +1,358 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from concurrent.futures import ThreadPoolExecutor + +import pytest +from common import INVALID_API_TOKEN, delete_datasets, list_chat_assistants +from libs.auth import RAGFlowHttpApiAuth + + +def is_sorted(data, field, descending=True): + timestamps = [ds[field] for ds in data] + return all(a >= b for a, b in zip(timestamps, timestamps[1:])) if descending else all(a <= b for a, b in zip(timestamps, timestamps[1:])) + + +class TestAuthorization: + @pytest.mark.parametrize( + "auth, expected_code, expected_message", + [ + (None, 0, "`Authorization` can't be empty"), + ( + RAGFlowHttpApiAuth(INVALID_API_TOKEN), + 109, + "Authentication error: API key is invalid!", + ), + ], + ) + def test_invalid_auth(self, auth, expected_code, expected_message): + res = list_chat_assistants(auth) + assert res["code"] == expected_code + assert res["message"] == expected_message + + +@pytest.mark.usefixtures("add_chat_assistants") +class TestChatAssistantsList: + def test_default(self, get_http_api_auth): + res = list_chat_assistants(get_http_api_auth) + assert res["code"] == 0 + assert len(res["data"]) == 5 + + @pytest.mark.parametrize( + "params, expected_code, expected_page_size, expected_message", + [ + ({"page": None, "page_size": 2}, 0, 2, ""), + ({"page": 0, "page_size": 2}, 0, 2, ""), + ({"page": 2, "page_size": 2}, 0, 2, ""), + ({"page": 3, "page_size": 2}, 0, 1, ""), + ({"page": "3", "page_size": 2}, 0, 1, ""), + pytest.param( + {"page": -1, "page_size": 2}, + 100, + 0, + "1064", + marks=pytest.mark.skip(reason="issues/5851"), + ), + pytest.param( + {"page": "a", "page_size": 2}, + 100, + 0, + """ValueError("invalid literal for int() with base 10: \'a\'")""", + marks=pytest.mark.skip(reason="issues/5851"), + ), + ], + ) + def test_page(self, get_http_api_auth, params, expected_code, expected_page_size, expected_message): + res = list_chat_assistants(get_http_api_auth, params=params) + assert res["code"] == expected_code + if expected_code == 0: + assert len(res["data"]) == expected_page_size + else: + assert res["message"] == expected_message + + @pytest.mark.parametrize( + "params, expected_code, expected_page_size, expected_message", + [ + ({"page_size": None}, 0, 5, ""), + ({"page_size": 0}, 0, 0, ""), + ({"page_size": 1}, 0, 1, ""), + ({"page_size": 6}, 0, 5, ""), + ({"page_size": "1"}, 0, 1, ""), + pytest.param( + {"page_size": -1}, + 100, + 0, + "1064", + marks=pytest.mark.skip(reason="issues/5851"), + ), + pytest.param( + {"page_size": "a"}, + 100, + 0, + """ValueError("invalid literal for int() with base 10: \'a\'")""", + marks=pytest.mark.skip(reason="issues/5851"), + ), + ], + ) + def test_page_size( + self, + get_http_api_auth, + params, + expected_code, + expected_page_size, + expected_message, + ): + res = list_chat_assistants(get_http_api_auth, params=params) + assert res["code"] == expected_code + if expected_code == 0: + assert len(res["data"]) == expected_page_size + else: + assert res["message"] == expected_message + + @pytest.mark.parametrize( + "params, expected_code, assertions, expected_message", + [ + ( + {"orderby": None}, + 0, + lambda r: (is_sorted(r["data"], "create_time", True)), + "", + ), + ( + {"orderby": "create_time"}, + 0, + lambda r: (is_sorted(r["data"], "create_time", True)), + "", + ), + ( + {"orderby": "update_time"}, + 0, + lambda r: (is_sorted(r["data"], "update_time", True)), + "", + ), + pytest.param( + {"orderby": "name", "desc": "False"}, + 0, + lambda r: (is_sorted(r["data"]["docs"], "name", False)), + "", + marks=pytest.mark.skip(reason="issues/5851"), + ), + pytest.param( + {"orderby": "unknown"}, + 102, + 0, + "orderby should be create_time or update_time", + marks=pytest.mark.skip(reason="issues/5851"), + ), + ], + ) + def test_orderby( + self, + get_http_api_auth, + params, + expected_code, + assertions, + expected_message, + ): + res = list_chat_assistants(get_http_api_auth, params=params) + assert res["code"] == expected_code + if expected_code == 0: + if callable(assertions): + assert assertions(res) + else: + assert res["message"] == expected_message + + @pytest.mark.parametrize( + "params, expected_code, assertions, expected_message", + [ + ( + {"desc": None}, + 0, + lambda r: (is_sorted(r["data"], "create_time", True)), + "", + ), + ( + {"desc": "true"}, + 0, + lambda r: (is_sorted(r["data"], "create_time", True)), + "", + ), + ( + {"desc": "True"}, + 0, + lambda r: (is_sorted(r["data"], "create_time", True)), + "", + ), + ( + {"desc": True}, + 0, + lambda r: (is_sorted(r["data"], "create_time", True)), + "", + ), + ( + {"desc": "false"}, + 0, + lambda r: (is_sorted(r["data"], "create_time", False)), + "", + ), + ( + {"desc": "False"}, + 0, + lambda r: (is_sorted(r["data"], "create_time", False)), + "", + ), + ( + {"desc": False}, + 0, + lambda r: (is_sorted(r["data"], "create_time", False)), + "", + ), + ( + {"desc": "False", "orderby": "update_time"}, + 0, + lambda r: (is_sorted(r["data"], "update_time", False)), + "", + ), + pytest.param( + {"desc": "unknown"}, + 102, + 0, + "desc should be true or false", + marks=pytest.mark.skip(reason="issues/5851"), + ), + ], + ) + def test_desc( + self, + get_http_api_auth, + params, + expected_code, + assertions, + expected_message, + ): + res = list_chat_assistants(get_http_api_auth, params=params) + assert res["code"] == expected_code + if expected_code == 0: + if callable(assertions): + assert assertions(res) + else: + assert res["message"] == expected_message + + @pytest.mark.parametrize( + "params, expected_code, expected_num, expected_message", + [ + ({"name": None}, 0, 5, ""), + ({"name": ""}, 0, 5, ""), + ({"name": "test_chat_assistant_1"}, 0, 1, ""), + ({"name": "unknown"}, 102, 0, "The chat doesn't exist"), + ], + ) + def test_name(self, get_http_api_auth, params, expected_code, expected_num, expected_message): + res = list_chat_assistants(get_http_api_auth, params=params) + assert res["code"] == expected_code + if expected_code == 0: + if params["name"] in [None, ""]: + assert len(res["data"]) == expected_num + else: + assert res["data"][0]["name"] == params["name"] + else: + assert res["message"] == expected_message + + @pytest.mark.parametrize( + "chat_assistant_id, expected_code, expected_num, expected_message", + [ + (None, 0, 5, ""), + ("", 0, 5, ""), + (lambda r: r[0], 0, 1, ""), + ("unknown", 102, 0, "The chat doesn't exist"), + ], + ) + def test_id( + self, + get_http_api_auth, + add_chat_assistants, + chat_assistant_id, + expected_code, + expected_num, + expected_message, + ): + _, _, _, chat_assistant_ids = add_chat_assistants + if callable(chat_assistant_id): + params = {"id": chat_assistant_id(chat_assistant_ids)} + else: + params = {"id": chat_assistant_id} + + res = list_chat_assistants(get_http_api_auth, params=params) + assert res["code"] == expected_code + if expected_code == 0: + if params["id"] in [None, ""]: + assert len(res["data"]) == expected_num + else: + assert res["data"][0]["id"] == params["id"] + else: + assert res["message"] == expected_message + + @pytest.mark.parametrize( + "chat_assistant_id, name, expected_code, expected_num, expected_message", + [ + (lambda r: r[0], "test_chat_assistant_0", 0, 1, ""), + (lambda r: r[0], "test_chat_assistant_1", 102, 0, "The chat doesn't exist"), + (lambda r: r[0], "unknown", 102, 0, "The chat doesn't exist"), + ("id", "chat_assistant_0", 102, 0, "The chat doesn't exist"), + ], + ) + def test_name_and_id( + self, + get_http_api_auth, + add_chat_assistants, + chat_assistant_id, + name, + expected_code, + expected_num, + expected_message, + ): + _, _, _, chat_assistant_ids = add_chat_assistants + if callable(chat_assistant_id): + params = {"id": chat_assistant_id(chat_assistant_ids), "name": name} + else: + params = {"id": chat_assistant_id, "name": name} + + res = list_chat_assistants(get_http_api_auth, params=params) + assert res["code"] == expected_code + if expected_code == 0: + assert len(res["data"]) == expected_num + else: + assert res["message"] == expected_message + + def test_concurrent_list(self, get_http_api_auth): + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(list_chat_assistants, get_http_api_auth) for i in range(100)] + responses = [f.result() for f in futures] + assert all(r["code"] == 0 for r in responses) + + def test_invalid_params(self, get_http_api_auth): + params = {"a": "b"} + res = list_chat_assistants(get_http_api_auth, params=params) + assert res["code"] == 0 + assert len(res["data"]) == 5 + + def test_list_chats_after_deleting_associated_dataset(self, get_http_api_auth, add_chat_assistants): + dataset_id, _, _, _ = add_chat_assistants + res = delete_datasets(get_http_api_auth, {"ids": [dataset_id]}) + assert res["code"] == 0 + + res = list_chat_assistants(get_http_api_auth) + assert res["code"] == 0 + assert len(res["data"]) == 5 diff --git a/sdk/python/test/test_http_api/test_chunk_management_within_dataset/conftest.py b/sdk/python/test/test_http_api/test_chunk_management_within_dataset/conftest.py index 31100abc1..ab1ed2622 100644 --- a/sdk/python/test/test_http_api/test_chunk_management_within_dataset/conftest.py +++ b/sdk/python/test/test_http_api/test_chunk_management_within_dataset/conftest.py @@ -20,7 +20,7 @@ from common import add_chunk, delete_chunks, list_documnets, parse_documnets from libs.utils import wait_for -@wait_for(10, 1, "Document parsing timeout") +@wait_for(30, 1, "Document parsing timeout") def condition(_auth, _dataset_id): res = list_documnets(_auth, _dataset_id) for doc in res["data"]["docs"]: diff --git a/sdk/python/test/test_http_api/test_dataset_mangement/test_list_datasets.py b/sdk/python/test/test_http_api/test_dataset_mangement/test_list_datasets.py index 35f3057de..0418919ff 100644 --- a/sdk/python/test/test_http_api/test_dataset_mangement/test_list_datasets.py +++ b/sdk/python/test/test_http_api/test_dataset_mangement/test_list_datasets.py @@ -331,6 +331,7 @@ class TestDatasetsList: params = {"id": dataset_id, "name": name} res = list_datasets(get_http_api_auth, params=params) + assert res["code"] == expected_code if expected_code == 0: assert len(res["data"]) == expected_num else: diff --git a/sdk/python/test/test_http_api/test_file_management_within_dataset/test_parse_documents.py b/sdk/python/test/test_http_api/test_file_management_within_dataset/test_parse_documents.py index 83fa40cf9..8e81ea589 100644 --- a/sdk/python/test/test_http_api/test_file_management_within_dataset/test_parse_documents.py +++ b/sdk/python/test/test_http_api/test_file_management_within_dataset/test_parse_documents.py @@ -21,6 +21,25 @@ from libs.auth import RAGFlowHttpApiAuth from libs.utils import wait_for +@wait_for(30, 1, "Document parsing timeout") +def condition(_auth, _dataset_id, _document_ids=None): + res = list_documnets(_auth, _dataset_id) + target_docs = res["data"]["docs"] + + if _document_ids is None: + for doc in target_docs: + if doc["run"] != "DONE": + return False + return True + + target_ids = set(_document_ids) + for doc in target_docs: + if doc["id"] in target_ids: + if doc.get("run") != "DONE": + return False + return True + + def validate_document_details(auth, dataset_id, document_ids): for document_id in document_ids: res = list_documnets(auth, dataset_id, params={"id": document_id}) @@ -82,14 +101,6 @@ class TestDocumentsParse: ], ) def test_basic_scenarios(self, get_http_api_auth, add_documents_func, payload, expected_code, expected_message): - @wait_for(10, 1, "Document parsing timeout") - def condition(_auth, _dataset_id, _document_ids): - for _document_id in _document_ids: - res = list_documnets(_auth, _dataset_id, {"id": _document_id}) - if res["data"]["docs"][0]["run"] != "DONE": - return False - return True - dataset_id, document_ids = add_documents_func if callable(payload): payload = payload(document_ids) @@ -134,14 +145,6 @@ class TestDocumentsParse: ], ) def test_parse_partial_invalid_document_id(self, get_http_api_auth, add_documents_func, payload): - @wait_for(10, 1, "Document parsing timeout") - def condition(_auth, _dataset_id): - res = list_documnets(_auth, _dataset_id) - for doc in res["data"]["docs"]: - if doc["run"] != "DONE": - return False - return True - dataset_id, document_ids = add_documents_func if callable(payload): payload = payload(document_ids) @@ -154,14 +157,6 @@ class TestDocumentsParse: validate_document_details(get_http_api_auth, dataset_id, document_ids) def test_repeated_parse(self, get_http_api_auth, add_documents_func): - @wait_for(10, 1, "Document parsing timeout") - def condition(_auth, _dataset_id): - res = list_documnets(_auth, _dataset_id) - for doc in res["data"]["docs"]: - if doc["run"] != "DONE": - return False - return True - dataset_id, document_ids = add_documents_func res = parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids}) assert res["code"] == 0 @@ -172,14 +167,6 @@ class TestDocumentsParse: assert res["code"] == 0 def test_duplicate_parse(self, get_http_api_auth, add_documents_func): - @wait_for(10, 1, "Document parsing timeout") - def condition(_auth, _dataset_id): - res = list_documnets(_auth, _dataset_id) - for doc in res["data"]["docs"]: - if doc["run"] != "DONE": - return False - return True - dataset_id, document_ids = add_documents_func res = parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids + document_ids}) assert res["code"] == 0