mirror of
https://git.mirrors.martin98.com/https://github.com/danielgatis/rembg
synced 2025-08-19 23:19:10 +08:00
Add a custom u2net session (#482)
This commit is contained in:
parent
8c02c27f21
commit
c0b08f831b
@ -164,6 +164,10 @@ Passing extras parameters
|
|||||||
rembg i -m sam -x '{"input_labels": [1], "input_points": [[100,100]]}' path/to/input.png path/to/output.png
|
rembg i -m sam -x '{"input_labels": [1], "input_points": [[100,100]]}' path/to/input.png path/to/output.png
|
||||||
```
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
rembg i -m u2net_custom -x '{"model_path": "~/.u2net/u2net.onnx"}' path/to/input.png path/to/output.png
|
||||||
|
```
|
||||||
|
|
||||||
### rembg `p`
|
### rembg `p`
|
||||||
|
|
||||||
Used when input and output are folders.
|
Used when input and output are folders.
|
||||||
|
@ -107,7 +107,7 @@ def rs_command(
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
session = new_session(model)
|
session = new_session(model, **kwargs)
|
||||||
bytes_per_img = image_width * image_height * 3
|
bytes_per_img = image_width * image_height * 3
|
||||||
|
|
||||||
if output_specifier:
|
if output_specifier:
|
||||||
|
@ -90,4 +90,4 @@ def i_command(model: str, extras: str, input: IO, output: IO, **kwargs) -> None:
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
output.write(remove(input.read(), session=new_session(model), **kwargs))
|
output.write(remove(input.read(), session=new_session(model, **kwargs), **kwargs))
|
||||||
|
@ -122,7 +122,7 @@ def p_command(
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
session = new_session(model)
|
session = new_session(model, **kwargs)
|
||||||
|
|
||||||
def process(each_input: pathlib.Path) -> None:
|
def process(each_input: pathlib.Path) -> None:
|
||||||
try:
|
try:
|
||||||
|
@ -186,7 +186,9 @@ def s_command(port: int, log_level: str, threads: int) -> None:
|
|||||||
return Response(
|
return Response(
|
||||||
remove(
|
remove(
|
||||||
content,
|
content,
|
||||||
session=sessions.setdefault(commons.model, new_session(commons.model)),
|
session=sessions.setdefault(
|
||||||
|
commons.model, new_session(commons.model, **kwargs)
|
||||||
|
),
|
||||||
alpha_matting=commons.a,
|
alpha_matting=commons.a,
|
||||||
alpha_matting_foreground_threshold=commons.af,
|
alpha_matting_foreground_threshold=commons.af,
|
||||||
alpha_matting_background_threshold=commons.ab,
|
alpha_matting_background_threshold=commons.ab,
|
||||||
@ -245,12 +247,18 @@ def s_command(port: int, log_level: str, threads: int) -> None:
|
|||||||
return await asyncify(im_without_bg)(file, commons) # type: ignore
|
return await asyncify(im_without_bg)(file, commons) # type: ignore
|
||||||
|
|
||||||
def gr_app(app):
|
def gr_app(app):
|
||||||
def inference(input_path, model):
|
def inference(input_path, model, cmd_args):
|
||||||
output_path = "output.png"
|
output_path = "output.png"
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
if cmd_args:
|
||||||
|
kwargs.update(json.loads(cmd_args))
|
||||||
|
kwargs["session"] = new_session(model, **kwargs)
|
||||||
|
|
||||||
with open(input_path, "rb") as i:
|
with open(input_path, "rb") as i:
|
||||||
with open(output_path, "wb") as o:
|
with open(output_path, "wb") as o:
|
||||||
input = i.read()
|
input = i.read()
|
||||||
output = remove(input, session=new_session(model))
|
output = remove(input, **kwargs)
|
||||||
o.write(output)
|
o.write(output)
|
||||||
return os.path.join(output_path)
|
return os.path.join(output_path)
|
||||||
|
|
||||||
@ -258,19 +266,8 @@ def s_command(port: int, log_level: str, threads: int) -> None:
|
|||||||
inference,
|
inference,
|
||||||
[
|
[
|
||||||
gr.components.Image(type="filepath", label="Input"),
|
gr.components.Image(type="filepath", label="Input"),
|
||||||
gr.components.Dropdown(
|
gr.components.Dropdown(sessions_names, value="u2net", label="Models"),
|
||||||
[
|
gr.components.Textbox(label="Arguments"),
|
||||||
"u2net",
|
|
||||||
"u2netp",
|
|
||||||
"u2net_human_seg",
|
|
||||||
"u2net_cloth_seg",
|
|
||||||
"silueta",
|
|
||||||
"isnet-general-use",
|
|
||||||
"isnet-anime",
|
|
||||||
],
|
|
||||||
value="u2net",
|
|
||||||
label="Models",
|
|
||||||
),
|
|
||||||
],
|
],
|
||||||
gr.components.Image(type="filepath", label="Output"),
|
gr.components.Image(type="filepath", label="Output"),
|
||||||
)
|
)
|
||||||
|
@ -29,7 +29,7 @@ class BaseSession:
|
|||||||
self.providers.extend(_providers)
|
self.providers.extend(_providers)
|
||||||
|
|
||||||
self.inner_session = ort.InferenceSession(
|
self.inner_session = ort.InferenceSession(
|
||||||
str(self.__class__.download_models()),
|
str(self.__class__.download_models(*args, **kwargs)),
|
||||||
providers=self.providers,
|
providers=self.providers,
|
||||||
sess_options=sess_opts,
|
sess_options=sess_opts,
|
||||||
)
|
)
|
||||||
|
@ -31,7 +31,7 @@ class DisSession(BaseSession):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def download_models(cls, *args, **kwargs):
|
def download_models(cls, *args, **kwargs):
|
||||||
fname = f"{cls.name()}.onnx"
|
fname = f"{cls.name(*args, **kwargs)}.onnx"
|
||||||
pooch.retrieve(
|
pooch.retrieve(
|
||||||
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-anime.onnx",
|
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-anime.onnx",
|
||||||
None
|
None
|
||||||
@ -42,7 +42,7 @@ class DisSession(BaseSession):
|
|||||||
progressbar=True,
|
progressbar=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
return os.path.join(cls.u2net_home(), fname)
|
return os.path.join(cls.u2net_home(*args, **kwargs), fname)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def name(cls, *args, **kwargs):
|
def name(cls, *args, **kwargs):
|
||||||
|
@ -31,7 +31,7 @@ class DisSession(BaseSession):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def download_models(cls, *args, **kwargs):
|
def download_models(cls, *args, **kwargs):
|
||||||
fname = f"{cls.name()}.onnx"
|
fname = f"{cls.name(*args, **kwargs)}.onnx"
|
||||||
pooch.retrieve(
|
pooch.retrieve(
|
||||||
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx",
|
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx",
|
||||||
None
|
None
|
||||||
@ -42,7 +42,7 @@ class DisSession(BaseSession):
|
|||||||
progressbar=True,
|
progressbar=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
return os.path.join(cls.u2net_home(), fname)
|
return os.path.join(cls.u2net_home(*args, **kwargs), fname)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def name(cls, *args, **kwargs):
|
def name(cls, *args, **kwargs):
|
||||||
|
@ -136,8 +136,8 @@ class SamSession(BaseSession):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def download_models(cls, *args, **kwargs):
|
def download_models(cls, *args, **kwargs):
|
||||||
fname_encoder = f"{cls.name()}_encoder.onnx"
|
fname_encoder = f"{cls.name(*args, **kwargs)}_encoder.onnx"
|
||||||
fname_decoder = f"{cls.name()}_decoder.onnx"
|
fname_decoder = f"{cls.name(*args, **kwargs)}_decoder.onnx"
|
||||||
|
|
||||||
pooch.retrieve(
|
pooch.retrieve(
|
||||||
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-encoder-quant.onnx",
|
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-encoder-quant.onnx",
|
||||||
@ -160,8 +160,8 @@ class SamSession(BaseSession):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
os.path.join(cls.u2net_home(), fname_encoder),
|
os.path.join(cls.u2net_home(*args, **kwargs), fname_encoder),
|
||||||
os.path.join(cls.u2net_home(), fname_decoder),
|
os.path.join(cls.u2net_home(*args, **kwargs), fname_decoder),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -44,7 +44,7 @@ class SiluetaSession(BaseSession):
|
|||||||
progressbar=True,
|
progressbar=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
return os.path.join(cls.u2net_home(), fname)
|
return os.path.join(cls.u2net_home(*args, **kwargs), fname)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def name(cls, *args, **kwargs):
|
def name(cls, *args, **kwargs):
|
||||||
|
@ -33,7 +33,7 @@ class U2netSession(BaseSession):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def download_models(cls, *args, **kwargs):
|
def download_models(cls, *args, **kwargs):
|
||||||
fname = f"{cls.name()}.onnx"
|
fname = f"{cls.name(*args, **kwargs)}.onnx"
|
||||||
pooch.retrieve(
|
pooch.retrieve(
|
||||||
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx",
|
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx",
|
||||||
None
|
None
|
||||||
@ -44,7 +44,7 @@ class U2netSession(BaseSession):
|
|||||||
progressbar=True,
|
progressbar=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
return os.path.join(cls.u2net_home(), fname)
|
return os.path.join(cls.u2net_home(*args, **kwargs), fname)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def name(cls, *args, **kwargs):
|
def name(cls, *args, **kwargs):
|
||||||
|
@ -94,7 +94,7 @@ class Unet2ClothSession(BaseSession):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def download_models(cls, *args, **kwargs):
|
def download_models(cls, *args, **kwargs):
|
||||||
fname = f"{cls.name()}.onnx"
|
fname = f"{cls.name(*args, **kwargs)}.onnx"
|
||||||
pooch.retrieve(
|
pooch.retrieve(
|
||||||
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx",
|
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx",
|
||||||
None
|
None
|
||||||
@ -105,7 +105,7 @@ class Unet2ClothSession(BaseSession):
|
|||||||
progressbar=True,
|
progressbar=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
return os.path.join(cls.u2net_home(), fname)
|
return os.path.join(cls.u2net_home(*args, **kwargs), fname)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def name(cls, *args, **kwargs):
|
def name(cls, *args, **kwargs):
|
||||||
|
45
rembg/sessions/u2net_custom.py
Normal file
45
rembg/sessions/u2net_custom.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
import os
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pooch
|
||||||
|
from PIL import Image
|
||||||
|
from PIL.Image import Image as PILImage
|
||||||
|
|
||||||
|
from .base import BaseSession
|
||||||
|
|
||||||
|
|
||||||
|
class U2netCustomSession(BaseSession):
|
||||||
|
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
||||||
|
ort_outs = self.inner_session.run(
|
||||||
|
None,
|
||||||
|
self.normalize(
|
||||||
|
img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
pred = ort_outs[0][:, 0, :, :]
|
||||||
|
|
||||||
|
ma = np.max(pred)
|
||||||
|
mi = np.min(pred)
|
||||||
|
|
||||||
|
pred = (pred - mi) / (ma - mi)
|
||||||
|
pred = np.squeeze(pred)
|
||||||
|
|
||||||
|
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
|
||||||
|
mask = mask.resize(img.size, Image.LANCZOS)
|
||||||
|
|
||||||
|
return [mask]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def download_models(cls, *args, **kwargs):
|
||||||
|
model_path = kwargs.get("model_path")
|
||||||
|
|
||||||
|
if model_path is None:
|
||||||
|
raise ValueError("model_path is required")
|
||||||
|
|
||||||
|
return os.path.abspath(os.path.expanduser(model_path))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def name(cls, *args, **kwargs):
|
||||||
|
return "u2net_custom"
|
@ -33,7 +33,7 @@ class U2netHumanSegSession(BaseSession):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def download_models(cls, *args, **kwargs):
|
def download_models(cls, *args, **kwargs):
|
||||||
fname = f"{cls.name()}.onnx"
|
fname = f"{cls.name(*args, **kwargs)}.onnx"
|
||||||
pooch.retrieve(
|
pooch.retrieve(
|
||||||
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx",
|
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx",
|
||||||
None
|
None
|
||||||
@ -44,7 +44,7 @@ class U2netHumanSegSession(BaseSession):
|
|||||||
progressbar=True,
|
progressbar=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
return os.path.join(cls.u2net_home(), fname)
|
return os.path.join(cls.u2net_home(*args, **kwargs), fname)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def name(cls, *args, **kwargs):
|
def name(cls, *args, **kwargs):
|
||||||
|
@ -33,7 +33,7 @@ class U2netpSession(BaseSession):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def download_models(cls, *args, **kwargs):
|
def download_models(cls, *args, **kwargs):
|
||||||
fname = f"{cls.name()}.onnx"
|
fname = f"{cls.name(*args, **kwargs)}.onnx"
|
||||||
pooch.retrieve(
|
pooch.retrieve(
|
||||||
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx",
|
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx",
|
||||||
None
|
None
|
||||||
@ -44,7 +44,7 @@ class U2netpSession(BaseSession):
|
|||||||
progressbar=True,
|
progressbar=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
return os.path.join(cls.u2net_home(), fname)
|
return os.path.join(cls.u2net_home(*args, **kwargs), fname)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def name(cls, *args, **kwargs):
|
def name(cls, *args, **kwargs):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user