add dis model
33
README.md
@ -265,6 +265,39 @@ The available models are:
|
||||
- u2net_human_seg ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx), [source](https://github.com/xuebinqin/U-2-Net)): A pre-trained model for human segmentation.
|
||||
- u2net_cloth_seg ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx), [source](https://github.com/levindabhi/cloth-segmentation)): A pre-trained model for Cloths Parsing from human portrait. Here clothes are parsed into 3 category: Upper body, Lower body and Full body.
|
||||
- silueta ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx), [source](https://github.com/xuebinqin/U-2-Net/issues/295)): Same as u2net but the size is reduced to 43Mb.
|
||||
- isnet-general-use ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx), [source](https://github.com/xuebinqin/U-2-Net/issues/295)): https://github.com/xuebinqin/DIS.
|
||||
|
||||
### Some differences between the models result
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<th>original</th>
|
||||
<th>unet</th>
|
||||
<th>unetp</th>
|
||||
<th>u2net_human_seg</th>
|
||||
<th>u2net_cloth_seg</th>
|
||||
<th>silueta</th>
|
||||
<th>isnet-general-use</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/fixtures/car-1.jpg" width="100" /></th>
|
||||
<th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/car-1.unet.jpg" width="100" /></th>
|
||||
<th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/car-1.unetp.jpg" width="100" /></th>
|
||||
<th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/car-1.u2net_human_seg.jpg" width="100" /></th>
|
||||
<th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/car-1.u2net_cloth_seg.jpg" width="100" /></th>
|
||||
<th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/car-1.silueta.jpg" width="100" /></th>
|
||||
<th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/car-1.isnet-general-use.jpg" width="100" /></th>
|
||||
</tr>
|
||||
<th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/fixtures/cloth-1.jpg" width="100" /></th>
|
||||
<th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/cloth-1.unet.jpg" width="100" /></th>
|
||||
<th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/cloth-1.unetp.jpg" width="100" /></th>
|
||||
<th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/cloth-1.u2net_human_seg.jpg" width="100" /></th>
|
||||
<th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/cloth-1.u2net_cloth_seg.jpg" width="100" /></th>
|
||||
<th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/cloth-1.silueta.jpg" width="100" /></th>
|
||||
<th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/cloth-1.isnet-general-use.jpg" width="100" /></th>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
|
||||
### How to train your own model
|
||||
|
||||
|
@ -34,7 +34,7 @@ def main() -> None:
|
||||
"--model",
|
||||
default="u2net",
|
||||
type=click.Choice(
|
||||
["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg", "silueta"]
|
||||
["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg", "silueta", "isnet-general-use"]
|
||||
),
|
||||
show_default=True,
|
||||
show_choices=True,
|
||||
@ -103,7 +103,7 @@ def i(model: str, input: IO, output: IO, **kwargs) -> None:
|
||||
"--model",
|
||||
default="u2net",
|
||||
type=click.Choice(
|
||||
["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg", "silueta"]
|
||||
["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg", "silueta", "isnet-general-use"]
|
||||
),
|
||||
show_default=True,
|
||||
show_choices=True,
|
||||
@ -311,6 +311,7 @@ def s(port: int, log_level: str, threads: int) -> None:
|
||||
u2net_human_seg = "u2net_human_seg"
|
||||
u2net_cloth_seg = "u2net_cloth_seg"
|
||||
silueta = "silueta"
|
||||
isnet_general_use = "isnet-general-use"
|
||||
|
||||
class CommonQueryParams:
|
||||
def __init__(
|
||||
|
@ -7,18 +7,18 @@ from PIL.Image import Image as PILImage
|
||||
|
||||
|
||||
class BaseSession:
|
||||
def __init__(self, model_name: str, inner_session: ort.InferenceSession, output_size: Tuple[int, int] = (320, 320)):
|
||||
def __init__(self, model_name: str, inner_session: ort.InferenceSession):
|
||||
self.model_name = model_name
|
||||
self.inner_session = inner_session
|
||||
self.output_size = output_size
|
||||
|
||||
def normalize(
|
||||
self,
|
||||
img: PILImage,
|
||||
mean: Tuple[float, float, float],
|
||||
std: Tuple[float, float, float],
|
||||
size: Tuple[int, int],
|
||||
) -> Dict[str, np.ndarray]:
|
||||
im = img.convert("RGB").resize(self.output_size, Image.LANCZOS)
|
||||
im = img.convert("RGB").resize(size, Image.LANCZOS)
|
||||
|
||||
im_ary = np.array(im)
|
||||
im_ary = im_ary / np.max(im_ary)
|
||||
|
@ -56,7 +56,7 @@ pallete3 = [
|
||||
class ClothSession(BaseSession):
|
||||
def predict(self, img: PILImage) -> List[PILImage]:
|
||||
ort_outs = self.inner_session.run(
|
||||
None, self.normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), (768, 768))
|
||||
None, self.normalize(img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (768, 768))
|
||||
)
|
||||
|
||||
pred = ort_outs
|
||||
|
30
rembg/session_dis.py
Normal file
@ -0,0 +1,30 @@
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from PIL.Image import Image as PILImage
|
||||
|
||||
from .session_base import BaseSession
|
||||
|
||||
|
||||
class DisSession(BaseSession):
|
||||
def predict(self, img: PILImage) -> List[PILImage]:
|
||||
ort_outs = self.inner_session.run(
|
||||
None,
|
||||
self.normalize(
|
||||
img, (0.485, 0.456, 0.406), (1., 1., 1.), (1024, 1024)
|
||||
),
|
||||
)
|
||||
|
||||
pred = ort_outs[0][:, 0, :, :]
|
||||
|
||||
ma = np.max(pred)
|
||||
mi = np.min(pred)
|
||||
|
||||
pred = (pred - mi) / (ma - mi)
|
||||
pred = np.squeeze(pred)
|
||||
|
||||
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
|
||||
mask = mask.resize(img.size, Image.LANCZOS)
|
||||
|
||||
return [mask]
|
@ -11,12 +11,10 @@ import pooch
|
||||
from .session_base import BaseSession
|
||||
from .session_cloth import ClothSession
|
||||
from .session_simple import SimpleSession
|
||||
from .session_dis import DisSession
|
||||
|
||||
|
||||
def new_session(model_name: str = "u2net", output_size=None) -> BaseSession:
|
||||
# Set output size if not set ( because isnet hat a different size )
|
||||
output_size = output_size or (320, 320)
|
||||
|
||||
def new_session(model_name: str = "u2net") -> BaseSession:
|
||||
session_class: Type[BaseSession]
|
||||
md5 = "60024c5c889badc19c04ad937298a77b"
|
||||
url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx"
|
||||
@ -44,8 +42,8 @@ def new_session(model_name: str = "u2net", output_size=None) -> BaseSession:
|
||||
session_class = SimpleSession
|
||||
elif model_name == "isnet-general-use":
|
||||
md5 = "fc16ebd8b0c10d971d3513d564d01e29"
|
||||
url = "https://github.com/Flippchen/rembg/releases/download/test/isnet-general-use.onnx"
|
||||
session_class = SimpleSession
|
||||
url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx"
|
||||
session_class = DisSession
|
||||
|
||||
u2net_home = os.getenv(
|
||||
"U2NET_HOME", os.path.join(os.getenv("XDG_DATA_HOME", "~"), ".u2net")
|
||||
@ -74,6 +72,5 @@ def new_session(model_name: str = "u2net", output_size=None) -> BaseSession:
|
||||
str(full_path),
|
||||
providers=ort.get_available_providers(),
|
||||
sess_options=sess_opts,
|
||||
),
|
||||
output_size=output_size
|
||||
)
|
||||
)
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import List, Tuple
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
@ -9,17 +9,10 @@ from .session_base import BaseSession
|
||||
|
||||
class SimpleSession(BaseSession):
|
||||
def predict(self, img: PILImage) -> List[PILImage]:
|
||||
if self.model_name == "isnet-general-use":
|
||||
mean = (0.5, 0.5, 0.5)
|
||||
std = (1., 1., 1.)
|
||||
else:
|
||||
mean = (0.485, 0.456, 0.406)
|
||||
std = (0.229, 0.224, 0.225)
|
||||
|
||||
ort_outs = self.inner_session.run(
|
||||
None,
|
||||
self.normalize(
|
||||
img, mean, std
|
||||
img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)
|
||||
),
|
||||
)
|
||||
|
||||
|
BIN
tests/fixtures/car-1.jpg
vendored
Normal file
After Width: | Height: | Size: 59 KiB |
BIN
tests/fixtures/cloth-1.jpg
vendored
Normal file
After Width: | Height: | Size: 211 KiB |
BIN
tests/results/car-1.isnet-general-use.png
Normal file
After Width: | Height: | Size: 202 KiB |
BIN
tests/results/car-1.silueta.png
Normal file
After Width: | Height: | Size: 159 KiB |
BIN
tests/results/car-1.u2net.png
Normal file
After Width: | Height: | Size: 237 KiB |
BIN
tests/results/car-1.u2net_cloth_seg.png
Normal file
After Width: | Height: | Size: 6.2 KiB |
BIN
tests/results/car-1.u2net_human_seg.png
Normal file
After Width: | Height: | Size: 68 KiB |
BIN
tests/results/car-1.u2netp.png
Normal file
After Width: | Height: | Size: 249 KiB |
BIN
tests/results/cloth-1.isnet-general-use.png
Normal file
After Width: | Height: | Size: 788 KiB |
BIN
tests/results/cloth-1.silueta.png
Normal file
After Width: | Height: | Size: 797 KiB |
BIN
tests/results/cloth-1.u2net.png
Normal file
After Width: | Height: | Size: 1.2 MiB |
BIN
tests/results/cloth-1.u2net_cloth_seg.png
Normal file
After Width: | Height: | Size: 540 KiB |
BIN
tests/results/cloth-1.u2net_human_seg.png
Normal file
After Width: | Height: | Size: 1.2 MiB |
BIN
tests/results/cloth-1.u2netp.png
Normal file
After Width: | Height: | Size: 1.2 MiB |
@ -1,20 +1,38 @@
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
from imagehash import average_hash
|
||||
from imagehash import phash as hash_img
|
||||
from PIL import Image
|
||||
|
||||
from rembg import remove
|
||||
from rembg import new_session
|
||||
|
||||
here = Path(__file__).parent.resolve()
|
||||
|
||||
|
||||
def test_remove():
|
||||
image = Path(here / ".." / "examples" / "animal-1.jpg").read_bytes()
|
||||
expected = Path(here / ".." / "examples" / "animal-1.out.png").read_bytes()
|
||||
actual = remove(image)
|
||||
for model in ["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg", "silueta", "isnet-general-use"]:
|
||||
for picture in ["car-1", "cloth-1"]:
|
||||
image_path = Path(here / "fixtures" / f"{picture}.jpg")
|
||||
expected_path = Path(here / "results" / f"{picture}.{model}.png")
|
||||
|
||||
actual_hash = average_hash(Image.open(BytesIO(actual)))
|
||||
expected_hash = average_hash(Image.open(BytesIO(expected)))
|
||||
image = image_path.read_bytes()
|
||||
expected = expected_path.read_bytes()
|
||||
|
||||
assert actual_hash == expected_hash
|
||||
actual = remove(image, session=new_session(model))
|
||||
|
||||
# Uncomment to update the expected results
|
||||
# f = open(expected_path, "ab")
|
||||
# f.write(actual)
|
||||
# f.close()
|
||||
|
||||
actual_hash = hash_img(Image.open(BytesIO(actual)))
|
||||
expected_hash = hash_img(Image.open(BytesIO(expected)))
|
||||
|
||||
print(f"image_path: {image_path}")
|
||||
print(f"expected_path: {expected_path}")
|
||||
print(f"actual_hash: {actual_hash}")
|
||||
print(f"expected_hash: {expected_hash}")
|
||||
print(f"actual_hash == expected_hash: {actual_hash == expected_hash}")
|
||||
print("---\n")
|
||||
|
||||
assert actual_hash == expected_hash
|
||||
|