refactor: optimize the calculation of rerank threshold and the logic for forbidden characters in model_uid (#8879)

This commit is contained in:
zhuhao 2024-09-30 12:55:01 +08:00 committed by GitHub
parent 503561f464
commit 77aef9ff1d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 24 additions and 11 deletions

View File

@ -59,6 +59,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
from core.model_runtime.model_providers.xinference.xinference_helper import ( from core.model_runtime.model_providers.xinference.xinference_helper import (
XinferenceHelper, XinferenceHelper,
XinferenceModelExtraParameter, XinferenceModelExtraParameter,
validate_model_uid,
) )
from core.model_runtime.utils import helper from core.model_runtime.utils import helper
@ -114,7 +115,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
} }
""" """
try: try:
if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]: if not validate_model_uid(credentials):
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
extra_param = XinferenceHelper.get_xinference_extra_parameter( extra_param = XinferenceHelper.get_xinference_extra_parameter(

View File

@ -15,6 +15,7 @@ from core.model_runtime.errors.invoke import (
) )
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.rerank_model import RerankModel from core.model_runtime.model_providers.__base.rerank_model import RerankModel
from core.model_runtime.model_providers.xinference.xinference_helper import validate_model_uid
class XinferenceRerankModel(RerankModel): class XinferenceRerankModel(RerankModel):
@ -77,10 +78,7 @@ class XinferenceRerankModel(RerankModel):
) )
# score threshold check # score threshold check
if score_threshold is not None: if score_threshold is None or result["relevance_score"] >= score_threshold:
if result["relevance_score"] >= score_threshold:
rerank_documents.append(rerank_document)
else:
rerank_documents.append(rerank_document) rerank_documents.append(rerank_document)
return RerankResult(model=model, docs=rerank_documents) return RerankResult(model=model, docs=rerank_documents)
@ -94,7 +92,7 @@ class XinferenceRerankModel(RerankModel):
:return: :return:
""" """
try: try:
if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]: if not validate_model_uid(credentials):
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
credentials["server_url"] = credentials["server_url"].removesuffix("/") credentials["server_url"] = credentials["server_url"].removesuffix("/")

View File

@ -14,6 +14,7 @@ from core.model_runtime.errors.invoke import (
) )
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
from core.model_runtime.model_providers.xinference.xinference_helper import validate_model_uid
class XinferenceSpeech2TextModel(Speech2TextModel): class XinferenceSpeech2TextModel(Speech2TextModel):
@ -42,7 +43,7 @@ class XinferenceSpeech2TextModel(Speech2TextModel):
:return: :return:
""" """
try: try:
if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]: if not validate_model_uid(credentials):
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
credentials["server_url"] = credentials["server_url"].removesuffix("/") credentials["server_url"] = credentials["server_url"].removesuffix("/")

View File

@ -17,7 +17,7 @@ from core.model_runtime.errors.invoke import (
) )
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper, validate_model_uid
class XinferenceTextEmbeddingModel(TextEmbeddingModel): class XinferenceTextEmbeddingModel(TextEmbeddingModel):
@ -110,7 +110,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
:return: :return:
""" """
try: try:
if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]: if not validate_model_uid(credentials):
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
server_url = credentials["server_url"] server_url = credentials["server_url"]

View File

@ -15,7 +15,7 @@ from core.model_runtime.errors.invoke import (
) )
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.tts_model import TTSModel from core.model_runtime.model_providers.__base.tts_model import TTSModel
from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper, validate_model_uid
class XinferenceText2SpeechModel(TTSModel): class XinferenceText2SpeechModel(TTSModel):
@ -70,7 +70,7 @@ class XinferenceText2SpeechModel(TTSModel):
:return: :return:
""" """
try: try:
if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]: if not validate_model_uid(credentials):
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
credentials["server_url"] = credentials["server_url"].removesuffix("/") credentials["server_url"] = credentials["server_url"].removesuffix("/")

View File

@ -132,3 +132,16 @@ class XinferenceHelper:
context_length=context_length, context_length=context_length,
model_family=model_family, model_family=model_family,
) )
def validate_model_uid(credentials: dict) -> bool:
"""
Validate the model_uid within the credentials dictionary to ensure it does not
contain forbidden characters ("/", "?", "#").
param credentials: model credentials
:return: True if the model_uid does not contain forbidden characters ("/", "?", "#"), else False.
"""
forbidden_characters = ["/", "?", "#"]
model_uid = credentials.get("model_uid", "")
return not any(char in forbidden_characters for char in model_uid)