diff --git a/rembg/sessions/base.py b/rembg/sessions/base.py index 16f988e..8136a03 100644 --- a/rembg/sessions/base.py +++ b/rembg/sessions/base.py @@ -28,6 +28,11 @@ class BaseSession: else: self.providers.extend(_providers) + model_path = kwargs.get("model_path") + + if model_path is None: + raise ValueError("model_path is required") + self.inner_session = ort.InferenceSession( str(self.__class__.download_models(*args, **kwargs)), providers=self.providers, diff --git a/rembg/sessions/u2net_custom.py b/rembg/sessions/u2net_custom.py index 09894b2..7380295 100644 --- a/rembg/sessions/u2net_custom.py +++ b/rembg/sessions/u2net_custom.py @@ -34,9 +34,8 @@ class U2netCustomSession(BaseSession): @classmethod def download_models(cls, *args, **kwargs): model_path = kwargs.get("model_path") - if model_path is None: - raise ValueError("model_path is required") + return return os.path.abspath(os.path.expanduser(model_path))