This commit is contained in:
Daniel Gatis 2022-01-20 15:05:45 -03:00
parent 1f6cce322f
commit 385d34da4a
6 changed files with 272 additions and 258 deletions

View File

@ -53,24 +53,24 @@ GPU=1 pip install rembg
Remove the background from a remote image Remove the background from a remote image
```bash ```bash
curl -s http://input.png | rembg > output.png curl -s http://input.png | rembg i > output.png
``` ```
Remove the background from a local file Remove the background from a local file
```bash ```bash
rembg -o path/to/output.png path/to/input.png rembg i path/to/input.png path/to/output.png
``` ```
Remove the background from all images in a folder Remove the background from all images in a folder
```bash ```bash
rembg -p path/to/input path/to/output rembg p path/to/input path/to/output
``` ```
### Usage as a server ### Usage as a server
Start the server Start the server
```bash ```bash
rembg-server rembg s
``` ```
Open your browser to Open your browser to
@ -140,14 +140,14 @@ docker build . -t rembg
Then run with: Then run with:
``` ```
docker run --rm -i rembg <in.png >out.png docker run --rm -i rembg i in.png out.png
``` ```
### Advance usage ### Advance usage
Sometimes it is possible to achieve better results by turning on alpha matting. Example: Sometimes it is possible to achieve better results by turning on alpha matting. Example:
```bash ```bash
curl -s http://input.png | rembg -a -ae 15 > output.png curl -s http://input.png | rembg i -a -ae 15 > output.png
``` ```
<table> <table>

View File

