diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index bdbafc8ded..a044a948aa 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -10,8 +10,15 @@ from core.model_runtime.entities.model_entities import ( PriceInfo, PriceType, ) -from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError -from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.plugin.entities.plugin_daemon import PluginDaemonInnerError, PluginModelProviderEntity from core.plugin.manager.model import PluginModelManager @@ -31,7 +38,7 @@ class AIModel(BaseModel): model_config = ConfigDict(protected_namespaces=()) @property - def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + def _invoke_error_mapping(self) -> dict[type[Exception], list[type[Exception]]]: """ Map model invoke error to unified error The key is the error type thrown to the caller @@ -40,9 +47,17 @@ class AIModel(BaseModel): :return: Invoke error mapping """ - raise NotImplementedError + return { + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError], + PluginDaemonInnerError: [PluginDaemonInnerError], + ValueError: [ValueError], + } - def _transform_invoke_error(self, error: Exception) -> InvokeError: + def _transform_invoke_error(self, error: Exception) -> Exception: """ Transform invoke error to unified error @@ -52,13 +67,15 @@ class AIModel(BaseModel): for invoke_error, model_errors in self._invoke_error_mapping.items(): if isinstance(error, tuple(model_errors)): if invoke_error == InvokeAuthorizationError: - return invoke_error( + return InvokeAuthorizationError( description=( f"[{self.provider_name}] Incorrect model credentials provided, please check and try again." ) ) - - return invoke_error(description=f"[{self.provider_name}] {invoke_error.description}, {str(error)}") + elif isinstance(invoke_error, InvokeError): + return invoke_error(description=f"[{self.provider_name}] {invoke_error.description}, {str(error)}") + else: + return error return InvokeError(description=f"[{self.provider_name}] Error: {str(error)}") diff --git a/api/core/plugin/manager/base.py b/api/core/plugin/manager/base.py index b25282cde2..6654c05e2d 100644 --- a/api/core/plugin/manager/base.py +++ b/api/core/plugin/manager/base.py @@ -53,7 +53,7 @@ class BasePluginManager: ) except requests.exceptions.ConnectionError as e: logger.exception(f"Request to Plugin Daemon Service failed: {e}") - raise ValueError("Request to Plugin Daemon Service failed") + raise PluginDaemonInnerError(code=-500, message="Request to Plugin Daemon Service failed") return response @@ -157,8 +157,17 @@ class BasePluginManager: Make a stream request to the plugin daemon inner API and yield the response as a model. """ for line in self._stream_request(method, path, params, headers, data, files): - line_data = json.loads(line) - rep = PluginDaemonBasicResponse[type](**line_data) + line_data = None + try: + line_data = json.loads(line) + rep = PluginDaemonBasicResponse[type](**line_data) + except Exception as e: + # TODO modify this when line_data has code and message + if line_data and "error" in line_data: + raise ValueError(line_data["error"]) + else: + raise ValueError(line) + if rep.code != 0: if rep.code == -500: try: diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 57af05861c..277513f3f2 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -103,7 +103,7 @@ class RetrievalService: if exceptions: exception_message = ";\n".join(exceptions) - raise Exception(exception_message) + raise ValueError(exception_message) if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value: data_post_processor = DataPostProcessor(