From c0b08f831b7158b66806544fd6239b880827a1d4 Mon Sep 17 00:00:00 2001 From: Daniel Gatis Date: Wed, 28 Jun 2023 20:07:13 -0300 Subject: [PATCH] Add a custom u2net session (#482) --- README.md | 4 +++ rembg/commands/b_command.py | 2 +- rembg/commands/i_command.py | 2 +- rembg/commands/p_command.py | 2 +- rembg/commands/s_command.py | 29 +++++++++----------- rembg/sessions/base.py | 2 +- rembg/sessions/dis_anime.py | 4 +-- rembg/sessions/dis_general_use.py | 4 +-- rembg/sessions/sam.py | 8 +++--- rembg/sessions/silueta.py | 2 +- rembg/sessions/u2net.py | 4 +-- rembg/sessions/u2net_cloth_seg.py | 4 +-- rembg/sessions/u2net_custom.py | 45 +++++++++++++++++++++++++++++++ rembg/sessions/u2net_human_seg.py | 4 +-- rembg/sessions/u2netp.py | 4 +-- 15 files changed, 83 insertions(+), 37 deletions(-) create mode 100644 rembg/sessions/u2net_custom.py diff --git a/README.md b/README.md index 0164b2f..293b256 100644 --- a/README.md +++ b/README.md @@ -164,6 +164,10 @@ Passing extras parameters rembg i -m sam -x '{"input_labels": [1], "input_points": [[100,100]]}' path/to/input.png path/to/output.png ``` +``` +rembg i -m u2net_custom -x '{"model_path": "~/.u2net/u2net.onnx"}' path/to/input.png path/to/output.png +``` + ### rembg `p` Used when input and output are folders. diff --git a/rembg/commands/b_command.py b/rembg/commands/b_command.py index a47118d..f6ab0ca 100644 --- a/rembg/commands/b_command.py +++ b/rembg/commands/b_command.py @@ -107,7 +107,7 @@ def rs_command( except Exception: pass - session = new_session(model) + session = new_session(model, **kwargs) bytes_per_img = image_width * image_height * 3 if output_specifier: diff --git a/rembg/commands/i_command.py b/rembg/commands/i_command.py index d65313c..1bce19e 100644 --- a/rembg/commands/i_command.py +++ b/rembg/commands/i_command.py @@ -90,4 +90,4 @@ def i_command(model: str, extras: str, input: IO, output: IO, **kwargs) -> None: except Exception: pass - output.write(remove(input.read(), session=new_session(model), **kwargs)) + output.write(remove(input.read(), session=new_session(model, **kwargs), **kwargs)) diff --git a/rembg/commands/p_command.py b/rembg/commands/p_command.py index 2163bfb..a866049 100644 --- a/rembg/commands/p_command.py +++ b/rembg/commands/p_command.py @@ -122,7 +122,7 @@ def p_command( except Exception: pass - session = new_session(model) + session = new_session(model, **kwargs) def process(each_input: pathlib.Path) -> None: try: diff --git a/rembg/commands/s_command.py b/rembg/commands/s_command.py index 4fba1ce..2330a6c 100644 --- a/rembg/commands/s_command.py +++ b/rembg/commands/s_command.py @@ -186,7 +186,9 @@ def s_command(port: int, log_level: str, threads: int) -> None: return Response( remove( content, - session=sessions.setdefault(commons.model, new_session(commons.model)), + session=sessions.setdefault( + commons.model, new_session(commons.model, **kwargs) + ), alpha_matting=commons.a, alpha_matting_foreground_threshold=commons.af, alpha_matting_background_threshold=commons.ab, @@ -245,12 +247,18 @@ def s_command(port: int, log_level: str, threads: int) -> None: return await asyncify(im_without_bg)(file, commons) # type: ignore def gr_app(app): - def inference(input_path, model): + def inference(input_path, model, cmd_args): output_path = "output.png" + + kwargs = {} + if cmd_args: + kwargs.update(json.loads(cmd_args)) + kwargs["session"] = new_session(model, **kwargs) + with open(input_path, "rb") as i: with open(output_path, "wb") as o: input = i.read() - output = remove(input, session=new_session(model)) + output = remove(input, **kwargs) o.write(output) return os.path.join(output_path) @@ -258,19 +266,8 @@ def s_command(port: int, log_level: str, threads: int) -> None: inference, [ gr.components.Image(type="filepath", label="Input"), - gr.components.Dropdown( - [ - "u2net", - "u2netp", - "u2net_human_seg", - "u2net_cloth_seg", - "silueta", - "isnet-general-use", - "isnet-anime", - ], - value="u2net", - label="Models", - ), + gr.components.Dropdown(sessions_names, value="u2net", label="Models"), + gr.components.Textbox(label="Arguments"), ], gr.components.Image(type="filepath", label="Output"), ) diff --git a/rembg/sessions/base.py b/rembg/sessions/base.py index f6bbe5a..16f988e 100644 --- a/rembg/sessions/base.py +++ b/rembg/sessions/base.py @@ -29,7 +29,7 @@ class BaseSession: self.providers.extend(_providers) self.inner_session = ort.InferenceSession( - str(self.__class__.download_models()), + str(self.__class__.download_models(*args, **kwargs)), providers=self.providers, sess_options=sess_opts, ) diff --git a/rembg/sessions/dis_anime.py b/rembg/sessions/dis_anime.py index a71618f..822051a 100644 --- a/rembg/sessions/dis_anime.py +++ b/rembg/sessions/dis_anime.py @@ -31,7 +31,7 @@ class DisSession(BaseSession): @classmethod def download_models(cls, *args, **kwargs): - fname = f"{cls.name()}.onnx" + fname = f"{cls.name(*args, **kwargs)}.onnx" pooch.retrieve( "https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-anime.onnx", None @@ -42,7 +42,7 @@ class DisSession(BaseSession): progressbar=True, ) - return os.path.join(cls.u2net_home(), fname) + return os.path.join(cls.u2net_home(*args, **kwargs), fname) @classmethod def name(cls, *args, **kwargs): diff --git a/rembg/sessions/dis_general_use.py b/rembg/sessions/dis_general_use.py index a71b34f..6a4cdae 100644 --- a/rembg/sessions/dis_general_use.py +++ b/rembg/sessions/dis_general_use.py @@ -31,7 +31,7 @@ class DisSession(BaseSession): @classmethod def download_models(cls, *args, **kwargs): - fname = f"{cls.name()}.onnx" + fname = f"{cls.name(*args, **kwargs)}.onnx" pooch.retrieve( "https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx", None @@ -42,7 +42,7 @@ class DisSession(BaseSession): progressbar=True, ) - return os.path.join(cls.u2net_home(), fname) + return os.path.join(cls.u2net_home(*args, **kwargs), fname) @classmethod def name(cls, *args, **kwargs): diff --git a/rembg/sessions/sam.py b/rembg/sessions/sam.py index 9d1ed7d..0bee3e0 100644 --- a/rembg/sessions/sam.py +++ b/rembg/sessions/sam.py @@ -136,8 +136,8 @@ class SamSession(BaseSession): @classmethod def download_models(cls, *args, **kwargs): - fname_encoder = f"{cls.name()}_encoder.onnx" - fname_decoder = f"{cls.name()}_decoder.onnx" + fname_encoder = f"{cls.name(*args, **kwargs)}_encoder.onnx" + fname_decoder = f"{cls.name(*args, **kwargs)}_decoder.onnx" pooch.retrieve( "https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-encoder-quant.onnx", @@ -160,8 +160,8 @@ class SamSession(BaseSession): ) return ( - os.path.join(cls.u2net_home(), fname_encoder), - os.path.join(cls.u2net_home(), fname_decoder), + os.path.join(cls.u2net_home(*args, **kwargs), fname_encoder), + os.path.join(cls.u2net_home(*args, **kwargs), fname_decoder), ) @classmethod diff --git a/rembg/sessions/silueta.py b/rembg/sessions/silueta.py index 50094f1..137d78f 100644 --- a/rembg/sessions/silueta.py +++ b/rembg/sessions/silueta.py @@ -44,7 +44,7 @@ class SiluetaSession(BaseSession): progressbar=True, ) - return os.path.join(cls.u2net_home(), fname) + return os.path.join(cls.u2net_home(*args, **kwargs), fname) @classmethod def name(cls, *args, **kwargs): diff --git a/rembg/sessions/u2net.py b/rembg/sessions/u2net.py index e984b18..15664f4 100644 --- a/rembg/sessions/u2net.py +++ b/rembg/sessions/u2net.py @@ -33,7 +33,7 @@ class U2netSession(BaseSession): @classmethod def download_models(cls, *args, **kwargs): - fname = f"{cls.name()}.onnx" + fname = f"{cls.name(*args, **kwargs)}.onnx" pooch.retrieve( "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx", None @@ -44,7 +44,7 @@ class U2netSession(BaseSession): progressbar=True, ) - return os.path.join(cls.u2net_home(), fname) + return os.path.join(cls.u2net_home(*args, **kwargs), fname) @classmethod def name(cls, *args, **kwargs): diff --git a/rembg/sessions/u2net_cloth_seg.py b/rembg/sessions/u2net_cloth_seg.py index 0308743..97179a8 100644 --- a/rembg/sessions/u2net_cloth_seg.py +++ b/rembg/sessions/u2net_cloth_seg.py @@ -94,7 +94,7 @@ class Unet2ClothSession(BaseSession): @classmethod def download_models(cls, *args, **kwargs): - fname = f"{cls.name()}.onnx" + fname = f"{cls.name(*args, **kwargs)}.onnx" pooch.retrieve( "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx", None @@ -105,7 +105,7 @@ class Unet2ClothSession(BaseSession): progressbar=True, ) - return os.path.join(cls.u2net_home(), fname) + return os.path.join(cls.u2net_home(*args, **kwargs), fname) @classmethod def name(cls, *args, **kwargs): diff --git a/rembg/sessions/u2net_custom.py b/rembg/sessions/u2net_custom.py new file mode 100644 index 0000000..09894b2 --- /dev/null +++ b/rembg/sessions/u2net_custom.py @@ -0,0 +1,45 @@ +import os +from typing import List + +import numpy as np +import pooch +from PIL import Image +from PIL.Image import Image as PILImage + +from .base import BaseSession + + +class U2netCustomSession(BaseSession): + def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]: + ort_outs = self.inner_session.run( + None, + self.normalize( + img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320) + ), + ) + + pred = ort_outs[0][:, 0, :, :] + + ma = np.max(pred) + mi = np.min(pred) + + pred = (pred - mi) / (ma - mi) + pred = np.squeeze(pred) + + mask = Image.fromarray((pred * 255).astype("uint8"), mode="L") + mask = mask.resize(img.size, Image.LANCZOS) + + return [mask] + + @classmethod + def download_models(cls, *args, **kwargs): + model_path = kwargs.get("model_path") + + if model_path is None: + raise ValueError("model_path is required") + + return os.path.abspath(os.path.expanduser(model_path)) + + @classmethod + def name(cls, *args, **kwargs): + return "u2net_custom" diff --git a/rembg/sessions/u2net_human_seg.py b/rembg/sessions/u2net_human_seg.py index 166c195..2c6c804 100644 --- a/rembg/sessions/u2net_human_seg.py +++ b/rembg/sessions/u2net_human_seg.py @@ -33,7 +33,7 @@ class U2netHumanSegSession(BaseSession): @classmethod def download_models(cls, *args, **kwargs): - fname = f"{cls.name()}.onnx" + fname = f"{cls.name(*args, **kwargs)}.onnx" pooch.retrieve( "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx", None @@ -44,7 +44,7 @@ class U2netHumanSegSession(BaseSession): progressbar=True, ) - return os.path.join(cls.u2net_home(), fname) + return os.path.join(cls.u2net_home(*args, **kwargs), fname) @classmethod def name(cls, *args, **kwargs): diff --git a/rembg/sessions/u2netp.py b/rembg/sessions/u2netp.py index b28fb6a..e34420b 100644 --- a/rembg/sessions/u2netp.py +++ b/rembg/sessions/u2netp.py @@ -33,7 +33,7 @@ class U2netpSession(BaseSession): @classmethod def download_models(cls, *args, **kwargs): - fname = f"{cls.name()}.onnx" + fname = f"{cls.name(*args, **kwargs)}.onnx" pooch.retrieve( "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx", None @@ -44,7 +44,7 @@ class U2netpSession(BaseSession): progressbar=True, ) - return os.path.join(cls.u2net_home(), fname) + return os.path.join(cls.u2net_home(*args, **kwargs), fname) @classmethod def name(cls, *args, **kwargs):