Add a custom u2net session (#482)

This commit is contained in:
Daniel Gatis 2023-06-28 20:07:13 -03:00 committed by GitHub
parent 8c02c27f21
commit c0b08f831b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 83 additions and 37 deletions

View File

@ -164,6 +164,10 @@ Passing extras parameters
rembg i -m sam -x '{"input_labels": [1], "input_points": [[100,100]]}' path/to/input.png path/to/output.png rembg i -m sam -x '{"input_labels": [1], "input_points": [[100,100]]}' path/to/input.png path/to/output.png
``` ```
```
rembg i -m u2net_custom -x '{"model_path": "~/.u2net/u2net.onnx"}' path/to/input.png path/to/output.png
```
### rembg `p` ### rembg `p`
Used when input and output are folders. Used when input and output are folders.

View File

@ -107,7 +107,7 @@ def rs_command(
except Exception: except Exception:
pass pass
session = new_session(model) session = new_session(model, **kwargs)
bytes_per_img = image_width * image_height * 3 bytes_per_img = image_width * image_height * 3
if output_specifier: if output_specifier:

View File

@ -90,4 +90,4 @@ def i_command(model: str, extras: str, input: IO, output: IO, **kwargs) -> None:
except Exception: except Exception:
pass pass
output.write(remove(input.read(), session=new_session(model), **kwargs)) output.write(remove(input.read(), session=new_session(model, **kwargs), **kwargs))

View File

@ -122,7 +122,7 @@ def p_command(
except Exception: except Exception:
pass pass
session = new_session(model) session = new_session(model, **kwargs)
def process(each_input: pathlib.Path) -> None: def process(each_input: pathlib.Path) -> None:
try: try:

View File

@ -186,7 +186,9 @@ def s_command(port: int, log_level: str, threads: int) -> None:
return Response( return Response(
remove( remove(
content, content,
session=sessions.setdefault(commons.model, new_session(commons.model)), session=sessions.setdefault(
commons.model, new_session(commons.model, **kwargs)
),
alpha_matting=commons.a, alpha_matting=commons.a,
alpha_matting_foreground_threshold=commons.af, alpha_matting_foreground_threshold=commons.af,
alpha_matting_background_threshold=commons.ab, alpha_matting_background_threshold=commons.ab,
@ -245,12 +247,18 @@ def s_command(port: int, log_level: str, threads: int) -> None:
return await asyncify(im_without_bg)(file, commons) # type: ignore return await asyncify(im_without_bg)(file, commons) # type: ignore
def gr_app(app): def gr_app(app):
def inference(input_path, model): def inference(input_path, model, cmd_args):
output_path = "output.png" output_path = "output.png"
kwargs = {}
if cmd_args:
kwargs.update(json.loads(cmd_args))
kwargs["session"] = new_session(model, **kwargs)
with open(input_path, "rb") as i: with open(input_path, "rb") as i:
with open(output_path, "wb") as o: with open(output_path, "wb") as o:
input = i.read() input = i.read()
output = remove(input, session=new_session(model)) output = remove(input, **kwargs)
o.write(output) o.write(output)
return os.path.join(output_path) return os.path.join(output_path)
@ -258,19 +266,8 @@ def s_command(port: int, log_level: str, threads: int) -> None:
inference, inference,
[ [
gr.components.Image(type="filepath", label="Input"), gr.components.Image(type="filepath", label="Input"),
gr.components.Dropdown( gr.components.Dropdown(sessions_names, value="u2net", label="Models"),
[ gr.components.Textbox(label="Arguments"),
"u2net",
"u2netp",
"u2net_human_seg",
"u2net_cloth_seg",
"silueta",
"isnet-general-use",
"isnet-anime",
],
value="u2net",
label="Models",
),
], ],
gr.components.Image(type="filepath", label="Output"), gr.components.Image(type="filepath", label="Output"),
) )

View File

@ -29,7 +29,7 @@ class BaseSession:
self.providers.extend(_providers) self.providers.extend(_providers)
self.inner_session = ort.InferenceSession( self.inner_session = ort.InferenceSession(
str(self.__class__.download_models()), str(self.__class__.download_models(*args, **kwargs)),
providers=self.providers, providers=self.providers,
sess_options=sess_opts, sess_options=sess_opts,
) )

View File

@ -31,7 +31,7 @@ class DisSession(BaseSession):
@classmethod @classmethod
def download_models(cls, *args, **kwargs): def download_models(cls, *args, **kwargs):
fname = f"{cls.name()}.onnx" fname = f"{cls.name(*args, **kwargs)}.onnx"
pooch.retrieve( pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-anime.onnx", "https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-anime.onnx",
None None
@ -42,7 +42,7 @@ class DisSession(BaseSession):
progressbar=True, progressbar=True,
) )
return os.path.join(cls.u2net_home(), fname) return os.path.join(cls.u2net_home(*args, **kwargs), fname)
@classmethod @classmethod
def name(cls, *args, **kwargs): def name(cls, *args, **kwargs):

View File

@ -31,7 +31,7 @@ class DisSession(BaseSession):
@classmethod @classmethod
def download_models(cls, *args, **kwargs): def download_models(cls, *args, **kwargs):
fname = f"{cls.name()}.onnx" fname = f"{cls.name(*args, **kwargs)}.onnx"
pooch.retrieve( pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx", "https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx",
None None
@ -42,7 +42,7 @@ class DisSession(BaseSession):
progressbar=True, progressbar=True,
) )
return os.path.join(cls.u2net_home(), fname) return os.path.join(cls.u2net_home(*args, **kwargs), fname)
@classmethod @classmethod
def name(cls, *args, **kwargs): def name(cls, *args, **kwargs):

View File

@ -136,8 +136,8 @@ class SamSession(BaseSession):
@classmethod @classmethod
def download_models(cls, *args, **kwargs): def download_models(cls, *args, **kwargs):
fname_encoder = f"{cls.name()}_encoder.onnx" fname_encoder = f"{cls.name(*args, **kwargs)}_encoder.onnx"
fname_decoder = f"{cls.name()}_decoder.onnx" fname_decoder = f"{cls.name(*args, **kwargs)}_decoder.onnx"
pooch.retrieve( pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-encoder-quant.onnx", "https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-encoder-quant.onnx",
@ -160,8 +160,8 @@ class SamSession(BaseSession):
) )
return ( return (
os.path.join(cls.u2net_home(), fname_encoder), os.path.join(cls.u2net_home(*args, **kwargs), fname_encoder),
os.path.join(cls.u2net_home(), fname_decoder), os.path.join(cls.u2net_home(*args, **kwargs), fname_decoder),
) )
@classmethod @classmethod

View File

@ -44,7 +44,7 @@ class SiluetaSession(BaseSession):
progressbar=True, progressbar=True,
) )
return os.path.join(cls.u2net_home(), fname) return os.path.join(cls.u2net_home(*args, **kwargs), fname)
@classmethod @classmethod
def name(cls, *args, **kwargs): def name(cls, *args, **kwargs):

