Add BiRefNet-General and BiRefNet-Portrait models as available models (#665)

This commit is contained in:
Dimitri Barbot 2024-08-26 16:29:24 +02:00 committed by GitHub
parent ed1c29576f
commit d4c40e1c3e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
38 changed files with 453 additions and 1 deletions

View File

@ -334,6 +334,13 @@ The available models are:
- isnet-general-use ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx), [source](https://github.com/xuebinqin/DIS)): A new pre-trained model for general use cases.
- isnet-anime ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-anime.onnx), [source](https://github.com/SkyTNT/anime-segmentation)): A high-accuracy segmentation for anime character.
- sam ([download encoder](https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-encoder-quant.onnx), [download decoder](https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-decoder-quant.onnx), [source](https://github.com/facebookresearch/segment-anything)): A pre-trained model for any use cases.
- birefnet-general ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-epoch_244.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A pre-trained model for general use cases.
- birefnet-general-lite ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-bb_swin_v1_tiny-epoch_232.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A light pre-trained model for general use cases.
- birefnet-portrait ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-portrait-epoch_150.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A pre-trained model for human portraits.
- birefnet-dis ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-DIS-epoch_590.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A pre-trained model for dichotomous image segmentation (DIS).
- birefnet-hrsod ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-HRSOD_DHU-epoch_115.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A pre-trained model for high-resolution salient object detection (HRSOD).
- birefnet-cod ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-COD-epoch_125.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A pre-trained model for concealed object detection (COD).
- birefnet-massive ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-massive-TR_DIS5K_TR_TEs-epoch_420.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A pre-trained model with massive dataset.
### How to train your own model

View File

@ -7,6 +7,41 @@ from .base import BaseSession
sessions_class: List[type[BaseSession]] = []
sessions_names: List[str] = []
from .birefnet_general import BiRefNetSessionGeneral
sessions_class.append(BiRefNetSessionGeneral)
sessions_names.append(BiRefNetSessionGeneral.name())
from .birefnet_general_lite import BiRefNetSessionGeneralLite
sessions_class.append(BiRefNetSessionGeneralLite)
sessions_names.append(BiRefNetSessionGeneralLite.name())
from .birefnet_portrait import BiRefNetSessionPortrait
sessions_class.append(BiRefNetSessionPortrait)
sessions_names.append(BiRefNetSessionPortrait.name())
from .birefnet_dis import BiRefNetSessionDIS
sessions_class.append(BiRefNetSessionDIS)
sessions_names.append(BiRefNetSessionDIS.name())
from .birefnet_hrsod import BiRefNetSessionHRSOD
sessions_class.append(BiRefNetSessionHRSOD)
sessions_names.append(BiRefNetSessionHRSOD.name())
from .birefnet_cod import BiRefNetSessionCOD
sessions_class.append(BiRefNetSessionCOD)
sessions_names.append(BiRefNetSessionCOD.name())
from .birefnet_massive import BiRefNetSessionMassive
sessions_class.append(BiRefNetSessionMassive)
sessions_names.append(BiRefNetSessionMassive.name())
from .dis_anime import DisSession
sessions_class.append(DisSession)

View File

@ -0,0 +1,52 @@
import os
import pooch
from . import BiRefNetSessionGeneral
class BiRefNetSessionCOD(BiRefNetSessionGeneral):
"""
This class represents a BiRefNet-COD session, which is a subclass of BiRefNetSessionGeneral.
"""
@classmethod
def download_models(cls, *args, **kwargs):
"""
Downloads the BiRefNet-COD model file from a specific URL and saves it.
Parameters:
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The path to the downloaded model file.
"""
fname = f"{cls.name(*args, **kwargs)}.onnx"
pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-COD-epoch_125.onnx",
(
None
if cls.checksum_disabled(*args, **kwargs)
else "md5:f6d0d21ca89d287f17e7afe9f5fd3b45"
),
fname=fname,
path=cls.u2net_home(*args, **kwargs),
progressbar=True,
)
return os.path.join(cls.u2net_home(*args, **kwargs), fname)
@classmethod
def name(cls, *args, **kwargs):
"""
Returns the name of the BiRefNet-COD session.
Parameters:
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The name of the session.
"""
return "birefnet-cod"

View File

@ -0,0 +1,52 @@
import os
import pooch
from . import BiRefNetSessionGeneral
class BiRefNetSessionDIS(BiRefNetSessionGeneral):
"""
This class represents a BiRefNet-DIS session, which is a subclass of BiRefNetSessionGeneral.
"""
@classmethod
def download_models(cls, *args, **kwargs):
"""
Downloads the BiRefNet-DIS model file from a specific URL and saves it.
Parameters:
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The path to the downloaded model file.
"""
fname = f"{cls.name(*args, **kwargs)}.onnx"
pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-DIS-epoch_590.onnx",
(
None
if cls.checksum_disabled(*args, **kwargs)
else "md5:2d4d44102b446f33a4ebb2e56c051f2b"
),
fname=fname,
path=cls.u2net_home(*args, **kwargs),
progressbar=True,
)
return os.path.join(cls.u2net_home(*args, **kwargs), fname)
@classmethod
def name(cls, *args, **kwargs):
"""
Returns the name of the BiRefNet-DIS session.
Parameters:
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The name of the session.
"""
return "birefnet-dis"

View File

@ -0,0 +1,91 @@
import os
from typing import List
import numpy as np
import pooch
from PIL import Image
from PIL.Image import Image as PILImage
from .base import BaseSession
class BiRefNetSessionGeneral(BaseSession):
"""
This class represents a BiRefNet-General session, which is a subclass of BaseSession.
"""
def sigmoid(self, mat):
return 1 / (1 + np.exp(-mat))
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
"""
Predicts the output masks for the input image using the inner session.
Parameters:
img (PILImage): The input image.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
List[PILImage]: The list of output masks.
"""
ort_outs = self.inner_session.run(
None,
self.normalize(
img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (1024, 1024)
),
)
pred = self.sigmoid(ort_outs[0][:, 0, :, :])
ma = np.max(pred)
mi = np.min(pred)
pred = (pred - mi) / (ma - mi)
pred = np.squeeze(pred)
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
mask = mask.resize(img.size, Image.Resampling.LANCZOS)
return [mask]
@classmethod
def download_models(cls, *args, **kwargs):
"""
Downloads the BiRefNet-General model file from a specific URL and saves it.
Parameters:
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The path to the downloaded model file.
"""
fname = f"{cls.name(*args, **kwargs)}.onnx"
pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-epoch_244.onnx",
(
None
if cls.checksum_disabled(*args, **kwargs)
else "md5:7a35a0141cbbc80de11d9c9a28f52697"
),
fname=fname,
path=cls.u2net_home(*args, **kwargs),
progressbar=True,
)
return os.path.join(cls.u2net_home(*args, **kwargs), fname)
@classmethod
def name(cls, *args, **kwargs):
"""
Returns the name of the BiRefNet-General session.
Parameters:
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The name of the session.
"""
return "birefnet-general"

