mirror of
https://git.mirrors.martin98.com/https://github.com/danielgatis/rembg
synced 2025-08-16 05:25:58 +08:00
Added docstrings to function definitions and classes / class methods (#522)
This commit is contained in:
parent
7c1b3c4cb3
commit
0eaaa0b393
93
rembg/bg.py
93
rembg/bg.py
@ -38,6 +38,17 @@ def alpha_matting_cutout(
|
|||||||
background_threshold: int,
|
background_threshold: int,
|
||||||
erode_structure_size: int,
|
erode_structure_size: int,
|
||||||
) -> PILImage:
|
) -> PILImage:
|
||||||
|
"""
|
||||||
|
Perform alpha matting on an image using a given mask and threshold values.
|
||||||
|
|
||||||
|
This function takes a PIL image `img` and a PIL image `mask` as input, along with
|
||||||
|
the `foreground_threshold` and `background_threshold` values used to determine
|
||||||
|
foreground and background pixels. The `erode_structure_size` parameter specifies
|
||||||
|
the size of the erosion structure to be applied to the mask.
|
||||||
|
|
||||||
|
The function returns a PIL image representing the cutout of the foreground object
|
||||||
|
from the original image.
|
||||||
|
"""
|
||||||
if img.mode == "RGBA" or img.mode == "CMYK":
|
if img.mode == "RGBA" or img.mode == "CMYK":
|
||||||
img = img.convert("RGB")
|
img = img.convert("RGB")
|
||||||
|
|
||||||
@ -74,17 +85,46 @@ def alpha_matting_cutout(
|
|||||||
|
|
||||||
|
|
||||||
def naive_cutout(img: PILImage, mask: PILImage) -> PILImage:
|
def naive_cutout(img: PILImage, mask: PILImage) -> PILImage:
|
||||||
|
"""
|
||||||
|
Perform a simple cutout operation on an image using a mask.
|
||||||
|
|
||||||
|
This function takes a PIL image `img` and a PIL image `mask` as input.
|
||||||
|
It uses the mask to create a new image where the pixels from `img` are
|
||||||
|
cut out based on the mask.
|
||||||
|
|
||||||
|
The function returns a PIL image representing the cutout of the original
|
||||||
|
image using the mask.
|
||||||
|
"""
|
||||||
empty = Image.new("RGBA", (img.size), 0)
|
empty = Image.new("RGBA", (img.size), 0)
|
||||||
cutout = Image.composite(img, empty, mask)
|
cutout = Image.composite(img, empty, mask)
|
||||||
return cutout
|
return cutout
|
||||||
|
|
||||||
|
|
||||||
def putalpha_cutout(img: PILImage, mask: PILImage) -> PILImage:
|
def putalpha_cutout(img: PILImage, mask: PILImage) -> PILImage:
|
||||||
|
"""
|
||||||
|
Apply the specified mask to the image as an alpha cutout.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (PILImage): The image to be modified.
|
||||||
|
mask (PILImage): The mask to be applied.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PILImage: The modified image with the alpha cutout applied.
|
||||||
|
"""
|
||||||
img.putalpha(mask)
|
img.putalpha(mask)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
|
||||||
def get_concat_v_multi(imgs: List[PILImage]) -> PILImage:
|
def get_concat_v_multi(imgs: List[PILImage]) -> PILImage:
|
||||||
|
"""
|
||||||
|
Concatenate multiple images vertically.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
imgs (List[PILImage]): The list of images to be concatenated.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PILImage: The concatenated image.
|
||||||
|
"""
|
||||||
pivot = imgs.pop(0)
|
pivot = imgs.pop(0)
|
||||||
for im in imgs:
|
for im in imgs:
|
||||||
pivot = get_concat_v(pivot, im)
|
pivot = get_concat_v(pivot, im)
|
||||||
@ -92,6 +132,16 @@ def get_concat_v_multi(imgs: List[PILImage]) -> PILImage:
|
|||||||
|
|
||||||
|
|
||||||
def get_concat_v(img1: PILImage, img2: PILImage) -> PILImage:
|
def get_concat_v(img1: PILImage, img2: PILImage) -> PILImage:
|
||||||
|
"""
|
||||||
|
Concatenate two images vertically.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img1 (PILImage): The first image.
|
||||||
|
img2 (PILImage): The second image to be concatenated below the first image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PILImage: The concatenated image.
|
||||||
|
"""
|
||||||
dst = Image.new("RGBA", (img1.width, img1.height + img2.height))
|
dst = Image.new("RGBA", (img1.width, img1.height + img2.height))
|
||||||
dst.paste(img1, (0, 0))
|
dst.paste(img1, (0, 0))
|
||||||
dst.paste(img2, (0, img1.height))
|
dst.paste(img2, (0, img1.height))
|
||||||
@ -112,6 +162,16 @@ def post_process(mask: np.ndarray) -> np.ndarray:
|
|||||||
|
|
||||||
|
|
||||||
def apply_background_color(img: PILImage, color: Tuple[int, int, int, int]) -> PILImage:
|
def apply_background_color(img: PILImage, color: Tuple[int, int, int, int]) -> PILImage:
|
||||||
|
"""
|
||||||
|
Apply the specified background color to the image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (PILImage): The image to be modified.
|
||||||
|
color (Tuple[int, int, int, int]): The RGBA color to be applied.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PILImage: The modified image with the background color applied.
|
||||||
|
"""
|
||||||
r, g, b, a = color
|
r, g, b, a = color
|
||||||
colored_image = Image.new("RGBA", img.size, (r, g, b, a))
|
colored_image = Image.new("RGBA", img.size, (r, g, b, a))
|
||||||
colored_image.paste(img, mask=img)
|
colored_image.paste(img, mask=img)
|
||||||
@ -120,10 +180,22 @@ def apply_background_color(img: PILImage, color: Tuple[int, int, int, int]) -> P
|
|||||||
|
|
||||||
|
|
||||||
def fix_image_orientation(img: PILImage) -> PILImage:
|
def fix_image_orientation(img: PILImage) -> PILImage:
|
||||||
|
"""
|
||||||
|
Fix the orientation of the image based on its EXIF data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (PILImage): The image to be fixed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PILImage: The fixed image.
|
||||||
|
"""
|
||||||
return ImageOps.exif_transpose(img)
|
return ImageOps.exif_transpose(img)
|
||||||
|
|
||||||
|
|
||||||
def download_models() -> None:
|
def download_models() -> None:
|
||||||
|
"""
|
||||||
|
Download models for image processing.
|
||||||
|
"""
|
||||||
for session in sessions_class:
|
for session in sessions_class:
|
||||||
session.download_models()
|
session.download_models()
|
||||||
|
|
||||||
@ -141,6 +213,27 @@ def remove(
|
|||||||
*args: Optional[Any],
|
*args: Optional[Any],
|
||||||
**kwargs: Optional[Any]
|
**kwargs: Optional[Any]
|
||||||
) -> Union[bytes, PILImage, np.ndarray]:
|
) -> Union[bytes, PILImage, np.ndarray]:
|
||||||
|
"""
|
||||||
|
Remove the background from an input image.
|
||||||
|
|
||||||
|
This function takes in various parameters and returns a modified version of the input image with the background removed. The function can handle input data in the form of bytes, a PIL image, or a numpy array. The function first checks the type of the input data and converts it to a PIL image if necessary. It then fixes the orientation of the image and proceeds to perform background removal using the 'u2net' model. The result is a list of binary masks representing the foreground objects in the image. These masks are post-processed and combined to create a final cutout image. If a background color is provided, it is applied to the cutout image. The function returns the resulting cutout image in the format specified by the input 'return_type' parameter.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
data (Union[bytes, PILImage, np.ndarray]): The input image data.
|
||||||
|
alpha_matting (bool, optional): Flag indicating whether to use alpha matting. Defaults to False.
|
||||||
|
alpha_matting_foreground_threshold (int, optional): Foreground threshold for alpha matting. Defaults to 240.
|
||||||
|
alpha_matting_background_threshold (int, optional): Background threshold for alpha matting. Defaults to 10.
|
||||||
|
alpha_matting_erode_size (int, optional): Erosion size for alpha matting. Defaults to 10.
|
||||||
|
session (Optional[BaseSession], optional): A session object for the 'u2net' model. Defaults to None.
|
||||||
|
only_mask (bool, optional): Flag indicating whether to return only the binary masks. Defaults to False.
|
||||||
|
post_process_mask (bool, optional): Flag indicating whether to post-process the masks. Defaults to False.
|
||||||
|
bgcolor (Optional[Tuple[int, int, int, int]], optional): Background color for the cutout image. Defaults to None.
|
||||||
|
*args (Optional[Any]): Additional positional arguments.
|
||||||
|
**kwargs (Optional[Any]): Additional keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Union[bytes, PILImage, np.ndarray]: The cutout image with the background removed.
|
||||||
|
"""
|
||||||
if isinstance(data, PILImage):
|
if isinstance(data, PILImage):
|
||||||
return_type = ReturnType.PILLOW
|
return_type = ReturnType.PILLOW
|
||||||
img = data
|
img = data
|
||||||
|
@ -102,6 +102,22 @@ def rs_command(
|
|||||||
output_specifier: str,
|
output_specifier: str,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""
|
||||||
|
Command-line interface for processing images by removing the background using a specified model and generating a mask.
|
||||||
|
|
||||||
|
This CLI command takes several options and arguments to configure the background removal process and save the processed images.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
model (str): The name of the model to use for background removal.
|
||||||
|
extras (str): Additional options in JSON format that can be passed to customize the background removal process.
|
||||||
|
image_width (int): The width of the input images in pixels.
|
||||||
|
image_height (int): The height of the input images in pixels.
|
||||||
|
output_specifier (str): A printf-style specifier for the output filenames. If specified, the processed images will be saved to the specified output directory with filenames generated using the specifier.
|
||||||
|
**kwargs: Additional keyword arguments that can be used to customize the background removal process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
kwargs.update(json.loads(extras))
|
kwargs.update(json.loads(extras))
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -85,6 +85,21 @@ from ..sessions import sessions_names
|
|||||||
type=click.File("wb", lazy=True),
|
type=click.File("wb", lazy=True),
|
||||||
)
|
)
|
||||||
def i_command(model: str, extras: str, input: IO, output: IO, **kwargs) -> None:
|
def i_command(model: str, extras: str, input: IO, output: IO, **kwargs) -> None:
|
||||||
|
"""
|
||||||
|
Click command line interface function to process an input file based on the provided options.
|
||||||
|
|
||||||
|
This function is the entry point for the CLI program. It reads an input file, applies image processing operations based on the provided options, and writes the output to a file.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
model (str): The name of the model to use for image processing.
|
||||||
|
extras (str): Additional options in JSON format.
|
||||||
|
input: The input file to process.
|
||||||
|
output: The output file to write the processed image to.
|
||||||
|
**kwargs: Additional keyword arguments corresponding to the command line options.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
kwargs.update(json.loads(extras))
|
kwargs.update(json.loads(extras))
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -117,6 +117,26 @@ def p_command(
|
|||||||
watch: bool,
|
watch: bool,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""
|
||||||
|
Command-line interface (CLI) program for performing background removal on images in a folder.
|
||||||
|
|
||||||
|
This program takes a folder as input and uses a specified model to remove the background from the images in the folder.
|
||||||
|
It provides various options for configuration, such as choosing the model, enabling alpha matting, setting trimap thresholds, erode size, etc.
|
||||||
|
Additional options include outputting only the mask and post-processing the mask.
|
||||||
|
The program can also watch the input folder for changes and automatically process new images.
|
||||||
|
The resulting images with the background removed are saved in the specified output folder.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
model (str): The name of the model to use for background removal.
|
||||||
|
extras (str): Additional options in JSON format.
|
||||||
|
input (pathlib.Path): The path to the input folder.
|
||||||
|
output (pathlib.Path): The path to the output folder.
|
||||||
|
watch (bool): Whether to watch the input folder for changes.
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
kwargs.update(json.loads(extras))
|
kwargs.update(json.loads(extras))
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -48,6 +48,12 @@ from ..sessions.base import BaseSession
|
|||||||
help="number of worker threads",
|
help="number of worker threads",
|
||||||
)
|
)
|
||||||
def s_command(port: int, log_level: str, threads: int) -> None:
|
def s_command(port: int, log_level: str, threads: int) -> None:
|
||||||
|
"""
|
||||||
|
Command-line interface for running the FastAPI web server.
|
||||||
|
|
||||||
|
This function starts the FastAPI web server with the specified port and log level.
|
||||||
|
If the number of worker threads is specified, it sets the thread limiter accordingly.
|
||||||
|
"""
|
||||||
sessions: dict[str, BaseSession] = {}
|
sessions: dict[str, BaseSession] = {}
|
||||||
tags_metadata = [
|
tags_metadata = [
|
||||||
{
|
{
|
||||||
|
@ -11,6 +11,23 @@ from .sessions.u2net import U2netSession
|
|||||||
def new_session(
|
def new_session(
|
||||||
model_name: str = "u2net", providers=None, *args, **kwargs
|
model_name: str = "u2net", providers=None, *args, **kwargs
|
||||||
) -> BaseSession:
|
) -> BaseSession:
|
||||||
|
"""
|
||||||
|
Create a new session object based on the specified model name.
|
||||||
|
|
||||||
|
This function searches for the session class based on the model name in the 'sessions_class' list.
|
||||||
|
It then creates an instance of the session class with the provided arguments.
|
||||||
|
The 'sess_opts' object is created using the 'ort.SessionOptions()' constructor.
|
||||||
|
If the 'OMP_NUM_THREADS' environment variable is set, the 'inter_op_num_threads' option of 'sess_opts' is set to its value.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
model_name (str): The name of the model.
|
||||||
|
providers: The providers for the session.
|
||||||
|
*args: Additional positional arguments.
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BaseSession: The created session object.
|
||||||
|
"""
|
||||||
session_class: Type[BaseSession] = U2netSession
|
session_class: Type[BaseSession] = U2netSession
|
||||||
|
|
||||||
for sc in sessions_class:
|
for sc in sessions_class:
|
||||||
|
@ -8,6 +8,7 @@ from PIL.Image import Image as PILImage
|
|||||||
|
|
||||||
|
|
||||||
class BaseSession:
|
class BaseSession:
|
||||||
|
"""This is a base class for managing a session with a machine learning model."""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
@ -16,6 +17,7 @@ class BaseSession:
|
|||||||
*args,
|
*args,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
|
"""Initialize an instance of the BaseSession class."""
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
self.providers = []
|
self.providers = []
|
||||||
|
@ -10,7 +10,21 @@ from .base import BaseSession
|
|||||||
|
|
||||||
|
|
||||||
class DisSession(BaseSession):
|
class DisSession(BaseSession):
|
||||||
|
"""
|
||||||
|
This class represents a session for object detection.
|
||||||
|
"""
|
||||||
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
||||||
|
"""
|
||||||
|
Use a pre-trained model to predict the object in the given image.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
img (PILImage): The input image.
|
||||||
|
*args: Variable length argument list.
|
||||||
|
**kwargs: Arbitrary keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[PILImage]: A list of predicted mask images.
|
||||||
|
"""
|
||||||
ort_outs = self.inner_session.run(
|
ort_outs = self.inner_session.run(
|
||||||
None,
|
None,
|
||||||
self.normalize(img, (0.485, 0.456, 0.406), (1.0, 1.0, 1.0), (1024, 1024)),
|
self.normalize(img, (0.485, 0.456, 0.406), (1.0, 1.0, 1.0), (1024, 1024)),
|
||||||
@ -31,6 +45,16 @@ class DisSession(BaseSession):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def download_models(cls, *args, **kwargs):
|
def download_models(cls, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Download the pre-trained models.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
*args: Variable length argument list.
|
||||||
|
**kwargs: Arbitrary keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The path of the downloaded model file.
|
||||||
|
"""
|
||||||
fname = f"{cls.name(*args, **kwargs)}.onnx"
|
fname = f"{cls.name(*args, **kwargs)}.onnx"
|
||||||
pooch.retrieve(
|
pooch.retrieve(
|
||||||
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-anime.onnx",
|
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-anime.onnx",
|
||||||
@ -46,4 +70,14 @@ class DisSession(BaseSession):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def name(cls, *args, **kwargs):
|
def name(cls, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Get the name of the pre-trained model.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
*args: Variable length argument list.
|
||||||
|
**kwargs: Arbitrary keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The name of the pre-trained model.
|
||||||
|
"""
|
||||||
return "isnet-anime"
|
return "isnet-anime"
|
||||||
|
@ -11,6 +11,17 @@ from .base import BaseSession
|
|||||||
|
|
||||||
class DisSession(BaseSession):
|
class DisSession(BaseSession):
|
||||||
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
||||||
|
"""
|
||||||
|
Predicts the mask image for the input image.
|
||||||
|
|
||||||
|
This method takes a PILImage object as input and returns a list of PILImage objects as output. It performs several image processing operations to generate the mask image.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
img (PILImage): The input image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[PILImage]: A list of PILImage objects representing the generated mask image.
|
||||||
|
"""
|
||||||
ort_outs = self.inner_session.run(
|
ort_outs = self.inner_session.run(
|
||||||
None,
|
None,
|
||||||
self.normalize(img, (0.485, 0.456, 0.406), (1.0, 1.0, 1.0), (1024, 1024)),
|
self.normalize(img, (0.485, 0.456, 0.406), (1.0, 1.0, 1.0), (1024, 1024)),
|
||||||
@ -31,6 +42,18 @@ class DisSession(BaseSession):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def download_models(cls, *args, **kwargs):
|
def download_models(cls, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Downloads the pre-trained model file.
|
||||||
|
|
||||||
|
This class method downloads the pre-trained model file from a specified URL using the pooch library.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
args: Additional positional arguments.
|
||||||
|
kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The path to the downloaded model file.
|
||||||
|
"""
|
||||||
fname = f"{cls.name(*args, **kwargs)}.onnx"
|
fname = f"{cls.name(*args, **kwargs)}.onnx"
|
||||||
pooch.retrieve(
|
pooch.retrieve(
|
||||||
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx",
|
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx",
|
||||||
@ -46,4 +69,16 @@ class DisSession(BaseSession):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def name(cls, *args, **kwargs):
|
def name(cls, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Returns the name of the model.
|
||||||
|
|
||||||
|
This class method returns the name of the model.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
args: Additional positional arguments.
|
||||||
|
kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The name of the model.
|
||||||
|
"""
|
||||||
return "isnet-general-use"
|
return "isnet-general-use"
|
||||||
|
@ -49,7 +49,25 @@ def pad_to_square(img: np.ndarray, size=1024):
|
|||||||
|
|
||||||
|
|
||||||
class SamSession(BaseSession):
|
class SamSession(BaseSession):
|
||||||
|
"""
|
||||||
|
This class represents a session for the Sam model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name (str): The name of the model.
|
||||||
|
sess_opts (ort.SessionOptions): The session options.
|
||||||
|
*args: Variable length argument list.
|
||||||
|
**kwargs: Arbitrary keyword arguments.
|
||||||
|
"""
|
||||||
def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
|
def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Initialize a new SamSession with the given model name and session options.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name (str): The name of the model.
|
||||||
|
sess_opts (ort.SessionOptions): The session options.
|
||||||
|
*args: Variable length argument list.
|
||||||
|
**kwargs: Arbitrary keyword arguments.
|
||||||
|
"""
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
paths = self.__class__.download_models()
|
paths = self.__class__.download_models()
|
||||||
self.encoder = ort.InferenceSession(
|
self.encoder = ort.InferenceSession(
|
||||||
@ -72,6 +90,20 @@ class SamSession(BaseSession):
|
|||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Normalize the input image by subtracting the mean and dividing by the standard deviation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (np.ndarray): The input image.
|
||||||
|
mean (tuple, optional): The mean values for normalization. Defaults to (123.675, 116.28, 103.53).
|
||||||
|
std (tuple, optional): The standard deviation values for normalization. Defaults to (58.395, 57.12, 57.375).
|
||||||
|
size (tuple, optional): The target size of the image. Defaults to (1024, 1024).
|
||||||
|
*args: Variable length argument list.
|
||||||
|
**kwargs: Arbitrary keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: The normalized image.
|
||||||
|
"""
|
||||||
pixel_mean = np.array([*mean]).reshape(1, 1, -1)
|
pixel_mean = np.array([*mean]).reshape(1, 1, -1)
|
||||||
pixel_std = np.array([*std]).reshape(1, 1, -1)
|
pixel_std = np.array([*std]).reshape(1, 1, -1)
|
||||||
x = (img - pixel_mean) / pixel_std
|
x = (img - pixel_mean) / pixel_std
|
||||||
@ -83,6 +115,19 @@ class SamSession(BaseSession):
|
|||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[PILImage]:
|
) -> List[PILImage]:
|
||||||
|
"""
|
||||||
|
Predict masks for an input image.
|
||||||
|
|
||||||
|
This function takes an image as input and performs various preprocessing steps on the image. It then runs the image through an encoder to obtain an image embedding. The function also takes input labels and points as additional arguments. It concatenates the input points and labels with padding and transforms them. It creates an empty mask input and an indicator for no mask. The function then passes the image embedding, point coordinates, point labels, mask input, and has mask input to a decoder. The decoder generates masks based on the input and returns them as a list of images.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
img (PILImage): The input image.
|
||||||
|
*args: Additional arguments.
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[PILImage]: A list of masks generated by the decoder.
|
||||||
|
"""
|
||||||
# Preprocess image
|
# Preprocess image
|
||||||
image = resize_longes_side(img)
|
image = resize_longes_side(img)
|
||||||
image = np.array(image)
|
image = np.array(image)
|
||||||
@ -136,6 +181,19 @@ class SamSession(BaseSession):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def download_models(cls, *args, **kwargs):
|
def download_models(cls, *args, **kwargs):
|
||||||
|
'''
|
||||||
|
Class method to download ONNX model files.
|
||||||
|
|
||||||
|
This method is responsible for downloading two ONNX model files from specified URLs and saving them locally. The downloaded files are saved with the naming convention 'name_encoder.onnx' and 'name_decoder.onnx', where 'name' is the value returned by the 'name' method.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
cls: The class object.
|
||||||
|
*args: Variable length argument list.
|
||||||
|
**kwargs: Arbitrary keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: A tuple containing the file paths of the downloaded encoder and decoder models.
|
||||||
|
'''
|
||||||
fname_encoder = f"{cls.name(*args, **kwargs)}_encoder.onnx"
|
fname_encoder = f"{cls.name(*args, **kwargs)}_encoder.onnx"
|
||||||
fname_decoder = f"{cls.name(*args, **kwargs)}_decoder.onnx"
|
fname_decoder = f"{cls.name(*args, **kwargs)}_decoder.onnx"
|
||||||
|
|
||||||
@ -166,4 +224,17 @@ class SamSession(BaseSession):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def name(cls, *args, **kwargs):
|
def name(cls, *args, **kwargs):
|
||||||
|
'''
|
||||||
|
Class method to return a string value.
|
||||||
|
|
||||||
|
This method returns the string value 'sam'.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
cls: The class object.
|
||||||
|
*args: Variable length argument list.
|
||||||
|
**kwargs: Arbitrary keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The string value 'sam'.
|
||||||
|
'''
|
||||||
return "sam"
|
return "sam"
|
||||||
|
@ -10,7 +10,21 @@ from .base import BaseSession
|
|||||||
|
|
||||||
|
|
||||||
class SiluetaSession(BaseSession):
|
class SiluetaSession(BaseSession):
|
||||||
|
"""This is a class representing a SiluetaSession object."""
|
||||||
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
||||||
|
"""
|
||||||
|
Predict the mask of the input image.
|
||||||
|
|
||||||
|
This method takes an image as input, preprocesses it, and performs a prediction to generate a mask. The generated mask is then post-processed and returned as a list of PILImage objects.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
img (PILImage): The input image to be processed.
|
||||||
|
*args: Variable length argument list.
|
||||||
|
**kwargs: Arbitrary keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[PILImage]: A list of post-processed masks.
|
||||||
|
"""
|
||||||
ort_outs = self.inner_session.run(
|
ort_outs = self.inner_session.run(
|
||||||
None,
|
None,
|
||||||
self.normalize(
|
self.normalize(
|
||||||
@ -33,6 +47,18 @@ class SiluetaSession(BaseSession):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def download_models(cls, *args, **kwargs):
|
def download_models(cls, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Download the pre-trained model file.
|
||||||
|
|
||||||
|
This method downloads the pre-trained model file from a specified URL. The file is saved to the U2NET home directory.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
*args: Variable length argument list.
|
||||||
|
**kwargs: Arbitrary keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The path to the downloaded model file.
|
||||||
|
"""
|
||||||
fname = f"{cls.name()}.onnx"
|
fname = f"{cls.name()}.onnx"
|
||||||
pooch.retrieve(
|
pooch.retrieve(
|
||||||
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx",
|
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx",
|
||||||
@ -48,4 +74,16 @@ class SiluetaSession(BaseSession):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def name(cls, *args, **kwargs):
|
def name(cls, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Return the name of the model.
|
||||||
|
|
||||||
|
This method returns the name of the Silueta model.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
*args: Variable length argument list.
|
||||||
|
**kwargs: Arbitrary keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The name of the model.
|
||||||
|
"""
|
||||||
return "silueta"
|
return "silueta"
|
||||||
|
@ -10,7 +10,21 @@ from .base import BaseSession
|
|||||||
|
|
||||||
|
|
||||||
class U2netSession(BaseSession):
|
class U2netSession(BaseSession):
|
||||||
|
"""
|
||||||
|
This class represents a U2net session, which is a subclass of BaseSession.
|
||||||
|
"""
|
||||||
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
||||||
|
"""
|
||||||
|
Predicts the output masks for the input image using the inner session.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
img (PILImage): The input image.
|
||||||
|
*args: Additional positional arguments.
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[PILImage]: The list of output masks.
|
||||||
|
"""
|
||||||
ort_outs = self.inner_session.run(
|
ort_outs = self.inner_session.run(
|
||||||
None,
|
None,
|
||||||
self.normalize(
|
self.normalize(
|
||||||
@ -33,6 +47,16 @@ class U2netSession(BaseSession):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def download_models(cls, *args, **kwargs):
|
def download_models(cls, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Downloads the U2net model file from a specific URL and saves it.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
*args: Additional positional arguments.
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The path to the downloaded model file.
|
||||||
|
"""
|
||||||
fname = f"{cls.name(*args, **kwargs)}.onnx"
|
fname = f"{cls.name(*args, **kwargs)}.onnx"
|
||||||
pooch.retrieve(
|
pooch.retrieve(
|
||||||
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx",
|
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx",
|
||||||
@ -48,4 +72,14 @@ class U2netSession(BaseSession):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def name(cls, *args, **kwargs):
|
def name(cls, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Returns the name of the U2net session.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
*args: Additional positional arguments.
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The name of the session.
|
||||||
|
"""
|
||||||
return "u2net"
|
return "u2net"
|
||||||
|
@ -57,6 +57,22 @@ palette3 = [
|
|||||||
|
|
||||||
class Unet2ClothSession(BaseSession):
|
class Unet2ClothSession(BaseSession):
|
||||||
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
||||||
|
"""
|
||||||
|
Predict the cloth category of an image.
|
||||||
|
|
||||||
|
This method takes an image as input and predicts the cloth category of the image.
|
||||||
|
The method uses the inner_session to make predictions using a pre-trained model.
|
||||||
|
The predicted mask is then converted to an image and resized to match the size of the input image.
|
||||||
|
Depending on the cloth category specified in the method arguments, the method applies different color palettes to the mask and appends the resulting images to a list.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
img (PILImage): The input image.
|
||||||
|
*args: Additional positional arguments.
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[PILImage]: A list of images representing the predicted masks.
|
||||||
|
"""
|
||||||
ort_outs = self.inner_session.run(
|
ort_outs = self.inner_session.run(
|
||||||
None,
|
None,
|
||||||
self.normalize(
|
self.normalize(
|
||||||
|
@ -11,6 +11,7 @@ from .base import BaseSession
|
|||||||
|
|
||||||
|
|
||||||
class U2netCustomSession(BaseSession):
|
class U2netCustomSession(BaseSession):
|
||||||
|
"""This is a class representing a custom session for the U2net model."""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
@ -19,6 +20,19 @@ class U2netCustomSession(BaseSession):
|
|||||||
*args,
|
*args,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Initialize a new U2netCustomSession object.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
model_name (str): The name of the model.
|
||||||
|
sess_opts (ort.SessionOptions): The session options.
|
||||||
|
providers: The providers.
|
||||||
|
*args: Additional positional arguments.
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If model_path is None.
|
||||||
|
"""
|
||||||
model_path = kwargs.get("model_path")
|
model_path = kwargs.get("model_path")
|
||||||
if model_path is None:
|
if model_path is None:
|
||||||
raise ValueError("model_path is required")
|
raise ValueError("model_path is required")
|
||||||
@ -26,6 +40,17 @@ class U2netCustomSession(BaseSession):
|
|||||||
super().__init__(model_name, sess_opts, providers, *args, **kwargs)
|
super().__init__(model_name, sess_opts, providers, *args, **kwargs)
|
||||||
|
|
||||||
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
||||||
|
"""
|
||||||
|
Predict the segmentation mask for the input image.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
img (PILImage): The input image.
|
||||||
|
*args: Additional positional arguments.
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[PILImage]: A list of PILImage objects representing the segmentation mask.
|
||||||
|
"""
|
||||||
ort_outs = self.inner_session.run(
|
ort_outs = self.inner_session.run(
|
||||||
None,
|
None,
|
||||||
self.normalize(
|
self.normalize(
|
||||||
@ -48,6 +73,16 @@ class U2netCustomSession(BaseSession):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def download_models(cls, *args, **kwargs):
|
def download_models(cls, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Download the model files.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
*args: Additional positional arguments.
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The absolute path to the model files.
|
||||||
|
"""
|
||||||
model_path = kwargs.get("model_path")
|
model_path = kwargs.get("model_path")
|
||||||
if model_path is None:
|
if model_path is None:
|
||||||
return
|
return
|
||||||
@ -56,4 +91,14 @@ class U2netCustomSession(BaseSession):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def name(cls, *args, **kwargs):
|
def name(cls, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Get the name of the model.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
*args: Additional positional arguments.
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The name of the model.
|
||||||
|
"""
|
||||||
return "u2net_custom"
|
return "u2net_custom"
|
||||||
|
@ -10,7 +10,21 @@ from .base import BaseSession
|
|||||||
|
|
||||||
|
|
||||||
class U2netHumanSegSession(BaseSession):
|
class U2netHumanSegSession(BaseSession):
|
||||||
|
"""
|
||||||
|
This class represents a session for performing human segmentation using the U2Net model.
|
||||||
|
"""
|
||||||
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
||||||
|
"""
|
||||||
|
Predicts human segmentation masks for the input image.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
img (PILImage): The input image.
|
||||||
|
*args: Variable length argument list.
|
||||||
|
**kwargs: Arbitrary keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[PILImage]: A list of predicted masks.
|
||||||
|
"""
|
||||||
ort_outs = self.inner_session.run(
|
ort_outs = self.inner_session.run(
|
||||||
None,
|
None,
|
||||||
self.normalize(
|
self.normalize(
|
||||||
@ -33,6 +47,16 @@ class U2netHumanSegSession(BaseSession):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def download_models(cls, *args, **kwargs):
|
def download_models(cls, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Downloads the U2Net model weights.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
*args: Variable length argument list.
|
||||||
|
**kwargs: Arbitrary keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The path to the downloaded model weights.
|
||||||
|
"""
|
||||||
fname = f"{cls.name(*args, **kwargs)}.onnx"
|
fname = f"{cls.name(*args, **kwargs)}.onnx"
|
||||||
pooch.retrieve(
|
pooch.retrieve(
|
||||||
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx",
|
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx",
|
||||||
@ -48,4 +72,14 @@ class U2netHumanSegSession(BaseSession):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def name(cls, *args, **kwargs):
|
def name(cls, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Returns the name of the U2Net model.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
*args: Variable length argument list.
|
||||||
|
**kwargs: Arbitrary keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The name of the model.
|
||||||
|
"""
|
||||||
return "u2net_human_seg"
|
return "u2net_human_seg"
|
||||||
|
@ -10,7 +10,17 @@ from .base import BaseSession
|
|||||||
|
|
||||||
|
|
||||||
class U2netpSession(BaseSession):
|
class U2netpSession(BaseSession):
|
||||||
|
"""This class represents a session for using the U2netp model."""
|
||||||
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
||||||
|
"""
|
||||||
|
Predicts the mask for the given image using the U2netp model.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
img (PILImage): The input image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[PILImage]: The predicted mask.
|
||||||
|
"""
|
||||||
ort_outs = self.inner_session.run(
|
ort_outs = self.inner_session.run(
|
||||||
None,
|
None,
|
||||||
self.normalize(
|
self.normalize(
|
||||||
@ -33,6 +43,12 @@ class U2netpSession(BaseSession):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def download_models(cls, *args, **kwargs):
|
def download_models(cls, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Downloads the U2netp model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The path to the downloaded model.
|
||||||
|
"""
|
||||||
fname = f"{cls.name(*args, **kwargs)}.onnx"
|
fname = f"{cls.name(*args, **kwargs)}.onnx"
|
||||||
pooch.retrieve(
|
pooch.retrieve(
|
||||||
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx",
|
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx",
|
||||||
@ -48,4 +64,10 @@ class U2netpSession(BaseSession):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def name(cls, *args, **kwargs):
|
def name(cls, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Returns the name of the U2netp model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The name of the model.
|
||||||
|
"""
|
||||||
return "u2netp"
|
return "u2netp"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user