mirror of
https://git.mirrors.martin98.com/https://github.com/danielgatis/rembg
synced 2025-08-14 03:15:56 +08:00
refact
This commit is contained in:
parent
1f6cce322f
commit
385d34da4a
12
README.md
12
README.md
@ -53,24 +53,24 @@ GPU=1 pip install rembg
|
||||
|
||||
Remove the background from a remote image
|
||||
```bash
|
||||
curl -s http://input.png | rembg > output.png
|
||||
curl -s http://input.png | rembg i > output.png
|
||||
```
|
||||
|
||||
Remove the background from a local file
|
||||
```bash
|
||||
rembg -o path/to/output.png path/to/input.png
|
||||
rembg i path/to/input.png path/to/output.png
|
||||
```
|
||||
|
||||
Remove the background from all images in a folder
|
||||
```bash
|
||||
rembg -p path/to/input path/to/output
|
||||
rembg p path/to/input path/to/output
|
||||
```
|
||||
|
||||
### Usage as a server
|
||||
|
||||
Start the server
|
||||
```bash
|
||||
rembg-server
|
||||
rembg s
|
||||
```
|
||||
|
||||
Open your browser to
|
||||
@ -140,14 +140,14 @@ docker build . -t rembg
|
||||
Then run with:
|
||||
|
||||
```
|
||||
docker run --rm -i rembg <in.png >out.png
|
||||
docker run --rm -i rembg i in.png out.png
|
||||
```
|
||||
|
||||
### Advance usage
|
||||
|
||||
Sometimes it is possible to achieve better results by turning on alpha matting. Example:
|
||||
```bash
|
||||
curl -s http://input.png | rembg -a -ae 15 > output.png
|
||||
curl -s http://input.png | rembg i -a -ae 15 > output.png
|
||||
```
|
||||
|
||||
<table>
|
||||
|
10
rembg/bg.py
10
rembg/bg.py
@ -75,11 +75,7 @@ def resize_image(img: Image, width: Optional[int], height: Optional[int]) -> Ima
|
||||
original_width, original_height = img.size
|
||||
width = original_width if width is None else width
|
||||
height = original_height if height is None else height
|
||||
return (
|
||||
img.resize((width, height))
|
||||
if original_width != width or original_height != height
|
||||
else img
|
||||
)
|
||||
return img.resize((width, height))
|
||||
|
||||
|
||||
def remove(
|
||||
@ -87,7 +83,7 @@ def remove(
|
||||
alpha_matting: bool = False,
|
||||
alpha_matting_foreground_threshold: int = 240,
|
||||
alpha_matting_background_threshold: int = 10,
|
||||
alpha_matting_erode_structure_size: int = 10,
|
||||
alpha_matting_erode_size: int = 10,
|
||||
alpha_matting_base_size: int = 1000,
|
||||
session: Optional[ort.InferenceSession] = None,
|
||||
width: Optional[int] = None,
|
||||
@ -109,7 +105,7 @@ def remove(
|
||||
mask,
|
||||
alpha_matting_foreground_threshold,
|
||||
alpha_matting_background_threshold,
|
||||
alpha_matting_erode_structure_size,
|
||||
alpha_matting_erode_size,
|
||||
alpha_matting_base_size,
|
||||
)
|
||||
except Exception:
|
||||
|
396
rembg/cli.py
396
rembg/cli.py
@ -1,156 +1,284 @@
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
from distutils.util import strtobool
|
||||
from typing import BinaryIO
|
||||
import pathlib
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from enum import Enum
|
||||
from typing import IO, Optional
|
||||
|
||||
import click
|
||||
import filetype
|
||||
from tqdm import tqdm
|
||||
import onnxruntime as ort
|
||||
import requests
|
||||
import uvicorn
|
||||
from fastapi import Depends, FastAPI, File, Query
|
||||
from starlette.responses import Response
|
||||
from tqdm import tqdm
|
||||
|
||||
from .bg import remove
|
||||
from .detect import ort_session
|
||||
|
||||
sessions: dict[str, ort.InferenceSession] = {}
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.version_option()
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
pass
|
||||
|
||||
ap.add_argument(
|
||||
"-m",
|
||||
"--model",
|
||||
default="u2net",
|
||||
type=str,
|
||||
choices=["u2net", "u2netp", "u2net_human_seg"],
|
||||
help="The model name.",
|
||||
)
|
||||
|
||||
ap.add_argument(
|
||||
"-a",
|
||||
"--alpha-matting",
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=False,
|
||||
type=lambda x: bool(strtobool(x)),
|
||||
help="When true use alpha matting cutout.",
|
||||
)
|
||||
@main.command(help="for a file as input")
|
||||
@click.option(
|
||||
"-m",
|
||||
"--model",
|
||||
default="u2net",
|
||||
type=click.Choice(["u2net", "u2netp", "u2net_human_seg"]),
|
||||
show_default=True,
|
||||
show_choices=True,
|
||||
help="model name",
|
||||
)
|
||||
@click.option(
|
||||
"-a",
|
||||
"--alpha-matting",
|
||||
is_flag=True,
|
||||
show_default=True,
|
||||
help="use alpha matting",
|
||||
)
|
||||
@click.option(
|
||||
"-af",
|
||||
"--alpha-matting-foreground-threshold",
|
||||
default=240,
|
||||
type=int,
|
||||
show_default=True,
|
||||
help="trimap fg threshold",
|
||||
)
|
||||
@click.option(
|
||||
"-ab",
|
||||
"--alpha-matting-background-threshold",
|
||||
default=10,
|
||||
type=int,
|
||||
show_default=True,
|
||||
help="trimap bg threshold",
|
||||
)
|
||||
@click.option(
|
||||
"-ae",
|
||||
"--alpha-matting-erode-size",
|
||||
default=10,
|
||||
type=int,
|
||||
show_default=True,
|
||||
help="erode size",
|
||||
)
|
||||
@click.option(
|
||||
"-az",
|
||||
"--alpha-matting-base-size",
|
||||
default=1000,
|
||||
type=int,
|
||||
show_default=True,
|
||||
help="image base size",
|
||||
)
|
||||
@click.option(
|
||||
"-w",
|
||||
"--width",
|
||||
default=None,
|
||||
type=int,
|
||||
show_default=True,
|
||||
help="output image size",
|
||||
)
|
||||
@click.option(
|
||||
"-h",
|
||||
"--height",
|
||||
default=None,
|
||||
type=int,
|
||||
show_default=True,
|
||||
help="output image size",
|
||||
)
|
||||
@click.argument(
|
||||
"input", default=(None if sys.stdin.isatty() else "-"), type=click.File("rb")
|
||||
)
|
||||
@click.argument(
|
||||
"output",
|
||||
default=(None if sys.stdin.isatty() else "-"),
|
||||
type=click.File("wb", lazy=True),
|
||||
)
|
||||
def i(model: str, input: IO, output: IO, **kwargs: dict):
|
||||
output.write(remove(input.read(), session=ort_session(model), **kwargs))
|
||||
|
||||
ap.add_argument(
|
||||
"-af",
|
||||
"--alpha-matting-foreground-threshold",
|
||||
default=240,
|
||||
type=int,
|
||||
help="The trimap foreground threshold.",
|
||||
)
|
||||
|
||||
ap.add_argument(
|
||||
"-ab",
|
||||
"--alpha-matting-background-threshold",
|
||||
default=10,
|
||||
type=int,
|
||||
help="The trimap background threshold.",
|
||||
)
|
||||
@main.command(help="for a folder as input")
|
||||
@click.option(
|
||||
"-m",
|
||||
"--model",
|
||||
default="u2net",
|
||||
type=click.Choice(["u2net", "u2netp", "u2net_human_seg"]),
|
||||
show_default=True,
|
||||
show_choices=True,
|
||||
help="model name",
|
||||
)
|
||||
@click.option(
|
||||
"-a",
|
||||
"--alpha-matting",
|
||||
is_flag=True,
|
||||
show_default=True,
|
||||
help="use alpha matting",
|
||||
)
|
||||
@click.option(
|
||||
"-af",
|
||||
"--alpha-matting-foreground-threshold",
|
||||
default=240,
|
||||
type=int,
|
||||
show_default=True,
|
||||
help="trimap fg threshold",
|
||||
)
|
||||
@click.option(
|
||||
"-ab",
|
||||
"--alpha-matting-background-threshold",
|
||||
default=10,
|
||||
type=int,
|
||||
show_default=True,
|
||||
help="trimap bg threshold",
|
||||
)
|
||||
@click.option(
|
||||
"-ae",
|
||||
"--alpha-matting-erode-size",
|
||||
default=10,
|
||||
type=int,
|
||||
show_default=True,
|
||||
help="erode size",
|
||||
)
|
||||
@click.option(
|
||||
"-az",
|
||||
"--alpha-matting-base-size",
|
||||
default=1000,
|
||||
type=int,
|
||||
show_default=True,
|
||||
help="image base size",
|
||||
)
|
||||
@click.option(
|
||||
"-w",
|
||||
"--width",
|
||||
default=None,
|
||||
type=int,
|
||||
show_default=True,
|
||||
help="output image size",
|
||||
)
|
||||
@click.option(
|
||||
"-h",
|
||||
"--height",
|
||||
default=None,
|
||||
type=int,
|
||||
show_default=True,
|
||||
help="output image size",
|
||||
)
|
||||
@click.argument(
|
||||
"input",
|
||||
type=click.Path(
|
||||
exists=True,
|
||||
path_type=pathlib.Path,
|
||||
file_okay=False,
|
||||
dir_okay=True,
|
||||
readable=True,
|
||||
),
|
||||
)
|
||||
@click.argument(
|
||||
"output",
|
||||
type=click.Path(
|
||||
exists=False,
|
||||
path_type=pathlib.Path,
|
||||
file_okay=False,
|
||||
dir_okay=True,
|
||||
writable=True,
|
||||
),
|
||||
)
|
||||
def p(model: str, input: pathlib.Path, output: pathlib.Path, **kwargs: dict):
|
||||
session = ort_session(model)
|
||||
for each_input in tqdm(list(input.glob("**/*"))):
|
||||
if each_input.is_dir():
|
||||
continue
|
||||
|
||||
ap.add_argument(
|
||||
"-ae",
|
||||
"--alpha-matting-erode-size",
|
||||
default=10,
|
||||
type=int,
|
||||
help="Size of element used for the erosion.",
|
||||
)
|
||||
mimetype = filetype.guess(each_input)
|
||||
if mimetype is None:
|
||||
continue
|
||||
if mimetype.mime.find("image") < 0:
|
||||
continue
|
||||
|
||||
ap.add_argument(
|
||||
"-az",
|
||||
"--alpha-matting-base-size",
|
||||
default=1000,
|
||||
type=int,
|
||||
help="The image base size.",
|
||||
)
|
||||
each_output = (output / each_input.name).with_suffix(".png")
|
||||
each_output.parents[0].mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ap.add_argument(
|
||||
"-p",
|
||||
"--path",
|
||||
nargs=2,
|
||||
help="An input folder and an output folder.",
|
||||
)
|
||||
|
||||
ap.add_argument(
|
||||
"input",
|
||||
nargs=(None if sys.stdin.isatty() else "?"),
|
||||
default=(None if sys.stdin.isatty() else sys.stdin.buffer),
|
||||
type=argparse.FileType("rb"),
|
||||
help="Path to the input image.",
|
||||
)
|
||||
|
||||
ap.add_argument(
|
||||
"output",
|
||||
nargs=(None if sys.stdin.isatty() else "?"),
|
||||
default=(None if sys.stdin.isatty() else sys.stdout.buffer),
|
||||
type=argparse.FileType("wb"),
|
||||
help="Path to the output png image.",
|
||||
)
|
||||
|
||||
args = ap.parse_args()
|
||||
session = sessions.setdefault(args.model, ort_session(args.model))
|
||||
|
||||
if args.path:
|
||||
full_paths = [os.path.abspath(path) for path in args.path]
|
||||
|
||||
input_paths = [full_paths[0]]
|
||||
output_path = full_paths[1]
|
||||
|
||||
if not os.path.exists(output_path):
|
||||
os.makedirs(output_path)
|
||||
|
||||
input_files = set()
|
||||
|
||||
for input_path in input_paths:
|
||||
if os.path.isfile(path):
|
||||
input_files.add(path)
|
||||
else:
|
||||
input_paths += set(glob.glob(input_path + "/*"))
|
||||
|
||||
for input_file in tqdm(input_files):
|
||||
input_file_type = filetype.guess(input_file)
|
||||
|
||||
if input_file_type is None:
|
||||
continue
|
||||
|
||||
if input_file_type.mime.find("image") < 0:
|
||||
continue
|
||||
|
||||
out_file = os.path.join(
|
||||
output_path, os.path.splitext(os.path.basename(input_file))[0] + ".png"
|
||||
)
|
||||
|
||||
Path(out_file).write_bytes(
|
||||
remove(
|
||||
Path(input_file).read_bytes(),
|
||||
session=session,
|
||||
alpha_matting=args.alpha_matting,
|
||||
alpha_matting_foreground_threshold=args.alpha_matting_foreground_threshold,
|
||||
alpha_matting_background_threshold=args.alpha_matting_background_threshold,
|
||||
alpha_matting_erode_structure_size=args.alpha_matting_erode_size,
|
||||
alpha_matting_base_size=args.alpha_matting_base_size,
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
args.output.write(
|
||||
remove(
|
||||
args.input.read(),
|
||||
session=session,
|
||||
alpha_matting=args.alpha_matting,
|
||||
alpha_matting_foreground_threshold=args.alpha_matting_foreground_threshold,
|
||||
alpha_matting_background_threshold=args.alpha_matting_background_threshold,
|
||||
alpha_matting_erode_structure_size=args.alpha_matting_erode_size,
|
||||
alpha_matting_base_size=args.alpha_matting_base_size,
|
||||
)
|
||||
each_output.write_bytes(
|
||||
remove(each_input.read_bytes(), session=session, **kwargs)
|
||||
)
|
||||
|
||||
|
||||
@main.command(help="for a http server")
|
||||
@click.option(
|
||||
"-p",
|
||||
"--port",
|
||||
default=5000,
|
||||
type=int,
|
||||
show_default=True,
|
||||
help="port",
|
||||
)
|
||||
@click.option(
|
||||
"-l",
|
||||
"--log_level",
|
||||
default="info",
|
||||
type=str,
|
||||
show_default=True,
|
||||
help="log level",
|
||||
)
|
||||
def s(port: int, log_level: str):
|
||||
sessions: dict[str, ort.InferenceSession] = {}
|
||||
app = FastAPI()
|
||||
|
||||
class ModelType(str, Enum):
|
||||
u2net = "u2net"
|
||||
u2netp = "u2netp"
|
||||
u2net_human_seg = "u2net_human_seg"
|
||||
|
||||
class CommonQueryParams:
|
||||
def __init__(
|
||||
self,
|
||||
model: ModelType = Query(ModelType.u2net),
|
||||
a: bool = Query(False),
|
||||
af: int = Query(240, ge=0),
|
||||
ab: int = Query(10, ge=0),
|
||||
ae: int = Query(10, ge=0),
|
||||
az: int = Query(1000, ge=0),
|
||||
width: Optional[int] = Query(None, gt=0),
|
||||
height: Optional[int] = Query(None, gt=0),
|
||||
):
|
||||
self.model = model
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.a = a
|
||||
self.af = af
|
||||
self.ab = ab
|
||||
self.ae = ae
|
||||
self.az = az
|
||||
|
||||
def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
|
||||
return Response(
|
||||
remove(
|
||||
content,
|
||||
session=sessions.setdefault(
|
||||
commons.model.value, ort_session(commons.model.value)
|
||||
),
|
||||
width=commons.width,
|
||||
height=commons.height,
|
||||
alpha_matting=commons.a,
|
||||
alpha_matting_foreground_threshold=commons.af,
|
||||
alpha_matting_background_threshold=commons.ab,
|
||||
alpha_matting_erode_size=commons.ae,
|
||||
alpha_matting_base_size=commons.az,
|
||||
),
|
||||
media_type="image/png",
|
||||
)
|
||||
|
||||
@app.get("/")
|
||||
def get_index(url: str, commons: CommonQueryParams = Depends()):
|
||||
return im_without_bg(requests.get(url).content, commons)
|
||||
|
||||
@app.post("/")
|
||||
def post_index(file: bytes = File(...), commons: CommonQueryParams = Depends()):
|
||||
return im_without_bg(file, commons)
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=port, log_level=log_level)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
110
rembg/server.py
110
rembg/server.py
@ -1,110 +0,0 @@
|
||||
import argparse
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
import uvicorn
|
||||
from fastapi import Depends, FastAPI, File, Form, Query, UploadFile
|
||||
from PIL import Image
|
||||
from starlette.responses import Response
|
||||
import onnxruntime as ort
|
||||
|
||||
from .bg import remove
|
||||
from .detect import ort_session
|
||||
|
||||
sessions: dict[str, ort.InferenceSession] = {}
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
class ModelType(str, Enum):
|
||||
u2net = "u2net"
|
||||
u2netp = "u2netp"
|
||||
u2net_human_seg = "u2net_human_seg"
|
||||
|
||||
|
||||
class CommonQueryParams:
|
||||
def __init__(
|
||||
self,
|
||||
model: ModelType = Query(ModelType.u2net),
|
||||
a: bool = Query(False),
|
||||
af: int = Query(240, ge=0),
|
||||
ab: int = Query(10, ge=0),
|
||||
ae: int = Query(10, ge=0),
|
||||
az: int = Query(1000, ge=0),
|
||||
width: Optional[int] = Query(None, gt=0),
|
||||
height: Optional[int] = Query(None, gt=0),
|
||||
):
|
||||
self.model = model
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.a = a
|
||||
self.af = af
|
||||
self.ab = ab
|
||||
self.ae = ae
|
||||
self.az = az
|
||||
|
||||
|
||||
def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
|
||||
return Response(
|
||||
remove(
|
||||
content,
|
||||
session=sessions.setdefault(
|
||||
commons.model.value, ort_session(commons.model.value)
|
||||
),
|
||||
width=commons.width,
|
||||
height=commons.height,
|
||||
alpha_matting=commons.a,
|
||||
alpha_matting_foreground_threshold=commons.af,
|
||||
alpha_matting_background_threshold=commons.ab,
|
||||
alpha_matting_erode_structure_size=commons.ae,
|
||||
alpha_matting_base_size=commons.az,
|
||||
),
|
||||
media_type="image/png",
|
||||
)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
def get_index(url: str, commons: CommonQueryParams = Depends()):
|
||||
return im_without_bg(requests.get(url).content, commons)
|
||||
|
||||
|
||||
@app.post("/")
|
||||
def post_index(file: bytes = File(...), commons: CommonQueryParams = Depends()):
|
||||
return im_without_bg(file, commons)
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
|
||||
ap.add_argument(
|
||||
"-a",
|
||||
"--addr",
|
||||
default="0.0.0.0",
|
||||
type=str,
|
||||
help="The IP address to bind to.",
|
||||
)
|
||||
|
||||
ap.add_argument(
|
||||
"-p",
|
||||
"--port",
|
||||
default=5000,
|
||||
type=int,
|
||||
help="The port to bind to.",
|
||||
)
|
||||
|
||||
ap.add_argument(
|
||||
"-l",
|
||||
"--log_level",
|
||||
default="info",
|
||||
type=str,
|
||||
help="The log level.",
|
||||
)
|
||||
|
||||
args = ap.parse_args()
|
||||
uvicorn.run(
|
||||
"rembg.server:app", host=args.addr, port=args.port, log_level=args.log_level
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -10,3 +10,4 @@ scikit-image==0.19.1
|
||||
scipy==1.7.3
|
||||
tqdm==4.62.3
|
||||
uvicorn==0.17.0
|
||||
click==8.0.3
|
||||
|
Loading…
x
Reference in New Issue
Block a user