View File

@ -0,0 +1,52 @@
import os
import pooch
from . import BiRefNetSessionGeneral
class BiRefNetSessionGeneralLite(BiRefNetSessionGeneral):
"""
This class represents a BiRefNet-General-Lite session, which is a subclass of BiRefNetSessionGeneral.
"""
@classmethod
def download_models(cls, *args, **kwargs):
"""
Downloads the BiRefNet-General-Lite model file from a specific URL and saves it.
Parameters:
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The path to the downloaded model file.
"""
fname = f"{cls.name(*args, **kwargs)}.onnx"
pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-bb_swin_v1_tiny-epoch_232.onnx",
(
None
if cls.checksum_disabled(*args, **kwargs)
else "md5:4fab47adc4ff364be1713e97b7e66334"
),
fname=fname,
path=cls.u2net_home(*args, **kwargs),
progressbar=True,
)
return os.path.join(cls.u2net_home(*args, **kwargs), fname)
@classmethod
def name(cls, *args, **kwargs):
"""
Returns the name of the BiRefNet-General-Lite session.
Parameters:
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The name of the session.
"""
return "birefnet-general-lite"

View File

@ -0,0 +1,52 @@
import os
import pooch
from . import BiRefNetSessionGeneral
class BiRefNetSessionHRSOD(BiRefNetSessionGeneral):
"""
This class represents a BiRefNet-HRSOD session, which is a subclass of BiRefNetSessionGeneral.
"""
@classmethod
def download_models(cls, *args, **kwargs):
"""
Downloads the BiRefNet-HRSOD model file from a specific URL and saves it.
Parameters:
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The path to the downloaded model file.
"""
fname = f"{cls.name(*args, **kwargs)}.onnx"
pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-HRSOD_DHU-epoch_115.onnx",
(
None
if cls.checksum_disabled(*args, **kwargs)
else "md5:c017ade5de8a50ff0fd74d790d268dda"
),
fname=fname,
path=cls.u2net_home(*args, **kwargs),
progressbar=True,
)
return os.path.join(cls.u2net_home(*args, **kwargs), fname)
@classmethod
def name(cls, *args, **kwargs):
"""
Returns the name of the BiRefNet-HRSOD session.
Parameters:
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The name of the session.
"""
return "birefnet-hrsod"

