refactor wenxin rerank (#9486)

Co-authored-by: cuihz <cuihz@knowbox.cn>
This commit is contained in:
chzphoenix 2024-10-21 09:03:25 +08:00 committed by GitHub
parent 444dc01931
commit 42fe208eda
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,20 +2,15 @@ from typing import Optional
import httpx import httpx
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.invoke import ( from core.model_runtime.errors.invoke import InvokeError
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.rerank_model import RerankModel from core.model_runtime.model_providers.__base.rerank_model import RerankModel
from core.model_runtime.model_providers.wenxin._common import _CommonWenxin from core.model_runtime.model_providers.wenxin._common import _CommonWenxin
from core.model_runtime.model_providers.wenxin.wenxin_errors import (
InternalServerError,
invoke_error_mapping,
)
class WenxinRerank(_CommonWenxin): class WenxinRerank(_CommonWenxin):
@ -32,7 +27,7 @@ class WenxinRerank(_CommonWenxin):
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
raise InvokeServerUnavailableError(str(e)) raise InternalServerError(str(e))
class WenxinRerankModel(RerankModel): class WenxinRerankModel(RerankModel):
@ -93,7 +88,7 @@ class WenxinRerankModel(RerankModel):
return RerankResult(model=model, docs=rerank_documents) return RerankResult(model=model, docs=rerank_documents)
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
raise InvokeServerUnavailableError(str(e)) raise InternalServerError(str(e))
def validate_credentials(self, model: str, credentials: dict) -> None: def validate_credentials(self, model: str, credentials: dict) -> None:
""" """
@ -124,24 +119,4 @@ class WenxinRerankModel(RerankModel):
""" """
Map model invoke error to unified error Map model invoke error to unified error
""" """
return { return invoke_error_mapping()
InvokeConnectionError: [httpx.ConnectError],
InvokeServerUnavailableError: [httpx.RemoteProtocolError],
InvokeRateLimitError: [],
InvokeAuthorizationError: [httpx.HTTPStatusError],
InvokeBadRequestError: [httpx.RequestError],
}
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
"""
generate custom model entities from credentials
"""
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
model_type=ModelType.RERANK,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))},
)
return entity