add dis model

This commit is contained in:
Daniel Gatis 2023-03-28 23:12:29 -03:00
parent ece38e1701
commit a54000d507
22 changed files with 103 additions and 31 deletions

View File

@ -265,6 +265,39 @@ The available models are:
- u2net_human_seg ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx), [source](https://github.com/xuebinqin/U-2-Net)): A pre-trained model for human segmentation.
- u2net_cloth_seg ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx), [source](https://github.com/levindabhi/cloth-segmentation)): A pre-trained model for Cloths Parsing from human portrait. Here clothes are parsed into 3 category: Upper body, Lower body and Full body.
- silueta ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx), [source](https://github.com/xuebinqin/U-2-Net/issues/295)): Same as u2net but the size is reduced to 43Mb.
- isnet-general-use ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx), [source](https://github.com/xuebinqin/U-2-Net/issues/295)): https://github.com/xuebinqin/DIS.
### Some differences between the models result
<table>
<tr>
<th>original</th>
<th>unet</th>
<th>unetp</th>
<th>u2net_human_seg</th>
<th>u2net_cloth_seg</th>
<th>silueta</th>
<th>isnet-general-use</th>
</tr>
<tr>
<th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/fixtures/car-1.jpg" width="100" /></th>
<th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/car-1.unet.jpg" width="100" /></th>
<th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/car-1.unetp.jpg" width="100" /></th>
<th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/car-1.u2net_human_seg.jpg" width="100" /></th>
<th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/car-1.u2net_cloth_seg.jpg" width="100" /></th>
<th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/car-1.silueta.jpg" width="100" /></th>
<th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/car-1.isnet-general-use.jpg" width="100" /></th>
</tr>
<th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/fixtures/cloth-1.jpg" width="100" /></th>
<th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/cloth-1.unet.jpg" width="100" /></th>
<th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/cloth-1.unetp.jpg" width="100" /></th>
<th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/cloth-1.u2net_human_seg.jpg" width="100" /></th>
<th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/cloth-1.u2net_cloth_seg.jpg" width="100" /></th>
<th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/cloth-1.silueta.jpg" width="100" /></th>
<th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/cloth-1.isnet-general-use.jpg" width="100" /></th>
</tr>
</table>
### How to train your own model

View File

@ -34,7 +34,7 @@ def main() -> None:
"--model",
default="u2net",
type=click.Choice(
["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg", "silueta"]
["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg", "silueta", "isnet-general-use"]
),
show_default=True,
show_choices=True,
@ -103,7 +103,7 @@ def i(model: str, input: IO, output: IO, **kwargs) -> None:
"--model",
default="u2net",
type=click.Choice(
["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg", "silueta"]
["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg", "silueta", "isnet-general-use"]
),
show_default=True,
show_choices=True,
@ -311,6 +311,7 @@ def s(port: int, log_level: str, threads: int) -> None:
u2net_human_seg = "u2net_human_seg"
u2net_cloth_seg = "u2net_cloth_seg"
silueta = "silueta"
isnet_general_use = "isnet-general-use"
class CommonQueryParams:
def __init__(

View File

@ -7,18 +7,18 @@ from PIL.Image import Image as PILImage
class BaseSession:
def __init__(self, model_name: str, inner_session: ort.InferenceSession, output_size: Tuple[int, int] = (320, 320)):
def __init__(self, model_name: str, inner_session: ort.InferenceSession):
self.model_name = model_name
self.inner_session = inner_session
self.output_size = output_size
def normalize(
self,
img: PILImage,
mean: Tuple[float, float, float],
std: Tuple[float, float, float],
size: Tuple[int, int],
) -> Dict[str, np.ndarray]:
im = img.convert("RGB").resize(self.output_size, Image.LANCZOS)
im = img.convert("RGB").resize(size, Image.LANCZOS)
im_ary = np.array(im)
im_ary = im_ary / np.max(im_ary)

View File

@ -56,7 +56,7 @@ pallete3 = [
class ClothSession(BaseSession):
def predict(self, img: PILImage) -> List[PILImage]:
ort_outs = self.inner_session.run(
None, self.normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), (768, 768))
None, self.normalize(img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (768, 768))
)
pred = ort_outs

30
rembg/session_dis.py Normal file
View File

@ -0,0 +1,30 @@
from typing import List
import numpy as np
from PIL import Image
from PIL.Image import Image as PILImage
from .session_base import BaseSession
class DisSession(BaseSession):
def predict(self, img: PILImage) -> List[PILImage]:
ort_outs = self.inner_session.run(
None,
self.normalize(
img, (0.485, 0.456, 0.406), (1., 1., 1.), (1024, 1024)
),
)
pred = ort_outs[0][:, 0, :, :]
ma = np.max(pred)
mi = np.min(pred)
pred = (pred - mi) / (ma - mi)
pred = np.squeeze(pred)
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
mask = mask.resize(img.size, Image.LANCZOS)
return [mask]

View File

@ -11,12 +11,10 @@ import pooch
from .session_base import BaseSession
from .session_cloth import ClothSession
from .session_simple import SimpleSession
from .session_dis import DisSession
def new_session(model_name: str = "u2net", output_size=None) -> BaseSession:
# Set output size if not set ( because isnet hat a different size )
output_size = output_size or (320, 320)
def new_session(model_name: str = "u2net") -> BaseSession:
session_class: Type[BaseSession]
md5 = "60024c5c889badc19c04ad937298a77b"
url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx"
@ -44,8 +42,8 @@ def new_session(model_name: str = "u2net", output_size=None) -> BaseSession:
session_class = SimpleSession
elif model_name == "isnet-general-use":
md5 = "fc16ebd8b0c10d971d3513d564d01e29"
url = "https://github.com/Flippchen/rembg/releases/download/test/isnet-general-use.onnx"
session_class = SimpleSession
url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx"
session_class = DisSession
u2net_home = os.getenv(
"U2NET_HOME", os.path.join(os.getenv("XDG_DATA_HOME", "~"), ".u2net")
@ -74,6 +72,5 @@ def new_session(model_name: str = "u2net", output_size=None) -> BaseSession:
str(full_path),
providers=ort.get_available_providers(),
sess_options=sess_opts,
),
output_size=output_size
)
)

