diff --git a/README.md b/README.md index ea836f0..1779d5e 100644 --- a/README.md +++ b/README.md @@ -334,6 +334,13 @@ The available models are: - isnet-general-use ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx), [source](https://github.com/xuebinqin/DIS)): A new pre-trained model for general use cases. - isnet-anime ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-anime.onnx), [source](https://github.com/SkyTNT/anime-segmentation)): A high-accuracy segmentation for anime character. - sam ([download encoder](https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-encoder-quant.onnx), [download decoder](https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-decoder-quant.onnx), [source](https://github.com/facebookresearch/segment-anything)): A pre-trained model for any use cases. +- birefnet-general ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-epoch_244.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A pre-trained model for general use cases. +- birefnet-general-lite ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-bb_swin_v1_tiny-epoch_232.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A light pre-trained model for general use cases. +- birefnet-portrait ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-portrait-epoch_150.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A pre-trained model for human portraits. +- birefnet-dis ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-DIS-epoch_590.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A pre-trained model for dichotomous image segmentation (DIS). +- birefnet-hrsod ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-HRSOD_DHU-epoch_115.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A pre-trained model for high-resolution salient object detection (HRSOD). +- birefnet-cod ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-COD-epoch_125.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A pre-trained model for concealed object detection (COD). +- birefnet-massive ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-massive-TR_DIS5K_TR_TEs-epoch_420.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A pre-trained model with massive dataset. ### How to train your own model diff --git a/rembg/sessions/__init__.py b/rembg/sessions/__init__.py index 367d9a5..b6e4c0d 100644 --- a/rembg/sessions/__init__.py +++ b/rembg/sessions/__init__.py @@ -7,6 +7,41 @@ from .base import BaseSession sessions_class: List[type[BaseSession]] = [] sessions_names: List[str] = [] +from .birefnet_general import BiRefNetSessionGeneral + +sessions_class.append(BiRefNetSessionGeneral) +sessions_names.append(BiRefNetSessionGeneral.name()) + +from .birefnet_general_lite import BiRefNetSessionGeneralLite + +sessions_class.append(BiRefNetSessionGeneralLite) +sessions_names.append(BiRefNetSessionGeneralLite.name()) + +from .birefnet_portrait import BiRefNetSessionPortrait + +sessions_class.append(BiRefNetSessionPortrait) +sessions_names.append(BiRefNetSessionPortrait.name()) + +from .birefnet_dis import BiRefNetSessionDIS + +sessions_class.append(BiRefNetSessionDIS) +sessions_names.append(BiRefNetSessionDIS.name()) + +from .birefnet_hrsod import BiRefNetSessionHRSOD + +sessions_class.append(BiRefNetSessionHRSOD) +sessions_names.append(BiRefNetSessionHRSOD.name()) + +from .birefnet_cod import BiRefNetSessionCOD + +sessions_class.append(BiRefNetSessionCOD) +sessions_names.append(BiRefNetSessionCOD.name()) + +from .birefnet_massive import BiRefNetSessionMassive + +sessions_class.append(BiRefNetSessionMassive) +sessions_names.append(BiRefNetSessionMassive.name()) + from .dis_anime import DisSession sessions_class.append(DisSession) diff --git a/rembg/sessions/birefnet_cod.py b/rembg/sessions/birefnet_cod.py new file mode 100644 index 0000000..d678f1a --- /dev/null +++ b/rembg/sessions/birefnet_cod.py @@ -0,0 +1,52 @@ +import os + +import pooch + +from . import BiRefNetSessionGeneral + + +class BiRefNetSessionCOD(BiRefNetSessionGeneral): + """ + This class represents a BiRefNet-COD session, which is a subclass of BiRefNetSessionGeneral. + """ + + @classmethod + def download_models(cls, *args, **kwargs): + """ + Downloads the BiRefNet-COD model file from a specific URL and saves it. + + Parameters: + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + str: The path to the downloaded model file. + """ + fname = f"{cls.name(*args, **kwargs)}.onnx" + pooch.retrieve( + "https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-COD-epoch_125.onnx", + ( + None + if cls.checksum_disabled(*args, **kwargs) + else "md5:f6d0d21ca89d287f17e7afe9f5fd3b45" + ), + fname=fname, + path=cls.u2net_home(*args, **kwargs), + progressbar=True, + ) + + return os.path.join(cls.u2net_home(*args, **kwargs), fname) + + @classmethod + def name(cls, *args, **kwargs): + """ + Returns the name of the BiRefNet-COD session. + + Parameters: + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + str: The name of the session. + """ + return "birefnet-cod" diff --git a/rembg/sessions/birefnet_dis.py b/rembg/sessions/birefnet_dis.py new file mode 100644 index 0000000..76bdb85 --- /dev/null +++ b/rembg/sessions/birefnet_dis.py @@ -0,0 +1,52 @@ +import os + +import pooch + +from . import BiRefNetSessionGeneral + + +class BiRefNetSessionDIS(BiRefNetSessionGeneral): + """ + This class represents a BiRefNet-DIS session, which is a subclass of BiRefNetSessionGeneral. + """ + + @classmethod + def download_models(cls, *args, **kwargs): + """ + Downloads the BiRefNet-DIS model file from a specific URL and saves it. + + Parameters: + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + str: The path to the downloaded model file. + """ + fname = f"{cls.name(*args, **kwargs)}.onnx" + pooch.retrieve( + "https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-DIS-epoch_590.onnx", + ( + None + if cls.checksum_disabled(*args, **kwargs) + else "md5:2d4d44102b446f33a4ebb2e56c051f2b" + ), + fname=fname, + path=cls.u2net_home(*args, **kwargs), + progressbar=True, + ) + + return os.path.join(cls.u2net_home(*args, **kwargs), fname) + + @classmethod + def name(cls, *args, **kwargs): + """ + Returns the name of the BiRefNet-DIS session. + + Parameters: + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + str: The name of the session. + """ + return "birefnet-dis" diff --git a/rembg/sessions/birefnet_general.py b/rembg/sessions/birefnet_general.py new file mode 100644 index 0000000..3fa1dca --- /dev/null +++ b/rembg/sessions/birefnet_general.py @@ -0,0 +1,91 @@ +import os +from typing import List + +import numpy as np +import pooch +from PIL import Image +from PIL.Image import Image as PILImage + +from .base import BaseSession + + +class BiRefNetSessionGeneral(BaseSession): + """ + This class represents a BiRefNet-General session, which is a subclass of BaseSession. + """ + + def sigmoid(self, mat): + return 1 / (1 + np.exp(-mat)) + + def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]: + """ + Predicts the output masks for the input image using the inner session. + + Parameters: + img (PILImage): The input image. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + List[PILImage]: The list of output masks. + """ + ort_outs = self.inner_session.run( + None, + self.normalize( + img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (1024, 1024) + ), + ) + + pred = self.sigmoid(ort_outs[0][:, 0, :, :]) + + ma = np.max(pred) + mi = np.min(pred) + + pred = (pred - mi) / (ma - mi) + pred = np.squeeze(pred) + + mask = Image.fromarray((pred * 255).astype("uint8"), mode="L") + mask = mask.resize(img.size, Image.Resampling.LANCZOS) + + return [mask] + + @classmethod + def download_models(cls, *args, **kwargs): + """ + Downloads the BiRefNet-General model file from a specific URL and saves it. + + Parameters: + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + str: The path to the downloaded model file. + """ + fname = f"{cls.name(*args, **kwargs)}.onnx" + pooch.retrieve( + "https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-epoch_244.onnx", + ( + None + if cls.checksum_disabled(*args, **kwargs) + else "md5:7a35a0141cbbc80de11d9c9a28f52697" + ), + fname=fname, + path=cls.u2net_home(*args, **kwargs), + progressbar=True, + ) + + return os.path.join(cls.u2net_home(*args, **kwargs), fname) + + @classmethod + def name(cls, *args, **kwargs): + """ + Returns the name of the BiRefNet-General session. + + Parameters: + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + str: The name of the session. + """ + return "birefnet-general" diff --git a/rembg/sessions/birefnet_general_lite.py b/rembg/sessions/birefnet_general_lite.py new file mode 100644 index 0000000..6196712 --- /dev/null +++ b/rembg/sessions/birefnet_general_lite.py @@ -0,0 +1,52 @@ +import os + +import pooch + +from . import BiRefNetSessionGeneral + + +class BiRefNetSessionGeneralLite(BiRefNetSessionGeneral): + """ + This class represents a BiRefNet-General-Lite session, which is a subclass of BiRefNetSessionGeneral. + """ + + @classmethod + def download_models(cls, *args, **kwargs): + """ + Downloads the BiRefNet-General-Lite model file from a specific URL and saves it. + + Parameters: + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + str: The path to the downloaded model file. + """ + fname = f"{cls.name(*args, **kwargs)}.onnx" + pooch.retrieve( + "https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-bb_swin_v1_tiny-epoch_232.onnx", + ( + None + if cls.checksum_disabled(*args, **kwargs) + else "md5:4fab47adc4ff364be1713e97b7e66334" + ), + fname=fname, + path=cls.u2net_home(*args, **kwargs), + progressbar=True, + ) + + return os.path.join(cls.u2net_home(*args, **kwargs), fname) + + @classmethod + def name(cls, *args, **kwargs): + """ + Returns the name of the BiRefNet-General-Lite session. + + Parameters: + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + str: The name of the session. + """ + return "birefnet-general-lite" diff --git a/rembg/sessions/birefnet_hrsod.py b/rembg/sessions/birefnet_hrsod.py new file mode 100644 index 0000000..2260a20 --- /dev/null +++ b/rembg/sessions/birefnet_hrsod.py @@ -0,0 +1,52 @@ +import os + +import pooch + +from . import BiRefNetSessionGeneral + + +class BiRefNetSessionHRSOD(BiRefNetSessionGeneral): + """ + This class represents a BiRefNet-HRSOD session, which is a subclass of BiRefNetSessionGeneral. + """ + + @classmethod + def download_models(cls, *args, **kwargs): + """ + Downloads the BiRefNet-HRSOD model file from a specific URL and saves it. + + Parameters: + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + str: The path to the downloaded model file. + """ + fname = f"{cls.name(*args, **kwargs)}.onnx" + pooch.retrieve( + "https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-HRSOD_DHU-epoch_115.onnx", + ( + None + if cls.checksum_disabled(*args, **kwargs) + else "md5:c017ade5de8a50ff0fd74d790d268dda" + ), + fname=fname, + path=cls.u2net_home(*args, **kwargs), + progressbar=True, + ) + + return os.path.join(cls.u2net_home(*args, **kwargs), fname) + + @classmethod + def name(cls, *args, **kwargs): + """ + Returns the name of the BiRefNet-HRSOD session. + + Parameters: + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + str: The name of the session. + """ + return "birefnet-hrsod" diff --git a/rembg/sessions/birefnet_massive.py b/rembg/sessions/birefnet_massive.py new file mode 100644 index 0000000..79233b8 --- /dev/null +++ b/rembg/sessions/birefnet_massive.py @@ -0,0 +1,52 @@ +import os + +import pooch + +from . import BiRefNetSessionGeneral + + +class BiRefNetSessionMassive(BiRefNetSessionGeneral): + """ + This class represents a BiRefNet-Massive session, which is a subclass of BiRefNetSessionGeneral. + """ + + @classmethod + def download_models(cls, *args, **kwargs): + """ + Downloads the BiRefNet-Massive model file from a specific URL and saves it. + + Parameters: + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + str: The path to the downloaded model file. + """ + fname = f"{cls.name(*args, **kwargs)}.onnx" + pooch.retrieve( + "https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-massive-TR_DIS5K_TR_TEs-epoch_420.onnx", + ( + None + if cls.checksum_disabled(*args, **kwargs) + else "md5:33e726a2136a3d59eb0fdf613e31e3e9" + ), + fname=fname, + path=cls.u2net_home(*args, **kwargs), + progressbar=True, + ) + + return os.path.join(cls.u2net_home(*args, **kwargs), fname) + + @classmethod + def name(cls, *args, **kwargs): + """ + Returns the name of the BiRefNet-Massive session. + + Parameters: + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + str: The name of the session. + """ + return "birefnet-massive" diff --git a/rembg/sessions/birefnet_portrait.py b/rembg/sessions/birefnet_portrait.py new file mode 100644 index 0000000..e6a5835 --- /dev/null +++ b/rembg/sessions/birefnet_portrait.py @@ -0,0 +1,52 @@ +import os + +import pooch + +from . import BiRefNetSessionGeneral + + +class BiRefNetSessionPortrait(BiRefNetSessionGeneral): + """ + This class represents a BiRefNet-Portrait session, which is a subclass of BiRefNetSessionGeneral. + """ + + @classmethod + def download_models(cls, *args, **kwargs): + """ + Downloads the BiRefNet-Portrait model file from a specific URL and saves it. + + Parameters: + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + str: The path to the downloaded model file. + """ + fname = f"{cls.name(*args, **kwargs)}.onnx" + pooch.retrieve( + "https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-portrait-epoch_150.onnx", + ( + None + if cls.checksum_disabled(*args, **kwargs) + else "md5:c3a64a6abf20250d090cd055f12a3b67" + ), + fname=fname, + path=cls.u2net_home(*args, **kwargs), + progressbar=True, + ) + + return os.path.join(cls.u2net_home(*args, **kwargs), fname) + + @classmethod + def name(cls, *args, **kwargs): + """ + Returns the name of the BiRefNet-Portrait session. + + Parameters: + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + str: The name of the session. + """ + return "birefnet-portrait" diff --git a/tests/results/anime-girl-1.birefnet-cod.png b/tests/results/anime-girl-1.birefnet-cod.png new file mode 100644 index 0000000..d2c740e Binary files /dev/null and b/tests/results/anime-girl-1.birefnet-cod.png differ diff --git a/tests/results/anime-girl-1.birefnet-dis.png b/tests/results/anime-girl-1.birefnet-dis.png new file mode 100644 index 0000000..c5bc4ec Binary files /dev/null and b/tests/results/anime-girl-1.birefnet-dis.png differ diff --git a/tests/results/anime-girl-1.birefnet-general-lite.png b/tests/results/anime-girl-1.birefnet-general-lite.png new file mode 100644 index 0000000..a1e966f Binary files /dev/null and b/tests/results/anime-girl-1.birefnet-general-lite.png differ diff --git a/tests/results/anime-girl-1.birefnet-general.png b/tests/results/anime-girl-1.birefnet-general.png new file mode 100644 index 0000000..6b39907 Binary files /dev/null and b/tests/results/anime-girl-1.birefnet-general.png differ diff --git a/tests/results/anime-girl-1.birefnet-hrsod.png b/tests/results/anime-girl-1.birefnet-hrsod.png new file mode 100644 index 0000000..b570a32 Binary files /dev/null and b/tests/results/anime-girl-1.birefnet-hrsod.png differ diff --git a/tests/results/anime-girl-1.birefnet-massive.png b/tests/results/anime-girl-1.birefnet-massive.png new file mode 100644 index 0000000..9674920 Binary files /dev/null and b/tests/results/anime-girl-1.birefnet-massive.png differ diff --git a/tests/results/anime-girl-1.birefnet-portrait.png b/tests/results/anime-girl-1.birefnet-portrait.png new file mode 100644 index 0000000..0180b8a Binary files /dev/null and b/tests/results/anime-girl-1.birefnet-portrait.png differ diff --git a/tests/results/car-1.birefnet-cod.png b/tests/results/car-1.birefnet-cod.png new file mode 100644 index 0000000..d0709f3 Binary files /dev/null and b/tests/results/car-1.birefnet-cod.png differ diff --git a/tests/results/car-1.birefnet-dis.png b/tests/results/car-1.birefnet-dis.png new file mode 100644 index 0000000..b56d96d Binary files /dev/null and b/tests/results/car-1.birefnet-dis.png differ diff --git a/tests/results/car-1.birefnet-general-lite.png b/tests/results/car-1.birefnet-general-lite.png new file mode 100644 index 0000000..0c6cc2f Binary files /dev/null and b/tests/results/car-1.birefnet-general-lite.png differ diff --git a/tests/results/car-1.birefnet-general.png b/tests/results/car-1.birefnet-general.png new file mode 100644 index 0000000..f026551 Binary files /dev/null and b/tests/results/car-1.birefnet-general.png differ diff --git a/tests/results/car-1.birefnet-hrsod.png b/tests/results/car-1.birefnet-hrsod.png new file mode 100644 index 0000000..d094aa0 Binary files /dev/null and b/tests/results/car-1.birefnet-hrsod.png differ diff --git a/tests/results/car-1.birefnet-massive.png b/tests/results/car-1.birefnet-massive.png new file mode 100644 index 0000000..f2bc667 Binary files /dev/null and b/tests/results/car-1.birefnet-massive.png differ diff --git a/tests/results/car-1.birefnet-portrait.png b/tests/results/car-1.birefnet-portrait.png new file mode 100644 index 0000000..0e783c0 Binary files /dev/null and b/tests/results/car-1.birefnet-portrait.png differ diff --git a/tests/results/cloth-1.birefnet-cod.png b/tests/results/cloth-1.birefnet-cod.png new file mode 100644 index 0000000..74fc20d Binary files /dev/null and b/tests/results/cloth-1.birefnet-cod.png differ diff --git a/tests/results/cloth-1.birefnet-dis.png b/tests/results/cloth-1.birefnet-dis.png new file mode 100644 index 0000000..db420a0 Binary files /dev/null and b/tests/results/cloth-1.birefnet-dis.png differ diff --git a/tests/results/cloth-1.birefnet-general-lite.png b/tests/results/cloth-1.birefnet-general-lite.png new file mode 100644 index 0000000..b9dc114 Binary files /dev/null and b/tests/results/cloth-1.birefnet-general-lite.png differ diff --git a/tests/results/cloth-1.birefnet-general.png b/tests/results/cloth-1.birefnet-general.png new file mode 100644 index 0000000..a224cb6 Binary files /dev/null and b/tests/results/cloth-1.birefnet-general.png differ diff --git a/tests/results/cloth-1.birefnet-hrsod.png b/tests/results/cloth-1.birefnet-hrsod.png new file mode 100644 index 0000000..3e4b712 Binary files /dev/null and b/tests/results/cloth-1.birefnet-hrsod.png differ diff --git a/tests/results/cloth-1.birefnet-massive.png b/tests/results/cloth-1.birefnet-massive.png new file mode 100644 index 0000000..2bbc1d6 Binary files /dev/null and b/tests/results/cloth-1.birefnet-massive.png differ diff --git a/tests/results/cloth-1.birefnet-portrait.png b/tests/results/cloth-1.birefnet-portrait.png new file mode 100644 index 0000000..3f6c41f Binary files /dev/null and b/tests/results/cloth-1.birefnet-portrait.png differ diff --git a/tests/results/plants-1.birefnet-cod.png b/tests/results/plants-1.birefnet-cod.png new file mode 100644 index 0000000..77a9266 Binary files /dev/null and b/tests/results/plants-1.birefnet-cod.png differ diff --git a/tests/results/plants-1.birefnet-dis.png b/tests/results/plants-1.birefnet-dis.png new file mode 100644 index 0000000..17b1a3e Binary files /dev/null and b/tests/results/plants-1.birefnet-dis.png differ diff --git a/tests/results/plants-1.birefnet-general-lite.png b/tests/results/plants-1.birefnet-general-lite.png new file mode 100644 index 0000000..a093d94 Binary files /dev/null and b/tests/results/plants-1.birefnet-general-lite.png differ diff --git a/tests/results/plants-1.birefnet-general.png b/tests/results/plants-1.birefnet-general.png new file mode 100644 index 0000000..2a7e431 Binary files /dev/null and b/tests/results/plants-1.birefnet-general.png differ diff --git a/tests/results/plants-1.birefnet-hrsod.png b/tests/results/plants-1.birefnet-hrsod.png new file mode 100644 index 0000000..27cfb04 Binary files /dev/null and b/tests/results/plants-1.birefnet-hrsod.png differ diff --git a/tests/results/plants-1.birefnet-massive.png b/tests/results/plants-1.birefnet-massive.png new file mode 100644 index 0000000..8f2dd20 Binary files /dev/null and b/tests/results/plants-1.birefnet-massive.png differ diff --git a/tests/results/plants-1.birefnet-portrait.png b/tests/results/plants-1.birefnet-portrait.png new file mode 100644 index 0000000..8abc38a Binary files /dev/null and b/tests/results/plants-1.birefnet-portrait.png differ diff --git a/tests/test_remove.py b/tests/test_remove.py index b35caa3..6e7ee75 100644 --- a/tests/test_remove.py +++ b/tests/test_remove.py @@ -37,7 +37,14 @@ def test_remove(): "silueta", "isnet-general-use", "isnet-anime", - "sam" + "sam", + "birefnet-general", + "birefnet-general-lite", + "birefnet-portrait", + "birefnet-dis", + "birefnet-hrsod", + "birefnet-cod", + "birefnet-massive" ]: for picture in ["anime-girl-1", "car-1", "cloth-1", "plants-1"]: image_path = Path(here / "fixtures" / f"{picture}.jpg")