From 939e3896efa5c95152539050383da18a69460c03 Mon Sep 17 00:00:00 2001 From: Flippchen <91947480+Flippchen@users.noreply.github.com> Date: Thu, 20 Apr 2023 11:16:34 +0200 Subject: [PATCH 1/9] updated onnxruntime to support new models --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index f38891a..fb41d7f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ filetype==1.2.0 pooch==1.6.0 imagehash==4.3.1 numpy==1.23.5 -onnxruntime==1.13.1 +onnxruntime==1.14.1 opencv-python-headless==4.6.0.66 pillow==9.3.0 pymatting==1.1.8 From 83976f416b7691eaeaa6eddace95c7a510e33c88 Mon Sep 17 00:00:00 2001 From: Flippchen <91947480+Flippchen@users.noreply.github.com> Date: Thu, 20 Apr 2023 11:18:50 +0200 Subject: [PATCH 2/9] added Segment Anything model class --- rembg/session_sam.py | 97 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 97 insertions(+) create mode 100644 rembg/session_sam.py diff --git a/rembg/session_sam.py b/rembg/session_sam.py new file mode 100644 index 0000000..2fb177c --- /dev/null +++ b/rembg/session_sam.py @@ -0,0 +1,97 @@ +from typing import List + +import numpy +import numpy as np +from PIL import Image +from PIL.Image import Image as PILImage +import onnxruntime as ort +from matplotlib import pyplot as plt + +from .session_base import BaseSession + + +def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int): + scale = long_side_length * 1.0 / max(oldh, oldw) + newh, neww = oldh * scale, oldw * scale + neww = int(neww + 0.5) + newh = int(newh + 0.5) + return (newh, neww) + + +def apply_coords(coords: np.ndarray, original_size, target_length) -> np.ndarray: + old_h, old_w = original_size + new_h, new_w = get_preprocess_shape( + original_size[0], original_size[1], target_length + ) + coords = coords.copy().astype(float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + +def resize_longes_side(img: PILImage, size=1024): + w, h = img.size + if h > w: + new_h, new_w = size, int(w * size / h) + else: + new_h, new_w = int(h * size / w), size + + return img.resize((new_w, new_h)) + + +def pad_to_square(img: numpy.ndarray, size=1024): + h, w = img.shape[:2] + padh = size - h + padw = size - w + img = np.pad(img, ((0, padh), (0, padw), (0, 0)), mode='constant') + img = img.astype(np.float32) + return img + + +class SamSession(BaseSession): + def __init__(self, model_name: str, encoder: ort.InferenceSession, decoder: ort.InferenceSession): + super().__init__(model_name, encoder) + self.decoder = decoder + + def normalize(self, img: numpy.ndarray, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), size=(1024, 1024)): + pixel_mean = np.array([123.675, 116.28, 103.53]).reshape(1, 1, -1) + pixel_std = np.array([58.395, 57.12, 57.375]).reshape(1, 1, -1) + x = (img - pixel_mean) / pixel_std + return x + + def predict(self, img: PILImage, input_point=np.array([[500, 375]]), input_label=np.array([1])) -> List[PILImage]: + # Preprocess image + image = resize_longes_side(img) + image = numpy.array(image) + image = self.normalize(image) + image = pad_to_square(image) + + # Transpose + image = image.transpose(2, 0, 1)[None, :, :, :] + # Run encoder (Image embedding) + encoded = self.inner_session.run(None, {"x": image}) + image_embedding = encoded[0] + + # Add a batch index, concatenate a padding point, and transform. + onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :] + onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32) + onnx_coord = apply_coords(onnx_coord, img.size[::1], 1024).astype(np.float32) + + # Create an empty mask input and an indicator for no mask. + onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32) + onnx_has_mask_input = np.zeros(1, dtype=np.float32) + + decoder_inputs = { + "image_embeddings": image_embedding, + "point_coords": onnx_coord, + "point_labels": onnx_label, + "mask_input": onnx_mask_input, + "has_mask_input": onnx_has_mask_input, + "orig_im_size": np.array(img.size[::-1], dtype=np.float32) + } + + masks, _, low_res_logits = self.decoder.run(None, decoder_inputs) + masks = masks > 0.0 + masks = [Image.fromarray((masks[i, 0] * 255).astype(np.uint8)) for i in range(masks.shape[0])] + + return masks From 106254c42d257cbd6e9db57c25505ab5880c4cd7 Mon Sep 17 00:00:00 2001 From: Flippchen <91947480+Flippchen@users.noreply.github.com> Date: Thu, 20 Apr 2023 11:21:45 +0200 Subject: [PATCH 3/9] edited session factory to support models with more than one model --- rembg/session_factory.py | 57 +++++++++++++++++++++++++++++++--------- 1 file changed, 44 insertions(+), 13 deletions(-) diff --git a/rembg/session_factory.py b/rembg/session_factory.py index bfb69cd..db2cb65 100644 --- a/rembg/session_factory.py +++ b/rembg/session_factory.py @@ -12,9 +12,29 @@ from .session_base import BaseSession from .session_cloth import ClothSession from .session_dis import DisSession from .session_simple import SimpleSession +from .session_sam import SamSession + + +def download_model(url: str, md5: str, fname: str, path: Path): + pooch.retrieve( + url, + f"md5:{md5}", + fname=fname, + path=path, + progressbar=True, + ) def new_session(model_name: str = "u2net") -> BaseSession: + # Define the model path + u2net_home = os.getenv( + "U2NET_HOME", os.path.join(os.getenv("XDG_DATA_HOME", "~"), ".u2net") + ) + + fname = f"{model_name}.onnx" + path = Path(u2net_home).expanduser() + full_path = Path(u2net_home).expanduser() / fname + session_class: Type[BaseSession] md5 = "60024c5c889badc19c04ad937298a77b" url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx" @@ -44,22 +64,33 @@ def new_session(model_name: str = "u2net") -> BaseSession: md5 = "fc16ebd8b0c10d971d3513d564d01e29" url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx" session_class = DisSession + elif model_name == "SAM": + path = Path(u2net_home).expanduser() - u2net_home = os.getenv( - "U2NET_HOME", os.path.join(os.getenv("XDG_DATA_HOME", "~"), ".u2net") - ) + fname_encoder = f"{model_name}_encoder.onnx" + encoder_md5 = "13d97c5c79ab13ef86d67cbde5f1b250" + encoder_url = "https://github.com/Flippchen/rembg/releases/download/test/vit_b-encoder-quant.onnx" - fname = f"{model_name}.onnx" - path = Path(u2net_home).expanduser() - full_path = Path(u2net_home).expanduser() / fname + fname_decoder = f"{model_name}_decoder.onnx" + decoder_md5 = "fa3d1c36a3187d3de1c8deebf33dd127" + decoder_url = "https://github.com/Flippchen/rembg/releases/download/test/vit_b-decoder-quant.onnx" - pooch.retrieve( - url, - f"md5:{md5}", - fname=fname, - path=Path(u2net_home).expanduser(), - progressbar=True, - ) + + download_model(encoder_url, encoder_md5, fname_encoder, path) + download_model(decoder_url, decoder_md5, fname_decoder, path) + + sess_opts = ort.SessionOptions() + + if "OMP_NUM_THREADS" in os.environ: + sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"]) + + return SamSession( + model_name, + ort.InferenceSession(str(path / fname_encoder), providers=ort.get_available_providers(), sess_options=sess_opts), + ort.InferenceSession(str(path / fname_decoder), providers=ort.get_available_providers(), sess_options=sess_opts) + ) + + download_model(url, md5, fname, path) sess_opts = ort.SessionOptions() From bb3c58f4113a28329de4173e677b9ab0ec52c0a3 Mon Sep 17 00:00:00 2001 From: Flippchen <91947480+Flippchen@users.noreply.github.com> Date: Thu, 20 Apr 2023 11:31:15 +0200 Subject: [PATCH 4/9] fix lint --- rembg/session_factory.py | 13 ++++++++++--- rembg/session_sam.py | 39 +++++++++++++++++++++++++++++++-------- 2 files changed, 41 insertions(+), 11 deletions(-) diff --git a/rembg/session_factory.py b/rembg/session_factory.py index db2cb65..7711dca 100644 --- a/rembg/session_factory.py +++ b/rembg/session_factory.py @@ -75,7 +75,6 @@ def new_session(model_name: str = "u2net") -> BaseSession: decoder_md5 = "fa3d1c36a3187d3de1c8deebf33dd127" decoder_url = "https://github.com/Flippchen/rembg/releases/download/test/vit_b-decoder-quant.onnx" - download_model(encoder_url, encoder_md5, fname_encoder, path) download_model(decoder_url, decoder_md5, fname_decoder, path) @@ -86,8 +85,16 @@ def new_session(model_name: str = "u2net") -> BaseSession: return SamSession( model_name, - ort.InferenceSession(str(path / fname_encoder), providers=ort.get_available_providers(), sess_options=sess_opts), - ort.InferenceSession(str(path / fname_decoder), providers=ort.get_available_providers(), sess_options=sess_opts) + ort.InferenceSession( + str(path / fname_encoder), + providers=ort.get_available_providers(), + sess_options=sess_opts + ), + ort.InferenceSession( + str(path / fname_decoder), + providers=ort.get_available_providers(), + sess_options=sess_opts + ), ) download_model(url, md5, fname, path) diff --git a/rembg/session_sam.py b/rembg/session_sam.py index 2fb177c..eb88114 100644 --- a/rembg/session_sam.py +++ b/rembg/session_sam.py @@ -43,23 +43,39 @@ def pad_to_square(img: numpy.ndarray, size=1024): h, w = img.shape[:2] padh = size - h padw = size - w - img = np.pad(img, ((0, padh), (0, padw), (0, 0)), mode='constant') + img = np.pad(img, ((0, padh), (0, padw), (0, 0)), mode="constant") img = img.astype(np.float32) return img class SamSession(BaseSession): - def __init__(self, model_name: str, encoder: ort.InferenceSession, decoder: ort.InferenceSession): + def __init__( + self, + model_name: str, + encoder: ort.InferenceSession, + decoder: ort.InferenceSession + ): super().__init__(model_name, encoder) self.decoder = decoder - def normalize(self, img: numpy.ndarray, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), size=(1024, 1024)): + def normalize( + self, + img: numpy.ndarray, + mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225), + size=(1024, 1024) + ): pixel_mean = np.array([123.675, 116.28, 103.53]).reshape(1, 1, -1) pixel_std = np.array([58.395, 57.12, 57.375]).reshape(1, 1, -1) x = (img - pixel_mean) / pixel_std return x - def predict(self, img: PILImage, input_point=np.array([[500, 375]]), input_label=np.array([1])) -> List[PILImage]: + def predict( + self, + img: PILImage, + input_point=np.array([[500, 375]]), + input_label=np.array([1]) + ) -> List[PILImage]: # Preprocess image image = resize_longes_side(img) image = numpy.array(image) @@ -73,8 +89,12 @@ class SamSession(BaseSession): image_embedding = encoded[0] # Add a batch index, concatenate a padding point, and transform. - onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :] - onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32) + onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[ + None, :, : + ] + onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[ + None, : + ].astype(np.float32) onnx_coord = apply_coords(onnx_coord, img.size[::1], 1024).astype(np.float32) # Create an empty mask input and an indicator for no mask. @@ -87,11 +107,14 @@ class SamSession(BaseSession): "point_labels": onnx_label, "mask_input": onnx_mask_input, "has_mask_input": onnx_has_mask_input, - "orig_im_size": np.array(img.size[::-1], dtype=np.float32) + "orig_im_size": np.array(img.size[::-1], dtype=np.float32), } masks, _, low_res_logits = self.decoder.run(None, decoder_inputs) masks = masks > 0.0 - masks = [Image.fromarray((masks[i, 0] * 255).astype(np.uint8)) for i in range(masks.shape[0])] + masks = [ + Image.fromarray((masks[i, 0] * 255).astype(np.uint8)) + for i in range(masks.shape[0]) + ] return masks From ff38b9a377efd8d37e41c33a396df469c78bddb5 Mon Sep 17 00:00:00 2001 From: Flippchen <91947480+Flippchen@users.noreply.github.com> Date: Thu, 20 Apr 2023 11:36:20 +0200 Subject: [PATCH 5/9] fix lint and refactored normalizing --- rembg/session_factory.py | 4 ++-- rembg/session_sam.py | 18 +++++++++--------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/rembg/session_factory.py b/rembg/session_factory.py index 7711dca..f107d14 100644 --- a/rembg/session_factory.py +++ b/rembg/session_factory.py @@ -88,12 +88,12 @@ def new_session(model_name: str = "u2net") -> BaseSession: ort.InferenceSession( str(path / fname_encoder), providers=ort.get_available_providers(), - sess_options=sess_opts + sess_options=sess_opts, ), ort.InferenceSession( str(path / fname_decoder), providers=ort.get_available_providers(), - sess_options=sess_opts + sess_options=sess_opts, ), ) diff --git a/rembg/session_sam.py b/rembg/session_sam.py index eb88114..0cb0b40 100644 --- a/rembg/session_sam.py +++ b/rembg/session_sam.py @@ -53,7 +53,7 @@ class SamSession(BaseSession): self, model_name: str, encoder: ort.InferenceSession, - decoder: ort.InferenceSession + decoder: ort.InferenceSession, ): super().__init__(model_name, encoder) self.decoder = decoder @@ -61,12 +61,12 @@ class SamSession(BaseSession): def normalize( self, img: numpy.ndarray, - mean=(0.485, 0.456, 0.406), - std=(0.229, 0.224, 0.225), - size=(1024, 1024) + mean=(123.675, 116.28, 103.53), + std=(58.395, 57.12, 57.375), + size=(1024, 1024), ): - pixel_mean = np.array([123.675, 116.28, 103.53]).reshape(1, 1, -1) - pixel_std = np.array([58.395, 57.12, 57.375]).reshape(1, 1, -1) + pixel_mean = np.array([*mean]).reshape(1, 1, -1) + pixel_std = np.array([*std]).reshape(1, 1, -1) x = (img - pixel_mean) / pixel_std return x @@ -74,7 +74,7 @@ class SamSession(BaseSession): self, img: PILImage, input_point=np.array([[500, 375]]), - input_label=np.array([1]) + input_label=np.array([1]), ) -> List[PILImage]: # Preprocess image image = resize_longes_side(img) @@ -90,10 +90,10 @@ class SamSession(BaseSession): # Add a batch index, concatenate a padding point, and transform. onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[ - None, :, : + None, :, : ] onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[ - None, : + None, : ].astype(np.float32) onnx_coord = apply_coords(onnx_coord, img.size[::1], 1024).astype(np.float32) From 72d1c6c64c6a48a8a611d3523cd4ee1820960cfa Mon Sep 17 00:00:00 2001 From: Flippchen <91947480+Flippchen@users.noreply.github.com> Date: Thu, 20 Apr 2023 11:41:07 +0200 Subject: [PATCH 6/9] reordered imports --- rembg/session_factory.py | 2 +- rembg/session_sam.py | 8 +++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/rembg/session_factory.py b/rembg/session_factory.py index f107d14..23d3f3f 100644 --- a/rembg/session_factory.py +++ b/rembg/session_factory.py @@ -11,8 +11,8 @@ import pooch from .session_base import BaseSession from .session_cloth import ClothSession from .session_dis import DisSession -from .session_simple import SimpleSession from .session_sam import SamSession +from .session_simple import SimpleSession def download_model(url: str, md5: str, fname: str, path: Path): diff --git a/rembg/session_sam.py b/rembg/session_sam.py index 0cb0b40..2e45132 100644 --- a/rembg/session_sam.py +++ b/rembg/session_sam.py @@ -1,11 +1,9 @@ from typing import List -import numpy import numpy as np from PIL import Image from PIL.Image import Image as PILImage import onnxruntime as ort -from matplotlib import pyplot as plt from .session_base import BaseSession @@ -39,7 +37,7 @@ def resize_longes_side(img: PILImage, size=1024): return img.resize((new_w, new_h)) -def pad_to_square(img: numpy.ndarray, size=1024): +def pad_to_square(img: np.ndarray, size=1024): h, w = img.shape[:2] padh = size - h padw = size - w @@ -60,7 +58,7 @@ class SamSession(BaseSession): def normalize( self, - img: numpy.ndarray, + img: np.ndarray, mean=(123.675, 116.28, 103.53), std=(58.395, 57.12, 57.375), size=(1024, 1024), @@ -78,7 +76,7 @@ class SamSession(BaseSession): ) -> List[PILImage]: # Preprocess image image = resize_longes_side(img) - image = numpy.array(image) + image = np.array(image) image = self.normalize(image) image = pad_to_square(image) From 3bdc06dff6f8afdc359c42d7af1c3acc6ee1a3eb Mon Sep 17 00:00:00 2001 From: Flippchen <91947480+Flippchen@users.noreply.github.com> Date: Thu, 20 Apr 2023 11:43:47 +0200 Subject: [PATCH 7/9] order pylint --- rembg/session_sam.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rembg/session_sam.py b/rembg/session_sam.py index 2e45132..6900f4b 100644 --- a/rembg/session_sam.py +++ b/rembg/session_sam.py @@ -1,9 +1,9 @@ from typing import List import numpy as np +import onnxruntime as ort from PIL import Image from PIL.Image import Image as PILImage -import onnxruntime as ort from .session_base import BaseSession From d7828b0369c1256b2d67d1386e8051eb8a295c76 Mon Sep 17 00:00:00 2001 From: Flippchen <91947480+Flippchen@users.noreply.github.com> Date: Thu, 20 Apr 2023 12:39:57 +0200 Subject: [PATCH 8/9] added input for remove function --- rembg/bg.py | 10 +++++++++- rembg/session_factory.py | 2 +- rembg/session_sam.py | 4 ++-- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/rembg/bg.py b/rembg/bg.py index a1f4215..8218b9f 100644 --- a/rembg/bg.py +++ b/rembg/bg.py @@ -123,6 +123,8 @@ def remove( only_mask: bool = False, post_process_mask: bool = False, bgcolor: Optional[Tuple[int, int, int, int]] = None, + input_point: Optional[np.ndarray] = None, + input_label: Optional[np.ndarray] = None, ) -> Union[bytes, PILImage, np.ndarray]: if isinstance(data, PILImage): return_type = ReturnType.PILLOW @@ -139,7 +141,13 @@ def remove( if session is None: session = new_session("u2net") - masks = session.predict(img) + if session.model_name == "sam": + if input_point is None or input_label is None: + raise ValueError("Input point and label are required for SAM model.") + masks = session.predict(img, input_point, input_label) + else: + masks = session.predict(img) + cutouts = [] for mask in masks: diff --git a/rembg/session_factory.py b/rembg/session_factory.py index 23d3f3f..9d021c4 100644 --- a/rembg/session_factory.py +++ b/rembg/session_factory.py @@ -64,7 +64,7 @@ def new_session(model_name: str = "u2net") -> BaseSession: md5 = "fc16ebd8b0c10d971d3513d564d01e29" url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx" session_class = DisSession - elif model_name == "SAM": + elif model_name == "sam": path = Path(u2net_home).expanduser() fname_encoder = f"{model_name}_encoder.onnx" diff --git a/rembg/session_sam.py b/rembg/session_sam.py index 6900f4b..712b4df 100644 --- a/rembg/session_sam.py +++ b/rembg/session_sam.py @@ -71,8 +71,8 @@ class SamSession(BaseSession): def predict( self, img: PILImage, - input_point=np.array([[500, 375]]), - input_label=np.array([1]), + input_point: np.ndarray, + input_label: np.ndarray, ) -> List[PILImage]: # Preprocess image image = resize_longes_side(img) From 394ab21ab94032e1eb9381dd2a26f9ca52881c1b Mon Sep 17 00:00:00 2001 From: Flippchen <91947480+Flippchen@users.noreply.github.com> Date: Thu, 20 Apr 2023 12:47:14 +0200 Subject: [PATCH 9/9] fix pylint --- rembg/bg.py | 7 ++++--- rembg/session_sam.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/rembg/bg.py b/rembg/bg.py index 8218b9f..ab342cb 100644 --- a/rembg/bg.py +++ b/rembg/bg.py @@ -20,6 +20,7 @@ from scipy.ndimage import binary_erosion from .session_base import BaseSession from .session_factory import new_session +from .session_sam import SamSession kernel = getStructuringElement(MORPH_ELLIPSE, (3, 3)) @@ -119,7 +120,7 @@ def remove( alpha_matting_foreground_threshold: int = 240, alpha_matting_background_threshold: int = 10, alpha_matting_erode_size: int = 10, - session: Optional[BaseSession] = None, + session: Optional[Union[BaseSession, SamSession]] = None, only_mask: bool = False, post_process_mask: bool = False, bgcolor: Optional[Tuple[int, int, int, int]] = None, @@ -141,10 +142,10 @@ def remove( if session is None: session = new_session("u2net") - if session.model_name == "sam": + if isinstance(session, SamSession): if input_point is None or input_label is None: raise ValueError("Input point and label are required for SAM model.") - masks = session.predict(img, input_point, input_label) + masks = session.predict_sam(img, input_point, input_label) else: masks = session.predict(img) diff --git a/rembg/session_sam.py b/rembg/session_sam.py index 712b4df..5bf2067 100644 --- a/rembg/session_sam.py +++ b/rembg/session_sam.py @@ -68,7 +68,7 @@ class SamSession(BaseSession): x = (img - pixel_mean) / pixel_std return x - def predict( + def predict_sam( self, img: PILImage, input_point: np.ndarray,