Add BiRefNet-General and BiRefNet-Portrait models as available models (#665)
@ -334,6 +334,13 @@ The available models are:
|
||||
- isnet-general-use ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx), [source](https://github.com/xuebinqin/DIS)): A new pre-trained model for general use cases.
|
||||
- isnet-anime ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-anime.onnx), [source](https://github.com/SkyTNT/anime-segmentation)): A high-accuracy segmentation for anime character.
|
||||
- sam ([download encoder](https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-encoder-quant.onnx), [download decoder](https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-decoder-quant.onnx), [source](https://github.com/facebookresearch/segment-anything)): A pre-trained model for any use cases.
|
||||
- birefnet-general ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-epoch_244.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A pre-trained model for general use cases.
|
||||
- birefnet-general-lite ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-bb_swin_v1_tiny-epoch_232.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A light pre-trained model for general use cases.
|
||||
- birefnet-portrait ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-portrait-epoch_150.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A pre-trained model for human portraits.
|
||||
- birefnet-dis ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-DIS-epoch_590.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A pre-trained model for dichotomous image segmentation (DIS).
|
||||
- birefnet-hrsod ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-HRSOD_DHU-epoch_115.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A pre-trained model for high-resolution salient object detection (HRSOD).
|
||||
- birefnet-cod ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-COD-epoch_125.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A pre-trained model for concealed object detection (COD).
|
||||
- birefnet-massive ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-massive-TR_DIS5K_TR_TEs-epoch_420.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A pre-trained model with massive dataset.
|
||||
|
||||
### How to train your own model
|
||||
|
||||
|
@ -7,6 +7,41 @@ from .base import BaseSession
|
||||
sessions_class: List[type[BaseSession]] = []
|
||||
sessions_names: List[str] = []
|
||||
|
||||
from .birefnet_general import BiRefNetSessionGeneral
|
||||
|
||||
sessions_class.append(BiRefNetSessionGeneral)
|
||||
sessions_names.append(BiRefNetSessionGeneral.name())
|
||||
|
||||
from .birefnet_general_lite import BiRefNetSessionGeneralLite
|
||||
|
||||
sessions_class.append(BiRefNetSessionGeneralLite)
|
||||
sessions_names.append(BiRefNetSessionGeneralLite.name())
|
||||
|
||||
from .birefnet_portrait import BiRefNetSessionPortrait
|
||||
|
||||
sessions_class.append(BiRefNetSessionPortrait)
|
||||
sessions_names.append(BiRefNetSessionPortrait.name())
|
||||
|
||||
from .birefnet_dis import BiRefNetSessionDIS
|
||||
|
||||
sessions_class.append(BiRefNetSessionDIS)
|
||||
sessions_names.append(BiRefNetSessionDIS.name())
|
||||
|
||||
from .birefnet_hrsod import BiRefNetSessionHRSOD
|
||||
|
||||
sessions_class.append(BiRefNetSessionHRSOD)
|
||||
sessions_names.append(BiRefNetSessionHRSOD.name())
|
||||
|
||||
from .birefnet_cod import BiRefNetSessionCOD
|
||||
|
||||
sessions_class.append(BiRefNetSessionCOD)
|
||||
sessions_names.append(BiRefNetSessionCOD.name())
|
||||
|
||||
from .birefnet_massive import BiRefNetSessionMassive
|
||||
|
||||
sessions_class.append(BiRefNetSessionMassive)
|
||||
sessions_names.append(BiRefNetSessionMassive.name())
|
||||
|
||||
from .dis_anime import DisSession
|
||||
|
||||
sessions_class.append(DisSession)
|
||||
|
52
rembg/sessions/birefnet_cod.py
Normal file
@ -0,0 +1,52 @@
|
||||
import os
|
||||
|
||||
import pooch
|
||||
|
||||
from . import BiRefNetSessionGeneral
|
||||
|
||||
|
||||
class BiRefNetSessionCOD(BiRefNetSessionGeneral):
|
||||
"""
|
||||
This class represents a BiRefNet-COD session, which is a subclass of BiRefNetSessionGeneral.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def download_models(cls, *args, **kwargs):
|
||||
"""
|
||||
Downloads the BiRefNet-COD model file from a specific URL and saves it.
|
||||
|
||||
Parameters:
|
||||
*args: Additional positional arguments.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
str: The path to the downloaded model file.
|
||||
"""
|
||||
fname = f"{cls.name(*args, **kwargs)}.onnx"
|
||||
pooch.retrieve(
|
||||
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-COD-epoch_125.onnx",
|
||||
(
|
||||
None
|
||||
if cls.checksum_disabled(*args, **kwargs)
|
||||
else "md5:f6d0d21ca89d287f17e7afe9f5fd3b45"
|
||||
),
|
||||
fname=fname,
|
||||
path=cls.u2net_home(*args, **kwargs),
|
||||
progressbar=True,
|
||||
)
|
||||
|
||||
return os.path.join(cls.u2net_home(*args, **kwargs), fname)
|
||||
|
||||
@classmethod
|
||||
def name(cls, *args, **kwargs):
|
||||
"""
|
||||
Returns the name of the BiRefNet-COD session.
|
||||
|
||||
Parameters:
|
||||
*args: Additional positional arguments.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
str: The name of the session.
|
||||
"""
|
||||
return "birefnet-cod"
|
52
rembg/sessions/birefnet_dis.py
Normal file
@ -0,0 +1,52 @@
|
||||
import os
|
||||
|
||||
import pooch
|
||||
|
||||
from . import BiRefNetSessionGeneral
|
||||
|
||||
|
||||
class BiRefNetSessionDIS(BiRefNetSessionGeneral):
|
||||
"""
|
||||
This class represents a BiRefNet-DIS session, which is a subclass of BiRefNetSessionGeneral.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def download_models(cls, *args, **kwargs):
|
||||
"""
|
||||
Downloads the BiRefNet-DIS model file from a specific URL and saves it.
|
||||
|
||||
Parameters:
|
||||
*args: Additional positional arguments.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
str: The path to the downloaded model file.
|
||||
"""
|
||||
fname = f"{cls.name(*args, **kwargs)}.onnx"
|
||||
pooch.retrieve(
|
||||
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-DIS-epoch_590.onnx",
|
||||
(
|
||||
None
|
||||
if cls.checksum_disabled(*args, **kwargs)
|
||||
else "md5:2d4d44102b446f33a4ebb2e56c051f2b"
|
||||
),
|
||||
fname=fname,
|
||||
path=cls.u2net_home(*args, **kwargs),
|
||||
progressbar=True,
|
||||
)
|
||||
|
||||
return os.path.join(cls.u2net_home(*args, **kwargs), fname)
|
||||
|
||||
@classmethod
|
||||
def name(cls, *args, **kwargs):
|
||||
"""
|
||||
Returns the name of the BiRefNet-DIS session.
|
||||
|
||||
Parameters:
|
||||
*args: Additional positional arguments.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
str: The name of the session.
|
||||
"""
|
||||
return "birefnet-dis"
|
91
rembg/sessions/birefnet_general.py
Normal file
@ -0,0 +1,91 @@
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import pooch
|
||||
from PIL import Image
|
||||
from PIL.Image import Image as PILImage
|
||||
|
||||
from .base import BaseSession
|
||||
|
||||
|
||||
class BiRefNetSessionGeneral(BaseSession):
|
||||
"""
|
||||
This class represents a BiRefNet-General session, which is a subclass of BaseSession.
|
||||
"""
|
||||
|
||||
def sigmoid(self, mat):
|
||||
return 1 / (1 + np.exp(-mat))
|
||||
|
||||
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
||||
"""
|
||||
Predicts the output masks for the input image using the inner session.
|
||||
|
||||
Parameters:
|
||||
img (PILImage): The input image.
|
||||
*args: Additional positional arguments.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
List[PILImage]: The list of output masks.
|
||||
"""
|
||||
ort_outs = self.inner_session.run(
|
||||
None,
|
||||
self.normalize(
|
||||
img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (1024, 1024)
|
||||
),
|
||||
)
|
||||
|
||||
pred = self.sigmoid(ort_outs[0][:, 0, :, :])
|
||||
|
||||
ma = np.max(pred)
|
||||
mi = np.min(pred)
|
||||
|
||||
pred = (pred - mi) / (ma - mi)
|
||||
pred = np.squeeze(pred)
|
||||
|
||||
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
|
||||
mask = mask.resize(img.size, Image.Resampling.LANCZOS)
|
||||
|
||||
return [mask]
|
||||
|
||||
@classmethod
|
||||
def download_models(cls, *args, **kwargs):
|
||||
"""
|
||||
Downloads the BiRefNet-General model file from a specific URL and saves it.
|
||||
|
||||
Parameters:
|
||||
*args: Additional positional arguments.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
str: The path to the downloaded model file.
|
||||
"""
|
||||
fname = f"{cls.name(*args, **kwargs)}.onnx"
|
||||
pooch.retrieve(
|
||||
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-epoch_244.onnx",
|
||||
(
|
||||
None
|
||||
if cls.checksum_disabled(*args, **kwargs)
|
||||
else "md5:7a35a0141cbbc80de11d9c9a28f52697"
|
||||
),
|
||||
fname=fname,
|
||||
path=cls.u2net_home(*args, **kwargs),
|
||||
progressbar=True,
|
||||
)
|
||||
|
||||
return os.path.join(cls.u2net_home(*args, **kwargs), fname)
|
||||
|
||||
@classmethod
|
||||
def name(cls, *args, **kwargs):
|
||||
"""
|
||||
Returns the name of the BiRefNet-General session.
|
||||
|
||||
Parameters:
|
||||
*args: Additional positional arguments.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
str: The name of the session.
|
||||
"""
|
||||
return "birefnet-general"
|
52
rembg/sessions/birefnet_general_lite.py
Normal file
@ -0,0 +1,52 @@
|
||||
import os
|
||||
|
||||
import pooch
|
||||
|
||||
from . import BiRefNetSessionGeneral
|
||||
|
||||
|
||||
class BiRefNetSessionGeneralLite(BiRefNetSessionGeneral):
|
||||
"""
|
||||
This class represents a BiRefNet-General-Lite session, which is a subclass of BiRefNetSessionGeneral.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def download_models(cls, *args, **kwargs):
|
||||
"""
|
||||
Downloads the BiRefNet-General-Lite model file from a specific URL and saves it.
|
||||
|
||||
Parameters:
|
||||
*args: Additional positional arguments.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
str: The path to the downloaded model file.
|
||||
"""
|
||||
fname = f"{cls.name(*args, **kwargs)}.onnx"
|
||||
pooch.retrieve(
|
||||
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-bb_swin_v1_tiny-epoch_232.onnx",
|
||||
(
|
||||
None
|
||||
if cls.checksum_disabled(*args, **kwargs)
|
||||
else "md5:4fab47adc4ff364be1713e97b7e66334"
|
||||
),
|
||||
fname=fname,
|
||||
path=cls.u2net_home(*args, **kwargs),
|
||||
progressbar=True,
|
||||
)
|
||||
|
||||
return os.path.join(cls.u2net_home(*args, **kwargs), fname)
|
||||
|
||||
@classmethod
|
||||
def name(cls, *args, **kwargs):
|
||||
"""
|
||||
Returns the name of the BiRefNet-General-Lite session.
|
||||
|
||||
Parameters:
|
||||
*args: Additional positional arguments.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
str: The name of the session.
|
||||
"""
|
||||
return "birefnet-general-lite"
|
52
rembg/sessions/birefnet_hrsod.py
Normal file
@ -0,0 +1,52 @@
|
||||
import os
|
||||
|
||||
import pooch
|
||||
|
||||
from . import BiRefNetSessionGeneral
|
||||
|
||||
|
||||
class BiRefNetSessionHRSOD(BiRefNetSessionGeneral):
|
||||
"""
|
||||
This class represents a BiRefNet-HRSOD session, which is a subclass of BiRefNetSessionGeneral.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def download_models(cls, *args, **kwargs):
|
||||
"""
|
||||
Downloads the BiRefNet-HRSOD model file from a specific URL and saves it.
|
||||
|
||||
Parameters:
|
||||
*args: Additional positional arguments.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
str: The path to the downloaded model file.
|
||||
"""
|
||||
fname = f"{cls.name(*args, **kwargs)}.onnx"
|
||||
pooch.retrieve(
|
||||
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-HRSOD_DHU-epoch_115.onnx",
|
||||
(
|
||||
None
|
||||
if cls.checksum_disabled(*args, **kwargs)
|
||||
else "md5:c017ade5de8a50ff0fd74d790d268dda"
|
||||
),
|
||||
fname=fname,
|
||||
path=cls.u2net_home(*args, **kwargs),
|
||||
progressbar=True,
|
||||
)
|
||||
|
||||
return os.path.join(cls.u2net_home(*args, **kwargs), fname)
|
||||
|
||||
@classmethod
|
||||
def name(cls, *args, **kwargs):
|
||||
"""
|
||||
Returns the name of the BiRefNet-HRSOD session.
|
||||
|
||||
Parameters:
|
||||
*args: Additional positional arguments.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
str: The name of the session.
|
||||
"""
|
||||
return "birefnet-hrsod"
|
52
rembg/sessions/birefnet_massive.py
Normal file
@ -0,0 +1,52 @@
|
||||
import os
|
||||
|
||||
import pooch
|
||||
|
||||
from . import BiRefNetSessionGeneral
|
||||
|
||||
|
||||
class BiRefNetSessionMassive(BiRefNetSessionGeneral):
|
||||
"""
|
||||
This class represents a BiRefNet-Massive session, which is a subclass of BiRefNetSessionGeneral.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def download_models(cls, *args, **kwargs):
|
||||
"""
|
||||
Downloads the BiRefNet-Massive model file from a specific URL and saves it.
|
||||
|
||||
Parameters:
|
||||
*args: Additional positional arguments.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
str: The path to the downloaded model file.
|
||||
"""
|
||||
fname = f"{cls.name(*args, **kwargs)}.onnx"
|
||||
pooch.retrieve(
|
||||
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-massive-TR_DIS5K_TR_TEs-epoch_420.onnx",
|
||||
(
|
||||
None
|
||||
if cls.checksum_disabled(*args, **kwargs)
|
||||
else "md5:33e726a2136a3d59eb0fdf613e31e3e9"
|
||||
),
|
||||
fname=fname,
|
||||
path=cls.u2net_home(*args, **kwargs),
|
||||
progressbar=True,
|
||||
)
|
||||
|
||||
return os.path.join(cls.u2net_home(*args, **kwargs), fname)
|
||||
|
||||
@classmethod
|
||||
def name(cls, *args, **kwargs):
|
||||
"""
|
||||
Returns the name of the BiRefNet-Massive session.
|
||||
|
||||
Parameters:
|
||||
*args: Additional positional arguments.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
str: The name of the session.
|
||||
"""
|
||||
return "birefnet-massive"
|
52
rembg/sessions/birefnet_portrait.py
Normal file
@ -0,0 +1,52 @@
|
||||
import os
|
||||
|
||||
import pooch
|
||||
|
||||
from . import BiRefNetSessionGeneral
|
||||
|
||||
|
||||
class BiRefNetSessionPortrait(BiRefNetSessionGeneral):
|
||||
"""
|
||||
This class represents a BiRefNet-Portrait session, which is a subclass of BiRefNetSessionGeneral.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def download_models(cls, *args, **kwargs):
|
||||
"""
|
||||
Downloads the BiRefNet-Portrait model file from a specific URL and saves it.
|
||||
|
||||
Parameters:
|
||||
*args: Additional positional arguments.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
str: The path to the downloaded model file.
|
||||
"""
|
||||
fname = f"{cls.name(*args, **kwargs)}.onnx"
|
||||
pooch.retrieve(
|
||||
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-portrait-epoch_150.onnx",
|
||||
(
|
||||
None
|
||||
if cls.checksum_disabled(*args, **kwargs)
|
||||
else "md5:c3a64a6abf20250d090cd055f12a3b67"
|
||||
),
|
||||
fname=fname,
|
||||
path=cls.u2net_home(*args, **kwargs),
|
||||
progressbar=True,
|
||||
)
|
||||
|
||||
return os.path.join(cls.u2net_home(*args, **kwargs), fname)
|
||||
|
||||
@classmethod
|
||||
def name(cls, *args, **kwargs):
|
||||
"""
|
||||
Returns the name of the BiRefNet-Portrait session.
|
||||
|
||||
Parameters:
|
||||
*args: Additional positional arguments.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
str: The name of the session.
|
||||
"""
|
||||
return "birefnet-portrait"
|
BIN
tests/results/anime-girl-1.birefnet-cod.png
Normal file
After Width: | Height: | Size: 183 KiB |
BIN
tests/results/anime-girl-1.birefnet-dis.png
Normal file
After Width: | Height: | Size: 187 KiB |
BIN
tests/results/anime-girl-1.birefnet-general-lite.png
Normal file
After Width: | Height: | Size: 185 KiB |
BIN
tests/results/anime-girl-1.birefnet-general.png
Normal file
After Width: | Height: | Size: 186 KiB |
BIN
tests/results/anime-girl-1.birefnet-hrsod.png
Normal file
After Width: | Height: | Size: 193 KiB |
BIN
tests/results/anime-girl-1.birefnet-massive.png
Normal file
After Width: | Height: | Size: 225 KiB |
BIN
tests/results/anime-girl-1.birefnet-portrait.png
Normal file
After Width: | Height: | Size: 194 KiB |
BIN
tests/results/car-1.birefnet-cod.png
Normal file
After Width: | Height: | Size: 76 KiB |
BIN
tests/results/car-1.birefnet-dis.png
Normal file
After Width: | Height: | Size: 94 KiB |
BIN
tests/results/car-1.birefnet-general-lite.png
Normal file
After Width: | Height: | Size: 95 KiB |
BIN
tests/results/car-1.birefnet-general.png
Normal file
After Width: | Height: | Size: 77 KiB |
BIN
tests/results/car-1.birefnet-hrsod.png
Normal file
After Width: | Height: | Size: 96 KiB |
BIN
tests/results/car-1.birefnet-massive.png
Normal file
After Width: | Height: | Size: 95 KiB |
BIN
tests/results/car-1.birefnet-portrait.png
Normal file
After Width: | Height: | Size: 79 KiB |
BIN
tests/results/cloth-1.birefnet-cod.png
Normal file
After Width: | Height: | Size: 373 KiB |
BIN
tests/results/cloth-1.birefnet-dis.png
Normal file
After Width: | Height: | Size: 391 KiB |
BIN
tests/results/cloth-1.birefnet-general-lite.png
Normal file
After Width: | Height: | Size: 388 KiB |
BIN
tests/results/cloth-1.birefnet-general.png
Normal file
After Width: | Height: | Size: 390 KiB |
BIN
tests/results/cloth-1.birefnet-hrsod.png
Normal file
After Width: | Height: | Size: 393 KiB |
BIN
tests/results/cloth-1.birefnet-massive.png
Normal file
After Width: | Height: | Size: 390 KiB |
BIN
tests/results/cloth-1.birefnet-portrait.png
Normal file
After Width: | Height: | Size: 391 KiB |
BIN
tests/results/plants-1.birefnet-cod.png
Normal file
After Width: | Height: | Size: 393 KiB |
BIN
tests/results/plants-1.birefnet-dis.png
Normal file
After Width: | Height: | Size: 813 KiB |
BIN
tests/results/plants-1.birefnet-general-lite.png
Normal file
After Width: | Height: | Size: 680 KiB |
BIN
tests/results/plants-1.birefnet-general.png
Normal file
After Width: | Height: | Size: 629 KiB |
BIN
tests/results/plants-1.birefnet-hrsod.png
Normal file
After Width: | Height: | Size: 738 KiB |
BIN
tests/results/plants-1.birefnet-massive.png
Normal file
After Width: | Height: | Size: 718 KiB |
BIN
tests/results/plants-1.birefnet-portrait.png
Normal file
After Width: | Height: | Size: 239 KiB |
@ -37,7 +37,14 @@ def test_remove():
|
||||
"silueta",
|
||||
"isnet-general-use",
|
||||
"isnet-anime",
|
||||
"sam"
|
||||
"sam",
|
||||
"birefnet-general",
|
||||
"birefnet-general-lite",
|
||||
"birefnet-portrait",
|
||||
"birefnet-dis",
|
||||
"birefnet-hrsod",
|
||||
"birefnet-cod",
|
||||
"birefnet-massive"
|
||||
]:
|
||||
for picture in ["anime-girl-1", "car-1", "cloth-1", "plants-1"]:
|
||||
image_path = Path(here / "fixtures" / f"{picture}.jpg")
|
||||
|