From 91f1814a87cd830630ef2fb276384a6fe27f2c7a Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Thu, 28 Nov 2024 18:56:10 +0800 Subject: [PATCH] Fix error response (#3719) ### What problem does this PR solve? ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --------- Co-authored-by: Jin Hai --- api/apps/chunk_app.py | 2 +- rag/llm/rerank_model.py | 2 ++ sdk/python/test/test_frontend_api/test_dataset.py | 4 ++++ 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index 0863df133..1d21bc103 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -96,7 +96,7 @@ def get(): kb_ids = KnowledgebaseService.get_kb_ids(tenant_id) chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids) if chunk is None: - return server_error_response("Chunk not found") + return server_error_response(Exception("Chunk not found")) k = [] for n in chunk.keys(): if re.search(r"(_vec$|_sm_|_tks|_ltks)", n): diff --git a/rag/llm/rerank_model.py b/rag/llm/rerank_model.py index e2eb3a93d..28420daab 100644 --- a/rag/llm/rerank_model.py +++ b/rag/llm/rerank_model.py @@ -158,6 +158,8 @@ class XInferenceRerank(Base): def __init__(self, key="xxxxxxx", model_name="", base_url=""): if base_url.find("/v1") == -1: base_url = urljoin(base_url, "/v1/rerank") + if base_url.find("/rerank") == -1: + base_url = urljoin(base_url, "/v1/rerank") self.model_name = model_name self.base_url = base_url self.headers = { diff --git a/sdk/python/test/test_frontend_api/test_dataset.py b/sdk/python/test/test_frontend_api/test_dataset.py index f9421d2be..d4e69c7aa 100644 --- a/sdk/python/test/test_frontend_api/test_dataset.py +++ b/sdk/python/test/test_frontend_api/test_dataset.py @@ -4,6 +4,7 @@ import pytest import random import string + def test_dataset(get_auth): # create dataset res = create_dataset(get_auth, "test_create_dataset") @@ -58,6 +59,7 @@ def test_dataset_1k_dataset(get_auth): assert res.get("code") == 0, f"{res.get('message')}" print(f"{len(dataset_list)} datasets are deleted") + def test_duplicated_name_dataset(get_auth): # create dataset for i in range(20): @@ -81,6 +83,7 @@ def test_duplicated_name_dataset(get_auth): assert res.get("code") == 0, f"{res.get('message')}" print(f"{len(dataset_list)} datasets are deleted") + def test_invalid_name_dataset(get_auth): # create dataset # with pytest.raises(Exception) as e: @@ -99,6 +102,7 @@ def test_invalid_name_dataset(get_auth): assert res['code'] == 102 print(res) + def test_update_different_params_dataset(get_auth): # create dataset res = create_dataset(get_auth, "test_create_dataset")