add alpha matting

This commit is contained in:
Daniel Gatis 2020-10-09 14:41:49 -03:00
parent df748667e6
commit 215cb3e934
10 changed files with 176 additions and 22 deletions

View File

@ -96,10 +96,37 @@ Then run
cat input.png | python app.py > out.png
```
### Advance usage
Sometimes it is possible to achieve better results by turning on alpha matting
```bash
curl -s http://input.png -a -ae 15 | rembg > output.png
```
Example:
<table>
<thead>
<tr>
<td>Original</td>
<td>Without alpha matting</td>
<td>With alpha matting (-a -ae 15)</td>
</tr>
</thead>
<tbody>
<tr>
<td><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/examples/food-1.jpg" width="100" /></td>
<td><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/examples/food-1.out.jpg" width="100" /></td>
<td><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/examples/food-1.out.alpha.jpg" width="100" /></td>
</tr>
</tbody>
</table>
### References
- https://arxiv.org/pdf/2005.09007.pdf
- https://github.com/NathanUA/U-2-Net
- https://github.com/pymatting/pymatting
### License

BIN
examples/food-1.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.8 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.0 MiB

BIN
examples/food-1.out.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.0 MiB

View File

@ -7,3 +7,5 @@ torchvision==0.7.0
waitress==1.4.4
tqdm==4.48.2
requests==2.24.0
scipy==1.5.2
pymatting==1.0.6

View File

@ -11,7 +11,7 @@ with open("requirements.txt") as f:
setup(
name="rembg",
version="1.0.10",
version="1.0.11",
description="Remove image background",
long_description=long_description,
long_description_content_type="text/markdown",

View File

@ -2,6 +2,10 @@ import io
import numpy as np
from PIL import Image
from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
from pymatting.util.util import stack_images
from scipy.ndimage.morphology import binary_erosion
from .u2net import detect
@ -9,20 +13,87 @@ model_u2net = detect.load_model(model_name="u2net")
model_u2netp = detect.load_model(model_name="u2netp")
def remove(data, model_name="u2net"):
def alpha_matting_cutout(
img, mask, foreground_threshold, background_threshold, erode_structure_size,
):
base_size = (1000, 1000)
size = img.size
img.thumbnail(base_size, Image.LANCZOS)
mask = mask.resize(img.size, Image.LANCZOS)
img = np.asarray(img)
mask = np.asarray(mask)
# guess likely foreground/background
is_foreground = mask > foreground_threshold
is_background = mask < background_threshold
# erode foreground/background
structure = None
if erode_structure_size > 0:
structure = np.ones((erode_structure_size, erode_structure_size), dtype=np.int)
is_foreground = binary_erosion(is_foreground, structure=structure)
is_background = binary_erosion(is_background, structure=structure, border_value=1)
# build trimap
# 0 = background
# 128 = unknown
# 255 = foreground
trimap = np.full(mask.shape, dtype=np.uint8, fill_value=128)
trimap[is_foreground] = 255
trimap[is_background] = 0
# build the cutout image
img_normalized = img / 255.0
trimap_normalized = trimap / 255.0
alpha = estimate_alpha_cf(img_normalized, trimap_normalized)
foreground = estimate_foreground_ml(img_normalized, alpha)
cutout = stack_images(foreground, alpha)
cutout = np.clip(cutout * 255, 0, 255).astype(np.uint8)
cutout = Image.fromarray(cutout)
cutout = cutout.resize(size, Image.LANCZOS)
return cutout
def naive_cutout(img, mask):
empty = Image.new("RGBA", (img.size), 0)
cutout = Image.composite(img, empty, mask.resize(img.size, Image.LANCZOS))
return cutout
def remove(
data,
model_name="u2net",
alpha_matting=False,
alpha_matting_foreground_threshold=235,
alpha_matting_background_threshold=15,
alpha_matting_erode_structure_size=15,
):
model = model_u2net
if model == "u2netp":
model = model_u2netp
img = Image.open(io.BytesIO(data))
roi = detect.predict(model, np.array(img))
roi = roi.resize((img.size), resample=Image.LANCZOS)
img = Image.open(io.BytesIO(data)).convert("RGB")
mask = detect.predict(model, np.array(img)).convert("L")
empty = Image.new("RGBA", (img.size), 0)
out = Image.composite(img, empty, roi.convert("L"))
if alpha_matting:
cutout = alpha_matting_cutout(
img,
mask,
alpha_matting_foreground_threshold,
alpha_matting_background_threshold,
alpha_matting_erode_structure_size,
)
else:
cutout = naive_cutout(img, mask)
bio = io.BytesIO()
out.save(bio, "PNG")
cutout.save(bio, "PNG")
return bio.getbuffer()

View File

@ -2,6 +2,7 @@ import argparse
import glob
import imghdr
import os
from distutils.util import strtobool
from ..bg import remove
@ -18,6 +19,40 @@ def main():
help="The model name.",
)
ap.add_argument(
"-a",
"--alpha-matting",
nargs="?",
const=True,
default=False,
type=lambda x: bool(strtobool(x)),
help="When true use alpha matting cutout.",
)
ap.add_argument(
"-af",
"--alpha-matting-foreground-threshold",
default=235,
type=int,
help="The trimap foreground threshold.",
)
ap.add_argument(
"-ab",
"--alpha-matting-background-threshold",
default=15,
type=int,
help="The trimap background threshold.",
)
ap.add_argument(
"-ae",
"--alpha-matting-erode-size",
default=15,
type=int,
help="Size of element used for the erosion.",
)
ap.add_argument(
"-p", "--path", nargs="+", help="Path of a file or a folder of files.",
)
@ -60,10 +95,30 @@ def main():
with open(fi, "rb") as input:
with open(os.path.splitext(fi)[0] + ".out.png", "wb") as output:
w(output, remove(r(input), args.model))
w(
output,
remove(
r(input),
model_name=args.model,
alpha_matting=args.alpha_matting,
alpha_matting_foreground_threshold=args.alpha_matting_foreground_threshold,
alpha_matting_background_threshold=args.alpha_matting_background_threshold,
alpha_matting_erode_structure_size=args.alpha_matting_erode_size,
),
)
else:
w(args.output, remove(r(args.input), args.model))
w(
args.output,
remove(
r(args.input),
model_name=args.model,
alpha_matting=args.alpha_matting,
alpha_matting_foreground_threshold=args.alpha_matting_foreground_threshold,
alpha_matting_background_threshold=args.alpha_matting_background_threshold,
alpha_matting_erode_structure_size=args.alpha_matting_erode_size,
),
)
if __name__ == "__main__":

View File

@ -11,24 +11,24 @@ from ..bg import remove
app = Flask(__name__)
@app.route('/', methods=['GET', 'POST'])
@app.route("/", methods=["GET", "POST"])
def index():
file_content = ''
file_content = ""
if request.method == 'POST':
if 'file' not in request.files:
if request.method == "POST":
if "file" not in request.files:
return {"error": "missing post form param 'file'"}, 400
file_content = request.files['file'].read()
file_content = request.files["file"].read()
if request.method == 'GET':
if request.method == "GET":
url = request.args.get("url", type=str)
if url is None:
return {"error": "missing query param 'url'"}, 400
file_content = urlopen(unquote_plus(url)).read()
if file_content == '':
if file_content == "":
return {"error": "File content is empty"}, 400
model = request.args.get("model", type=str, default="u2net")
@ -36,10 +36,7 @@ def index():
return {"error": "invalid query param 'model'"}, 400
try:
return send_file(
BytesIO(remove(file_content, model)),
mimetype="image/png",
)
return send_file(BytesIO(remove(file_content, model)), mimetype="image/png",)
except Exception as e:
app.logger.exception(e, exc_info=True)
return {"error": "oops, something went wrong!"}, 500

View File

@ -107,7 +107,9 @@ def predict(net, item):
with torch.no_grad():
if torch.cuda.is_available():
inputs_test = torch.cuda.FloatTensor(sample["image"].unsqueeze(0).cuda().float())
inputs_test = torch.cuda.FloatTensor(
sample["image"].unsqueeze(0).cuda().float()
)
else:
inputs_test = torch.FloatTensor(sample["image"].unsqueeze(0).float())