From 660fc3bb347b4b4b865ca2bdd449bb317cf84c24 Mon Sep 17 00:00:00 2001 From: Ziyu Huang Date: Sun, 20 Oct 2024 21:59:58 +0800 Subject: [PATCH] Resolve 9508 openai compatible rerank (#9511) --- .../openai_api_compatible.yaml | 14 ++ .../openai_api_compatible/rerank/__init__.py | 0 .../openai_api_compatible/rerank/rerank.py | 159 ++++++++++++++++++ 3 files changed, 173 insertions(+) create mode 100644 api/core/model_runtime/model_providers/openai_api_compatible/rerank/__init__.py create mode 100644 api/core/model_runtime/model_providers/openai_api_compatible/rerank/rerank.py diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml b/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml index 88c76fe16e..69a081f35c 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml +++ b/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml @@ -8,6 +8,7 @@ supported_model_types: - llm - text-embedding - speech2text + - rerank configurate_methods: - customizable-model model_credential_schema: @@ -83,6 +84,19 @@ model_credential_schema: placeholder: zh_Hans: 在此输入您的模型上下文长度 en_US: Enter your Model context size + - variable: context_size + label: + zh_Hans: 模型上下文长度 + en_US: Model context size + required: true + show_on: + - variable: __model_type + value: rerank + type: text-input + default: '4096' + placeholder: + zh_Hans: 在此输入您的模型上下文长度 + en_US: Enter your Model context size - variable: max_tokens_to_sample label: zh_Hans: 最大 token 上限 diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/rerank/__init__.py b/api/core/model_runtime/model_providers/openai_api_compatible/rerank/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/rerank/rerank.py b/api/core/model_runtime/model_providers/openai_api_compatible/rerank/rerank.py new file mode 100644 index 0000000000..508da4bf20 --- /dev/null +++ b/api/core/model_runtime/model_providers/openai_api_compatible/rerank/rerank.py @@ -0,0 +1,159 @@ +from json import dumps +from typing import Optional + +import httpx +from requests import post +from yarl import URL + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.rerank_model import RerankModel + + +class OAICompatRerankModel(RerankModel): + """ + rerank model API is compatible with Jina rerank model API. So copy the JinaRerankModel class code here. + we need enhance for llama.cpp , which return raw score, not normalize score 0~1. It seems Dify need it + """ + + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: + """ + Invoke rerank model + + :param model: model name + :param credentials: model credentials + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n documents to return + :param user: unique user id + :return: rerank result + """ + if len(docs) == 0: + return RerankResult(model=model, docs=[]) + + server_url = credentials["endpoint_url"] + model_name = model + + if not server_url: + raise CredentialsValidateFailedError("server_url is required") + if not model_name: + raise CredentialsValidateFailedError("model_name is required") + + url = server_url + headers = {"Authorization": f"Bearer {credentials.get('api_key')}", "Content-Type": "application/json"} + + # TODO: Do we need truncate docs to avoid llama.cpp return error? + + data = {"model": model_name, "query": query, "documents": docs, "top_n": top_n} + + try: + response = post(str(URL(url) / "rerank"), headers=headers, data=dumps(data), timeout=60) + response.raise_for_status() + results = response.json() + + rerank_documents = [] + scores = [result["relevance_score"] for result in results["results"]] + + # Min-Max Normalization: Normalize scores to 0 ~ 1.0 range + min_score = min(scores) + max_score = max(scores) + score_range = max_score - min_score if max_score != min_score else 1.0 # Avoid division by zero + + for result in results["results"]: + index = result["index"] + + # Retrieve document text (fallback if llama.cpp rerank doesn't return it) + text = result.get("document", {}).get("text", docs[index]) + + # Normalize the score + normalized_score = (result["relevance_score"] - min_score) / score_range + + # Create RerankDocument object with normalized score + rerank_document = RerankDocument( + index=index, + text=text, + score=normalized_score, + ) + + # Apply threshold (if defined) + if score_threshold is None or normalized_score >= score_threshold: + rerank_documents.append(rerank_document) + + # Sort rerank_documents by normalized score in descending order + rerank_documents.sort(key=lambda doc: doc.score, reverse=True) + + return RerankResult(model=model, docs=rerank_documents) + + except httpx.HTTPStatusError as e: + raise InvokeServerUnavailableError(str(e)) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + self._invoke( + model=model, + credentials=credentials, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8, + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + """ + return { + InvokeConnectionError: [httpx.ConnectError], + InvokeServerUnavailableError: [httpx.RemoteProtocolError], + InvokeRateLimitError: [], + InvokeAuthorizationError: [httpx.HTTPStatusError], + InvokeBadRequestError: [httpx.RequestError], + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + """ + generate custom model entities from credentials + """ + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + model_type=ModelType.RERANK, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={}, + ) + + return entity