From aca4cf436996da1019777e50c02aff09b217193d Mon Sep 17 00:00:00 2001 From: liu an Date: Mon, 31 Mar 2025 10:05:35 +0800 Subject: [PATCH] Test: Added test cases for Retrieval Chunks HTTP API (#6649) ### What problem does this PR solve? cover [retrieval chunk](https://ragflow.io/docs/v0.17.2/http_api_reference#retrieve-chunks) endpoints ### Type of change - [x] add test cases --- sdk/python/test/test_http_api/common.py | 6 + .../test_retrieval_chunks.py | 303 ++++++++++++++++++ 2 files changed, 309 insertions(+) create mode 100644 sdk/python/test/test_http_api/test_chunk_management_within_dataset/test_retrieval_chunks.py diff --git a/sdk/python/test/test_http_api/common.py b/sdk/python/test/test_http_api/common.py index a76771364..739fd06da 100644 --- a/sdk/python/test/test_http_api/common.py +++ b/sdk/python/test/test_http_api/common.py @@ -172,6 +172,12 @@ def delete_chunks(auth, dataset_id, document_id, payload=None): return res.json() +def retrieval_chunks(auth, payload=None): + url = f"{HOST_ADDRESS}/api/v1/retrieval" + res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) + return res.json() + + def batch_add_chunks(auth, dataset_id, document_id, num): chunk_ids = [] for i in range(num): diff --git a/sdk/python/test/test_http_api/test_chunk_management_within_dataset/test_retrieval_chunks.py b/sdk/python/test/test_http_api/test_chunk_management_within_dataset/test_retrieval_chunks.py new file mode 100644 index 000000000..833deff50 --- /dev/null +++ b/sdk/python/test/test_http_api/test_chunk_management_within_dataset/test_retrieval_chunks.py @@ -0,0 +1,303 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os + +import pytest +from common import ( + INVALID_API_TOKEN, + retrieval_chunks, +) +from libs.auth import RAGFlowHttpApiAuth + + +class TestAuthorization: + @pytest.mark.parametrize( + "auth, expected_code, expected_message", + [ + (None, 0, "`Authorization` can't be empty"), + ( + RAGFlowHttpApiAuth(INVALID_API_TOKEN), + 109, + "Authentication error: API key is invalid!", + ), + ], + ) + def test_invalid_auth(self, auth, expected_code, expected_message): + res = retrieval_chunks(auth) + assert res["code"] == expected_code + assert res["message"] == expected_message + + +class TestChunksRetrieval: + @pytest.mark.parametrize( + "payload, expected_code, expected_page_size, expected_message", + [ + ({"question": "chunk", "dataset_ids": None}, 0, 4, ""), + ({"question": "chunk", "document_ids": None}, 102, 0, "`dataset_ids` is required."), + ({"question": "chunk", "dataset_ids": None, "document_ids": None}, 0, 4, ""), + ({"question": "chunk"}, 102, 0, "`dataset_ids` is required."), + ], + ) + def test_basic_scenarios( + self, get_http_api_auth, add_chunks, payload, expected_code, expected_page_size, expected_message + ): + dataset_id, document_id, _ = add_chunks + if "dataset_ids" in payload: + payload["dataset_ids"] = [dataset_id] + if "document_ids" in payload: + payload["document_ids"] = [document_id] + res = retrieval_chunks(get_http_api_auth, payload) + assert res["code"] == expected_code + if expected_code == 0: + assert len(res["data"]["chunks"]) == expected_page_size + else: + assert res["message"] == expected_message + + @pytest.mark.parametrize( + "payload, expected_code, expected_page_size, expected_message", + [ + pytest.param( + {"page": None, "page_size": 2}, + 100, + 2, + """TypeError("int() argument must be a string, a bytes-like object or a real number, not \'NoneType\'")""", + marks=pytest.mark.skip, + ), + pytest.param( + {"page": 0, "page_size": 2}, + 100, + 0, + "ValueError('Search does not support negative slicing.')", + marks=pytest.mark.skip, + ), + pytest.param({"page": 2, "page_size": 2}, 0, 2, "", marks=pytest.mark.skip(reason="issues/6646")), + ({"page": 3, "page_size": 2}, 0, 0, ""), + ({"page": "3", "page_size": 2}, 0, 0, ""), + pytest.param( + {"page": -1, "page_size": 2}, + 100, + 0, + "ValueError('Search does not support negative slicing.')", + marks=pytest.mark.skip, + ), + pytest.param( + {"page": "a", "page_size": 2}, + 100, + 0, + """ValueError("invalid literal for int() with base 10: \'a\'")""", + marks=pytest.mark.skip, + ), + ], + ) + def test_page(self, get_http_api_auth, add_chunks, payload, expected_code, expected_page_size, expected_message): + dataset_id, _, _ = add_chunks + payload.update({"question": "chunk", "dataset_ids": [dataset_id]}) + res = retrieval_chunks(get_http_api_auth, payload) + assert res["code"] == expected_code + if expected_code == 0: + assert len(res["data"]["chunks"]) == expected_page_size + else: + assert res["message"] == expected_message + + @pytest.mark.parametrize( + "payload, expected_code, expected_page_size, expected_message", + [ + pytest.param( + {"page_size": None}, + 100, + 0, + """TypeError("int() argument must be a string, a bytes-like object or a real number, not \'NoneType\'")""", + marks=pytest.mark.skip, + ), + ({"page_size": 0}, 0, 0, ""), + ({"page_size": 1}, 0, 1, ""), + ({"page_size": 5}, 0, 4, ""), + ({"page_size": "1"}, 0, 1, ""), + ({"page_size": -1}, 0, 0, ""), + pytest.param( + {"page_size": "a"}, + 100, + 0, + """ValueError("invalid literal for int() with base 10: \'a\'")""", + marks=pytest.mark.skip, + ), + ], + ) + def test_page_size( + self, get_http_api_auth, add_chunks, payload, expected_code, expected_page_size, expected_message + ): + dataset_id, _, _ = add_chunks + payload.update({"question": "chunk", "dataset_ids": [dataset_id]}) + + res = retrieval_chunks(get_http_api_auth, payload) + assert res["code"] == expected_code + if expected_code == 0: + assert len(res["data"]["chunks"]) == expected_page_size + else: + assert res["message"] == expected_message + + @pytest.mark.parametrize( + "payload, expected_code, expected_page_size, expected_message", + [ + ({"vector_similarity_weight": 0}, 0, 4, ""), + ({"vector_similarity_weight": 0.5}, 0, 4, ""), + ({"vector_similarity_weight": 10}, 0, 4, ""), + pytest.param( + {"vector_similarity_weight": "a"}, + 100, + 0, + """ValueError("could not convert string to float: \'a\'")""", + marks=pytest.mark.skip, + ), + ], + ) + def test_vector_similarity_weight( + self, get_http_api_auth, add_chunks, payload, expected_code, expected_page_size, expected_message + ): + dataset_id, _, _ = add_chunks + payload.update({"question": "chunk", "dataset_ids": [dataset_id]}) + res = retrieval_chunks(get_http_api_auth, payload) + assert res["code"] == expected_code + if expected_code == 0: + assert len(res["data"]["chunks"]) == expected_page_size + else: + assert res["message"] == expected_message + + @pytest.mark.parametrize( + "payload, expected_code, expected_page_size, expected_message", + [ + ({"top_k": 10}, 0, 4, ""), + pytest.param( + {"top_k": 1}, + 0, + 4, + "", + marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="Infinity"), + ), + pytest.param( + {"top_k": 1}, + 0, + 1, + "", + marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "elasticsearch"], reason="elasticsearch"), + ), + pytest.param( + {"top_k": -1}, + 100, + 4, + "must be greater than 0", + marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="Infinity"), + ), + pytest.param( + {"top_k": -1}, + 100, + 4, + "3014", + marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "elasticsearch"], reason="elasticsearch"), + ), + pytest.param( + {"top_k": "a"}, + 100, + 0, + """ValueError("invalid literal for int() with base 10: \'a\'")""", + marks=pytest.mark.skip, + ), + ], + ) + def test_top_k(self, get_http_api_auth, add_chunks, payload, expected_code, expected_page_size, expected_message): + dataset_id, _, _ = add_chunks + payload.update({"question": "chunk", "dataset_ids": [dataset_id]}) + res = retrieval_chunks(get_http_api_auth, payload) + assert res["code"] == expected_code + if expected_code == 0: + assert len(res["data"]["chunks"]) == expected_page_size + else: + assert expected_message in res["message"] + + @pytest.mark.parametrize( + "payload, expected_code, expected_message", + [ + ({"rerank_id": "BAAI/bge-reranker-v2-m3"}, 0, ""), + pytest.param( + {"rerank_id": "unknown"}, 100, "LookupError('Model(unknown) not authorized')", marks=pytest.mark.skip + ), + ], + ) + def test_rerank_id(self, get_http_api_auth, add_chunks, payload, expected_code, expected_message): + dataset_id, _, _ = add_chunks + payload.update({"question": "chunk", "dataset_ids": [dataset_id]}) + res = retrieval_chunks(get_http_api_auth, payload) + assert res["code"] == expected_code + if expected_code == 0: + assert len(res["data"]["chunks"]) > 0 + else: + assert expected_message in res["message"] + + @pytest.mark.skip(reason="chat model is not set") + @pytest.mark.parametrize( + "payload, expected_code, expected_page_size, expected_message", + [ + ({"keyword": True}, 0, 5, ""), + ({"keyword": "True"}, 0, 5, ""), + ({"keyword": False}, 0, 5, ""), + ({"keyword": "False"}, 0, 5, ""), + ({"keyword": None}, 0, 5, ""), + ], + ) + def test_keyword(self, get_http_api_auth, add_chunks, payload, expected_code, expected_page_size, expected_message): + dataset_id, _, _ = add_chunks + payload.update({"question": "chunk test", "dataset_ids": [dataset_id]}) + res = retrieval_chunks(get_http_api_auth, payload) + assert res["code"] == expected_code + if expected_code == 0: + assert len(res["data"]["chunks"]) == expected_page_size + else: + assert res["message"] == expected_message + + @pytest.mark.parametrize( + "payload, expected_code, expected_highlight, expected_message", + [ + ({"highlight": True}, 0, True, ""), + ({"highlight": "True"}, 0, True, ""), + pytest.param({"highlight": False}, 0, False, "", marks=pytest.mark.skip(reason="issues/6648")), + ({"highlight": "False"}, 0, False, ""), + pytest.param({"highlight": None}, 0, False, "", marks=pytest.mark.skip(reason="issues/6648")), + ], + ) + def test_highlight( + self, get_http_api_auth, add_chunks, payload, expected_code, expected_highlight, expected_message + ): + dataset_id, _, _ = add_chunks + payload.update({"question": "chunk", "dataset_ids": [dataset_id]}) + res = retrieval_chunks(get_http_api_auth, payload) + assert res["code"] == expected_code + if expected_highlight: + for chunk in res["data"]["chunks"]: + assert "highlight" in chunk + else: + for chunk in res["data"]["chunks"]: + assert "highlight" not in chunk + + if expected_code != 0: + assert res["message"] == expected_message + + def test_invalid_params(self, get_http_api_auth, add_chunks): + dataset_id, _, _ = add_chunks + payload = {"question": "chunk", "dataset_ids": [dataset_id], "a": "b"} + res = retrieval_chunks(get_http_api_auth, payload) + assert res["code"] == 0 + assert len(res["data"]["chunks"]) == 4