From 77aef9ff1db220dfbab5da6e48061051b98b7cbd Mon Sep 17 00:00:00 2001 From: zhuhao <37029601+hwzhuhao@users.noreply.github.com> Date: Mon, 30 Sep 2024 12:55:01 +0800 Subject: [PATCH] refactor: optimize the calculation of rerank threshold and the logic for forbidden characters in model_uid (#8879) --- .../model_providers/xinference/llm/llm.py | 3 ++- .../model_providers/xinference/rerank/rerank.py | 8 +++----- .../xinference/speech2text/speech2text.py | 3 ++- .../xinference/text_embedding/text_embedding.py | 4 ++-- .../model_providers/xinference/tts/tts.py | 4 ++-- .../model_providers/xinference/xinference_helper.py | 13 +++++++++++++ 6 files changed, 24 insertions(+), 11 deletions(-) diff --git a/api/core/model_runtime/model_providers/xinference/llm/llm.py b/api/core/model_runtime/model_providers/xinference/llm/llm.py index 286640079b..0c9d08679a 100644 --- a/api/core/model_runtime/model_providers/xinference/llm/llm.py +++ b/api/core/model_runtime/model_providers/xinference/llm/llm.py @@ -59,6 +59,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large from core.model_runtime.model_providers.xinference.xinference_helper import ( XinferenceHelper, XinferenceModelExtraParameter, + validate_model_uid, ) from core.model_runtime.utils import helper @@ -114,7 +115,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): } """ try: - if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]: + if not validate_model_uid(credentials): raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") extra_param = XinferenceHelper.get_xinference_extra_parameter( diff --git a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py index 8f18bc42d2..6368cd76dc 100644 --- a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py @@ -15,6 +15,7 @@ from core.model_runtime.errors.invoke import ( ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.rerank_model import RerankModel +from core.model_runtime.model_providers.xinference.xinference_helper import validate_model_uid class XinferenceRerankModel(RerankModel): @@ -77,10 +78,7 @@ class XinferenceRerankModel(RerankModel): ) # score threshold check - if score_threshold is not None: - if result["relevance_score"] >= score_threshold: - rerank_documents.append(rerank_document) - else: + if score_threshold is None or result["relevance_score"] >= score_threshold: rerank_documents.append(rerank_document) return RerankResult(model=model, docs=rerank_documents) @@ -94,7 +92,7 @@ class XinferenceRerankModel(RerankModel): :return: """ try: - if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]: + if not validate_model_uid(credentials): raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") credentials["server_url"] = credentials["server_url"].removesuffix("/") diff --git a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py index a6c5b8a0a5..c5ad383911 100644 --- a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py @@ -14,6 +14,7 @@ from core.model_runtime.errors.invoke import ( ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel +from core.model_runtime.model_providers.xinference.xinference_helper import validate_model_uid class XinferenceSpeech2TextModel(Speech2TextModel): @@ -42,7 +43,7 @@ class XinferenceSpeech2TextModel(Speech2TextModel): :return: """ try: - if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]: + if not validate_model_uid(credentials): raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") credentials["server_url"] = credentials["server_url"].removesuffix("/") diff --git a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py index 1627239132..ddc21b365c 100644 --- a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py @@ -17,7 +17,7 @@ from core.model_runtime.errors.invoke import ( ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper +from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper, validate_model_uid class XinferenceTextEmbeddingModel(TextEmbeddingModel): @@ -110,7 +110,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): :return: """ try: - if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]: + if not validate_model_uid(credentials): raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") server_url = credentials["server_url"] diff --git a/api/core/model_runtime/model_providers/xinference/tts/tts.py b/api/core/model_runtime/model_providers/xinference/tts/tts.py index 81dbe397d2..6290e8551d 100644 --- a/api/core/model_runtime/model_providers/xinference/tts/tts.py +++ b/api/core/model_runtime/model_providers/xinference/tts/tts.py @@ -15,7 +15,7 @@ from core.model_runtime.errors.invoke import ( ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.tts_model import TTSModel -from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper +from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper, validate_model_uid class XinferenceText2SpeechModel(TTSModel): @@ -70,7 +70,7 @@ class XinferenceText2SpeechModel(TTSModel): :return: """ try: - if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]: + if not validate_model_uid(credentials): raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") credentials["server_url"] = credentials["server_url"].removesuffix("/") diff --git a/api/core/model_runtime/model_providers/xinference/xinference_helper.py b/api/core/model_runtime/model_providers/xinference/xinference_helper.py index 619ee1492a..baa3ccbe8a 100644 --- a/api/core/model_runtime/model_providers/xinference/xinference_helper.py +++ b/api/core/model_runtime/model_providers/xinference/xinference_helper.py @@ -132,3 +132,16 @@ class XinferenceHelper: context_length=context_length, model_family=model_family, ) + + +def validate_model_uid(credentials: dict) -> bool: + """ + Validate the model_uid within the credentials dictionary to ensure it does not + contain forbidden characters ("/", "?", "#"). + + param credentials: model credentials + :return: True if the model_uid does not contain forbidden characters ("/", "?", "#"), else False. + """ + forbidden_characters = ["/", "?", "#"] + model_uid = credentials.get("model_uid", "") + return not any(char in forbidden_characters for char in model_uid)