mirror of
https://git.mirrors.martin98.com/https://github.com/danielgatis/rembg
synced 2025-08-14 06:26:02 +08:00
refact
This commit is contained in:
parent
1f6cce322f
commit
385d34da4a
12
README.md
12
README.md
@ -53,24 +53,24 @@ GPU=1 pip install rembg
|
|||||||
|
|
||||||
Remove the background from a remote image
|
Remove the background from a remote image
|
||||||
```bash
|
```bash
|
||||||
curl -s http://input.png | rembg > output.png
|
curl -s http://input.png | rembg i > output.png
|
||||||
```
|
```
|
||||||
|
|
||||||
Remove the background from a local file
|
Remove the background from a local file
|
||||||
```bash
|
```bash
|
||||||
rembg -o path/to/output.png path/to/input.png
|
rembg i path/to/input.png path/to/output.png
|
||||||
```
|
```
|
||||||
|
|
||||||
Remove the background from all images in a folder
|
Remove the background from all images in a folder
|
||||||
```bash
|
```bash
|
||||||
rembg -p path/to/input path/to/output
|
rembg p path/to/input path/to/output
|
||||||
```
|
```
|
||||||
|
|
||||||
### Usage as a server
|
### Usage as a server
|
||||||
|
|
||||||
Start the server
|
Start the server
|
||||||
```bash
|
```bash
|
||||||
rembg-server
|
rembg s
|
||||||
```
|
```
|
||||||
|
|
||||||
Open your browser to
|
Open your browser to
|
||||||
@ -140,14 +140,14 @@ docker build . -t rembg
|
|||||||
Then run with:
|
Then run with:
|
||||||
|
|
||||||
```
|
```
|
||||||
docker run --rm -i rembg <in.png >out.png
|
docker run --rm -i rembg i in.png out.png
|
||||||
```
|
```
|
||||||
|
|
||||||
### Advance usage
|
### Advance usage
|
||||||
|
|
||||||
Sometimes it is possible to achieve better results by turning on alpha matting. Example:
|
Sometimes it is possible to achieve better results by turning on alpha matting. Example:
|
||||||
```bash
|
```bash
|
||||||
curl -s http://input.png | rembg -a -ae 15 > output.png
|
curl -s http://input.png | rembg i -a -ae 15 > output.png
|
||||||
```
|
```
|
||||||
|
|
||||||
<table>
|
<table>
|
||||||
|
10
rembg/bg.py
10
rembg/bg.py
@ -75,11 +75,7 @@ def resize_image(img: Image, width: Optional[int], height: Optional[int]) -> Ima
|
|||||||
original_width, original_height = img.size
|
original_width, original_height = img.size
|
||||||
width = original_width if width is None else width
|
width = original_width if width is None else width
|
||||||
height = original_height if height is None else height
|
height = original_height if height is None else height
|
||||||
return (
|
return img.resize((width, height))
|
||||||
img.resize((width, height))
|
|
||||||
if original_width != width or original_height != height
|
|
||||||
else img
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def remove(
|
def remove(
|
||||||
@ -87,7 +83,7 @@ def remove(
|
|||||||
alpha_matting: bool = False,
|
alpha_matting: bool = False,
|
||||||
alpha_matting_foreground_threshold: int = 240,
|
alpha_matting_foreground_threshold: int = 240,
|
||||||
alpha_matting_background_threshold: int = 10,
|
alpha_matting_background_threshold: int = 10,
|
||||||
alpha_matting_erode_structure_size: int = 10,
|
alpha_matting_erode_size: int = 10,
|
||||||
alpha_matting_base_size: int = 1000,
|
alpha_matting_base_size: int = 1000,
|
||||||
session: Optional[ort.InferenceSession] = None,
|
session: Optional[ort.InferenceSession] = None,
|
||||||
width: Optional[int] = None,
|
width: Optional[int] = None,
|
||||||
@ -109,7 +105,7 @@ def remove(
|
|||||||
mask,
|
mask,
|
||||||
alpha_matting_foreground_threshold,
|
alpha_matting_foreground_threshold,
|
||||||
alpha_matting_background_threshold,
|
alpha_matting_background_threshold,
|
||||||
alpha_matting_erode_structure_size,
|
alpha_matting_erode_size,
|
||||||
alpha_matting_base_size,
|
alpha_matting_base_size,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
344
rembg/cli.py
344
rembg/cli.py
@ -1,155 +1,283 @@
|
|||||||
import argparse
|
import pathlib
|
||||||
import glob
|
|
||||||
import os
|
|
||||||
from distutils.util import strtobool
|
|
||||||
from typing import BinaryIO
|
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from enum import Enum
|
||||||
|
from typing import IO, Optional
|
||||||
|
|
||||||
|
import click
|
||||||
import filetype
|
import filetype
|
||||||
from tqdm import tqdm
|
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
|
import requests
|
||||||
|
import uvicorn
|
||||||
|
from fastapi import Depends, FastAPI, File, Query
|
||||||
|
from starlette.responses import Response
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from .bg import remove
|
from .bg import remove
|
||||||
from .detect import ort_session
|
from .detect import ort_session
|
||||||
|
|
||||||
sessions: dict[str, ort.InferenceSession] = {}
|
|
||||||
|
|
||||||
|
|
||||||
|
@click.group()
|
||||||
|
@click.version_option()
|
||||||
def main():
|
def main():
|
||||||
ap = argparse.ArgumentParser()
|
pass
|
||||||
|
|
||||||
ap.add_argument(
|
|
||||||
|
@main.command(help="for a file as input")
|
||||||
|
@click.option(
|
||||||
"-m",
|
"-m",
|
||||||
"--model",
|
"--model",
|
||||||
default="u2net",
|
default="u2net",
|
||||||
type=str,
|
type=click.Choice(["u2net", "u2netp", "u2net_human_seg"]),
|
||||||
choices=["u2net", "u2netp", "u2net_human_seg"],
|
show_default=True,
|
||||||
help="The model name.",
|
show_choices=True,
|
||||||
)
|
help="model name",
|
||||||
|
)
|
||||||
ap.add_argument(
|
@click.option(
|
||||||
"-a",
|
"-a",
|
||||||
"--alpha-matting",
|
"--alpha-matting",
|
||||||
nargs="?",
|
is_flag=True,
|
||||||
const=True,
|
show_default=True,
|
||||||
default=False,
|
help="use alpha matting",
|
||||||
type=lambda x: bool(strtobool(x)),
|
)
|
||||||
help="When true use alpha matting cutout.",
|
@click.option(
|
||||||
)
|
|
||||||
|
|
||||||
ap.add_argument(
|
|
||||||
"-af",
|
"-af",
|
||||||
"--alpha-matting-foreground-threshold",
|
"--alpha-matting-foreground-threshold",
|
||||||
default=240,
|
default=240,
|
||||||
type=int,
|
type=int,
|
||||||
help="The trimap foreground threshold.",
|
show_default=True,
|
||||||
)
|
help="trimap fg threshold",
|
||||||
|
)
|
||||||
ap.add_argument(
|
@click.option(
|
||||||
"-ab",
|
"-ab",
|
||||||
"--alpha-matting-background-threshold",
|
"--alpha-matting-background-threshold",
|
||||||
default=10,
|
default=10,
|
||||||
type=int,
|
type=int,
|
||||||
help="The trimap background threshold.",
|
show_default=True,
|
||||||
)
|
help="trimap bg threshold",
|
||||||
|
)
|
||||||
ap.add_argument(
|
@click.option(
|
||||||
"-ae",
|
"-ae",
|
||||||
"--alpha-matting-erode-size",
|
"--alpha-matting-erode-size",
|
||||||
default=10,
|
default=10,
|
||||||
type=int,
|
type=int,
|
||||||
help="Size of element used for the erosion.",
|
show_default=True,
|
||||||
)
|
help="erode size",
|
||||||
|
)
|
||||||
ap.add_argument(
|
@click.option(
|
||||||
"-az",
|
"-az",
|
||||||
"--alpha-matting-base-size",
|
"--alpha-matting-base-size",
|
||||||
default=1000,
|
default=1000,
|
||||||
type=int,
|
type=int,
|
||||||
help="The image base size.",
|
show_default=True,
|
||||||
)
|
help="image base size",
|
||||||
|
)
|
||||||
ap.add_argument(
|
@click.option(
|
||||||
"-p",
|
"-w",
|
||||||
"--path",
|
"--width",
|
||||||
nargs=2,
|
default=None,
|
||||||
help="An input folder and an output folder.",
|
type=int,
|
||||||
)
|
show_default=True,
|
||||||
|
help="output image size",
|
||||||
ap.add_argument(
|
)
|
||||||
"input",
|
@click.option(
|
||||||
nargs=(None if sys.stdin.isatty() else "?"),
|
"-h",
|
||||||
default=(None if sys.stdin.isatty() else sys.stdin.buffer),
|
"--height",
|
||||||
type=argparse.FileType("rb"),
|
default=None,
|
||||||
help="Path to the input image.",
|
type=int,
|
||||||
)
|
show_default=True,
|
||||||
|
help="output image size",
|
||||||
ap.add_argument(
|
)
|
||||||
|
@click.argument(
|
||||||
|
"input", default=(None if sys.stdin.isatty() else "-"), type=click.File("rb")
|
||||||
|
)
|
||||||
|
@click.argument(
|
||||||
"output",
|
"output",
|
||||||
nargs=(None if sys.stdin.isatty() else "?"),
|
default=(None if sys.stdin.isatty() else "-"),
|
||||||
default=(None if sys.stdin.isatty() else sys.stdout.buffer),
|
type=click.File("wb", lazy=True),
|
||||||
type=argparse.FileType("wb"),
|
)
|
||||||
help="Path to the output png image.",
|
def i(model: str, input: IO, output: IO, **kwargs: dict):
|
||||||
)
|
output.write(remove(input.read(), session=ort_session(model), **kwargs))
|
||||||
|
|
||||||
args = ap.parse_args()
|
|
||||||
session = sessions.setdefault(args.model, ort_session(args.model))
|
|
||||||
|
|
||||||
if args.path:
|
@main.command(help="for a folder as input")
|
||||||
full_paths = [os.path.abspath(path) for path in args.path]
|
@click.option(
|
||||||
|
"-m",
|
||||||
input_paths = [full_paths[0]]
|
"--model",
|
||||||
output_path = full_paths[1]
|
default="u2net",
|
||||||
|
type=click.Choice(["u2net", "u2netp", "u2net_human_seg"]),
|
||||||
if not os.path.exists(output_path):
|
show_default=True,
|
||||||
os.makedirs(output_path)
|
show_choices=True,
|
||||||
|
help="model name",
|
||||||
input_files = set()
|
)
|
||||||
|
@click.option(
|
||||||
for input_path in input_paths:
|
"-a",
|
||||||
if os.path.isfile(path):
|
"--alpha-matting",
|
||||||
input_files.add(path)
|
is_flag=True,
|
||||||
else:
|
show_default=True,
|
||||||
input_paths += set(glob.glob(input_path + "/*"))
|
help="use alpha matting",
|
||||||
|
)
|
||||||
for input_file in tqdm(input_files):
|
@click.option(
|
||||||
input_file_type = filetype.guess(input_file)
|
"-af",
|
||||||
|
"--alpha-matting-foreground-threshold",
|
||||||
if input_file_type is None:
|
default=240,
|
||||||
|
type=int,
|
||||||
|
show_default=True,
|
||||||
|
help="trimap fg threshold",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"-ab",
|
||||||
|
"--alpha-matting-background-threshold",
|
||||||
|
default=10,
|
||||||
|
type=int,
|
||||||
|
show_default=True,
|
||||||
|
help="trimap bg threshold",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"-ae",
|
||||||
|
"--alpha-matting-erode-size",
|
||||||
|
default=10,
|
||||||
|
type=int,
|
||||||
|
show_default=True,
|
||||||
|
help="erode size",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"-az",
|
||||||
|
"--alpha-matting-base-size",
|
||||||
|
default=1000,
|
||||||
|
type=int,
|
||||||
|
show_default=True,
|
||||||
|
help="image base size",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"-w",
|
||||||
|
"--width",
|
||||||
|
default=None,
|
||||||
|
type=int,
|
||||||
|
show_default=True,
|
||||||
|
help="output image size",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"-h",
|
||||||
|
"--height",
|
||||||
|
default=None,
|
||||||
|
type=int,
|
||||||
|
show_default=True,
|
||||||
|
help="output image size",
|
||||||
|
)
|
||||||
|
@click.argument(
|
||||||
|
"input",
|
||||||
|
type=click.Path(
|
||||||
|
exists=True,
|
||||||
|
path_type=pathlib.Path,
|
||||||
|
file_okay=False,
|
||||||
|
dir_okay=True,
|
||||||
|
readable=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
@click.argument(
|
||||||
|
"output",
|
||||||
|
type=click.Path(
|
||||||
|
exists=False,
|
||||||
|
path_type=pathlib.Path,
|
||||||
|
file_okay=False,
|
||||||
|
dir_okay=True,
|
||||||
|
writable=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def p(model: str, input: pathlib.Path, output: pathlib.Path, **kwargs: dict):
|
||||||
|
session = ort_session(model)
|
||||||
|
for each_input in tqdm(list(input.glob("**/*"))):
|
||||||
|
if each_input.is_dir():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if input_file_type.mime.find("image") < 0:
|
mimetype = filetype.guess(each_input)
|
||||||
|
if mimetype is None:
|
||||||
|
continue
|
||||||
|
if mimetype.mime.find("image") < 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
out_file = os.path.join(
|
each_output = (output / each_input.name).with_suffix(".png")
|
||||||
output_path, os.path.splitext(os.path.basename(input_file))[0] + ".png"
|
each_output.parents[0].mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
each_output.write_bytes(
|
||||||
|
remove(each_input.read_bytes(), session=session, **kwargs)
|
||||||
)
|
)
|
||||||
|
|
||||||
Path(out_file).write_bytes(
|
|
||||||
|
@main.command(help="for a http server")
|
||||||
|
@click.option(
|
||||||
|
"-p",
|
||||||
|
"--port",
|
||||||
|
default=5000,
|
||||||
|
type=int,
|
||||||
|
show_default=True,
|
||||||
|
help="port",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"-l",
|
||||||
|
"--log_level",
|
||||||
|
default="info",
|
||||||
|
type=str,
|
||||||
|
show_default=True,
|
||||||
|
help="log level",
|
||||||
|
)
|
||||||
|
def s(port: int, log_level: str):
|
||||||
|
sessions: dict[str, ort.InferenceSession] = {}
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
class ModelType(str, Enum):
|
||||||
|
u2net = "u2net"
|
||||||
|
u2netp = "u2netp"
|
||||||
|
u2net_human_seg = "u2net_human_seg"
|
||||||
|
|
||||||
|
class CommonQueryParams:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: ModelType = Query(ModelType.u2net),
|
||||||
|
a: bool = Query(False),
|
||||||
|
af: int = Query(240, ge=0),
|
||||||
|
ab: int = Query(10, ge=0),
|
||||||
|
ae: int = Query(10, ge=0),
|
||||||
|
az: int = Query(1000, ge=0),
|
||||||
|
width: Optional[int] = Query(None, gt=0),
|
||||||
|
height: Optional[int] = Query(None, gt=0),
|
||||||
|
):
|
||||||
|
self.model = model
|
||||||
|
self.width = width
|
||||||
|
self.height = height
|
||||||
|
self.a = a
|
||||||
|
self.af = af
|
||||||
|
self.ab = ab
|
||||||
|
self.ae = ae
|
||||||
|
self.az = az
|
||||||
|
|
||||||
|
def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
|
||||||
|
return Response(
|
||||||
remove(
|
remove(
|
||||||
Path(input_file).read_bytes(),
|
content,
|
||||||
session=session,
|
session=sessions.setdefault(
|
||||||
alpha_matting=args.alpha_matting,
|
commons.model.value, ort_session(commons.model.value)
|
||||||
alpha_matting_foreground_threshold=args.alpha_matting_foreground_threshold,
|
),
|
||||||
alpha_matting_background_threshold=args.alpha_matting_background_threshold,
|
width=commons.width,
|
||||||
alpha_matting_erode_structure_size=args.alpha_matting_erode_size,
|
height=commons.height,
|
||||||
alpha_matting_base_size=args.alpha_matting_base_size,
|
alpha_matting=commons.a,
|
||||||
)
|
alpha_matting_foreground_threshold=commons.af,
|
||||||
|
alpha_matting_background_threshold=commons.ab,
|
||||||
|
alpha_matting_erode_size=commons.ae,
|
||||||
|
alpha_matting_base_size=commons.az,
|
||||||
|
),
|
||||||
|
media_type="image/png",
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
@app.get("/")
|
||||||
args.output.write(
|
def get_index(url: str, commons: CommonQueryParams = Depends()):
|
||||||
remove(
|
return im_without_bg(requests.get(url).content, commons)
|
||||||
args.input.read(),
|
|
||||||
session=session,
|
@app.post("/")
|
||||||
alpha_matting=args.alpha_matting,
|
def post_index(file: bytes = File(...), commons: CommonQueryParams = Depends()):
|
||||||
alpha_matting_foreground_threshold=args.alpha_matting_foreground_threshold,
|
return im_without_bg(file, commons)
|
||||||
alpha_matting_background_threshold=args.alpha_matting_background_threshold,
|
|
||||||
alpha_matting_erode_structure_size=args.alpha_matting_erode_size,
|
uvicorn.run(app, host="0.0.0.0", port=port, log_level=log_level)
|
||||||
alpha_matting_base_size=args.alpha_matting_base_size,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
110
rembg/server.py
110
rembg/server.py
@ -1,110 +0,0 @@
|
|||||||
import argparse
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import requests
|
|
||||||
import uvicorn
|
|
||||||
from fastapi import Depends, FastAPI, File, Form, Query, UploadFile
|
|
||||||
from PIL import Image
|
|
||||||
from starlette.responses import Response
|
|
||||||
import onnxruntime as ort
|
|
||||||
|
|
||||||
from .bg import remove
|
|
||||||
from .detect import ort_session
|
|
||||||
|
|
||||||
sessions: dict[str, ort.InferenceSession] = {}
|
|
||||||
app = FastAPI()
|
|
||||||
|
|
||||||
|
|
||||||
class ModelType(str, Enum):
|
|
||||||
u2net = "u2net"
|
|
||||||
u2netp = "u2netp"
|
|
||||||
u2net_human_seg = "u2net_human_seg"
|
|
||||||
|
|
||||||
|
|
||||||
class CommonQueryParams:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: ModelType = Query(ModelType.u2net),
|
|
||||||
a: bool = Query(False),
|
|
||||||
af: int = Query(240, ge=0),
|
|
||||||
ab: int = Query(10, ge=0),
|
|
||||||
ae: int = Query(10, ge=0),
|
|
||||||
az: int = Query(1000, ge=0),
|
|
||||||
width: Optional[int] = Query(None, gt=0),
|
|
||||||
height: Optional[int] = Query(None, gt=0),
|
|
||||||
):
|
|
||||||
self.model = model
|
|
||||||
self.width = width
|
|
||||||
self.height = height
|
|
||||||
self.a = a
|
|
||||||
self.af = af
|
|
||||||
self.ab = ab
|
|
||||||
self.ae = ae
|
|
||||||
self.az = az
|
|
||||||
|
|
||||||
|
|
||||||
def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
|
|
||||||
return Response(
|
|
||||||
remove(
|
|
||||||
content,
|
|
||||||
session=sessions.setdefault(
|
|
||||||
commons.model.value, ort_session(commons.model.value)
|
|
||||||
),
|
|
||||||
width=commons.width,
|
|
||||||
height=commons.height,
|
|
||||||
alpha_matting=commons.a,
|
|
||||||
alpha_matting_foreground_threshold=commons.af,
|
|
||||||
alpha_matting_background_threshold=commons.ab,
|
|
||||||
alpha_matting_erode_structure_size=commons.ae,
|
|
||||||
alpha_matting_base_size=commons.az,
|
|
||||||
),
|
|
||||||
media_type="image/png",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/")
|
|
||||||
def get_index(url: str, commons: CommonQueryParams = Depends()):
|
|
||||||
return im_without_bg(requests.get(url).content, commons)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/")
|
|
||||||
def post_index(file: bytes = File(...), commons: CommonQueryParams = Depends()):
|
|
||||||
return im_without_bg(file, commons)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
ap = argparse.ArgumentParser()
|
|
||||||
|
|
||||||
ap.add_argument(
|
|
||||||
"-a",
|
|
||||||
"--addr",
|
|
||||||
default="0.0.0.0",
|
|
||||||
type=str,
|
|
||||||
help="The IP address to bind to.",
|
|
||||||
)
|
|
||||||
|
|
||||||
ap.add_argument(
|
|
||||||
"-p",
|
|
||||||
"--port",
|
|
||||||
default=5000,
|
|
||||||
type=int,
|
|
||||||
help="The port to bind to.",
|
|
||||||
)
|
|
||||||
|
|
||||||
ap.add_argument(
|
|
||||||
"-l",
|
|
||||||
"--log_level",
|
|
||||||
default="info",
|
|
||||||
type=str,
|
|
||||||
help="The log level.",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = ap.parse_args()
|
|
||||||
uvicorn.run(
|
|
||||||
"rembg.server:app", host=args.addr, port=args.port, log_level=args.log_level
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@ -10,3 +10,4 @@ scikit-image==0.19.1
|
|||||||
scipy==1.7.3
|
scipy==1.7.3
|
||||||
tqdm==4.62.3
|
tqdm==4.62.3
|
||||||
uvicorn==0.17.0
|
uvicorn==0.17.0
|
||||||
|
click==8.0.3
|
||||||
|
Loading…
x
Reference in New Issue
Block a user