View File

@ -0,0 +1,52 @@
import os
import pooch
from . import BiRefNetSessionGeneral
class BiRefNetSessionMassive(BiRefNetSessionGeneral):
"""
This class represents a BiRefNet-Massive session, which is a subclass of BiRefNetSessionGeneral.
"""
@classmethod
def download_models(cls, *args, **kwargs):
"""
Downloads the BiRefNet-Massive model file from a specific URL and saves it.
Parameters:
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The path to the downloaded model file.
"""
fname = f"{cls.name(*args, **kwargs)}.onnx"
pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-massive-TR_DIS5K_TR_TEs-epoch_420.onnx",
(
None
if cls.checksum_disabled(*args, **kwargs)
else "md5:33e726a2136a3d59eb0fdf613e31e3e9"
),
fname=fname,
path=cls.u2net_home(*args, **kwargs),
progressbar=True,
)
return os.path.join(cls.u2net_home(*args, **kwargs), fname)
@classmethod
def name(cls, *args, **kwargs):
"""
Returns the name of the BiRefNet-Massive session.
Parameters:
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The name of the session.
"""
return "birefnet-massive"

View File

@ -0,0 +1,52 @@
import os
import pooch
from . import BiRefNetSessionGeneral
class BiRefNetSessionPortrait(BiRefNetSessionGeneral):
"""
This class represents a BiRefNet-Portrait session, which is a subclass of BiRefNetSessionGeneral.
"""
@classmethod
def download_models(cls, *args, **kwargs):
"""
Downloads the BiRefNet-Portrait model file from a specific URL and saves it.
Parameters:
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The path to the downloaded model file.
"""
fname = f"{cls.name(*args, **kwargs)}.onnx"
pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-portrait-epoch_150.onnx",
(
None
if cls.checksum_disabled(*args, **kwargs)
else "md5:c3a64a6abf20250d090cd055f12a3b67"
),
fname=fname,
path=cls.u2net_home(*args, **kwargs),
progressbar=True,
)
return os.path.join(cls.u2net_home(*args, **kwargs), fname)
@classmethod
def name(cls, *args, **kwargs):
"""
Returns the name of the BiRefNet-Portrait session.
Parameters:
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The name of the session.
"""
return "birefnet-portrait"

Binary file not shown.

After

Width:  |  Height:  |  Size: 183 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 187 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 185 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 186 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 193 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 225 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 194 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 76 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 94 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 95 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 77 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 96 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 95 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 79 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 373 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 391 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 388 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 390 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 393 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 390 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 391 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 393 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 813 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 680 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 629 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 738 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 718 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 239 KiB

View File

@ -37,7 +37,14 @@ def test_remove():
"silueta",
"isnet-general-use",
"isnet-anime",
"sam"
"sam",
"birefnet-general",
"birefnet-general-lite",
"birefnet-portrait",
"birefnet-dis",
"birefnet-hrsod",
"birefnet-cod",
"birefnet-massive"
]:
for picture in ["anime-girl-1", "car-1", "cloth-1", "plants-1"]:
image_path = Path(here / "fixtures" / f"{picture}.jpg")