mirror of
https://git.mirrors.martin98.com/https://github.com/danielgatis/rembg
synced 2025-08-19 23:19:10 +08:00
Add a custom u2net session (#482)
This commit is contained in:
parent
8c02c27f21
commit
c0b08f831b
@ -164,6 +164,10 @@ Passing extras parameters
|
||||
rembg i -m sam -x '{"input_labels": [1], "input_points": [[100,100]]}' path/to/input.png path/to/output.png
|
||||
```
|
||||
|
||||
```
|
||||
rembg i -m u2net_custom -x '{"model_path": "~/.u2net/u2net.onnx"}' path/to/input.png path/to/output.png
|
||||
```
|
||||
|
||||
### rembg `p`
|
||||
|
||||
Used when input and output are folders.
|
||||
|
@ -107,7 +107,7 @@ def rs_command(
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
session = new_session(model)
|
||||
session = new_session(model, **kwargs)
|
||||
bytes_per_img = image_width * image_height * 3
|
||||
|
||||
if output_specifier:
|
||||
|
@ -90,4 +90,4 @@ def i_command(model: str, extras: str, input: IO, output: IO, **kwargs) -> None:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
output.write(remove(input.read(), session=new_session(model), **kwargs))
|
||||
output.write(remove(input.read(), session=new_session(model, **kwargs), **kwargs))
|
||||
|
@ -122,7 +122,7 @@ def p_command(
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
session = new_session(model)
|
||||
session = new_session(model, **kwargs)
|
||||
|
||||
def process(each_input: pathlib.Path) -> None:
|
||||
try:
|
||||
|
@ -186,7 +186,9 @@ def s_command(port: int, log_level: str, threads: int) -> None:
|
||||
return Response(
|
||||
remove(
|
||||
content,
|
||||
session=sessions.setdefault(commons.model, new_session(commons.model)),
|
||||
session=sessions.setdefault(
|
||||
commons.model, new_session(commons.model, **kwargs)
|
||||
),
|
||||
alpha_matting=commons.a,
|
||||
alpha_matting_foreground_threshold=commons.af,
|
||||
alpha_matting_background_threshold=commons.ab,
|
||||
@ -245,12 +247,18 @@ def s_command(port: int, log_level: str, threads: int) -> None:
|
||||
return await asyncify(im_without_bg)(file, commons) # type: ignore
|
||||
|
||||
def gr_app(app):
|
||||
def inference(input_path, model):
|
||||
def inference(input_path, model, cmd_args):
|
||||
output_path = "output.png"
|
||||
|
||||
kwargs = {}
|
||||
if cmd_args:
|
||||
kwargs.update(json.loads(cmd_args))
|
||||
kwargs["session"] = new_session(model, **kwargs)
|
||||
|
||||
with open(input_path, "rb") as i:
|
||||
with open(output_path, "wb") as o:
|
||||
input = i.read()
|
||||
output = remove(input, session=new_session(model))
|
||||
output = remove(input, **kwargs)
|
||||
o.write(output)
|
||||
return os.path.join(output_path)
|
||||
|
||||
@ -258,19 +266,8 @@ def s_command(port: int, log_level: str, threads: int) -> None:
|
||||
inference,
|
||||
[
|
||||
gr.components.Image(type="filepath", label="Input"),
|
||||
gr.components.Dropdown(
|
||||
[
|
||||
"u2net",
|
||||
"u2netp",
|
||||
"u2net_human_seg",
|
||||
"u2net_cloth_seg",
|
||||
"silueta",
|
||||
"isnet-general-use",
|
||||
"isnet-anime",
|
||||
],
|
||||
value="u2net",
|
||||
label="Models",
|
||||
),
|
||||
gr.components.Dropdown(sessions_names, value="u2net", label="Models"),
|
||||
gr.components.Textbox(label="Arguments"),
|
||||
],
|
||||
gr.components.Image(type="filepath", label="Output"),
|
||||
)
|
||||
|
@ -29,7 +29,7 @@ class BaseSession:
|
||||
self.providers.extend(_providers)
|
||||
|
||||
self.inner_session = ort.InferenceSession(
|
||||
str(self.__class__.download_models()),
|
||||
str(self.__class__.download_models(*args, **kwargs)),
|
||||
providers=self.providers,
|
||||
sess_options=sess_opts,
|
||||
)
|
||||
|
@ -31,7 +31,7 @@ class DisSession(BaseSession):
|
||||
|
||||
@classmethod
|
||||
def download_models(cls, *args, **kwargs):
|
||||
fname = f"{cls.name()}.onnx"
|
||||
fname = f"{cls.name(*args, **kwargs)}.onnx"
|
||||
pooch.retrieve(
|
||||
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-anime.onnx",
|
||||
None
|
||||
@ -42,7 +42,7 @@ class DisSession(BaseSession):
|
||||
progressbar=True,
|
||||
)
|
||||
|
||||
return os.path.join(cls.u2net_home(), fname)
|
||||
return os.path.join(cls.u2net_home(*args, **kwargs), fname)
|
||||
|
||||
@classmethod
|
||||
def name(cls, *args, **kwargs):
|
||||
|
@ -31,7 +31,7 @@ class DisSession(BaseSession):
|
||||
|
||||
@classmethod
|
||||
def download_models(cls, *args, **kwargs):
|
||||
fname = f"{cls.name()}.onnx"
|
||||
fname = f"{cls.name(*args, **kwargs)}.onnx"
|
||||
pooch.retrieve(
|
||||
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx",
|
||||
None
|
||||
@ -42,7 +42,7 @@ class DisSession(BaseSession):
|
||||
progressbar=True,
|
||||
)
|
||||
|
||||
return os.path.join(cls.u2net_home(), fname)
|
||||
return os.path.join(cls.u2net_home(*args, **kwargs), fname)
|
||||
|
||||
@classmethod
|
||||
def name(cls, *args, **kwargs):
|
||||
|
@ -136,8 +136,8 @@ class SamSession(BaseSession):
|
||||
|
||||
@classmethod
|
||||
def download_models(cls, *args, **kwargs):
|
||||
fname_encoder = f"{cls.name()}_encoder.onnx"
|
||||
fname_decoder = f"{cls.name()}_decoder.onnx"
|
||||
fname_encoder = f"{cls.name(*args, **kwargs)}_encoder.onnx"
|
||||
fname_decoder = f"{cls.name(*args, **kwargs)}_decoder.onnx"
|
||||
|
||||
pooch.retrieve(
|
||||
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-encoder-quant.onnx",
|
||||
@ -160,8 +160,8 @@ class SamSession(BaseSession):
|
||||
)
|
||||
|
||||
return (
|
||||
os.path.join(cls.u2net_home(), fname_encoder),
|
||||
os.path.join(cls.u2net_home(), fname_decoder),
|
||||
os.path.join(cls.u2net_home(*args, **kwargs), fname_encoder),
|
||||
os.path.join(cls.u2net_home(*args, **kwargs), fname_decoder),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -44,7 +44,7 @@ class SiluetaSession(BaseSession):
|
||||
progressbar=True,
|
||||
)
|
||||
|
||||
return os.path.join(cls.u2net_home(), fname)
|
||||
return os.path.join(cls.u2net_home(*args, **kwargs), fname)
|
||||
|
||||
@classmethod
|
||||
def name(cls, *args, **kwargs):
|
||||
|
@ -33,7 +33,7 @@ class U2netSession(BaseSession):
|
||||
|
||||
@classmethod
|
||||
def download_models(cls, *args, **kwargs):
|
||||
fname = f"{cls.name()}.onnx"
|
||||
fname = f"{cls.name(*args, **kwargs)}.onnx"
|
||||
pooch.retrieve(
|
||||
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx",
|
||||
None
|
||||
@ -44,7 +44,7 @@ class U2netSession(BaseSession):
|
||||
progressbar=True,
|
||||
)
|
||||
|
||||
return os.path.join(cls.u2net_home(), fname)
|
||||
return os.path.join(cls.u2net_home(*args, **kwargs), fname)
|
||||
|
||||
@classmethod
|
||||
def name(cls, *args, **kwargs):
|
||||
|
@ -94,7 +94,7 @@ class Unet2ClothSession(BaseSession):
|
||||
|
||||
@classmethod
|
||||
def download_models(cls, *args, **kwargs):
|
||||
fname = f"{cls.name()}.onnx"
|
||||
fname = f"{cls.name(*args, **kwargs)}.onnx"
|
||||
pooch.retrieve(
|
||||
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx",
|
||||
None
|
||||
@ -105,7 +105,7 @@ class Unet2ClothSession(BaseSession):
|
||||
progressbar=True,
|
||||
)
|
||||
|
||||
return os.path.join(cls.u2net_home(), fname)
|
||||
return os.path.join(cls.u2net_home(*args, **kwargs), fname)
|
||||
|
||||
@classmethod
|
||||
def name(cls, *args, **kwargs):
|
||||
|
45
rembg/sessions/u2net_custom.py
Normal file
45
rembg/sessions/u2net_custom.py
Normal file
@ -0,0 +1,45 @@
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import pooch
|
||||
from PIL import Image
|
||||
from PIL.Image import Image as PILImage
|
||||
|
||||
from .base import BaseSession
|
||||
|
||||
|
||||
class U2netCustomSession(BaseSession):
|
||||
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
||||
ort_outs = self.inner_session.run(
|
||||
None,
|
||||
self.normalize(
|
||||
img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)
|
||||
),
|
||||
)
|
||||
|
||||
pred = ort_outs[0][:, 0, :, :]
|
||||
|
||||
ma = np.max(pred)
|
||||
mi = np.min(pred)
|
||||
|
||||
pred = (pred - mi) / (ma - mi)
|
||||
pred = np.squeeze(pred)
|
||||
|
||||
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
|
||||
mask = mask.resize(img.size, Image.LANCZOS)
|
||||
|
||||
return [mask]
|
||||
|
||||
@classmethod
|
||||
def download_models(cls, *args, **kwargs):
|
||||
model_path = kwargs.get("model_path")
|
||||
|
||||
if model_path is None:
|
||||
raise ValueError("model_path is required")
|
||||
|
||||
return os.path.abspath(os.path.expanduser(model_path))
|
||||
|
||||
@classmethod
|
||||
def name(cls, *args, **kwargs):
|
||||
return "u2net_custom"
|
@ -33,7 +33,7 @@ class U2netHumanSegSession(BaseSession):
|
||||
|
||||
@classmethod
|
||||
def download_models(cls, *args, **kwargs):
|
||||
fname = f"{cls.name()}.onnx"
|
||||
fname = f"{cls.name(*args, **kwargs)}.onnx"
|
||||
pooch.retrieve(
|
||||
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx",
|
||||
None
|
||||
@ -44,7 +44,7 @@ class U2netHumanSegSession(BaseSession):
|
||||
progressbar=True,
|
||||
)
|
||||
|
||||
return os.path.join(cls.u2net_home(), fname)
|
||||
return os.path.join(cls.u2net_home(*args, **kwargs), fname)
|
||||
|
||||
@classmethod
|
||||
def name(cls, *args, **kwargs):
|
||||
|
@ -33,7 +33,7 @@ class U2netpSession(BaseSession):
|
||||
|
||||
@classmethod
|
||||
def download_models(cls, *args, **kwargs):
|
||||
fname = f"{cls.name()}.onnx"
|
||||
fname = f"{cls.name(*args, **kwargs)}.onnx"
|
||||
pooch.retrieve(
|
||||
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx",
|
||||
None
|
||||
@ -44,7 +44,7 @@ class U2netpSession(BaseSession):
|
||||
progressbar=True,
|
||||
)
|
||||
|
||||
return os.path.join(cls.u2net_home(), fname)
|
||||
return os.path.join(cls.u2net_home(*args, **kwargs), fname)
|
||||
|
||||
@classmethod
|
||||
def name(cls, *args, **kwargs):
|
||||
|
Loading…
x
Reference in New Issue
Block a user