add mask option and better gpu install

This commit is contained in:
Daniel Gatis 2022-02-11 10:50:50 -03:00
parent 521e49a18a
commit 722e23cc8d
7 changed files with 26 additions and 82 deletions

View File

@ -11,7 +11,7 @@ WORKDIR /rembg
COPY . .
RUN GPU=1 pip3 install .
RUN ["pip3", "install", ".[gpu]"]
ENTRYPOINT ["rembg"]
CMD []

View File

@ -45,7 +45,7 @@ pip install rembg
GPU support:
```bash
GPU=1 pip install rembg
pip install rembg[gpu]
```
### Usage as a cli

View File

@ -18,13 +18,7 @@ def alpha_matting_cutout(
foreground_threshold: int,
background_threshold: int,
erode_structure_size: int,
base_size: int,
) -> Image:
size = img.size
img.thumbnail((base_size, base_size), Image.LANCZOS)
mask = mask.resize(img.size, Image.LANCZOS)
img = np.asarray(img)
mask = np.asarray(mask)
@ -60,45 +54,37 @@ def alpha_matting_cutout(
cutout = np.clip(cutout * 255, 0, 255).astype(np.uint8)
cutout = Image.fromarray(cutout)
cutout = cutout.resize(size, Image.LANCZOS)
return cutout
def naive_cutout(img: Image, mask: Image) -> Image:
empty = Image.new("RGBA", (img.size), 0)
cutout = Image.composite(img, empty, mask.resize(img.size, Image.LANCZOS))
cutout = Image.composite(img, empty, mask)
return cutout
def resize_image(img: Image, width: Optional[int], height: Optional[int]) -> Image:
original_width, original_height = img.size
width = original_width if width is None else width
height = original_height if height is None else height
return img.resize((width, height))
def remove(
data: bytes,
alpha_matting: bool = False,
alpha_matting_foreground_threshold: int = 240,
alpha_matting_background_threshold: int = 10,
alpha_matting_erode_size: int = 10,
alpha_matting_base_size: int = 1000,
session: Optional[ort.InferenceSession] = None,
width: Optional[int] = None,
height: Optional[int] = None,
only_mask: bool = False,
) -> bytes:
img = Image.open(io.BytesIO(data)).convert("RGB")
if width is not None or height is not None:
img = resize_image(img, width, height)
if session is None:
session = ort_session("u2net")
mask = predict(session, np.array(img)).convert("L")
mask = mask.resize(img.size, Image.LANCZOS)
if alpha_matting:
if only_mask:
cutout = mask
elif alpha_matting:
try:
cutout = alpha_matting_cutout(
img,
@ -106,7 +92,6 @@ def remove(
alpha_matting_foreground_threshold,
alpha_matting_background_threshold,
alpha_matting_erode_size,
alpha_matting_base_size,
)
except Exception:
cutout = naive_cutout(img, mask)

View File

@ -66,28 +66,11 @@ def main():
help="erode size",
)
@click.option(
"-az",
"--alpha-matting-base-size",
default=1000,
type=int,
"-om",
"--only-mask",
is_flag=True,
show_default=True,
help="image base size",
)
@click.option(
"-w",
"--width",
default=None,
type=int,
show_default=True,
help="output image size",
)
@click.option(
"-h",
"--height",
default=None,
type=int,
show_default=True,
help="output image size",
help="output only the mask",
)
@click.argument(
"input", default=(None if sys.stdin.isatty() else "-"), type=click.File("rb")
@ -143,28 +126,11 @@ def i(model: str, input: IO, output: IO, **kwargs):
help="erode size",
)
@click.option(
"-az",
"--alpha-matting-base-size",
default=1000,
type=int,
"-om",
"--only-mask",
is_flag=True,
show_default=True,
help="image base size",
)
@click.option(
"-w",
"--width",
default=None,
type=int,
show_default=True,
help="output image size",
)
@click.option(
"-h",
"--height",
default=None,
type=int,
show_default=True,
help="output image size",
help="output only the mask",
)
@click.argument(
"input",
@ -240,18 +206,14 @@ def s(port: int, log_level: str):
af: int = Query(240, ge=0),
ab: int = Query(10, ge=0),
ae: int = Query(10, ge=0),
az: int = Query(1000, ge=0),
width: Optional[int] = Query(None, gt=0),
height: Optional[int] = Query(None, gt=0),
om: bool = Query(False),
):
self.model = model
self.width = width
self.height = height
self.a = a
self.af = af
self.ab = ab
self.ae = ae
self.az = az
self.om = om
def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
return Response(
@ -260,13 +222,11 @@ def s(port: int, log_level: str):
session=sessions.setdefault(
commons.model.value, ort_session(commons.model.value)
),
width=commons.width,
height=commons.height,
alpha_matting=commons.a,
alpha_matting_foreground_threshold=commons.af,
alpha_matting_background_threshold=commons.ab,
alpha_matting_erode_size=commons.ae,
alpha_matting_base_size=commons.az,
only_mask=commons.om,
),
media_type="image/png",
)

View File

@ -1 +0,0 @@
onnxruntime==1.10.0

View File

@ -5,6 +5,7 @@ fastapi==0.72.0
filetype==1.0.9
gdown==4.2.0
numpy==1.21.5
onnxruntime==1.10.0
pillow==9.0.0
pymatting==1.1.5
python-multipart==0.0.5

View File

@ -14,12 +14,8 @@ long_description = (here / "README.md").read_text(encoding="utf-8")
with open("requirements.txt") as f:
requireds = f.read().splitlines()
if os.getenv("GPU") is None:
with open("requirements-cpu.txt") as f:
requireds += f.read().splitlines()
else:
with open("requirements-gpu.txt") as f:
requireds += f.read().splitlines()
with open("requirements-gpu.txt") as f:
gpu_requireds = f.read().splitlines()
setup(
name="rembg",
@ -42,6 +38,9 @@ setup(
"rembg=rembg.cli:main",
],
},
extras_require={
'gpu': gpu_requireds,
},
version=versioneer.get_version(),
cmdclass=versioneer.get_cmdclass(),
)