View File

@ -33,7 +33,7 @@ class U2netSession(BaseSession):
@classmethod @classmethod
def download_models(cls, *args, **kwargs): def download_models(cls, *args, **kwargs):
fname = f"{cls.name()}.onnx" fname = f"{cls.name(*args, **kwargs)}.onnx"
pooch.retrieve( pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx", "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx",
None None
@ -44,7 +44,7 @@ class U2netSession(BaseSession):
progressbar=True, progressbar=True,
) )
return os.path.join(cls.u2net_home(), fname) return os.path.join(cls.u2net_home(*args, **kwargs), fname)
@classmethod @classmethod
def name(cls, *args, **kwargs): def name(cls, *args, **kwargs):

View File

@ -94,7 +94,7 @@ class Unet2ClothSession(BaseSession):
@classmethod @classmethod
def download_models(cls, *args, **kwargs): def download_models(cls, *args, **kwargs):
fname = f"{cls.name()}.onnx" fname = f"{cls.name(*args, **kwargs)}.onnx"
pooch.retrieve( pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx", "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx",
None None
@ -105,7 +105,7 @@ class Unet2ClothSession(BaseSession):
progressbar=True, progressbar=True,
) )
return os.path.join(cls.u2net_home(), fname) return os.path.join(cls.u2net_home(*args, **kwargs), fname)
@classmethod @classmethod
def name(cls, *args, **kwargs): def name(cls, *args, **kwargs):

View File

@ -0,0 +1,45 @@
import os
from typing import List
import numpy as np
import pooch
from PIL import Image
from PIL.Image import Image as PILImage
from .base import BaseSession
class U2netCustomSession(BaseSession):
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
ort_outs = self.inner_session.run(
None,
self.normalize(
img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)
),
)
pred = ort_outs[0][:, 0, :, :]
ma = np.max(pred)
mi = np.min(pred)
pred = (pred - mi) / (ma - mi)
pred = np.squeeze(pred)
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
mask = mask.resize(img.size, Image.LANCZOS)
return [mask]
@classmethod
def download_models(cls, *args, **kwargs):
model_path = kwargs.get("model_path")
if model_path is None:
raise ValueError("model_path is required")
return os.path.abspath(os.path.expanduser(model_path))
@classmethod
def name(cls, *args, **kwargs):
return "u2net_custom"

View File

@ -33,7 +33,7 @@ class U2netHumanSegSession(BaseSession):
@classmethod @classmethod
def download_models(cls, *args, **kwargs): def download_models(cls, *args, **kwargs):
fname = f"{cls.name()}.onnx" fname = f"{cls.name(*args, **kwargs)}.onnx"
pooch.retrieve( pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx", "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx",
None None
@ -44,7 +44,7 @@ class U2netHumanSegSession(BaseSession):
progressbar=True, progressbar=True,
) )
return os.path.join(cls.u2net_home(), fname) return os.path.join(cls.u2net_home(*args, **kwargs), fname)
@classmethod @classmethod
def name(cls, *args, **kwargs): def name(cls, *args, **kwargs):

View File

@ -33,7 +33,7 @@ class U2netpSession(BaseSession):
@classmethod @classmethod
def download_models(cls, *args, **kwargs): def download_models(cls, *args, **kwargs):
fname = f"{cls.name()}.onnx" fname = f"{cls.name(*args, **kwargs)}.onnx"
pooch.retrieve( pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx", "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx",
None None
@ -44,7 +44,7 @@ class U2netpSession(BaseSession):
progressbar=True, progressbar=True,
) )
return os.path.join(cls.u2net_home(), fname) return os.path.join(cls.u2net_home(*args, **kwargs), fname)
@classmethod @classmethod
def name(cls, *args, **kwargs): def name(cls, *args, **kwargs):