From d7828b0369c1256b2d67d1386e8051eb8a295c76 Mon Sep 17 00:00:00 2001 From: Flippchen <91947480+Flippchen@users.noreply.github.com> Date: Thu, 20 Apr 2023 12:39:57 +0200 Subject: [PATCH] added input for remove function --- rembg/bg.py | 10 +++++++++- rembg/session_factory.py | 2 +- rembg/session_sam.py | 4 ++-- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/rembg/bg.py b/rembg/bg.py index a1f4215..8218b9f 100644 --- a/rembg/bg.py +++ b/rembg/bg.py @@ -123,6 +123,8 @@ def remove( only_mask: bool = False, post_process_mask: bool = False, bgcolor: Optional[Tuple[int, int, int, int]] = None, + input_point: Optional[np.ndarray] = None, + input_label: Optional[np.ndarray] = None, ) -> Union[bytes, PILImage, np.ndarray]: if isinstance(data, PILImage): return_type = ReturnType.PILLOW @@ -139,7 +141,13 @@ def remove( if session is None: session = new_session("u2net") - masks = session.predict(img) + if session.model_name == "sam": + if input_point is None or input_label is None: + raise ValueError("Input point and label are required for SAM model.") + masks = session.predict(img, input_point, input_label) + else: + masks = session.predict(img) + cutouts = [] for mask in masks: diff --git a/rembg/session_factory.py b/rembg/session_factory.py index 23d3f3f..9d021c4 100644 --- a/rembg/session_factory.py +++ b/rembg/session_factory.py @@ -64,7 +64,7 @@ def new_session(model_name: str = "u2net") -> BaseSession: md5 = "fc16ebd8b0c10d971d3513d564d01e29" url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx" session_class = DisSession - elif model_name == "SAM": + elif model_name == "sam": path = Path(u2net_home).expanduser() fname_encoder = f"{model_name}_encoder.onnx" diff --git a/rembg/session_sam.py b/rembg/session_sam.py index 6900f4b..712b4df 100644 --- a/rembg/session_sam.py +++ b/rembg/session_sam.py @@ -71,8 +71,8 @@ class SamSession(BaseSession): def predict( self, img: PILImage, - input_point=np.array([[500, 375]]), - input_label=np.array([1]), + input_point: np.ndarray, + input_label: np.ndarray, ) -> List[PILImage]: # Preprocess image image = resize_longes_side(img)