diff --git a/README.md b/README.md index 9a497ea..2143586 100644 --- a/README.md +++ b/README.md @@ -265,6 +265,39 @@ The available models are: - u2net_human_seg ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx), [source](https://github.com/xuebinqin/U-2-Net)): A pre-trained model for human segmentation. - u2net_cloth_seg ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx), [source](https://github.com/levindabhi/cloth-segmentation)): A pre-trained model for Cloths Parsing from human portrait. Here clothes are parsed into 3 category: Upper body, Lower body and Full body. - silueta ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx), [source](https://github.com/xuebinqin/U-2-Net/issues/295)): Same as u2net but the size is reduced to 43Mb. +- isnet-general-use ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx), [source](https://github.com/xuebinqin/U-2-Net/issues/295)): https://github.com/xuebinqin/DIS. + +### Some differences between the models result + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
originalunetunetpu2net_human_segu2net_cloth_segsiluetaisnet-general-use
+ ### How to train your own model diff --git a/rembg/cli.py b/rembg/cli.py index b7cba46..10adb29 100644 --- a/rembg/cli.py +++ b/rembg/cli.py @@ -34,7 +34,7 @@ def main() -> None: "--model", default="u2net", type=click.Choice( - ["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg", "silueta"] + ["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg", "silueta", "isnet-general-use"] ), show_default=True, show_choices=True, @@ -103,7 +103,7 @@ def i(model: str, input: IO, output: IO, **kwargs) -> None: "--model", default="u2net", type=click.Choice( - ["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg", "silueta"] + ["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg", "silueta", "isnet-general-use"] ), show_default=True, show_choices=True, @@ -311,6 +311,7 @@ def s(port: int, log_level: str, threads: int) -> None: u2net_human_seg = "u2net_human_seg" u2net_cloth_seg = "u2net_cloth_seg" silueta = "silueta" + isnet_general_use = "isnet-general-use" class CommonQueryParams: def __init__( diff --git a/rembg/session_base.py b/rembg/session_base.py index 1409e3a..aa98693 100644 --- a/rembg/session_base.py +++ b/rembg/session_base.py @@ -7,18 +7,18 @@ from PIL.Image import Image as PILImage class BaseSession: - def __init__(self, model_name: str, inner_session: ort.InferenceSession, output_size: Tuple[int, int] = (320, 320)): + def __init__(self, model_name: str, inner_session: ort.InferenceSession): self.model_name = model_name self.inner_session = inner_session - self.output_size = output_size def normalize( self, img: PILImage, mean: Tuple[float, float, float], std: Tuple[float, float, float], + size: Tuple[int, int], ) -> Dict[str, np.ndarray]: - im = img.convert("RGB").resize(self.output_size, Image.LANCZOS) + im = img.convert("RGB").resize(size, Image.LANCZOS) im_ary = np.array(im) im_ary = im_ary / np.max(im_ary) diff --git a/rembg/session_cloth.py b/rembg/session_cloth.py index 11bcef7..d2967d1 100644 --- a/rembg/session_cloth.py +++ b/rembg/session_cloth.py @@ -56,7 +56,7 @@ pallete3 = [ class ClothSession(BaseSession): def predict(self, img: PILImage) -> List[PILImage]: ort_outs = self.inner_session.run( - None, self.normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), (768, 768)) + None, self.normalize(img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (768, 768)) ) pred = ort_outs diff --git a/rembg/session_dis.py b/rembg/session_dis.py new file mode 100644 index 0000000..1d4244f --- /dev/null +++ b/rembg/session_dis.py @@ -0,0 +1,30 @@ +from typing import List + +import numpy as np +from PIL import Image +from PIL.Image import Image as PILImage + +from .session_base import BaseSession + + +class DisSession(BaseSession): + def predict(self, img: PILImage) -> List[PILImage]: + ort_outs = self.inner_session.run( + None, + self.normalize( + img, (0.485, 0.456, 0.406), (1., 1., 1.), (1024, 1024) + ), + ) + + pred = ort_outs[0][:, 0, :, :] + + ma = np.max(pred) + mi = np.min(pred) + + pred = (pred - mi) / (ma - mi) + pred = np.squeeze(pred) + + mask = Image.fromarray((pred * 255).astype("uint8"), mode="L") + mask = mask.resize(img.size, Image.LANCZOS) + + return [mask] diff --git a/rembg/session_factory.py b/rembg/session_factory.py index ea59623..a15c8ed 100644 --- a/rembg/session_factory.py +++ b/rembg/session_factory.py @@ -11,12 +11,10 @@ import pooch from .session_base import BaseSession from .session_cloth import ClothSession from .session_simple import SimpleSession +from .session_dis import DisSession -def new_session(model_name: str = "u2net", output_size=None) -> BaseSession: - # Set output size if not set ( because isnet hat a different size ) - output_size = output_size or (320, 320) - +def new_session(model_name: str = "u2net") -> BaseSession: session_class: Type[BaseSession] md5 = "60024c5c889badc19c04ad937298a77b" url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx" @@ -44,8 +42,8 @@ def new_session(model_name: str = "u2net", output_size=None) -> BaseSession: session_class = SimpleSession elif model_name == "isnet-general-use": md5 = "fc16ebd8b0c10d971d3513d564d01e29" - url = "https://github.com/Flippchen/rembg/releases/download/test/isnet-general-use.onnx" - session_class = SimpleSession + url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx" + session_class = DisSession u2net_home = os.getenv( "U2NET_HOME", os.path.join(os.getenv("XDG_DATA_HOME", "~"), ".u2net") @@ -74,6 +72,5 @@ def new_session(model_name: str = "u2net", output_size=None) -> BaseSession: str(full_path), providers=ort.get_available_providers(), sess_options=sess_opts, - ), - output_size=output_size + ) ) diff --git a/rembg/session_simple.py b/rembg/session_simple.py index 9417491..7ec3181 100644 --- a/rembg/session_simple.py +++ b/rembg/session_simple.py @@ -1,4 +1,4 @@ -from typing import List, Tuple +from typing import List import numpy as np from PIL import Image @@ -9,17 +9,10 @@ from .session_base import BaseSession class SimpleSession(BaseSession): def predict(self, img: PILImage) -> List[PILImage]: - if self.model_name == "isnet-general-use": - mean = (0.5, 0.5, 0.5) - std = (1., 1., 1.) - else: - mean = (0.485, 0.456, 0.406) - std = (0.229, 0.224, 0.225) - ort_outs = self.inner_session.run( None, self.normalize( - img, mean, std + img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320) ), ) diff --git a/tests/fixtures/car-1.jpg b/tests/fixtures/car-1.jpg new file mode 100644 index 0000000..c6dc1e6 Binary files /dev/null and b/tests/fixtures/car-1.jpg differ diff --git a/tests/fixtures/cloth-1.jpg b/tests/fixtures/cloth-1.jpg new file mode 100644 index 0000000..a33ba7e Binary files /dev/null and b/tests/fixtures/cloth-1.jpg differ diff --git a/tests/results/car-1.isnet-general-use.png b/tests/results/car-1.isnet-general-use.png new file mode 100644 index 0000000..2e4beb7 Binary files /dev/null and b/tests/results/car-1.isnet-general-use.png differ diff --git a/tests/results/car-1.silueta.png b/tests/results/car-1.silueta.png new file mode 100644 index 0000000..82f572f Binary files /dev/null and b/tests/results/car-1.silueta.png differ diff --git a/tests/results/car-1.u2net.png b/tests/results/car-1.u2net.png new file mode 100644 index 0000000..e5c3994 Binary files /dev/null and b/tests/results/car-1.u2net.png differ diff --git a/tests/results/car-1.u2net_cloth_seg.png b/tests/results/car-1.u2net_cloth_seg.png new file mode 100644 index 0000000..64ffd88 Binary files /dev/null and b/tests/results/car-1.u2net_cloth_seg.png differ diff --git a/tests/results/car-1.u2net_human_seg.png b/tests/results/car-1.u2net_human_seg.png new file mode 100644 index 0000000..fee65a6 Binary files /dev/null and b/tests/results/car-1.u2net_human_seg.png differ diff --git a/tests/results/car-1.u2netp.png b/tests/results/car-1.u2netp.png new file mode 100644 index 0000000..ba72870 Binary files /dev/null and b/tests/results/car-1.u2netp.png differ diff --git a/tests/results/cloth-1.isnet-general-use.png b/tests/results/cloth-1.isnet-general-use.png new file mode 100644 index 0000000..6e474f7 Binary files /dev/null and b/tests/results/cloth-1.isnet-general-use.png differ diff --git a/tests/results/cloth-1.silueta.png b/tests/results/cloth-1.silueta.png new file mode 100644 index 0000000..9bc356a Binary files /dev/null and b/tests/results/cloth-1.silueta.png differ diff --git a/tests/results/cloth-1.u2net.png b/tests/results/cloth-1.u2net.png new file mode 100644 index 0000000..501bb5b Binary files /dev/null and b/tests/results/cloth-1.u2net.png differ diff --git a/tests/results/cloth-1.u2net_cloth_seg.png b/tests/results/cloth-1.u2net_cloth_seg.png new file mode 100644 index 0000000..bc72550 Binary files /dev/null and b/tests/results/cloth-1.u2net_cloth_seg.png differ diff --git a/tests/results/cloth-1.u2net_human_seg.png b/tests/results/cloth-1.u2net_human_seg.png new file mode 100644 index 0000000..2abde7f Binary files /dev/null and b/tests/results/cloth-1.u2net_human_seg.png differ diff --git a/tests/results/cloth-1.u2netp.png b/tests/results/cloth-1.u2netp.png new file mode 100644 index 0000000..cc11944 Binary files /dev/null and b/tests/results/cloth-1.u2netp.png differ diff --git a/tests/test_remove.py b/tests/test_remove.py index 7421b12..7c384b0 100644 --- a/tests/test_remove.py +++ b/tests/test_remove.py @@ -1,20 +1,38 @@ from io import BytesIO from pathlib import Path -from imagehash import average_hash +from imagehash import phash as hash_img from PIL import Image from rembg import remove +from rembg import new_session here = Path(__file__).parent.resolve() - def test_remove(): - image = Path(here / ".." / "examples" / "animal-1.jpg").read_bytes() - expected = Path(here / ".." / "examples" / "animal-1.out.png").read_bytes() - actual = remove(image) + for model in ["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg", "silueta", "isnet-general-use"]: + for picture in ["car-1", "cloth-1"]: + image_path = Path(here / "fixtures" / f"{picture}.jpg") + expected_path = Path(here / "results" / f"{picture}.{model}.png") - actual_hash = average_hash(Image.open(BytesIO(actual))) - expected_hash = average_hash(Image.open(BytesIO(expected))) + image = image_path.read_bytes() + expected = expected_path.read_bytes() - assert actual_hash == expected_hash + actual = remove(image, session=new_session(model)) + + # Uncomment to update the expected results + # f = open(expected_path, "ab") + # f.write(actual) + # f.close() + + actual_hash = hash_img(Image.open(BytesIO(actual))) + expected_hash = hash_img(Image.open(BytesIO(expected))) + + print(f"image_path: {image_path}") + print(f"expected_path: {expected_path}") + print(f"actual_hash: {actual_hash}") + print(f"expected_hash: {expected_hash}") + print(f"actual_hash == expected_hash: {actual_hash == expected_hash}") + print("---\n") + + assert actual_hash == expected_hash