diff --git a/README.md b/README.md
index 9a497ea..2143586 100644
--- a/README.md
+++ b/README.md
@@ -265,6 +265,39 @@ The available models are:
- u2net_human_seg ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx), [source](https://github.com/xuebinqin/U-2-Net)): A pre-trained model for human segmentation.
- u2net_cloth_seg ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx), [source](https://github.com/levindabhi/cloth-segmentation)): A pre-trained model for Cloths Parsing from human portrait. Here clothes are parsed into 3 category: Upper body, Lower body and Full body.
- silueta ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx), [source](https://github.com/xuebinqin/U-2-Net/issues/295)): Same as u2net but the size is reduced to 43Mb.
+- isnet-general-use ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx), [source](https://github.com/xuebinqin/U-2-Net/issues/295)): https://github.com/xuebinqin/DIS.
+
+### Some differences between the models result
+
+
+
+ original |
+ unet |
+ unetp |
+ u2net_human_seg |
+ u2net_cloth_seg |
+ silueta |
+ isnet-general-use |
+
+
+  |
+  |
+  |
+  |
+  |
+  |
+  |
+
+  |
+  |
+  |
+  |
+  |
+  |
+  |
+
+
+
### How to train your own model
diff --git a/rembg/cli.py b/rembg/cli.py
index b7cba46..10adb29 100644
--- a/rembg/cli.py
+++ b/rembg/cli.py
@@ -34,7 +34,7 @@ def main() -> None:
"--model",
default="u2net",
type=click.Choice(
- ["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg", "silueta"]
+ ["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg", "silueta", "isnet-general-use"]
),
show_default=True,
show_choices=True,
@@ -103,7 +103,7 @@ def i(model: str, input: IO, output: IO, **kwargs) -> None:
"--model",
default="u2net",
type=click.Choice(
- ["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg", "silueta"]
+ ["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg", "silueta", "isnet-general-use"]
),
show_default=True,
show_choices=True,
@@ -311,6 +311,7 @@ def s(port: int, log_level: str, threads: int) -> None:
u2net_human_seg = "u2net_human_seg"
u2net_cloth_seg = "u2net_cloth_seg"
silueta = "silueta"
+ isnet_general_use = "isnet-general-use"
class CommonQueryParams:
def __init__(
diff --git a/rembg/session_base.py b/rembg/session_base.py
index 1409e3a..aa98693 100644
--- a/rembg/session_base.py
+++ b/rembg/session_base.py
@@ -7,18 +7,18 @@ from PIL.Image import Image as PILImage
class BaseSession:
- def __init__(self, model_name: str, inner_session: ort.InferenceSession, output_size: Tuple[int, int] = (320, 320)):
+ def __init__(self, model_name: str, inner_session: ort.InferenceSession):
self.model_name = model_name
self.inner_session = inner_session
- self.output_size = output_size
def normalize(
self,
img: PILImage,
mean: Tuple[float, float, float],
std: Tuple[float, float, float],
+ size: Tuple[int, int],
) -> Dict[str, np.ndarray]:
- im = img.convert("RGB").resize(self.output_size, Image.LANCZOS)
+ im = img.convert("RGB").resize(size, Image.LANCZOS)
im_ary = np.array(im)
im_ary = im_ary / np.max(im_ary)
diff --git a/rembg/session_cloth.py b/rembg/session_cloth.py
index 11bcef7..d2967d1 100644
--- a/rembg/session_cloth.py
+++ b/rembg/session_cloth.py
@@ -56,7 +56,7 @@ pallete3 = [
class ClothSession(BaseSession):
def predict(self, img: PILImage) -> List[PILImage]:
ort_outs = self.inner_session.run(
- None, self.normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), (768, 768))
+ None, self.normalize(img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (768, 768))
)
pred = ort_outs
diff --git a/rembg/session_dis.py b/rembg/session_dis.py
new file mode 100644
index 0000000..1d4244f
--- /dev/null
+++ b/rembg/session_dis.py
@@ -0,0 +1,30 @@
+from typing import List
+
+import numpy as np
+from PIL import Image
+from PIL.Image import Image as PILImage
+
+from .session_base import BaseSession
+
+
+class DisSession(BaseSession):
+ def predict(self, img: PILImage) -> List[PILImage]:
+ ort_outs = self.inner_session.run(
+ None,
+ self.normalize(
+ img, (0.485, 0.456, 0.406), (1., 1., 1.), (1024, 1024)
+ ),
+ )
+
+ pred = ort_outs[0][:, 0, :, :]
+
+ ma = np.max(pred)
+ mi = np.min(pred)
+
+ pred = (pred - mi) / (ma - mi)
+ pred = np.squeeze(pred)
+
+ mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
+ mask = mask.resize(img.size, Image.LANCZOS)
+
+ return [mask]
diff --git a/rembg/session_factory.py b/rembg/session_factory.py
index ea59623..a15c8ed 100644
--- a/rembg/session_factory.py
+++ b/rembg/session_factory.py
@@ -11,12 +11,10 @@ import pooch
from .session_base import BaseSession
from .session_cloth import ClothSession
from .session_simple import SimpleSession
+from .session_dis import DisSession
-def new_session(model_name: str = "u2net", output_size=None) -> BaseSession:
- # Set output size if not set ( because isnet hat a different size )
- output_size = output_size or (320, 320)
-
+def new_session(model_name: str = "u2net") -> BaseSession:
session_class: Type[BaseSession]
md5 = "60024c5c889badc19c04ad937298a77b"
url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx"
@@ -44,8 +42,8 @@ def new_session(model_name: str = "u2net", output_size=None) -> BaseSession:
session_class = SimpleSession
elif model_name == "isnet-general-use":
md5 = "fc16ebd8b0c10d971d3513d564d01e29"
- url = "https://github.com/Flippchen/rembg/releases/download/test/isnet-general-use.onnx"
- session_class = SimpleSession
+ url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx"
+ session_class = DisSession
u2net_home = os.getenv(
"U2NET_HOME", os.path.join(os.getenv("XDG_DATA_HOME", "~"), ".u2net")
@@ -74,6 +72,5 @@ def new_session(model_name: str = "u2net", output_size=None) -> BaseSession:
str(full_path),
providers=ort.get_available_providers(),
sess_options=sess_opts,
- ),
- output_size=output_size
+ )
)
diff --git a/rembg/session_simple.py b/rembg/session_simple.py
index 9417491..7ec3181 100644
--- a/rembg/session_simple.py
+++ b/rembg/session_simple.py
@@ -1,4 +1,4 @@
-from typing import List, Tuple
+from typing import List
import numpy as np
from PIL import Image
@@ -9,17 +9,10 @@ from .session_base import BaseSession
class SimpleSession(BaseSession):
def predict(self, img: PILImage) -> List[PILImage]:
- if self.model_name == "isnet-general-use":
- mean = (0.5, 0.5, 0.5)
- std = (1., 1., 1.)
- else:
- mean = (0.485, 0.456, 0.406)
- std = (0.229, 0.224, 0.225)
-
ort_outs = self.inner_session.run(
None,
self.normalize(
- img, mean, std
+ img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)
),
)
diff --git a/tests/fixtures/car-1.jpg b/tests/fixtures/car-1.jpg
new file mode 100644
index 0000000..c6dc1e6
Binary files /dev/null and b/tests/fixtures/car-1.jpg differ
diff --git a/tests/fixtures/cloth-1.jpg b/tests/fixtures/cloth-1.jpg
new file mode 100644
index 0000000..a33ba7e
Binary files /dev/null and b/tests/fixtures/cloth-1.jpg differ
diff --git a/tests/results/car-1.isnet-general-use.png b/tests/results/car-1.isnet-general-use.png
new file mode 100644
index 0000000..2e4beb7
Binary files /dev/null and b/tests/results/car-1.isnet-general-use.png differ
diff --git a/tests/results/car-1.silueta.png b/tests/results/car-1.silueta.png
new file mode 100644
index 0000000..82f572f
Binary files /dev/null and b/tests/results/car-1.silueta.png differ
diff --git a/tests/results/car-1.u2net.png b/tests/results/car-1.u2net.png
new file mode 100644
index 0000000..e5c3994
Binary files /dev/null and b/tests/results/car-1.u2net.png differ
diff --git a/tests/results/car-1.u2net_cloth_seg.png b/tests/results/car-1.u2net_cloth_seg.png
new file mode 100644
index 0000000..64ffd88
Binary files /dev/null and b/tests/results/car-1.u2net_cloth_seg.png differ
diff --git a/tests/results/car-1.u2net_human_seg.png b/tests/results/car-1.u2net_human_seg.png
new file mode 100644
index 0000000..fee65a6
Binary files /dev/null and b/tests/results/car-1.u2net_human_seg.png differ
diff --git a/tests/results/car-1.u2netp.png b/tests/results/car-1.u2netp.png
new file mode 100644
index 0000000..ba72870
Binary files /dev/null and b/tests/results/car-1.u2netp.png differ
diff --git a/tests/results/cloth-1.isnet-general-use.png b/tests/results/cloth-1.isnet-general-use.png
new file mode 100644
index 0000000..6e474f7
Binary files /dev/null and b/tests/results/cloth-1.isnet-general-use.png differ
diff --git a/tests/results/cloth-1.silueta.png b/tests/results/cloth-1.silueta.png
new file mode 100644
index 0000000..9bc356a
Binary files /dev/null and b/tests/results/cloth-1.silueta.png differ
diff --git a/tests/results/cloth-1.u2net.png b/tests/results/cloth-1.u2net.png
new file mode 100644
index 0000000..501bb5b
Binary files /dev/null and b/tests/results/cloth-1.u2net.png differ
diff --git a/tests/results/cloth-1.u2net_cloth_seg.png b/tests/results/cloth-1.u2net_cloth_seg.png
new file mode 100644
index 0000000..bc72550
Binary files /dev/null and b/tests/results/cloth-1.u2net_cloth_seg.png differ
diff --git a/tests/results/cloth-1.u2net_human_seg.png b/tests/results/cloth-1.u2net_human_seg.png
new file mode 100644
index 0000000..2abde7f
Binary files /dev/null and b/tests/results/cloth-1.u2net_human_seg.png differ
diff --git a/tests/results/cloth-1.u2netp.png b/tests/results/cloth-1.u2netp.png
new file mode 100644
index 0000000..cc11944
Binary files /dev/null and b/tests/results/cloth-1.u2netp.png differ
diff --git a/tests/test_remove.py b/tests/test_remove.py
index 7421b12..7c384b0 100644
--- a/tests/test_remove.py
+++ b/tests/test_remove.py
@@ -1,20 +1,38 @@
from io import BytesIO
from pathlib import Path
-from imagehash import average_hash
+from imagehash import phash as hash_img
from PIL import Image
from rembg import remove
+from rembg import new_session
here = Path(__file__).parent.resolve()
-
def test_remove():
- image = Path(here / ".." / "examples" / "animal-1.jpg").read_bytes()
- expected = Path(here / ".." / "examples" / "animal-1.out.png").read_bytes()
- actual = remove(image)
+ for model in ["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg", "silueta", "isnet-general-use"]:
+ for picture in ["car-1", "cloth-1"]:
+ image_path = Path(here / "fixtures" / f"{picture}.jpg")
+ expected_path = Path(here / "results" / f"{picture}.{model}.png")
- actual_hash = average_hash(Image.open(BytesIO(actual)))
- expected_hash = average_hash(Image.open(BytesIO(expected)))
+ image = image_path.read_bytes()
+ expected = expected_path.read_bytes()
- assert actual_hash == expected_hash
+ actual = remove(image, session=new_session(model))
+
+ # Uncomment to update the expected results
+ # f = open(expected_path, "ab")
+ # f.write(actual)
+ # f.close()
+
+ actual_hash = hash_img(Image.open(BytesIO(actual)))
+ expected_hash = hash_img(Image.open(BytesIO(expected)))
+
+ print(f"image_path: {image_path}")
+ print(f"expected_path: {expected_path}")
+ print(f"actual_hash: {actual_hash}")
+ print(f"expected_hash: {expected_hash}")
+ print(f"actual_hash == expected_hash: {actual_hash == expected_hash}")
+ print("---\n")
+
+ assert actual_hash == expected_hash