diff --git a/README.md b/README.md
index bf6f61c..1f6f82e 100644
--- a/README.md
+++ b/README.md
@@ -96,10 +96,37 @@ Then run
cat input.png | python app.py > out.png
```
+### Advance usage
+
+Sometimes it is possible to achieve better results by turning on alpha matting
+```bash
+ curl -s http://input.png -a -ae 15 | rembg > output.png
+```
+
+Example:
+
+
+
+
+ Original |
+ Without alpha matting |
+ With alpha matting (-a -ae 15) |
+
+
+
+
+  |
+  |
+  |
+
+
+
+
### References
- https://arxiv.org/pdf/2005.09007.pdf
- https://github.com/NathanUA/U-2-Net
+- https://github.com/pymatting/pymatting
### License
diff --git a/examples/food-1.jpg b/examples/food-1.jpg
new file mode 100644
index 0000000..2fb4e06
Binary files /dev/null and b/examples/food-1.jpg differ
diff --git a/examples/food-1.out.alpha.jpg b/examples/food-1.out.alpha.jpg
new file mode 100644
index 0000000..8db3d60
Binary files /dev/null and b/examples/food-1.out.alpha.jpg differ
diff --git a/examples/food-1.out.jpg b/examples/food-1.out.jpg
new file mode 100644
index 0000000..c763d1d
Binary files /dev/null and b/examples/food-1.out.jpg differ
diff --git a/requirements.txt b/requirements.txt
index 266f869..e2686cb 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -7,3 +7,5 @@ torchvision==0.7.0
waitress==1.4.4
tqdm==4.48.2
requests==2.24.0
+scipy==1.5.2
+pymatting==1.0.6
diff --git a/setup.py b/setup.py
index 76a63bc..f6edeea 100644
--- a/setup.py
+++ b/setup.py
@@ -11,7 +11,7 @@ with open("requirements.txt") as f:
setup(
name="rembg",
- version="1.0.10",
+ version="1.0.11",
description="Remove image background",
long_description=long_description,
long_description_content_type="text/markdown",
diff --git a/src/rembg/bg.py b/src/rembg/bg.py
index db9d90c..2a6633b 100644
--- a/src/rembg/bg.py
+++ b/src/rembg/bg.py
@@ -2,6 +2,10 @@ import io
import numpy as np
from PIL import Image
+from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
+from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
+from pymatting.util.util import stack_images
+from scipy.ndimage.morphology import binary_erosion
from .u2net import detect
@@ -9,20 +13,87 @@ model_u2net = detect.load_model(model_name="u2net")
model_u2netp = detect.load_model(model_name="u2netp")
-def remove(data, model_name="u2net"):
+def alpha_matting_cutout(
+ img, mask, foreground_threshold, background_threshold, erode_structure_size,
+):
+ base_size = (1000, 1000)
+ size = img.size
+
+ img.thumbnail(base_size, Image.LANCZOS)
+ mask = mask.resize(img.size, Image.LANCZOS)
+
+ img = np.asarray(img)
+ mask = np.asarray(mask)
+
+ # guess likely foreground/background
+ is_foreground = mask > foreground_threshold
+ is_background = mask < background_threshold
+
+ # erode foreground/background
+ structure = None
+ if erode_structure_size > 0:
+ structure = np.ones((erode_structure_size, erode_structure_size), dtype=np.int)
+
+ is_foreground = binary_erosion(is_foreground, structure=structure)
+ is_background = binary_erosion(is_background, structure=structure, border_value=1)
+
+ # build trimap
+ # 0 = background
+ # 128 = unknown
+ # 255 = foreground
+ trimap = np.full(mask.shape, dtype=np.uint8, fill_value=128)
+ trimap[is_foreground] = 255
+ trimap[is_background] = 0
+
+ # build the cutout image
+ img_normalized = img / 255.0
+ trimap_normalized = trimap / 255.0
+
+ alpha = estimate_alpha_cf(img_normalized, trimap_normalized)
+ foreground = estimate_foreground_ml(img_normalized, alpha)
+ cutout = stack_images(foreground, alpha)
+
+ cutout = np.clip(cutout * 255, 0, 255).astype(np.uint8)
+ cutout = Image.fromarray(cutout)
+ cutout = cutout.resize(size, Image.LANCZOS)
+
+ return cutout
+
+
+def naive_cutout(img, mask):
+ empty = Image.new("RGBA", (img.size), 0)
+ cutout = Image.composite(img, empty, mask.resize(img.size, Image.LANCZOS))
+ return cutout
+
+
+def remove(
+ data,
+ model_name="u2net",
+ alpha_matting=False,
+ alpha_matting_foreground_threshold=235,
+ alpha_matting_background_threshold=15,
+ alpha_matting_erode_structure_size=15,
+):
model = model_u2net
if model == "u2netp":
model = model_u2netp
- img = Image.open(io.BytesIO(data))
- roi = detect.predict(model, np.array(img))
- roi = roi.resize((img.size), resample=Image.LANCZOS)
+ img = Image.open(io.BytesIO(data)).convert("RGB")
+ mask = detect.predict(model, np.array(img)).convert("L")
- empty = Image.new("RGBA", (img.size), 0)
- out = Image.composite(img, empty, roi.convert("L"))
+ if alpha_matting:
+ cutout = alpha_matting_cutout(
+ img,
+ mask,
+ alpha_matting_foreground_threshold,
+ alpha_matting_background_threshold,
+ alpha_matting_erode_structure_size,
+ )
+ else:
+ cutout = naive_cutout(img, mask)
bio = io.BytesIO()
- out.save(bio, "PNG")
+ cutout.save(bio, "PNG")
return bio.getbuffer()
diff --git a/src/rembg/cmd/cli.py b/src/rembg/cmd/cli.py
index 2018f50..f076c6b 100644
--- a/src/rembg/cmd/cli.py
+++ b/src/rembg/cmd/cli.py
@@ -2,6 +2,7 @@ import argparse
import glob
import imghdr
import os
+from distutils.util import strtobool
from ..bg import remove
@@ -18,6 +19,40 @@ def main():
help="The model name.",
)
+ ap.add_argument(
+ "-a",
+ "--alpha-matting",
+ nargs="?",
+ const=True,
+ default=False,
+ type=lambda x: bool(strtobool(x)),
+ help="When true use alpha matting cutout.",
+ )
+
+ ap.add_argument(
+ "-af",
+ "--alpha-matting-foreground-threshold",
+ default=235,
+ type=int,
+ help="The trimap foreground threshold.",
+ )
+
+ ap.add_argument(
+ "-ab",
+ "--alpha-matting-background-threshold",
+ default=15,
+ type=int,
+ help="The trimap background threshold.",
+ )
+
+ ap.add_argument(
+ "-ae",
+ "--alpha-matting-erode-size",
+ default=15,
+ type=int,
+ help="Size of element used for the erosion.",
+ )
+
ap.add_argument(
"-p", "--path", nargs="+", help="Path of a file or a folder of files.",
)
@@ -60,10 +95,30 @@ def main():
with open(fi, "rb") as input:
with open(os.path.splitext(fi)[0] + ".out.png", "wb") as output:
- w(output, remove(r(input), args.model))
+ w(
+ output,
+ remove(
+ r(input),
+ model_name=args.model,
+ alpha_matting=args.alpha_matting,
+ alpha_matting_foreground_threshold=args.alpha_matting_foreground_threshold,
+ alpha_matting_background_threshold=args.alpha_matting_background_threshold,
+ alpha_matting_erode_structure_size=args.alpha_matting_erode_size,
+ ),
+ )
else:
- w(args.output, remove(r(args.input), args.model))
+ w(
+ args.output,
+ remove(
+ r(args.input),
+ model_name=args.model,
+ alpha_matting=args.alpha_matting,
+ alpha_matting_foreground_threshold=args.alpha_matting_foreground_threshold,
+ alpha_matting_background_threshold=args.alpha_matting_background_threshold,
+ alpha_matting_erode_structure_size=args.alpha_matting_erode_size,
+ ),
+ )
if __name__ == "__main__":
diff --git a/src/rembg/cmd/server.py b/src/rembg/cmd/server.py
index dd1d7a1..f946fd7 100644
--- a/src/rembg/cmd/server.py
+++ b/src/rembg/cmd/server.py
@@ -11,24 +11,24 @@ from ..bg import remove
app = Flask(__name__)
-@app.route('/', methods=['GET', 'POST'])
+@app.route("/", methods=["GET", "POST"])
def index():
- file_content = ''
+ file_content = ""
- if request.method == 'POST':
- if 'file' not in request.files:
+ if request.method == "POST":
+ if "file" not in request.files:
return {"error": "missing post form param 'file'"}, 400
- file_content = request.files['file'].read()
+ file_content = request.files["file"].read()
- if request.method == 'GET':
+ if request.method == "GET":
url = request.args.get("url", type=str)
if url is None:
return {"error": "missing query param 'url'"}, 400
file_content = urlopen(unquote_plus(url)).read()
- if file_content == '':
+ if file_content == "":
return {"error": "File content is empty"}, 400
model = request.args.get("model", type=str, default="u2net")
@@ -36,10 +36,7 @@ def index():
return {"error": "invalid query param 'model'"}, 400
try:
- return send_file(
- BytesIO(remove(file_content, model)),
- mimetype="image/png",
- )
+ return send_file(BytesIO(remove(file_content, model)), mimetype="image/png",)
except Exception as e:
app.logger.exception(e, exc_info=True)
return {"error": "oops, something went wrong!"}, 500
diff --git a/src/rembg/u2net/detect.py b/src/rembg/u2net/detect.py
index 924336c..5afa475 100644
--- a/src/rembg/u2net/detect.py
+++ b/src/rembg/u2net/detect.py
@@ -107,7 +107,9 @@ def predict(net, item):
with torch.no_grad():
if torch.cuda.is_available():
- inputs_test = torch.cuda.FloatTensor(sample["image"].unsqueeze(0).cuda().float())
+ inputs_test = torch.cuda.FloatTensor(
+ sample["image"].unsqueeze(0).cuda().float()
+ )
else:
inputs_test = torch.FloatTensor(sample["image"].unsqueeze(0).float())