@ -75,11 +75,7 @@ def resize_image(img: Image, width: Optional[int], height: Optional[int]) -> Ima
original_width, original_height = img.size original_width, original_height = img.size
width = original_width if width is None else width width = original_width if width is None else width
height = original_height if height is None else height height = original_height if height is None else height
return ( return img.resize((width, height))
img.resize((width, height))
if original_width != width or original_height != height
else img
)
def remove( def remove(
@ -87,7 +83,7 @@ def remove(
alpha_matting: bool = False, alpha_matting: bool = False,
alpha_matting_foreground_threshold: int = 240, alpha_matting_foreground_threshold: int = 240,
alpha_matting_background_threshold: int = 10, alpha_matting_background_threshold: int = 10,
alpha_matting_erode_structure_size: int = 10, alpha_matting_erode_size: int = 10,
alpha_matting_base_size: int = 1000, alpha_matting_base_size: int = 1000,
session: Optional[ort.InferenceSession] = None, session: Optional[ort.InferenceSession] = None,
width: Optional[int] = None, width: Optional[int] = None,
@ -109,7 +105,7 @@ def remove(
mask, mask,
alpha_matting_foreground_threshold, alpha_matting_foreground_threshold,
alpha_matting_background_threshold, alpha_matting_background_threshold,
alpha_matting_erode_structure_size, alpha_matting_erode_size,
alpha_matting_base_size, alpha_matting_base_size,
) )
except Exception: except Exception:

View File

@ -1,155 +1,283 @@
import argparse import pathlib
import glob
import os
from distutils.util import strtobool
from typing import BinaryIO
import sys import sys
from pathlib import Path from enum import Enum
from typing import IO, Optional
import click
import filetype import filetype
from tqdm import tqdm
import onnxruntime as ort import onnxruntime as ort
import requests
import uvicorn
from fastapi import Depends, FastAPI, File, Query
from starlette.responses import Response
from tqdm import tqdm
from .bg import remove from .bg import remove
from .detect import ort_session from .detect import ort_session
sessions: dict[str, ort.InferenceSession] = {}
@click.group()
@click.version_option()
def main(): def main():
ap = argparse.ArgumentParser() pass
ap.add_argument(
@main.command(help="for a file as input")
@click.option(
"-m", "-m",
"--model", "--model",
default="u2net", default="u2net",
type=str, type=click.Choice(["u2net", "u2netp", "u2net_human_seg"]),
choices=["u2net", "u2netp", "u2net_human_seg"], show_default=True,
help="The model name.", show_choices=True,
) help="model name",
)
ap.add_argument( @click.option(
"-a", "-a",
"--alpha-matting", "--alpha-matting",
nargs="?", is_flag=True,
const=True, show_default=True,
default=False, help="use alpha matting",
type=lambda x: bool(strtobool(x)), )
help="When true use alpha matting cutout.", @click.option(
)
ap.add_argument(
"-af", "-af",
"--alpha-matting-foreground-threshold", "--alpha-matting-foreground-threshold",
default=240, default=240,
type=int, type=int,
help="The trimap foreground threshold.", show_default=True,
) help="trimap fg threshold",
)
ap.add_argument( @click.option(
"-ab", "-ab",
"--alpha-matting-background-threshold", "--alpha-matting-background-threshold",
default=10, default=10,
type=int, type=int,
help="The trimap background threshold.", show_default=True,
) help="trimap bg threshold",
)
ap.add_argument( @click.option(
"-ae", "-ae",
"--alpha-matting-erode-size", "--alpha-matting-erode-size",
default=10, default=10,
type=int, type=int,
help="Size of element used for the erosion.", show_default=True,
) help="erode size",
)
ap.add_argument( @click.option(
"-az", "-az",
"--alpha-matting-base-size", "--alpha-matting-base-size",
default=1000, default=1000,
type=int, type=int,
help="The image base size.", show_default=True,
) help="image base size",
)
ap.add_argument( @click.option(
"-p", "-w",
"--path", "--width",
nargs=2, default=None,
help="An input folder and an output folder.", type=int,
) show_default=True,
help="output image size",
ap.add_argument( )
"input", @click.option(
nargs=(None if sys.stdin.isatty() else "?"), "-h",
default=(None if sys.stdin.isatty() else sys.stdin.buffer), "--height",
type=argparse.FileType("rb"), default=None,
help="Path to the input image.", type=int,
) show_default=True,
help="output image size",
ap.add_argument( )
@click.argument(
"input", default=(None if sys.stdin.isatty() else "-"), type=click.File("rb")
)
@click.argument(
"output", "output",
nargs=(None if sys.stdin.isatty() else "?"), default=(None if sys.stdin.isatty() else "-"),
default=(None if sys.stdin.isatty() else sys.stdout.buffer), type=click.File("wb", lazy=True),
type=argparse.FileType("wb"), )
help="Path to the output png image.", def i(model: str, input: IO, output: IO, **kwargs: dict):
) output.write(remove(input.read(), session=ort_session(model), **kwargs))
args = ap.parse_args()
session = sessions.setdefault(args.model, ort_session(args.model))
if args.path: @main.command(help="for a folder as input")
full_paths = [os.path.abspath(path) for path in args.path] @click.option(
"-m",
input_paths = [full_paths[0]] "--model",
output_path = full_paths[1] default="u2net",
type=click.Choice(["u2net", "u2netp", "u2net_human_seg"]),
if not os.path.exists(output_path): show_default=True,
os.makedirs(output_path) show_choices=True,
help="model name",
input_files = set() )
@click.option(
for input_path in input_paths: "-a",
if os.path.isfile(path): "--alpha-matting",
input_files.add(path) is_flag=True,
else: show_default=True,
input_paths += set(glob.glob(input_path + "/*")) help="use alpha matting",
)
for input_file in tqdm(input_files): @click.option(
input_file_type = filetype.guess(input_file) "-af",
"--alpha-matting-foreground-threshold",
if input_file_type is None: default=240,
type=int,
show_default=True,
help="trimap fg threshold",
)
@click.option(
"-ab",
"--alpha-matting-background-threshold",
default=10,
type=int,
show_default=True,
help="trimap bg threshold",
)
@click.option(
"-ae",
"--alpha-matting-erode-size",
default=10,
type=int,
show_default=True,
help="erode size",
)
@click.option(
"-az",
"--alpha-matting-base-size",
default=1000,
type=int,
show_default=True,
help="image base size",
)
@click.option(
"-w",
"--width",
default=None,
type=int,
show_default=True,
help="output image size",
)
@click.option(
"-h",
"--height",
default=None,
type=int,
show_default=True,
help="output image size",
)
@click.argument(
"input",
type=click.Path(
exists=True,
path_type=pathlib.Path,
file_okay=False,
dir_okay=True,
readable=True,
),
)
@click.argument(
"output",
type=click.Path(
exists=False,
path_type=pathlib.Path,
file_okay=False,
dir_okay=True,
writable=True,
),
)
def p(model: str, input: pathlib.Path, output: pathlib.Path, **kwargs: dict):
session = ort_session(model)
for each_input in tqdm(list(input.glob("**/*"))):
if each_input.is_dir():
continue continue
if input_file_type.mime.find("image") < 0: mimetype = filetype.guess(each_input)
if mimetype is None:
continue
if mimetype.mime.find("image") < 0:
continue continue
out_file = os.path.join( each_output = (output / each_input.name).with_suffix(".png")
output_path, os.path.splitext(os.path.basename(input_file))[0] + ".png" each_output.parents[0].mkdir(parents=True, exist_ok=True)
each_output.write_bytes(
remove(each_input.read_bytes(), session=session, **kwargs)
) )
Path(out_file).write_bytes(
@main.command(help="for a http server")
@click.option(
"-p",
"--port",
default=5000,
type=int,
show_default=True,
help="port",
)
@click.option(
"-l",
"--log_level",
default="info",
type=str,
show_default=True,
help="log level",
)
def s(port: int, log_level: str):
sessions: dict[str, ort.InferenceSession] = {}
app = FastAPI()
class ModelType(str, Enum):
u2net = "u2net"
u2netp = "u2netp"
u2net_human_seg = "u2net_human_seg"
class CommonQueryParams:
def __init__(
self,
model: ModelType = Query(ModelType.u2net),
a: bool = Query(False),
af: int = Query(240, ge=0),
ab: int = Query(10, ge=0),
ae: int = Query(10, ge=0),
az: int = Query(1000, ge=0),
width: Optional[int] = Query(None, gt=0),
height: Optional[int] = Query(None, gt=0),
):
self.model = model
self.width = width
self.height = height
self.a = a
self.af = af
self.ab = ab
self.ae = ae
self.az = az
def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
return Response(
remove( remove(
Path(input_file).read_bytes(), content,
session=session, session=sessions.setdefault(
alpha_matting=args.alpha_matting, commons.model.value, ort_session(commons.model.value)
alpha_matting_foreground_threshold=args.alpha_matting_foreground_threshold, ),
alpha_matting_background_threshold=args.alpha_matting_background_threshold, width=commons.width,
alpha_matting_erode_structure_size=args.alpha_matting_erode_size, height=commons.height,
alpha_matting_base_size=args.alpha_matting_base_size, alpha_matting=commons.a,
) alpha_matting_foreground_threshold=commons.af,
alpha_matting_background_threshold=commons.ab,
alpha_matting_erode_size=commons.ae,
alpha_matting_base_size=commons.az,
),
media_type="image/png",
) )
else: @app.get("/")
args.output.write( def get_index(url: str, commons: CommonQueryParams = Depends()):
remove( return im_without_bg(requests.get(url).content, commons)
args.input.read(),
session=session, @app.post("/")
alpha_matting=args.alpha_matting, def post_index(file: bytes = File(...), commons: CommonQueryParams = Depends()):
alpha_matting_foreground_threshold=args.alpha_matting_foreground_threshold, return im_without_bg(file, commons)
alpha_matting_background_threshold=args.alpha_matting_background_threshold,
alpha_matting_erode_structure_size=args.alpha_matting_erode_size, uvicorn.run(app, host="0.0.0.0", port=port, log_level=log_level)
alpha_matting_base_size=args.alpha_matting_base_size,
)
)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,110 +0,0 @@
import argparse
from enum import Enum
from typing import Optional
import requests
import uvicorn
from fastapi import Depends, FastAPI, File, Form, Query, UploadFile
from PIL import Image
from starlette.responses import Response
import onnxruntime as ort
from .bg import remove
from .detect import ort_session
sessions: dict[str, ort.InferenceSession] = {}
app = FastAPI()
class ModelType(str, Enum):
u2net = "u2net"
u2netp = "u2netp"
u2net_human_seg = "u2net_human_seg"
class CommonQueryParams:
def __init__(
self,
model: ModelType = Query(ModelType.u2net),
a: bool = Query(False),
af: int = Query(240, ge=0),
ab: int = Query(10, ge=0),
ae: int = Query(10, ge=0),
az: int = Query(1000, ge=0),
width: Optional[int] = Query(None, gt=0),
height: Optional[int] = Query(None, gt=0),
):
self.model = model
self.width = width
self.height = height
self.a = a
self.af = af
self.ab = ab
self.ae = ae
self.az = az
def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
return Response(
remove(
content,
session=sessions.setdefault(
commons.model.value, ort_session(commons.model.value)
),
width=commons.width,
height=commons.height,
alpha_matting=commons.a,
alpha_matting_foreground_threshold=commons.af,
alpha_matting_background_threshold=commons.ab,
alpha_matting_erode_structure_size=commons.ae,
alpha_matting_base_size=commons.az,
),
media_type="image/png",
)
@app.get("/")
def get_index(url: str, commons: CommonQueryParams = Depends()):
return im_without_bg(requests.get(url).content, commons)
@app.post("/")
def post_index(file: bytes = File(...), commons: CommonQueryParams = Depends()):
return im_without_bg(file, commons)
def main():
ap = argparse.ArgumentParser()
ap.add_argument(
"-a",
"--addr",
default="0.0.0.0",
type=str,
help="The IP address to bind to.",
)
ap.add_argument(
"-p",
"--port",
default=5000,
type=int,
help="The port to bind to.",
)
ap.add_argument(
"-l",
"--log_level",
default="info",
type=str,
help="The log level.",
)
args = ap.parse_args()
uvicorn.run(
"rembg.server:app", host=args.addr, port=args.port, log_level=args.log_level
)
if __name__ == "__main__":
main()

View File

@ -10,3 +10,4 @@ scikit-image==0.19.1
scipy==1.7.3 scipy==1.7.3
tqdm==4.62.3 tqdm==4.62.3
uvicorn==0.17.0 uvicorn==0.17.0
click==8.0.3

View File

@ -40,7 +40,6 @@ setup(
entry_points={ entry_points={
"console_scripts": [ "console_scripts": [
"rembg=rembg.cli:main", "rembg=rembg.cli:main",
"rembg-server=rembg.server:main",
], ],
}, },
version=versioneer.get_version(), version=versioneer.get_version(),