refactor: optimize the calculation of rerank threshold and the logic for forbidden characters in model_uid (#8879)

This commit is contained in:
zhuhao 2024-09-30 12:55:01 +08:00 committed by GitHub
parent 503561f464
commit 77aef9ff1d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 24 additions and 11 deletions

View File

@ -59,6 +59,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
from core.model_runtime.model_providers.xinference.xinference_helper import (
XinferenceHelper,
XinferenceModelExtraParameter,
validate_model_uid,
)
from core.model_runtime.utils import helper
@ -114,7 +115,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
}
"""
try:
if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]:
if not validate_model_uid(credentials):
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
extra_param = XinferenceHelper.get_xinference_extra_parameter(

View File

@ -15,6 +15,7 @@ from core.model_runtime.errors.invoke import (
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
from core.model_runtime.model_providers.xinference.xinference_helper import validate_model_uid
class XinferenceRerankModel(RerankModel):
@ -77,10 +78,7 @@ class XinferenceRerankModel(RerankModel):
)
# score threshold check
if score_threshold is not None:
if result["relevance_score"] >= score_threshold:
rerank_documents.append(rerank_document)
else:
if score_threshold is None or result["relevance_score"] >= score_threshold:
rerank_documents.append(rerank_document)
return RerankResult(model=model, docs=rerank_documents)
@ -94,7 +92,7 @@ class XinferenceRerankModel(RerankModel):
:return:
"""
try:
if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]:
if not validate_model_uid(credentials):
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
credentials["server_url"] = credentials["server_url"].removesuffix("/")

View File

@ -14,6 +14,7 @@ from core.model_runtime.errors.invoke import (
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
from core.model_runtime.model_providers.xinference.xinference_helper import validate_model_uid
class XinferenceSpeech2TextModel(Speech2TextModel):
@ -42,7 +43,7 @@ class XinferenceSpeech2TextModel(Speech2TextModel):
:return:
"""
try:
if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]:
if not validate_model_uid(credentials):
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
credentials["server_url"] = credentials["server_url"].removesuffix("/")

View File

@ -17,7 +17,7 @@ from core.model_runtime.errors.invoke import (
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper
from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper, validate_model_uid
class XinferenceTextEmbeddingModel(TextEmbeddingModel):
@ -110,7 +110,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
:return:
"""
try:
if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]:
if not validate_model_uid(credentials):
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
server_url = credentials["server_url"]

View File

@ -15,7 +15,7 @@ from core.model_runtime.errors.invoke import (
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.tts_model import TTSModel
from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper
from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper, validate_model_uid
class XinferenceText2SpeechModel(TTSModel):
@ -70,7 +70,7 @@ class XinferenceText2SpeechModel(TTSModel):
:return:
"""
try:
if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]:
if not validate_model_uid(credentials):
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
credentials["server_url"] = credentials["server_url"].removesuffix("/")

View File

@ -132,3 +132,16 @@ class XinferenceHelper:
context_length=context_length,
model_family=model_family,
)
def validate_model_uid(credentials: dict) -> bool:
"""
Validate the model_uid within the credentials dictionary to ensure it does not
contain forbidden characters ("/", "?", "#").
param credentials: model credentials
:return: True if the model_uid does not contain forbidden characters ("/", "?", "#"), else False.
"""
forbidden_characters = ["/", "?", "#"]
model_uid = credentials.get("model_uid", "")
return not any(char in forbidden_characters for char in model_uid)