diff --git a/rembg/bg.py b/rembg/bg.py index 504b279..f3efe3d 100644 --- a/rembg/bg.py +++ b/rembg/bg.py @@ -55,11 +55,11 @@ def alpha_matting_cutout( if img.mode == "RGBA" or img.mode == "CMYK": img = img.convert("RGB") - img = np.asarray(img) - mask = np.asarray(mask) + img_array = np.asarray(img) + mask_array = np.asarray(mask) - is_foreground = mask > foreground_threshold - is_background = mask < background_threshold + is_foreground = mask_array > foreground_threshold + is_background = mask_array < background_threshold structure = None if erode_structure_size > 0: @@ -70,11 +70,11 @@ def alpha_matting_cutout( is_foreground = binary_erosion(is_foreground, structure=structure) is_background = binary_erosion(is_background, structure=structure, border_value=1) - trimap = np.full(mask.shape, dtype=np.uint8, fill_value=128) + trimap = np.full(mask_array.shape, dtype=np.uint8, fill_value=128) trimap[is_foreground] = 255 trimap[is_background] = 0 - img_normalized = img / 255.0 + img_normalized = img_array / 255.0 trimap_normalized = trimap / 255.0 alpha = estimate_alpha_cf(img_normalized, trimap_normalized) diff --git a/rembg/commands/b_command.py b/rembg/commands/b_command.py index c9f362c..7683f6f 100644 --- a/rembg/commands/b_command.py +++ b/rembg/commands/b_command.py @@ -6,7 +6,7 @@ import sys from typing import IO import click -from PIL import Image +from PIL.Image import Image as PILImage from ..bg import remove from ..session_factory import new_session @@ -134,7 +134,7 @@ def b_command( if not os.path.isdir(output_dir): os.makedirs(output_dir, exist_ok=True) - def img_to_byte_array(img: Image) -> bytes: + def img_to_byte_array(img: PILImage) -> bytes: buff = io.BytesIO() img.save(buff, format="PNG") return buff.getvalue() diff --git a/rembg/commands/p_command.py b/rembg/commands/p_command.py index 4f300a0..e83b75e 100644 --- a/rembg/commands/p_command.py +++ b/rembg/commands/p_command.py @@ -186,9 +186,9 @@ def p_command( inputs = list(input.glob("**/*")) if not watch: - inputs = tqdm(inputs) + inputs_tqdm = tqdm(inputs) - for each_input in inputs: + for each_input in inputs_tqdm: if not each_input.is_dir(): process(each_input) diff --git a/rembg/sessions/sam.py b/rembg/sessions/sam.py index 0e494e1..f160b82 100644 --- a/rembg/sessions/sam.py +++ b/rembg/sessions/sam.py @@ -1,6 +1,6 @@ import os from copy import deepcopy -from typing import List +from typing import Dict, List, Tuple import cv2 import numpy as np @@ -87,8 +87,9 @@ class SamSession(BaseSession): self, model_name: str, sess_opts: ort.SessionOptions, + providers=None, *args, - **kwargs, + **kwargs ): """ Initialize a new SamSession with the given model name and session options. @@ -101,52 +102,27 @@ class SamSession(BaseSession): """ self.model_name = model_name - self.providers = [] + valid_providers = [] + available_providers = ort.get_available_providers() - _providers = ort.get_available_providers() - for provider in kwargs.get("providers", []): - if provider in _providers: - self.providers.append(provider) + for provider in (providers or []): + if provider in available_providers: + valid_providers.append(provider) else: - self.providers.extend(_providers) + valid_providers.extend(available_providers) paths = self.__class__.download_models(*args, **kwargs) self.encoder = ort.InferenceSession( str(paths[0]), - providers=self.providers, + providers=valid_providers, sess_options=sess_opts, ) self.decoder = ort.InferenceSession( str(paths[1]), - providers=self.providers, + providers=valid_providers, sess_options=sess_opts, ) - def normalize( - self, - img: np.ndarray, - mean=(), - std=(), - size=(), - *args, - **kwargs, - ): - """ - Normalize the input image by subtracting the mean and dividing by the standard deviation. - - Args: - img (np.ndarray): The input image. - mean (tuple, optional): The mean values for normalization. Defaults to (). - std (tuple, optional): The standard deviation values for normalization. Defaults to (). - size (tuple, optional): The target size of the image. Defaults to (). - *args: Variable length argument list. - **kwargs: Arbitrary keyword arguments. - - Returns: - np.ndarray: The normalized image. - """ - return img - def predict( self, img: PILImage, @@ -269,8 +245,7 @@ class SamSession(BaseSession): for m in masks[0, :, :, :]: mask[m > 0.0] = [255, 255, 255] - mask = Image.fromarray(mask).convert("L") - return [mask] + return [Image.fromarray(mask).convert("L")] @classmethod def download_models(cls, *args, **kwargs):