This commit is contained in:
Daniel Gatis 2022-01-20 15:05:45 -03:00
parent 1f6cce322f
commit 385d34da4a
6 changed files with 272 additions and 258 deletions

View File

@ -53,24 +53,24 @@ GPU=1 pip install rembg
Remove the background from a remote image
```bash
curl -s http://input.png | rembg > output.png
curl -s http://input.png | rembg i > output.png
```
Remove the background from a local file
```bash
rembg -o path/to/output.png path/to/input.png
rembg i path/to/input.png path/to/output.png
```
Remove the background from all images in a folder
```bash
rembg -p path/to/input path/to/output
rembg p path/to/input path/to/output
```
### Usage as a server
Start the server
```bash
rembg-server
rembg s
```
Open your browser to
@ -140,14 +140,14 @@ docker build . -t rembg
Then run with:
```
docker run --rm -i rembg <in.png >out.png
docker run --rm -i rembg i in.png out.png
```
### Advance usage
Sometimes it is possible to achieve better results by turning on alpha matting. Example:
```bash
curl -s http://input.png | rembg -a -ae 15 > output.png
curl -s http://input.png | rembg i -a -ae 15 > output.png
```
<table>

View File

@ -75,11 +75,7 @@ def resize_image(img: Image, width: Optional[int], height: Optional[int]) -> Ima
original_width, original_height = img.size
width = original_width if width is None else width
height = original_height if height is None else height
return (
img.resize((width, height))
if original_width != width or original_height != height
else img
)
return img.resize((width, height))
def remove(
@ -87,7 +83,7 @@ def remove(
alpha_matting: bool = False,
alpha_matting_foreground_threshold: int = 240,
alpha_matting_background_threshold: int = 10,
alpha_matting_erode_structure_size: int = 10,
alpha_matting_erode_size: int = 10,
alpha_matting_base_size: int = 1000,
session: Optional[ort.InferenceSession] = None,
width: Optional[int] = None,
@ -109,7 +105,7 @@ def remove(
mask,
alpha_matting_foreground_threshold,
alpha_matting_background_threshold,
alpha_matting_erode_structure_size,
alpha_matting_erode_size,
alpha_matting_base_size,
)
except Exception:

View File

@ -1,156 +1,284 @@
import argparse
import glob
import os
from distutils.util import strtobool
from typing import BinaryIO
import pathlib
import sys
from pathlib import Path
from enum import Enum
from typing import IO, Optional
import click
import filetype
from tqdm import tqdm
import onnxruntime as ort
import requests
import uvicorn
from fastapi import Depends, FastAPI, File, Query
from starlette.responses import Response
from tqdm import tqdm
from .bg import remove
from .detect import ort_session
sessions: dict[str, ort.InferenceSession] = {}
@click.group()
@click.version_option()
def main():
ap = argparse.ArgumentParser()
pass
ap.add_argument(
"-m",
"--model",
default="u2net",
type=str,
choices=["u2net", "u2netp", "u2net_human_seg"],
help="The model name.",
)
ap.add_argument(
"-a",
"--alpha-matting",
nargs="?",
const=True,
default=False,
type=lambda x: bool(strtobool(x)),
help="When true use alpha matting cutout.",
)
@main.command(help="for a file as input")
@click.option(
"-m",
"--model",
default="u2net",
type=click.Choice(["u2net", "u2netp", "u2net_human_seg"]),
show_default=True,
show_choices=True,
help="model name",
)
@click.option(
"-a",
"--alpha-matting",
is_flag=True,
show_default=True,
help="use alpha matting",
)
@click.option(
"-af",
"--alpha-matting-foreground-threshold",
default=240,
type=int,
show_default=True,
help="trimap fg threshold",
)
@click.option(
"-ab",
"--alpha-matting-background-threshold",
default=10,
type=int,
show_default=True,
help="trimap bg threshold",
)
@click.option(
"-ae",
"--alpha-matting-erode-size",
default=10,
type=int,
show_default=True,
help="erode size",
)
@click.option(
"-az",
"--alpha-matting-base-size",
default=1000,
type=int,
show_default=True,
help="image base size",
)
@click.option(
"-w",
"--width",
default=None,
type=int,
show_default=True,
help="output image size",
)
@click.option(
"-h",
"--height",
default=None,
type=int,
show_default=True,
help="output image size",
)
@click.argument(
"input", default=(None if sys.stdin.isatty() else "-"), type=click.File("rb")
)
@click.argument(
"output",
default=(None if sys.stdin.isatty() else "-"),
type=click.File("wb", lazy=True),
)
def i(model: str, input: IO, output: IO, **kwargs: dict):
output.write(remove(input.read(), session=ort_session(model), **kwargs))
ap.add_argument(
"-af",
"--alpha-matting-foreground-threshold",
default=240,
type=int,
help="The trimap foreground threshold.",
)
ap.add_argument(
"-ab",
"--alpha-matting-background-threshold",
default=10,
type=int,
help="The trimap background threshold.",
)
@main.command(help="for a folder as input")
@click.option(
"-m",
"--model",
default="u2net",
type=click.Choice(["u2net", "u2netp", "u2net_human_seg"]),
show_default=True,
show_choices=True,
help="model name",
)
@click.option(
"-a",
"--alpha-matting",
is_flag=True,
show_default=True,
help="use alpha matting",
)
@click.option(
"-af",
"--alpha-matting-foreground-threshold",
default=240,
type=int,
show_default=True,
help="trimap fg threshold",
)
@click.option(
"-ab",
"--alpha-matting-background-threshold",
default=10,
type=int,
show_default=True,
help="trimap bg threshold",
)
@click.option(
"-ae",
"--alpha-matting-erode-size",
default=10,
type=int,
show_default=True,
help="erode size",
)
@click.option(
"-az",
"--alpha-matting-base-size",
default=1000,
type=int,
show_default=True,
help="image base size",
)
@click.option(
"-w",
"--width",
default=None,
type=int,
show_default=True,
help="output image size",
)
@click.option(
"-h",
"--height",
default=None,
type=int,
show_default=True,
help="output image size",
)
@click.argument(
"input",
type=click.Path(
exists=True,
path_type=pathlib.Path,
file_okay=False,
dir_okay=True,
readable=True,
),
)
@click.argument(
"output",
type=click.Path(
exists=False,
path_type=pathlib.Path,
file_okay=False,
dir_okay=True,
writable=True,
),
)
def p(model: str, input: pathlib.Path, output: pathlib.Path, **kwargs: dict):
session = ort_session(model)
for each_input in tqdm(list(input.glob("**/*"))):
if each_input.is_dir():
continue
ap.add_argument(
"-ae",
"--alpha-matting-erode-size",
default=10,
type=int,
help="Size of element used for the erosion.",
)
mimetype = filetype.guess(each_input)
if mimetype is None:
continue
if mimetype.mime.find("image") < 0:
continue
ap.add_argument(
"-az",
"--alpha-matting-base-size",
default=1000,
type=int,
help="The image base size.",
)
each_output = (output / each_input.name).with_suffix(".png")
each_output.parents[0].mkdir(parents=True, exist_ok=True)
ap.add_argument(
"-p",
"--path",
nargs=2,
help="An input folder and an output folder.",
)
ap.add_argument(
"input",
nargs=(None if sys.stdin.isatty() else "?"),
default=(None if sys.stdin.isatty() else sys.stdin.buffer),
type=argparse.FileType("rb"),
help="Path to the input image.",
)
ap.add_argument(
"output",
nargs=(None if sys.stdin.isatty() else "?"),
default=(None if sys.stdin.isatty() else sys.stdout.buffer),
type=argparse.FileType("wb"),
help="Path to the output png image.",
)
args = ap.parse_args()
session = sessions.setdefault(args.model, ort_session(args.model))
if args.path:
full_paths = [os.path.abspath(path) for path in args.path]
input_paths = [full_paths[0]]
output_path = full_paths[1]
if not os.path.exists(output_path):
os.makedirs(output_path)
input_files = set()
for input_path in input_paths:
if os.path.isfile(path):
input_files.add(path)
else:
input_paths += set(glob.glob(input_path + "/*"))
for input_file in tqdm(input_files):
input_file_type = filetype.guess(input_file)
if input_file_type is None:
continue
if input_file_type.mime.find("image") < 0:
continue
out_file = os.path.join(
output_path, os.path.splitext(os.path.basename(input_file))[0] + ".png"
)
Path(out_file).write_bytes(
remove(
Path(input_file).read_bytes(),
session=session,
alpha_matting=args.alpha_matting,
alpha_matting_foreground_threshold=args.alpha_matting_foreground_threshold,
alpha_matting_background_threshold=args.alpha_matting_background_threshold,
alpha_matting_erode_structure_size=args.alpha_matting_erode_size,
alpha_matting_base_size=args.alpha_matting_base_size,
)
)
else:
args.output.write(
remove(
args.input.read(),
session=session,
alpha_matting=args.alpha_matting,
alpha_matting_foreground_threshold=args.alpha_matting_foreground_threshold,
alpha_matting_background_threshold=args.alpha_matting_background_threshold,
alpha_matting_erode_structure_size=args.alpha_matting_erode_size,
alpha_matting_base_size=args.alpha_matting_base_size,
)
each_output.write_bytes(
remove(each_input.read_bytes(), session=session, **kwargs)
)
@main.command(help="for a http server")
@click.option(
"-p",
"--port",
default=5000,
type=int,
show_default=True,
help="port",
)
@click.option(
"-l",
"--log_level",
default="info",
type=str,
show_default=True,
help="log level",
)
def s(port: int, log_level: str):
sessions: dict[str, ort.InferenceSession] = {}
app = FastAPI()
class ModelType(str, Enum):
u2net = "u2net"
u2netp = "u2netp"
u2net_human_seg = "u2net_human_seg"
class CommonQueryParams:
def __init__(
self,
model: ModelType = Query(ModelType.u2net),
a: bool = Query(False),
af: int = Query(240, ge=0),
ab: int = Query(10, ge=0),
ae: int = Query(10, ge=0),
az: int = Query(1000, ge=0),
width: Optional[int] = Query(None, gt=0),
height: Optional[int] = Query(None, gt=0),
):
self.model = model
self.width = width
self.height = height
self.a = a
self.af = af
self.ab = ab
self.ae = ae
self.az = az
def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
return Response(
remove(
content,
session=sessions.setdefault(
commons.model.value, ort_session(commons.model.value)
),
width=commons.width,
height=commons.height,
alpha_matting=commons.a,
alpha_matting_foreground_threshold=commons.af,
alpha_matting_background_threshold=commons.ab,
alpha_matting_erode_size=commons.ae,
alpha_matting_base_size=commons.az,
),
media_type="image/png",
)
@app.get("/")
def get_index(url: str, commons: CommonQueryParams = Depends()):
return im_without_bg(requests.get(url).content, commons)
@app.post("/")
def post_index(file: bytes = File(...), commons: CommonQueryParams = Depends()):
return im_without_bg(file, commons)
uvicorn.run(app, host="0.0.0.0", port=port, log_level=log_level)
if __name__ == "__main__":
main()

View File

@ -1,110 +0,0 @@
import argparse
from enum import Enum
from typing import Optional
import requests
import uvicorn
from fastapi import Depends, FastAPI, File, Form, Query, UploadFile
from PIL import Image
from starlette.responses import Response
import onnxruntime as ort
from .bg import remove
from .detect import ort_session
sessions: dict[str, ort.InferenceSession] = {}
app = FastAPI()
class ModelType(str, Enum):
u2net = "u2net"
u2netp = "u2netp"
u2net_human_seg = "u2net_human_seg"
class CommonQueryParams:
def __init__(
self,
model: ModelType = Query(ModelType.u2net),
a: bool = Query(False),
af: int = Query(240, ge=0),
ab: int = Query(10, ge=0),
ae: int = Query(10, ge=0),
az: int = Query(1000, ge=0),
width: Optional[int] = Query(None, gt=0),
height: Optional[int] = Query(None, gt=0),
):
self.model = model
self.width = width
self.height = height
self.a = a
self.af = af
self.ab = ab
self.ae = ae
self.az = az
def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
return Response(
remove(
content,
session=sessions.setdefault(
commons.model.value, ort_session(commons.model.value)
),
width=commons.width,
height=commons.height,
alpha_matting=commons.a,
alpha_matting_foreground_threshold=commons.af,
alpha_matting_background_threshold=commons.ab,
alpha_matting_erode_structure_size=commons.ae,
alpha_matting_base_size=commons.az,
),
media_type="image/png",
)
@app.get("/")
def get_index(url: str, commons: CommonQueryParams = Depends()):
return im_without_bg(requests.get(url).content, commons)
@app.post("/")
def post_index(file: bytes = File(...), commons: CommonQueryParams = Depends()):
return im_without_bg(file, commons)
def main():
ap = argparse.ArgumentParser()
ap.add_argument(
"-a",
"--addr",
default="0.0.0.0",
type=str,
help="The IP address to bind to.",
)
ap.add_argument(
"-p",
"--port",
default=5000,
type=int,
help="The port to bind to.",
)
ap.add_argument(
"-l",
"--log_level",
default="info",
type=str,
help="The log level.",
)
args = ap.parse_args()
uvicorn.run(
"rembg.server:app", host=args.addr, port=args.port, log_level=args.log_level
)
if __name__ == "__main__":
main()

View File

@ -10,3 +10,4 @@ scikit-image==0.19.1
scipy==1.7.3
tqdm==4.62.3
uvicorn==0.17.0
click==8.0.3

View File

@ -40,7 +40,6 @@ setup(
entry_points={
"console_scripts": [
"rembg=rembg.cli:main",
"rembg-server=rembg.server:main",
],
},
version=versioneer.get_version(),