View File

@ -1,4 +1,4 @@
from typing import List, Tuple
from typing import List
import numpy as np
from PIL import Image
@ -9,17 +9,10 @@ from .session_base import BaseSession
class SimpleSession(BaseSession):
def predict(self, img: PILImage) -> List[PILImage]:
if self.model_name == "isnet-general-use":
mean = (0.5, 0.5, 0.5)
std = (1., 1., 1.)
else:
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
ort_outs = self.inner_session.run(
None,
self.normalize(
img, mean, std
img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)
),
)

BIN
tests/fixtures/car-1.jpg vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 59 KiB

BIN
tests/fixtures/cloth-1.jpg vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 211 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 202 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 159 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 237 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 68 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 249 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 788 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 797 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 540 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 MiB

View File

@ -1,20 +1,38 @@
from io import BytesIO
from pathlib import Path
from imagehash import average_hash
from imagehash import phash as hash_img
from PIL import Image
from rembg import remove
from rembg import new_session
here = Path(__file__).parent.resolve()
def test_remove():
image = Path(here / ".." / "examples" / "animal-1.jpg").read_bytes()
expected = Path(here / ".." / "examples" / "animal-1.out.png").read_bytes()
actual = remove(image)
for model in ["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg", "silueta", "isnet-general-use"]:
for picture in ["car-1", "cloth-1"]:
image_path = Path(here / "fixtures" / f"{picture}.jpg")
expected_path = Path(here / "results" / f"{picture}.{model}.png")
actual_hash = average_hash(Image.open(BytesIO(actual)))
expected_hash = average_hash(Image.open(BytesIO(expected)))
image = image_path.read_bytes()
expected = expected_path.read_bytes()
assert actual_hash == expected_hash
actual = remove(image, session=new_session(model))
# Uncomment to update the expected results
# f = open(expected_path, "ab")
# f.write(actual)
# f.close()
actual_hash = hash_img(Image.open(BytesIO(actual)))
expected_hash = hash_img(Image.open(BytesIO(expected)))
print(f"image_path: {image_path}")
print(f"expected_path: {expected_path}")
print(f"actual_hash: {actual_hash}")
print(f"expected_hash: {expected_hash}")
print(f"actual_hash == expected_hash: {actual_hash == expected_hash}")
print("---\n")
assert actual_hash == expected_hash