From b527af6af5248cfdbcfd143bdb93f2609470c92b Mon Sep 17 00:00:00 2001 From: Daniel Gatis Date: Thu, 29 Jun 2023 21:09:22 -0300 Subject: [PATCH] fix docker --- rembg/sessions/base.py | 5 ----- rembg/sessions/u2net_custom.py | 20 ++++++++++++++++++++ 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/rembg/sessions/base.py b/rembg/sessions/base.py index 8136a03..16f988e 100644 --- a/rembg/sessions/base.py +++ b/rembg/sessions/base.py @@ -28,11 +28,6 @@ class BaseSession: else: self.providers.extend(_providers) - model_path = kwargs.get("model_path") - - if model_path is None: - raise ValueError("model_path is required") - self.inner_session = ort.InferenceSession( str(self.__class__.download_models(*args, **kwargs)), providers=self.providers, diff --git a/rembg/sessions/u2net_custom.py b/rembg/sessions/u2net_custom.py index 7380295..de6346f 100644 --- a/rembg/sessions/u2net_custom.py +++ b/rembg/sessions/u2net_custom.py @@ -10,6 +10,26 @@ from .base import BaseSession class U2netCustomSession(BaseSession): + def __init__( + self, + model_name: str, + sess_opts: ort.SessionOptions, + providers=None, + *args, + **kwargs + ): + model_path = kwargs.get("model_path") + if model_path is None: + raise ValueError("model_path is required") + + super().__init__( + model_name, + sess_opts, + providers, + *args, + **kwargs + ) + def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]: ort_outs = self.inner_session.run( None,