mirror of
https://git.mirrors.martin98.com/https://github.com/danielgatis/rembg
synced 2025-08-14 22:25:56 +08:00
add alpha matting
This commit is contained in:
parent
df748667e6
commit
215cb3e934
27
README.md
27
README.md
@ -96,10 +96,37 @@ Then run
|
||||
cat input.png | python app.py > out.png
|
||||
```
|
||||
|
||||
### Advance usage
|
||||
|
||||
Sometimes it is possible to achieve better results by turning on alpha matting
|
||||
```bash
|
||||
curl -s http://input.png -a -ae 15 | rembg > output.png
|
||||
```
|
||||
|
||||
Example:
|
||||
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<td>Original</td>
|
||||
<td>Without alpha matting</td>
|
||||
<td>With alpha matting (-a -ae 15)</td>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/examples/food-1.jpg" width="100" /></td>
|
||||
<td><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/examples/food-1.out.jpg" width="100" /></td>
|
||||
<td><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/examples/food-1.out.alpha.jpg" width="100" /></td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
### References
|
||||
|
||||
- https://arxiv.org/pdf/2005.09007.pdf
|
||||
- https://github.com/NathanUA/U-2-Net
|
||||
- https://github.com/pymatting/pymatting
|
||||
|
||||
### License
|
||||
|
||||
|
BIN
examples/food-1.jpg
Normal file
BIN
examples/food-1.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 4.8 MiB |
BIN
examples/food-1.out.alpha.jpg
Normal file
BIN
examples/food-1.out.alpha.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 3.0 MiB |
BIN
examples/food-1.out.jpg
Normal file
BIN
examples/food-1.out.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 4.0 MiB |
@ -7,3 +7,5 @@ torchvision==0.7.0
|
||||
waitress==1.4.4
|
||||
tqdm==4.48.2
|
||||
requests==2.24.0
|
||||
scipy==1.5.2
|
||||
pymatting==1.0.6
|
||||
|
2
setup.py
2
setup.py
@ -11,7 +11,7 @@ with open("requirements.txt") as f:
|
||||
|
||||
setup(
|
||||
name="rembg",
|
||||
version="1.0.10",
|
||||
version="1.0.11",
|
||||
description="Remove image background",
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
|
@ -2,6 +2,10 @@ import io
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
|
||||
from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
|
||||
from pymatting.util.util import stack_images
|
||||
from scipy.ndimage.morphology import binary_erosion
|
||||
|
||||
from .u2net import detect
|
||||
|
||||
@ -9,20 +13,87 @@ model_u2net = detect.load_model(model_name="u2net")
|
||||
model_u2netp = detect.load_model(model_name="u2netp")
|
||||
|
||||
|
||||
def remove(data, model_name="u2net"):
|
||||
def alpha_matting_cutout(
|
||||
img, mask, foreground_threshold, background_threshold, erode_structure_size,
|
||||
):
|
||||
base_size = (1000, 1000)
|
||||
size = img.size
|
||||
|
||||
img.thumbnail(base_size, Image.LANCZOS)
|
||||
mask = mask.resize(img.size, Image.LANCZOS)
|
||||
|
||||
img = np.asarray(img)
|
||||
mask = np.asarray(mask)
|
||||
|
||||
# guess likely foreground/background
|
||||
is_foreground = mask > foreground_threshold
|
||||
is_background = mask < background_threshold
|
||||
|
||||
# erode foreground/background
|
||||
structure = None
|
||||
if erode_structure_size > 0:
|
||||
structure = np.ones((erode_structure_size, erode_structure_size), dtype=np.int)
|
||||
|
||||
is_foreground = binary_erosion(is_foreground, structure=structure)
|
||||
is_background = binary_erosion(is_background, structure=structure, border_value=1)
|
||||
|
||||
# build trimap
|
||||
# 0 = background
|
||||
# 128 = unknown
|
||||
# 255 = foreground
|
||||
trimap = np.full(mask.shape, dtype=np.uint8, fill_value=128)
|
||||
trimap[is_foreground] = 255
|
||||
trimap[is_background] = 0
|
||||
|
||||
# build the cutout image
|
||||
img_normalized = img / 255.0
|
||||
trimap_normalized = trimap / 255.0
|
||||
|
||||
alpha = estimate_alpha_cf(img_normalized, trimap_normalized)
|
||||
foreground = estimate_foreground_ml(img_normalized, alpha)
|
||||
cutout = stack_images(foreground, alpha)
|
||||
|
||||
cutout = np.clip(cutout * 255, 0, 255).astype(np.uint8)
|
||||
cutout = Image.fromarray(cutout)
|
||||
cutout = cutout.resize(size, Image.LANCZOS)
|
||||
|
||||
return cutout
|
||||
|
||||
|
||||
def naive_cutout(img, mask):
|
||||
empty = Image.new("RGBA", (img.size), 0)
|
||||
cutout = Image.composite(img, empty, mask.resize(img.size, Image.LANCZOS))
|
||||
return cutout
|
||||
|
||||
|
||||
def remove(
|
||||
data,
|
||||
model_name="u2net",
|
||||
alpha_matting=False,
|
||||
alpha_matting_foreground_threshold=235,
|
||||
alpha_matting_background_threshold=15,
|
||||
alpha_matting_erode_structure_size=15,
|
||||
):
|
||||
model = model_u2net
|
||||
|
||||
if model == "u2netp":
|
||||
model = model_u2netp
|
||||
|
||||
img = Image.open(io.BytesIO(data))
|
||||
roi = detect.predict(model, np.array(img))
|
||||
roi = roi.resize((img.size), resample=Image.LANCZOS)
|
||||
img = Image.open(io.BytesIO(data)).convert("RGB")
|
||||
mask = detect.predict(model, np.array(img)).convert("L")
|
||||
|
||||
empty = Image.new("RGBA", (img.size), 0)
|
||||
out = Image.composite(img, empty, roi.convert("L"))
|
||||
if alpha_matting:
|
||||
cutout = alpha_matting_cutout(
|
||||
img,
|
||||
mask,
|
||||
alpha_matting_foreground_threshold,
|
||||
alpha_matting_background_threshold,
|
||||
alpha_matting_erode_structure_size,
|
||||
)
|
||||
else:
|
||||
cutout = naive_cutout(img, mask)
|
||||
|
||||
bio = io.BytesIO()
|
||||
out.save(bio, "PNG")
|
||||
cutout.save(bio, "PNG")
|
||||
|
||||
return bio.getbuffer()
|
||||
|
@ -2,6 +2,7 @@ import argparse
|
||||
import glob
|
||||
import imghdr
|
||||
import os
|
||||
from distutils.util import strtobool
|
||||
|
||||
from ..bg import remove
|
||||
|
||||
@ -18,6 +19,40 @@ def main():
|
||||
help="The model name.",
|
||||
)
|
||||
|
||||
ap.add_argument(
|
||||
"-a",
|
||||
"--alpha-matting",
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=False,
|
||||
type=lambda x: bool(strtobool(x)),
|
||||
help="When true use alpha matting cutout.",
|
||||
)
|
||||
|
||||
ap.add_argument(
|
||||
"-af",
|
||||
"--alpha-matting-foreground-threshold",
|
||||
default=235,
|
||||
type=int,
|
||||
help="The trimap foreground threshold.",
|
||||
)
|
||||
|
||||
ap.add_argument(
|
||||
"-ab",
|
||||
"--alpha-matting-background-threshold",
|
||||
default=15,
|
||||
type=int,
|
||||
help="The trimap background threshold.",
|
||||
)
|
||||
|
||||
ap.add_argument(
|
||||
"-ae",
|
||||
"--alpha-matting-erode-size",
|
||||
default=15,
|
||||
type=int,
|
||||
help="Size of element used for the erosion.",
|
||||
)
|
||||
|
||||
ap.add_argument(
|
||||
"-p", "--path", nargs="+", help="Path of a file or a folder of files.",
|
||||
)
|
||||
@ -60,10 +95,30 @@ def main():
|
||||
|
||||
with open(fi, "rb") as input:
|
||||
with open(os.path.splitext(fi)[0] + ".out.png", "wb") as output:
|
||||
w(output, remove(r(input), args.model))
|
||||
w(
|
||||
output,
|
||||
remove(
|
||||
r(input),
|
||||
model_name=args.model,
|
||||
alpha_matting=args.alpha_matting,
|
||||
alpha_matting_foreground_threshold=args.alpha_matting_foreground_threshold,
|
||||
alpha_matting_background_threshold=args.alpha_matting_background_threshold,
|
||||
alpha_matting_erode_structure_size=args.alpha_matting_erode_size,
|
||||
),
|
||||
)
|
||||
|
||||
else:
|
||||
w(args.output, remove(r(args.input), args.model))
|
||||
w(
|
||||
args.output,
|
||||
remove(
|
||||
r(args.input),
|
||||
model_name=args.model,
|
||||
alpha_matting=args.alpha_matting,
|
||||
alpha_matting_foreground_threshold=args.alpha_matting_foreground_threshold,
|
||||
alpha_matting_background_threshold=args.alpha_matting_background_threshold,
|
||||
alpha_matting_erode_structure_size=args.alpha_matting_erode_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -11,24 +11,24 @@ from ..bg import remove
|
||||
app = Flask(__name__)
|
||||
|
||||
|
||||
@app.route('/', methods=['GET', 'POST'])
|
||||
@app.route("/", methods=["GET", "POST"])
|
||||
def index():
|
||||
file_content = ''
|
||||
file_content = ""
|
||||
|
||||
if request.method == 'POST':
|
||||
if 'file' not in request.files:
|
||||
if request.method == "POST":
|
||||
if "file" not in request.files:
|
||||
return {"error": "missing post form param 'file'"}, 400
|
||||
|
||||
file_content = request.files['file'].read()
|
||||
file_content = request.files["file"].read()
|
||||
|
||||
if request.method == 'GET':
|
||||
if request.method == "GET":
|
||||
url = request.args.get("url", type=str)
|
||||
if url is None:
|
||||
return {"error": "missing query param 'url'"}, 400
|
||||
|
||||
file_content = urlopen(unquote_plus(url)).read()
|
||||
|
||||
if file_content == '':
|
||||
if file_content == "":
|
||||
return {"error": "File content is empty"}, 400
|
||||
|
||||
model = request.args.get("model", type=str, default="u2net")
|
||||
@ -36,10 +36,7 @@ def index():
|
||||
return {"error": "invalid query param 'model'"}, 400
|
||||
|
||||
try:
|
||||
return send_file(
|
||||
BytesIO(remove(file_content, model)),
|
||||
mimetype="image/png",
|
||||
)
|
||||
return send_file(BytesIO(remove(file_content, model)), mimetype="image/png",)
|
||||
except Exception as e:
|
||||
app.logger.exception(e, exc_info=True)
|
||||
return {"error": "oops, something went wrong!"}, 500
|
||||
|
@ -107,7 +107,9 @@ def predict(net, item):
|
||||
with torch.no_grad():
|
||||
|
||||
if torch.cuda.is_available():
|
||||
inputs_test = torch.cuda.FloatTensor(sample["image"].unsqueeze(0).cuda().float())
|
||||
inputs_test = torch.cuda.FloatTensor(
|
||||
sample["image"].unsqueeze(0).cuda().float()
|
||||
)
|
||||
else:
|
||||
inputs_test = torch.FloatTensor(sample["image"].unsqueeze(0).float())
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user