mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-12 17:59:04 +08:00
Test: Added test cases for Retrieval Chunks HTTP API (#6649)
### What problem does this PR solve? cover [retrieval chunk](https://ragflow.io/docs/v0.17.2/http_api_reference#retrieve-chunks) endpoints ### Type of change - [x] add test cases
This commit is contained in:
parent
9aa047257a
commit
aca4cf4369
@ -172,6 +172,12 @@ def delete_chunks(auth, dataset_id, document_id, payload=None):
|
||||
return res.json()
|
||||
|
||||
|
||||
def retrieval_chunks(auth, payload=None):
|
||||
url = f"{HOST_ADDRESS}/api/v1/retrieval"
|
||||
res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload)
|
||||
return res.json()
|
||||
|
||||
|
||||
def batch_add_chunks(auth, dataset_id, document_id, num):
|
||||
chunk_ids = []
|
||||
for i in range(num):
|
||||
|
@ -0,0 +1,303 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from common import (
|
||||
INVALID_API_TOKEN,
|
||||
retrieval_chunks,
|
||||
)
|
||||
from libs.auth import RAGFlowHttpApiAuth
|
||||
|
||||
|
||||
class TestAuthorization:
|
||||
@pytest.mark.parametrize(
|
||||
"auth, expected_code, expected_message",
|
||||
[
|
||||
(None, 0, "`Authorization` can't be empty"),
|
||||
(
|
||||
RAGFlowHttpApiAuth(INVALID_API_TOKEN),
|
||||
109,
|
||||
"Authentication error: API key is invalid!",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_invalid_auth(self, auth, expected_code, expected_message):
|
||||
res = retrieval_chunks(auth)
|
||||
assert res["code"] == expected_code
|
||||
assert res["message"] == expected_message
|
||||
|
||||
|
||||
class TestChunksRetrieval:
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_code, expected_page_size, expected_message",
|
||||
[
|
||||
({"question": "chunk", "dataset_ids": None}, 0, 4, ""),
|
||||
({"question": "chunk", "document_ids": None}, 102, 0, "`dataset_ids` is required."),
|
||||
({"question": "chunk", "dataset_ids": None, "document_ids": None}, 0, 4, ""),
|
||||
({"question": "chunk"}, 102, 0, "`dataset_ids` is required."),
|
||||
],
|
||||
)
|
||||
def test_basic_scenarios(
|
||||
self, get_http_api_auth, add_chunks, payload, expected_code, expected_page_size, expected_message
|
||||
):
|
||||
dataset_id, document_id, _ = add_chunks
|
||||
if "dataset_ids" in payload:
|
||||
payload["dataset_ids"] = [dataset_id]
|
||||
if "document_ids" in payload:
|
||||
payload["document_ids"] = [document_id]
|
||||
res = retrieval_chunks(get_http_api_auth, payload)
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
assert len(res["data"]["chunks"]) == expected_page_size
|
||||
else:
|
||||
assert res["message"] == expected_message
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_code, expected_page_size, expected_message",
|
||||
[
|
||||
pytest.param(
|
||||
{"page": None, "page_size": 2},
|
||||
100,
|
||||
2,
|
||||
"""TypeError("int() argument must be a string, a bytes-like object or a real number, not \'NoneType\'")""",
|
||||
marks=pytest.mark.skip,
|
||||
),
|
||||
pytest.param(
|
||||
{"page": 0, "page_size": 2},
|
||||
100,
|
||||
0,
|
||||
"ValueError('Search does not support negative slicing.')",
|
||||
marks=pytest.mark.skip,
|
||||
),
|
||||
pytest.param({"page": 2, "page_size": 2}, 0, 2, "", marks=pytest.mark.skip(reason="issues/6646")),
|
||||
({"page": 3, "page_size": 2}, 0, 0, ""),
|
||||
({"page": "3", "page_size": 2}, 0, 0, ""),
|
||||
pytest.param(
|
||||
{"page": -1, "page_size": 2},
|
||||
100,
|
||||
0,
|
||||
"ValueError('Search does not support negative slicing.')",
|
||||
marks=pytest.mark.skip,
|
||||
),
|
||||
pytest.param(
|
||||
{"page": "a", "page_size": 2},
|
||||
100,
|
||||
0,
|
||||
"""ValueError("invalid literal for int() with base 10: \'a\'")""",
|
||||
marks=pytest.mark.skip,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_page(self, get_http_api_auth, add_chunks, payload, expected_code, expected_page_size, expected_message):
|
||||
dataset_id, _, _ = add_chunks
|
||||
payload.update({"question": "chunk", "dataset_ids": [dataset_id]})
|
||||
res = retrieval_chunks(get_http_api_auth, payload)
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
assert len(res["data"]["chunks"]) == expected_page_size
|
||||
else:
|
||||
assert res["message"] == expected_message
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_code, expected_page_size, expected_message",
|
||||
[
|
||||
pytest.param(
|
||||
{"page_size": None},
|
||||
100,
|
||||
0,
|
||||
"""TypeError("int() argument must be a string, a bytes-like object or a real number, not \'NoneType\'")""",
|
||||
marks=pytest.mark.skip,
|
||||
),
|
||||
({"page_size": 0}, 0, 0, ""),
|
||||
({"page_size": 1}, 0, 1, ""),
|
||||
({"page_size": 5}, 0, 4, ""),
|
||||
({"page_size": "1"}, 0, 1, ""),
|
||||
({"page_size": -1}, 0, 0, ""),
|
||||
pytest.param(
|
||||
{"page_size": "a"},
|
||||
100,
|
||||
0,
|
||||
"""ValueError("invalid literal for int() with base 10: \'a\'")""",
|
||||
marks=pytest.mark.skip,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_page_size(
|
||||
self, get_http_api_auth, add_chunks, payload, expected_code, expected_page_size, expected_message
|
||||
):
|
||||
dataset_id, _, _ = add_chunks
|
||||
payload.update({"question": "chunk", "dataset_ids": [dataset_id]})
|
||||
|
||||
res = retrieval_chunks(get_http_api_auth, payload)
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
assert len(res["data"]["chunks"]) == expected_page_size
|
||||
else:
|
||||
assert res["message"] == expected_message
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_code, expected_page_size, expected_message",
|
||||
[
|
||||
({"vector_similarity_weight": 0}, 0, 4, ""),
|
||||
({"vector_similarity_weight": 0.5}, 0, 4, ""),
|
||||
({"vector_similarity_weight": 10}, 0, 4, ""),
|
||||
pytest.param(
|
||||
{"vector_similarity_weight": "a"},
|
||||
100,
|
||||
0,
|
||||
"""ValueError("could not convert string to float: \'a\'")""",
|
||||
marks=pytest.mark.skip,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_vector_similarity_weight(
|
||||
self, get_http_api_auth, add_chunks, payload, expected_code, expected_page_size, expected_message
|
||||
):
|
||||
dataset_id, _, _ = add_chunks
|
||||
payload.update({"question": "chunk", "dataset_ids": [dataset_id]})
|
||||
res = retrieval_chunks(get_http_api_auth, payload)
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
assert len(res["data"]["chunks"]) == expected_page_size
|
||||
else:
|
||||
assert res["message"] == expected_message
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_code, expected_page_size, expected_message",
|
||||
[
|
||||
({"top_k": 10}, 0, 4, ""),
|
||||
pytest.param(
|
||||
{"top_k": 1},
|
||||
0,
|
||||
4,
|
||||
"",
|
||||
marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="Infinity"),
|
||||
),
|
||||
pytest.param(
|
||||
{"top_k": 1},
|
||||
0,
|
||||
1,
|
||||
"",
|
||||
marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "elasticsearch"], reason="elasticsearch"),
|
||||
),
|
||||
pytest.param(
|
||||
{"top_k": -1},
|
||||
100,
|
||||
4,
|
||||
"must be greater than 0",
|
||||
marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="Infinity"),
|
||||
),
|
||||
pytest.param(
|
||||
{"top_k": -1},
|
||||
100,
|
||||
4,
|
||||
"3014",
|
||||
marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "elasticsearch"], reason="elasticsearch"),
|
||||
),
|
||||
pytest.param(
|
||||
{"top_k": "a"},
|
||||
100,
|
||||
0,
|
||||
"""ValueError("invalid literal for int() with base 10: \'a\'")""",
|
||||
marks=pytest.mark.skip,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_top_k(self, get_http_api_auth, add_chunks, payload, expected_code, expected_page_size, expected_message):
|
||||
dataset_id, _, _ = add_chunks
|
||||
payload.update({"question": "chunk", "dataset_ids": [dataset_id]})
|
||||
res = retrieval_chunks(get_http_api_auth, payload)
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
assert len(res["data"]["chunks"]) == expected_page_size
|
||||
else:
|
||||
assert expected_message in res["message"]
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_code, expected_message",
|
||||
[
|
||||
({"rerank_id": "BAAI/bge-reranker-v2-m3"}, 0, ""),
|
||||
pytest.param(
|
||||
{"rerank_id": "unknown"}, 100, "LookupError('Model(unknown) not authorized')", marks=pytest.mark.skip
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_rerank_id(self, get_http_api_auth, add_chunks, payload, expected_code, expected_message):
|
||||
dataset_id, _, _ = add_chunks
|
||||
payload.update({"question": "chunk", "dataset_ids": [dataset_id]})
|
||||
res = retrieval_chunks(get_http_api_auth, payload)
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
assert len(res["data"]["chunks"]) > 0
|
||||
else:
|
||||
assert expected_message in res["message"]
|
||||
|
||||
@pytest.mark.skip(reason="chat model is not set")
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_code, expected_page_size, expected_message",
|
||||
[
|
||||
({"keyword": True}, 0, 5, ""),
|
||||
({"keyword": "True"}, 0, 5, ""),
|
||||
({"keyword": False}, 0, 5, ""),
|
||||
({"keyword": "False"}, 0, 5, ""),
|
||||
({"keyword": None}, 0, 5, ""),
|
||||
],
|
||||
)
|
||||
def test_keyword(self, get_http_api_auth, add_chunks, payload, expected_code, expected_page_size, expected_message):
|
||||
dataset_id, _, _ = add_chunks
|
||||
payload.update({"question": "chunk test", "dataset_ids": [dataset_id]})
|
||||
res = retrieval_chunks(get_http_api_auth, payload)
|
||||
assert res["code"] == expected_code
|
||||
if expected_code == 0:
|
||||
assert len(res["data"]["chunks"]) == expected_page_size
|
||||
else:
|
||||
assert res["message"] == expected_message
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_code, expected_highlight, expected_message",
|
||||
[
|
||||
({"highlight": True}, 0, True, ""),
|
||||
({"highlight": "True"}, 0, True, ""),
|
||||
pytest.param({"highlight": False}, 0, False, "", marks=pytest.mark.skip(reason="issues/6648")),
|
||||
({"highlight": "False"}, 0, False, ""),
|
||||
pytest.param({"highlight": None}, 0, False, "", marks=pytest.mark.skip(reason="issues/6648")),
|
||||
],
|
||||
)
|
||||
def test_highlight(
|
||||
self, get_http_api_auth, add_chunks, payload, expected_code, expected_highlight, expected_message
|
||||
):
|
||||
dataset_id, _, _ = add_chunks
|
||||
payload.update({"question": "chunk", "dataset_ids": [dataset_id]})
|
||||
res = retrieval_chunks(get_http_api_auth, payload)
|
||||
assert res["code"] == expected_code
|
||||
if expected_highlight:
|
||||
for chunk in res["data"]["chunks"]:
|
||||
assert "highlight" in chunk
|
||||
else:
|
||||
for chunk in res["data"]["chunks"]:
|
||||
assert "highlight" not in chunk
|
||||
|
||||
if expected_code != 0:
|
||||
assert res["message"] == expected_message
|
||||
|
||||
def test_invalid_params(self, get_http_api_auth, add_chunks):
|
||||
dataset_id, _, _ = add_chunks
|
||||
payload = {"question": "chunk", "dataset_ids": [dataset_id], "a": "b"}
|
||||
res = retrieval_chunks(get_http_api_auth, payload)
|
||||
assert res["code"] == 0
|
||||
assert len(res["data"]["chunks"]) == 4
|
Loading…
x
Reference in New Issue
Block a user