diff --git a/README.md b/README.md index bf6f61c..1f6f82e 100644 --- a/README.md +++ b/README.md @@ -96,10 +96,37 @@ Then run cat input.png | python app.py > out.png ``` +### Advance usage + +Sometimes it is possible to achieve better results by turning on alpha matting +```bash + curl -s http://input.png -a -ae 15 | rembg > output.png +``` + +Example: + + + + + + + + + + + + + + + + +
OriginalWithout alpha mattingWith alpha matting (-a -ae 15)
+ ### References - https://arxiv.org/pdf/2005.09007.pdf - https://github.com/NathanUA/U-2-Net +- https://github.com/pymatting/pymatting ### License diff --git a/examples/food-1.jpg b/examples/food-1.jpg new file mode 100644 index 0000000..2fb4e06 Binary files /dev/null and b/examples/food-1.jpg differ diff --git a/examples/food-1.out.alpha.jpg b/examples/food-1.out.alpha.jpg new file mode 100644 index 0000000..8db3d60 Binary files /dev/null and b/examples/food-1.out.alpha.jpg differ diff --git a/examples/food-1.out.jpg b/examples/food-1.out.jpg new file mode 100644 index 0000000..c763d1d Binary files /dev/null and b/examples/food-1.out.jpg differ diff --git a/requirements.txt b/requirements.txt index 266f869..e2686cb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,5 @@ torchvision==0.7.0 waitress==1.4.4 tqdm==4.48.2 requests==2.24.0 +scipy==1.5.2 +pymatting==1.0.6 diff --git a/setup.py b/setup.py index 76a63bc..f6edeea 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ with open("requirements.txt") as f: setup( name="rembg", - version="1.0.10", + version="1.0.11", description="Remove image background", long_description=long_description, long_description_content_type="text/markdown", diff --git a/src/rembg/bg.py b/src/rembg/bg.py index db9d90c..2a6633b 100644 --- a/src/rembg/bg.py +++ b/src/rembg/bg.py @@ -2,6 +2,10 @@ import io import numpy as np from PIL import Image +from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf +from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml +from pymatting.util.util import stack_images +from scipy.ndimage.morphology import binary_erosion from .u2net import detect @@ -9,20 +13,87 @@ model_u2net = detect.load_model(model_name="u2net") model_u2netp = detect.load_model(model_name="u2netp") -def remove(data, model_name="u2net"): +def alpha_matting_cutout( + img, mask, foreground_threshold, background_threshold, erode_structure_size, +): + base_size = (1000, 1000) + size = img.size + + img.thumbnail(base_size, Image.LANCZOS) + mask = mask.resize(img.size, Image.LANCZOS) + + img = np.asarray(img) + mask = np.asarray(mask) + + # guess likely foreground/background + is_foreground = mask > foreground_threshold + is_background = mask < background_threshold + + # erode foreground/background + structure = None + if erode_structure_size > 0: + structure = np.ones((erode_structure_size, erode_structure_size), dtype=np.int) + + is_foreground = binary_erosion(is_foreground, structure=structure) + is_background = binary_erosion(is_background, structure=structure, border_value=1) + + # build trimap + # 0 = background + # 128 = unknown + # 255 = foreground + trimap = np.full(mask.shape, dtype=np.uint8, fill_value=128) + trimap[is_foreground] = 255 + trimap[is_background] = 0 + + # build the cutout image + img_normalized = img / 255.0 + trimap_normalized = trimap / 255.0 + + alpha = estimate_alpha_cf(img_normalized, trimap_normalized) + foreground = estimate_foreground_ml(img_normalized, alpha) + cutout = stack_images(foreground, alpha) + + cutout = np.clip(cutout * 255, 0, 255).astype(np.uint8) + cutout = Image.fromarray(cutout) + cutout = cutout.resize(size, Image.LANCZOS) + + return cutout + + +def naive_cutout(img, mask): + empty = Image.new("RGBA", (img.size), 0) + cutout = Image.composite(img, empty, mask.resize(img.size, Image.LANCZOS)) + return cutout + + +def remove( + data, + model_name="u2net", + alpha_matting=False, + alpha_matting_foreground_threshold=235, + alpha_matting_background_threshold=15, + alpha_matting_erode_structure_size=15, +): model = model_u2net if model == "u2netp": model = model_u2netp - img = Image.open(io.BytesIO(data)) - roi = detect.predict(model, np.array(img)) - roi = roi.resize((img.size), resample=Image.LANCZOS) + img = Image.open(io.BytesIO(data)).convert("RGB") + mask = detect.predict(model, np.array(img)).convert("L") - empty = Image.new("RGBA", (img.size), 0) - out = Image.composite(img, empty, roi.convert("L")) + if alpha_matting: + cutout = alpha_matting_cutout( + img, + mask, + alpha_matting_foreground_threshold, + alpha_matting_background_threshold, + alpha_matting_erode_structure_size, + ) + else: + cutout = naive_cutout(img, mask) bio = io.BytesIO() - out.save(bio, "PNG") + cutout.save(bio, "PNG") return bio.getbuffer() diff --git a/src/rembg/cmd/cli.py b/src/rembg/cmd/cli.py index 2018f50..f076c6b 100644 --- a/src/rembg/cmd/cli.py +++ b/src/rembg/cmd/cli.py @@ -2,6 +2,7 @@ import argparse import glob import imghdr import os +from distutils.util import strtobool from ..bg import remove @@ -18,6 +19,40 @@ def main(): help="The model name.", ) + ap.add_argument( + "-a", + "--alpha-matting", + nargs="?", + const=True, + default=False, + type=lambda x: bool(strtobool(x)), + help="When true use alpha matting cutout.", + ) + + ap.add_argument( + "-af", + "--alpha-matting-foreground-threshold", + default=235, + type=int, + help="The trimap foreground threshold.", + ) + + ap.add_argument( + "-ab", + "--alpha-matting-background-threshold", + default=15, + type=int, + help="The trimap background threshold.", + ) + + ap.add_argument( + "-ae", + "--alpha-matting-erode-size", + default=15, + type=int, + help="Size of element used for the erosion.", + ) + ap.add_argument( "-p", "--path", nargs="+", help="Path of a file or a folder of files.", ) @@ -60,10 +95,30 @@ def main(): with open(fi, "rb") as input: with open(os.path.splitext(fi)[0] + ".out.png", "wb") as output: - w(output, remove(r(input), args.model)) + w( + output, + remove( + r(input), + model_name=args.model, + alpha_matting=args.alpha_matting, + alpha_matting_foreground_threshold=args.alpha_matting_foreground_threshold, + alpha_matting_background_threshold=args.alpha_matting_background_threshold, + alpha_matting_erode_structure_size=args.alpha_matting_erode_size, + ), + ) else: - w(args.output, remove(r(args.input), args.model)) + w( + args.output, + remove( + r(args.input), + model_name=args.model, + alpha_matting=args.alpha_matting, + alpha_matting_foreground_threshold=args.alpha_matting_foreground_threshold, + alpha_matting_background_threshold=args.alpha_matting_background_threshold, + alpha_matting_erode_structure_size=args.alpha_matting_erode_size, + ), + ) if __name__ == "__main__": diff --git a/src/rembg/cmd/server.py b/src/rembg/cmd/server.py index dd1d7a1..f946fd7 100644 --- a/src/rembg/cmd/server.py +++ b/src/rembg/cmd/server.py @@ -11,24 +11,24 @@ from ..bg import remove app = Flask(__name__) -@app.route('/', methods=['GET', 'POST']) +@app.route("/", methods=["GET", "POST"]) def index(): - file_content = '' + file_content = "" - if request.method == 'POST': - if 'file' not in request.files: + if request.method == "POST": + if "file" not in request.files: return {"error": "missing post form param 'file'"}, 400 - file_content = request.files['file'].read() + file_content = request.files["file"].read() - if request.method == 'GET': + if request.method == "GET": url = request.args.get("url", type=str) if url is None: return {"error": "missing query param 'url'"}, 400 file_content = urlopen(unquote_plus(url)).read() - if file_content == '': + if file_content == "": return {"error": "File content is empty"}, 400 model = request.args.get("model", type=str, default="u2net") @@ -36,10 +36,7 @@ def index(): return {"error": "invalid query param 'model'"}, 400 try: - return send_file( - BytesIO(remove(file_content, model)), - mimetype="image/png", - ) + return send_file(BytesIO(remove(file_content, model)), mimetype="image/png",) except Exception as e: app.logger.exception(e, exc_info=True) return {"error": "oops, something went wrong!"}, 500 diff --git a/src/rembg/u2net/detect.py b/src/rembg/u2net/detect.py index 924336c..5afa475 100644 --- a/src/rembg/u2net/detect.py +++ b/src/rembg/u2net/detect.py @@ -107,7 +107,9 @@ def predict(net, item): with torch.no_grad(): if torch.cuda.is_available(): - inputs_test = torch.cuda.FloatTensor(sample["image"].unsqueeze(0).cuda().float()) + inputs_test = torch.cuda.FloatTensor( + sample["image"].unsqueeze(0).cuda().float() + ) else: inputs_test = torch.FloatTensor(sample["image"].unsqueeze(0).float())