Test: Added test cases for Create Chat Assistant HTTP API (#6763)

### What problem does this PR solve?

cover [create chat
assistant](https://ragflow.io/docs/v0.17.2/http_api_reference#create-chat-assistant)
endpoints

### Type of change

- [x] add test cases
This commit is contained in:
liu an 2025-04-02 18:49:59 +08:00 committed by GitHub
parent 6c77ef5a5e
commit 0d1c5fdd2f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 284 additions and 19 deletions

View File

@ -32,6 +32,7 @@ CHAT_ASSISTANT_API_URL = "/api/v1/chats"
INVALID_API_TOKEN = "invalid_key_123" INVALID_API_TOKEN = "invalid_key_123"
DATASET_NAME_LIMIT = 128 DATASET_NAME_LIMIT = 128
DOCUMENT_NAME_LIMIT = 128 DOCUMENT_NAME_LIMIT = 128
CHAT_ASSISTANT_LIMIT = 255
# DATASET MANAGEMENT # DATASET MANAGEMENT

View File

@ -15,7 +15,8 @@
# #
import pytest import pytest
from common import batch_create_datasets, bulk_upload_documents, delete_datasets from common import add_chunk, batch_create_datasets, bulk_upload_documents, delete_chat_assistants, delete_datasets, list_documnets, parse_documnets
from libs.utils import wait_for
from libs.utils.file_utils import ( from libs.utils.file_utils import (
create_docx_file, create_docx_file,
create_eml_file, create_eml_file,
@ -30,12 +31,27 @@ from libs.utils.file_utils import (
) )
@wait_for(30, 1, "Document parsing timeout")
def condition(_auth, _dataset_id):
res = list_documnets(_auth, _dataset_id)
for doc in res["data"]["docs"]:
if doc["run"] != "DONE":
return False
return True
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def clear_datasets(get_http_api_auth): def clear_datasets(get_http_api_auth):
yield yield
delete_datasets(get_http_api_auth) delete_datasets(get_http_api_auth)
@pytest.fixture(scope="function")
def clear_chat_assistants(get_http_api_auth):
yield
delete_chat_assistants(get_http_api_auth)
@pytest.fixture @pytest.fixture
def generate_test_files(request, tmp_path): def generate_test_files(request, tmp_path):
file_creators = { file_creators = {
@ -92,3 +108,21 @@ def add_document(get_http_api_auth, add_dataset, ragflow_tmp_dir):
dataset_id = add_dataset dataset_id = add_dataset
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 1, ragflow_tmp_dir) document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 1, ragflow_tmp_dir)
return dataset_id, document_ids[0] return dataset_id, document_ids[0]
@pytest.fixture(scope="class")
def add_chunks(get_http_api_auth, add_document):
dataset_id, document_id = add_document
parse_documnets(get_http_api_auth, dataset_id, {"document_ids": [document_id]})
condition(get_http_api_auth, dataset_id)
chunk_ids = []
for i in range(4):
res = add_chunk(get_http_api_auth, dataset_id, document_id, {"content": f"chunk test {i}"})
chunk_ids.append(res["data"]["chunk"]["id"])
# issues/6487
from time import sleep
sleep(1)
return dataset_id, document_id, chunk_ids

View File

@ -0,0 +1,248 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pytest
from common import CHAT_ASSISTANT_LIMIT, INVALID_API_TOKEN, create_chat_assistant
from libs.auth import RAGFlowHttpApiAuth
from libs.utils import encode_avatar
from libs.utils.file_utils import create_image_file
@pytest.mark.usefixtures("clear_chat_assistants")
class TestAuthorization:
@pytest.mark.parametrize(
"auth, expected_code, expected_message",
[
(None, 0, "`Authorization` can't be empty"),
(
RAGFlowHttpApiAuth(INVALID_API_TOKEN),
109,
"Authentication error: API key is invalid!",
),
],
)
def test_invalid_auth(self, auth, expected_code, expected_message):
res = create_chat_assistant(auth)
assert res["code"] == expected_code
assert res["message"] == expected_message
@pytest.mark.usefixtures("clear_chat_assistants")
class TestChatAssistantCreate:
@pytest.mark.parametrize(
"payload, expected_code, expected_message",
[
({"name": "valid_name"}, 0, ""),
pytest.param({"name": "a" * (CHAT_ASSISTANT_LIMIT + 1)}, 102, "", marks=pytest.mark.skip(reason="issues/")),
pytest.param({"name": 1}, 100, "", marks=pytest.mark.skip(reason="issues/")),
({"name": ""}, 102, "`name` is required."),
({"name": "duplicated_name"}, 102, "Duplicated chat name in creating chat."),
({"name": "case insensitive"}, 102, "Duplicated chat name in creating chat."),
],
)
def test_name(self, get_http_api_auth, add_chunks, payload, expected_code, expected_message):
payload["dataset_ids"] = [] # issues/
if payload["name"] == "duplicated_name":
create_chat_assistant(get_http_api_auth, payload)
elif payload["name"] == "case insensitive":
create_chat_assistant(get_http_api_auth, {"name": payload["name"].upper()})
res = create_chat_assistant(get_http_api_auth, payload)
assert res["code"] == expected_code, res
if expected_code == 0:
assert res["data"]["name"] == payload["name"]
else:
assert res["message"] == expected_message
@pytest.mark.parametrize(
"dataset_ids, expected_code, expected_message",
[
([], 0, ""),
(lambda r: [r], 0, ""),
(["invalid_dataset_id"], 102, "You don't own the dataset invalid_dataset_id"),
("invalid_dataset_id", 102, "You don't own the dataset i"),
],
)
def test_dataset_ids(self, get_http_api_auth, add_chunks, dataset_ids, expected_code, expected_message):
dataset_id, _, _ = add_chunks
payload = {"name": "ragflow test"}
if callable(dataset_ids):
payload["dataset_ids"] = dataset_ids(dataset_id)
else:
payload["dataset_ids"] = dataset_ids
res = create_chat_assistant(get_http_api_auth, payload)
assert res["code"] == expected_code, res
if expected_code == 0:
assert res["data"]["name"] == payload["name"]
else:
assert res["message"] == expected_message
def test_avatar(self, get_http_api_auth, tmp_path):
fn = create_image_file(tmp_path / "ragflow_test.png")
payload = {"name": "avatar_test", "avatar": encode_avatar(fn), "dataset_ids": []}
res = create_chat_assistant(get_http_api_auth, payload)
assert res["code"] == 0
@pytest.mark.parametrize(
"llm, expected_code, expected_message",
[
({}, 0, ""),
({"model_name": "glm-4"}, 0, ""),
({"model_name": "unknown"}, 102, "`model_name` unknown doesn't exist"),
({"temperature": 0}, 0, ""),
({"temperature": 1}, 0, ""),
pytest.param({"temperature": -1}, 0, "", marks=pytest.mark.skip),
pytest.param({"temperature": 10}, 0, "", marks=pytest.mark.skip),
pytest.param({"temperature": "a"}, 0, "", marks=pytest.mark.skip),
({"top_p": 0}, 0, ""),
({"top_p": 1}, 0, ""),
pytest.param({"top_p": -1}, 0, "", marks=pytest.mark.skip),
pytest.param({"top_p": 10}, 0, "", marks=pytest.mark.skip),
pytest.param({"top_p": "a"}, 0, "", marks=pytest.mark.skip),
({"presence_penalty": 0}, 0, ""),
({"presence_penalty": 1}, 0, ""),
pytest.param({"presence_penalty": -1}, 0, "", marks=pytest.mark.skip),
pytest.param({"presence_penalty": 10}, 0, "", marks=pytest.mark.skip),
pytest.param({"presence_penalty": "a"}, 0, "", marks=pytest.mark.skip),
({"frequency_penalty": 0}, 0, ""),
({"frequency_penalty": 1}, 0, ""),
pytest.param({"frequency_penalty": -1}, 0, "", marks=pytest.mark.skip),
pytest.param({"frequency_penalty": 10}, 0, "", marks=pytest.mark.skip),
pytest.param({"frequency_penalty": "a"}, 0, "", marks=pytest.mark.skip),
({"max_token": 0}, 0, ""),
({"max_token": 1024}, 0, ""),
pytest.param({"max_token": -1}, 0, "", marks=pytest.mark.skip),
pytest.param({"max_token": 10}, 0, "", marks=pytest.mark.skip),
pytest.param({"max_token": "a"}, 0, "", marks=pytest.mark.skip),
pytest.param({"unknown": "unknown"}, 0, "", marks=pytest.mark.skip),
],
)
def test_llm(self, get_http_api_auth, add_chunks, llm, expected_code, expected_message):
dataset_id, _, _ = add_chunks
payload = {"name": "llm_test", "dataset_ids": [dataset_id], "llm": llm}
res = create_chat_assistant(get_http_api_auth, payload)
assert res["code"] == expected_code
if expected_code == 0:
if llm:
for k, v in llm.items():
assert res["data"]["llm"][k] == v
else:
assert res["data"]["llm"]["model_name"] == "glm-4-flash@ZHIPU-AI"
assert res["data"]["llm"]["temperature"] == 0.1
assert res["data"]["llm"]["top_p"] == 0.3
assert res["data"]["llm"]["presence_penalty"] == 0.4
assert res["data"]["llm"]["frequency_penalty"] == 0.7
assert res["data"]["llm"]["max_tokens"] == 512
else:
assert res["message"] == expected_message
@pytest.mark.parametrize(
"prompt, expected_code, expected_message",
[
({}, 0, ""),
({"similarity_threshold": 0}, 0, ""),
({"similarity_threshold": 1}, 0, ""),
pytest.param({"similarity_threshold": -1}, 0, "", marks=pytest.mark.skip),
pytest.param({"similarity_threshold": 10}, 0, "", marks=pytest.mark.skip),
pytest.param({"similarity_threshold": "a"}, 0, "", marks=pytest.mark.skip),
({"keywords_similarity_weight": 0}, 0, ""),
({"keywords_similarity_weight": 1}, 0, ""),
pytest.param({"keywords_similarity_weight": -1}, 0, "", marks=pytest.mark.skip),
pytest.param({"keywords_similarity_weight": 10}, 0, "", marks=pytest.mark.skip),
pytest.param({"keywords_similarity_weight": "a"}, 0, "", marks=pytest.mark.skip),
({"variables": []}, 0, ""),
({"top_n": 0}, 0, ""),
({"top_n": 1}, 0, ""),
pytest.param({"top_n": -1}, 0, "", marks=pytest.mark.skip),
pytest.param({"top_n": 10}, 0, "", marks=pytest.mark.skip),
pytest.param({"top_n": "a"}, 0, "", marks=pytest.mark.skip),
({"empty_response": "Hello World"}, 0, ""),
({"empty_response": ""}, 0, ""),
({"empty_response": "!@#$%^&*()"}, 0, ""),
({"empty_response": "中文测试"}, 0, ""),
pytest.param({"empty_response": 123}, 0, "", marks=pytest.mark.skip),
pytest.param({"empty_response": True}, 0, "", marks=pytest.mark.skip),
pytest.param({"empty_response": " "}, 0, "", marks=pytest.mark.skip),
({"opener": "Hello World"}, 0, ""),
({"opener": ""}, 0, ""),
({"opener": "!@#$%^&*()"}, 0, ""),
({"opener": "中文测试"}, 0, ""),
pytest.param({"opener": 123}, 0, "", marks=pytest.mark.skip),
pytest.param({"opener": True}, 0, "", marks=pytest.mark.skip),
pytest.param({"opener": " "}, 0, "", marks=pytest.mark.skip),
({"show_quote": True}, 0, ""),
({"show_quote": False}, 0, ""),
({"prompt": "Hello World {knowledge}"}, 0, ""),
({"prompt": "{knowledge}"}, 0, ""),
({"prompt": "!@#$%^&*() {knowledge}"}, 0, ""),
({"prompt": "中文测试 {knowledge}"}, 0, ""),
({"prompt": "Hello World"}, 102, "Parameter 'knowledge' is not used"),
({"prompt": "Hello World", "variables": []}, 0, ""),
pytest.param({"prompt": 123}, 100, """AttributeError("\'int\' object has no attribute \'find\'")""", marks=pytest.mark.skip),
pytest.param({"prompt": True}, 100, """AttributeError("\'int\' object has no attribute \'find\'")""", marks=pytest.mark.skip),
pytest.param({"unknown": "unknown"}, 0, "", marks=pytest.mark.skip),
],
)
def test_prompt(self, get_http_api_auth, add_chunks, prompt, expected_code, expected_message):
dataset_id, _, _ = add_chunks
payload = {"name": "prompt_test", "dataset_ids": [dataset_id], "prompt": prompt}
res = create_chat_assistant(get_http_api_auth, payload)
assert res["code"] == expected_code
if expected_code == 0:
if prompt:
for k, v in prompt.items():
if k == "keywords_similarity_weight":
assert res["data"]["prompt"][k] == 1 - v
else:
assert res["data"]["prompt"][k] == v
else:
assert res["data"]["prompt"]["similarity_threshold"] == 0.2
assert res["data"]["prompt"]["keywords_similarity_weight"] == 0.7
assert res["data"]["prompt"]["top_n"] == 6
assert res["data"]["prompt"]["variables"] == [{"key": "knowledge", "optional": False}]
assert res["data"]["prompt"]["rerank_model"] == ""
assert res["data"]["prompt"]["empty_response"] == "Sorry! No relevant content was found in the knowledge base!"
assert res["data"]["prompt"]["opener"] == "Hi! I'm your assistant, what can I do for you?"
assert res["data"]["prompt"]["show_quote"] is True
assert (
res["data"]["prompt"]["prompt"]
== 'You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, your answer must include the sentence "The answer you are looking for is not found in the knowledge base!" Answers need to consider chat history.\n Here is the knowledge base:\n {knowledge}\n The above is the knowledge base.'
)
else:
assert res["message"] == expected_message
@pytest.mark.parametrize(
"dataset_id, expected_code, expected_message",
[
("invalid_dataset_id", 102, "You don't own the dataset invalid_dataset_id"),
],
)
def test_invalid_dataset_id(self, get_http_api_auth, dataset_id, expected_code, expected_message):
payload = {"name": "prompt_test", "dataset_ids": [dataset_id]}
res = create_chat_assistant(get_http_api_auth, payload)
assert res["code"] == expected_code
assert expected_message in res["message"]
@pytest.mark.usefixtures("clear_chat_assistants")
class TestChatAssistantCreate2:
def test_unparsed_document(self, get_http_api_auth, add_document):
dataset_id, _ = add_document
payload = {"name": "prompt_test", "dataset_ids": [dataset_id]}
res = create_chat_assistant(get_http_api_auth, payload)
assert res["code"] == 102
assert "doesn't own parsed file" in res["message"]

View File

@ -29,24 +29,6 @@ def condition(_auth, _dataset_id):
return True return True
@pytest.fixture(scope="class")
def add_chunks(get_http_api_auth, add_document):
dataset_id, document_id = add_document
parse_documnets(get_http_api_auth, dataset_id, {"document_ids": [document_id]})
condition(get_http_api_auth, dataset_id)
chunk_ids = []
for i in range(4):
res = add_chunk(get_http_api_auth, dataset_id, document_id, {"content": f"chunk test {i}"})
chunk_ids.append(res["data"]["chunk"]["id"])
# issues/6487
from time import sleep
sleep(1)
return dataset_id, document_id, chunk_ids
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def add_chunks_func(request, get_http_api_auth, add_document): def add_chunks_func(request, get_http_api_auth, add_document):
dataset_id, document_id = add_document dataset_id, document_id = add_document