mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 13:59:06 +08:00
refactor: optimize the calculation of rerank threshold and the logic for forbidden characters in model_uid (#8879)
This commit is contained in:
parent
503561f464
commit
77aef9ff1d
@ -59,6 +59,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
|
||||
from core.model_runtime.model_providers.xinference.xinference_helper import (
|
||||
XinferenceHelper,
|
||||
XinferenceModelExtraParameter,
|
||||
validate_model_uid,
|
||||
)
|
||||
from core.model_runtime.utils import helper
|
||||
|
||||
@ -114,7 +115,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
}
|
||||
"""
|
||||
try:
|
||||
if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]:
|
||||
if not validate_model_uid(credentials):
|
||||
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
|
||||
|
||||
extra_param = XinferenceHelper.get_xinference_extra_parameter(
|
||||
|
@ -15,6 +15,7 @@ from core.model_runtime.errors.invoke import (
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
from core.model_runtime.model_providers.xinference.xinference_helper import validate_model_uid
|
||||
|
||||
|
||||
class XinferenceRerankModel(RerankModel):
|
||||
@ -77,10 +78,7 @@ class XinferenceRerankModel(RerankModel):
|
||||
)
|
||||
|
||||
# score threshold check
|
||||
if score_threshold is not None:
|
||||
if result["relevance_score"] >= score_threshold:
|
||||
rerank_documents.append(rerank_document)
|
||||
else:
|
||||
if score_threshold is None or result["relevance_score"] >= score_threshold:
|
||||
rerank_documents.append(rerank_document)
|
||||
|
||||
return RerankResult(model=model, docs=rerank_documents)
|
||||
@ -94,7 +92,7 @@ class XinferenceRerankModel(RerankModel):
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]:
|
||||
if not validate_model_uid(credentials):
|
||||
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
|
||||
|
||||
credentials["server_url"] = credentials["server_url"].removesuffix("/")
|
||||
|
@ -14,6 +14,7 @@ from core.model_runtime.errors.invoke import (
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
||||
from core.model_runtime.model_providers.xinference.xinference_helper import validate_model_uid
|
||||
|
||||
|
||||
class XinferenceSpeech2TextModel(Speech2TextModel):
|
||||
@ -42,7 +43,7 @@ class XinferenceSpeech2TextModel(Speech2TextModel):
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]:
|
||||
if not validate_model_uid(credentials):
|
||||
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
|
||||
|
||||
credentials["server_url"] = credentials["server_url"].removesuffix("/")
|
||||
|
@ -17,7 +17,7 @@ from core.model_runtime.errors.invoke import (
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper
|
||||
from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper, validate_model_uid
|
||||
|
||||
|
||||
class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
||||
@ -110,7 +110,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]:
|
||||
if not validate_model_uid(credentials):
|
||||
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
|
||||
|
||||
server_url = credentials["server_url"]
|
||||
|
@ -15,7 +15,7 @@ from core.model_runtime.errors.invoke import (
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper
|
||||
from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper, validate_model_uid
|
||||
|
||||
|
||||
class XinferenceText2SpeechModel(TTSModel):
|
||||
@ -70,7 +70,7 @@ class XinferenceText2SpeechModel(TTSModel):
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]:
|
||||
if not validate_model_uid(credentials):
|
||||
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
|
||||
|
||||
credentials["server_url"] = credentials["server_url"].removesuffix("/")
|
||||
|
@ -132,3 +132,16 @@ class XinferenceHelper:
|
||||
context_length=context_length,
|
||||
model_family=model_family,
|
||||
)
|
||||
|
||||
|
||||
def validate_model_uid(credentials: dict) -> bool:
|
||||
"""
|
||||
Validate the model_uid within the credentials dictionary to ensure it does not
|
||||
contain forbidden characters ("/", "?", "#").
|
||||
|
||||
param credentials: model credentials
|
||||
:return: True if the model_uid does not contain forbidden characters ("/", "?", "#"), else False.
|
||||
"""
|
||||
forbidden_characters = ["/", "?", "#"]
|
||||
model_uid = credentials.get("model_uid", "")
|
||||
return not any(char in forbidden_characters for char in model_uid)
|
||||
|
Loading…
x
Reference in New Issue
Block a user