diff --git a/rembg/bg.py b/rembg/bg.py index a1f4215..ab342cb 100644 --- a/rembg/bg.py +++ b/rembg/bg.py @@ -20,6 +20,7 @@ from scipy.ndimage import binary_erosion from .session_base import BaseSession from .session_factory import new_session +from .session_sam import SamSession kernel = getStructuringElement(MORPH_ELLIPSE, (3, 3)) @@ -119,10 +120,12 @@ def remove( alpha_matting_foreground_threshold: int = 240, alpha_matting_background_threshold: int = 10, alpha_matting_erode_size: int = 10, - session: Optional[BaseSession] = None, + session: Optional[Union[BaseSession, SamSession]] = None, only_mask: bool = False, post_process_mask: bool = False, bgcolor: Optional[Tuple[int, int, int, int]] = None, + input_point: Optional[np.ndarray] = None, + input_label: Optional[np.ndarray] = None, ) -> Union[bytes, PILImage, np.ndarray]: if isinstance(data, PILImage): return_type = ReturnType.PILLOW @@ -139,7 +142,13 @@ def remove( if session is None: session = new_session("u2net") - masks = session.predict(img) + if isinstance(session, SamSession): + if input_point is None or input_label is None: + raise ValueError("Input point and label are required for SAM model.") + masks = session.predict_sam(img, input_point, input_label) + else: + masks = session.predict(img) + cutouts = [] for mask in masks: diff --git a/rembg/session_factory.py b/rembg/session_factory.py index bfb69cd..9d021c4 100644 --- a/rembg/session_factory.py +++ b/rembg/session_factory.py @@ -11,10 +11,30 @@ import pooch from .session_base import BaseSession from .session_cloth import ClothSession from .session_dis import DisSession +from .session_sam import SamSession from .session_simple import SimpleSession +def download_model(url: str, md5: str, fname: str, path: Path): + pooch.retrieve( + url, + f"md5:{md5}", + fname=fname, + path=path, + progressbar=True, + ) + + def new_session(model_name: str = "u2net") -> BaseSession: + # Define the model path + u2net_home = os.getenv( + "U2NET_HOME", os.path.join(os.getenv("XDG_DATA_HOME", "~"), ".u2net") + ) + + fname = f"{model_name}.onnx" + path = Path(u2net_home).expanduser() + full_path = Path(u2net_home).expanduser() / fname + session_class: Type[BaseSession] md5 = "60024c5c889badc19c04ad937298a77b" url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx" @@ -44,22 +64,40 @@ def new_session(model_name: str = "u2net") -> BaseSession: md5 = "fc16ebd8b0c10d971d3513d564d01e29" url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx" session_class = DisSession + elif model_name == "sam": + path = Path(u2net_home).expanduser() - u2net_home = os.getenv( - "U2NET_HOME", os.path.join(os.getenv("XDG_DATA_HOME", "~"), ".u2net") - ) + fname_encoder = f"{model_name}_encoder.onnx" + encoder_md5 = "13d97c5c79ab13ef86d67cbde5f1b250" + encoder_url = "https://github.com/Flippchen/rembg/releases/download/test/vit_b-encoder-quant.onnx" - fname = f"{model_name}.onnx" - path = Path(u2net_home).expanduser() - full_path = Path(u2net_home).expanduser() / fname + fname_decoder = f"{model_name}_decoder.onnx" + decoder_md5 = "fa3d1c36a3187d3de1c8deebf33dd127" + decoder_url = "https://github.com/Flippchen/rembg/releases/download/test/vit_b-decoder-quant.onnx" - pooch.retrieve( - url, - f"md5:{md5}", - fname=fname, - path=Path(u2net_home).expanduser(), - progressbar=True, - ) + download_model(encoder_url, encoder_md5, fname_encoder, path) + download_model(decoder_url, decoder_md5, fname_decoder, path) + + sess_opts = ort.SessionOptions() + + if "OMP_NUM_THREADS" in os.environ: + sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"]) + + return SamSession( + model_name, + ort.InferenceSession( + str(path / fname_encoder), + providers=ort.get_available_providers(), + sess_options=sess_opts, + ), + ort.InferenceSession( + str(path / fname_decoder), + providers=ort.get_available_providers(), + sess_options=sess_opts, + ), + ) + + download_model(url, md5, fname, path) sess_opts = ort.SessionOptions() diff --git a/rembg/session_sam.py b/rembg/session_sam.py new file mode 100644 index 0000000..5bf2067 --- /dev/null +++ b/rembg/session_sam.py @@ -0,0 +1,118 @@ +from typing import List + +import numpy as np +import onnxruntime as ort +from PIL import Image +from PIL.Image import Image as PILImage + +from .session_base import BaseSession + + +def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int): + scale = long_side_length * 1.0 / max(oldh, oldw) + newh, neww = oldh * scale, oldw * scale + neww = int(neww + 0.5) + newh = int(newh + 0.5) + return (newh, neww) + + +def apply_coords(coords: np.ndarray, original_size, target_length) -> np.ndarray: + old_h, old_w = original_size + new_h, new_w = get_preprocess_shape( + original_size[0], original_size[1], target_length + ) + coords = coords.copy().astype(float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + +def resize_longes_side(img: PILImage, size=1024): + w, h = img.size + if h > w: + new_h, new_w = size, int(w * size / h) + else: + new_h, new_w = int(h * size / w), size + + return img.resize((new_w, new_h)) + + +def pad_to_square(img: np.ndarray, size=1024): + h, w = img.shape[:2] + padh = size - h + padw = size - w + img = np.pad(img, ((0, padh), (0, padw), (0, 0)), mode="constant") + img = img.astype(np.float32) + return img + + +class SamSession(BaseSession): + def __init__( + self, + model_name: str, + encoder: ort.InferenceSession, + decoder: ort.InferenceSession, + ): + super().__init__(model_name, encoder) + self.decoder = decoder + + def normalize( + self, + img: np.ndarray, + mean=(123.675, 116.28, 103.53), + std=(58.395, 57.12, 57.375), + size=(1024, 1024), + ): + pixel_mean = np.array([*mean]).reshape(1, 1, -1) + pixel_std = np.array([*std]).reshape(1, 1, -1) + x = (img - pixel_mean) / pixel_std + return x + + def predict_sam( + self, + img: PILImage, + input_point: np.ndarray, + input_label: np.ndarray, + ) -> List[PILImage]: + # Preprocess image + image = resize_longes_side(img) + image = np.array(image) + image = self.normalize(image) + image = pad_to_square(image) + + # Transpose + image = image.transpose(2, 0, 1)[None, :, :, :] + # Run encoder (Image embedding) + encoded = self.inner_session.run(None, {"x": image}) + image_embedding = encoded[0] + + # Add a batch index, concatenate a padding point, and transform. + onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[ + None, :, : + ] + onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[ + None, : + ].astype(np.float32) + onnx_coord = apply_coords(onnx_coord, img.size[::1], 1024).astype(np.float32) + + # Create an empty mask input and an indicator for no mask. + onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32) + onnx_has_mask_input = np.zeros(1, dtype=np.float32) + + decoder_inputs = { + "image_embeddings": image_embedding, + "point_coords": onnx_coord, + "point_labels": onnx_label, + "mask_input": onnx_mask_input, + "has_mask_input": onnx_has_mask_input, + "orig_im_size": np.array(img.size[::-1], dtype=np.float32), + } + + masks, _, low_res_logits = self.decoder.run(None, decoder_inputs) + masks = masks > 0.0 + masks = [ + Image.fromarray((masks[i, 0] * 255).astype(np.uint8)) + for i in range(masks.shape[0]) + ] + + return masks diff --git a/requirements.txt b/requirements.txt index f38891a..fb41d7f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ filetype==1.2.0 pooch==1.6.0 imagehash==4.3.1 numpy==1.23.5 -onnxruntime==1.13.1 +onnxruntime==1.14.1 opencv-python-headless==4.6.0.66 pillow==9.3.0 pymatting==1.1.8