diff --git a/api/core/model_runtime/model_providers/tongyi/rerank/__init__.py b/api/core/model_runtime/model_providers/tongyi/rerank/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/tongyi/rerank/_position.yaml b/api/core/model_runtime/model_providers/tongyi/rerank/_position.yaml new file mode 100644 index 0000000000..439afda992 --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/rerank/_position.yaml @@ -0,0 +1 @@ +- gte-rerank diff --git a/api/core/model_runtime/model_providers/tongyi/rerank/gte-rerank.yaml b/api/core/model_runtime/model_providers/tongyi/rerank/gte-rerank.yaml new file mode 100644 index 0000000000..44d51b9b0d --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/rerank/gte-rerank.yaml @@ -0,0 +1,4 @@ +model: gte-rerank +model_type: rerank +model_properties: + context_size: 4000 diff --git a/api/core/model_runtime/model_providers/tongyi/rerank/rerank.py b/api/core/model_runtime/model_providers/tongyi/rerank/rerank.py new file mode 100644 index 0000000000..c9245bd82d --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/rerank/rerank.py @@ -0,0 +1,136 @@ +from typing import Optional + +import dashscope +from dashscope.common.error import ( + AuthenticationError, + InvalidParameter, + RequestFailure, + ServiceUnavailableError, + UnsupportedHTTPMethod, + UnsupportedModel, +) + +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.rerank_model import RerankModel + + +class GTERerankModel(RerankModel): + """ + Model class for GTE rerank model. + """ + + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: + """ + Invoke rerank model + + :param model: model name + :param credentials: model credentials + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n + :param user: unique user id + :return: rerank result + """ + if len(docs) == 0: + return RerankResult(model=model, docs=docs) + + # initialize client + dashscope.api_key = credentials["dashscope_api_key"] + + response = dashscope.TextReRank.call( + query=query, + documents=docs, + model=model, + top_n=top_n, + return_documents=True, + ) + + rerank_documents = [] + for _, result in enumerate(response.output.results): + # format document + rerank_document = RerankDocument( + index=result.index, + score=result.relevance_score, + text=result["document"]["text"], + ) + + # score threshold check + if score_threshold is not None: + if result.relevance_score >= score_threshold: + rerank_documents.append(rerank_document) + else: + rerank_documents.append(rerank_document) + + return RerankResult(model=model, docs=rerank_documents) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + self.invoke( + model=model, + credentials=credentials, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8, + ) + except Exception as ex: + print(ex) + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [ + RequestFailure, + ], + InvokeServerUnavailableError: [ + ServiceUnavailableError, + ], + InvokeRateLimitError: [], + InvokeAuthorizationError: [ + AuthenticationError, + ], + InvokeBadRequestError: [ + InvalidParameter, + UnsupportedModel, + UnsupportedHTTPMethod, + ], + } diff --git a/api/core/model_runtime/model_providers/tongyi/tongyi.yaml b/api/core/model_runtime/model_providers/tongyi/tongyi.yaml index 1a09c20fd9..6349c22714 100644 --- a/api/core/model_runtime/model_providers/tongyi/tongyi.yaml +++ b/api/core/model_runtime/model_providers/tongyi/tongyi.yaml @@ -18,6 +18,7 @@ supported_model_types: - llm - tts - text-embedding + - rerank configurate_methods: - predefined-model - customizable-model diff --git a/api/tests/integration_tests/model_runtime/tongyi/test_rerank.py b/api/tests/integration_tests/model_runtime/tongyi/test_rerank.py new file mode 100644 index 0000000000..2dcfb92c63 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/tongyi/test_rerank.py @@ -0,0 +1,40 @@ +import os + +import dashscope +import pytest + +from core.model_runtime.entities.rerank_entities import RerankResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.tongyi.rerank.rerank import GTERerankModel + + +def test_validate_credentials(): + model = GTERerankModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials(model="get-rank", credentials={"dashscope_api_key": "invalid_key"}) + + model.validate_credentials( + model="get-rank", credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")} + ) + + +def test_invoke_model(): + model = GTERerankModel() + + result = model.invoke( + model=dashscope.TextReRank.Models.gte_rerank, + credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")}, + query="什么是文本排序模型", + docs=[ + "文本排序模型广泛用于搜索引擎和推荐系统中,它们根据文本相关性对候选文本进行排序", + "量子计算是计算科学的一个前沿领域", + "预训练语言模型的发展给文本排序模型带来了新的进展", + ], + score_threshold=0.7, + ) + + assert isinstance(result, RerankResult) + assert len(result.docs) == 1 + assert result.docs[0].index == 0 + assert result.docs[0].score >= 0.7