mirror of
https://git.mirrors.martin98.com/https://github.com/danielgatis/rembg
synced 2025-08-15 14:25:57 +08:00
add alpha matting
This commit is contained in:
parent
df748667e6
commit
215cb3e934
27
README.md
27
README.md
@ -96,10 +96,37 @@ Then run
|
|||||||
cat input.png | python app.py > out.png
|
cat input.png | python app.py > out.png
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Advance usage
|
||||||
|
|
||||||
|
Sometimes it is possible to achieve better results by turning on alpha matting
|
||||||
|
```bash
|
||||||
|
curl -s http://input.png -a -ae 15 | rembg > output.png
|
||||||
|
```
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
<table>
|
||||||
|
<thead>
|
||||||
|
<tr>
|
||||||
|
<td>Original</td>
|
||||||
|
<td>Without alpha matting</td>
|
||||||
|
<td>With alpha matting (-a -ae 15)</td>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
<tr>
|
||||||
|
<td><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/examples/food-1.jpg" width="100" /></td>
|
||||||
|
<td><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/examples/food-1.out.jpg" width="100" /></td>
|
||||||
|
<td><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/examples/food-1.out.alpha.jpg" width="100" /></td>
|
||||||
|
</tr>
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
|
||||||
### References
|
### References
|
||||||
|
|
||||||
- https://arxiv.org/pdf/2005.09007.pdf
|
- https://arxiv.org/pdf/2005.09007.pdf
|
||||||
- https://github.com/NathanUA/U-2-Net
|
- https://github.com/NathanUA/U-2-Net
|
||||||
|
- https://github.com/pymatting/pymatting
|
||||||
|
|
||||||
### License
|
### License
|
||||||
|
|
||||||
|
BIN
examples/food-1.jpg
Normal file
BIN
examples/food-1.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 4.8 MiB |
BIN
examples/food-1.out.alpha.jpg
Normal file
BIN
examples/food-1.out.alpha.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 3.0 MiB |
BIN
examples/food-1.out.jpg
Normal file
BIN
examples/food-1.out.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 4.0 MiB |
@ -7,3 +7,5 @@ torchvision==0.7.0
|
|||||||
waitress==1.4.4
|
waitress==1.4.4
|
||||||
tqdm==4.48.2
|
tqdm==4.48.2
|
||||||
requests==2.24.0
|
requests==2.24.0
|
||||||
|
scipy==1.5.2
|
||||||
|
pymatting==1.0.6
|
||||||
|
2
setup.py
2
setup.py
@ -11,7 +11,7 @@ with open("requirements.txt") as f:
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="rembg",
|
name="rembg",
|
||||||
version="1.0.10",
|
version="1.0.11",
|
||||||
description="Remove image background",
|
description="Remove image background",
|
||||||
long_description=long_description,
|
long_description=long_description,
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
|
@ -2,6 +2,10 @@ import io
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
|
||||||
|
from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
|
||||||
|
from pymatting.util.util import stack_images
|
||||||
|
from scipy.ndimage.morphology import binary_erosion
|
||||||
|
|
||||||
from .u2net import detect
|
from .u2net import detect
|
||||||
|
|
||||||
@ -9,20 +13,87 @@ model_u2net = detect.load_model(model_name="u2net")
|
|||||||
model_u2netp = detect.load_model(model_name="u2netp")
|
model_u2netp = detect.load_model(model_name="u2netp")
|
||||||
|
|
||||||
|
|
||||||
def remove(data, model_name="u2net"):
|
def alpha_matting_cutout(
|
||||||
|
img, mask, foreground_threshold, background_threshold, erode_structure_size,
|
||||||
|
):
|
||||||
|
base_size = (1000, 1000)
|
||||||
|
size = img.size
|
||||||
|
|
||||||
|
img.thumbnail(base_size, Image.LANCZOS)
|
||||||
|
mask = mask.resize(img.size, Image.LANCZOS)
|
||||||
|
|
||||||
|
img = np.asarray(img)
|
||||||
|
mask = np.asarray(mask)
|
||||||
|
|
||||||
|
# guess likely foreground/background
|
||||||
|
is_foreground = mask > foreground_threshold
|
||||||
|
is_background = mask < background_threshold
|
||||||
|
|
||||||
|
# erode foreground/background
|
||||||
|
structure = None
|
||||||
|
if erode_structure_size > 0:
|
||||||
|
structure = np.ones((erode_structure_size, erode_structure_size), dtype=np.int)
|
||||||
|
|
||||||
|
is_foreground = binary_erosion(is_foreground, structure=structure)
|
||||||
|
is_background = binary_erosion(is_background, structure=structure, border_value=1)
|
||||||
|
|
||||||
|
# build trimap
|
||||||
|
# 0 = background
|
||||||
|
# 128 = unknown
|
||||||
|
# 255 = foreground
|
||||||
|
trimap = np.full(mask.shape, dtype=np.uint8, fill_value=128)
|
||||||
|
trimap[is_foreground] = 255
|
||||||
|
trimap[is_background] = 0
|
||||||
|
|
||||||
|
# build the cutout image
|
||||||
|
img_normalized = img / 255.0
|
||||||
|
trimap_normalized = trimap / 255.0
|
||||||
|
|
||||||
|
alpha = estimate_alpha_cf(img_normalized, trimap_normalized)
|
||||||
|
foreground = estimate_foreground_ml(img_normalized, alpha)
|
||||||
|
cutout = stack_images(foreground, alpha)
|
||||||
|
|
||||||
|
cutout = np.clip(cutout * 255, 0, 255).astype(np.uint8)
|
||||||
|
cutout = Image.fromarray(cutout)
|
||||||
|
cutout = cutout.resize(size, Image.LANCZOS)
|
||||||
|
|
||||||
|
return cutout
|
||||||
|
|
||||||
|
|
||||||
|
def naive_cutout(img, mask):
|
||||||
|
empty = Image.new("RGBA", (img.size), 0)
|
||||||
|
cutout = Image.composite(img, empty, mask.resize(img.size, Image.LANCZOS))
|
||||||
|
return cutout
|
||||||
|
|
||||||
|
|
||||||
|
def remove(
|
||||||
|
data,
|
||||||
|
model_name="u2net",
|
||||||
|
alpha_matting=False,
|
||||||
|
alpha_matting_foreground_threshold=235,
|
||||||
|
alpha_matting_background_threshold=15,
|
||||||
|
alpha_matting_erode_structure_size=15,
|
||||||
|
):
|
||||||
model = model_u2net
|
model = model_u2net
|
||||||
|
|
||||||
if model == "u2netp":
|
if model == "u2netp":
|
||||||
model = model_u2netp
|
model = model_u2netp
|
||||||
|
|
||||||
img = Image.open(io.BytesIO(data))
|
img = Image.open(io.BytesIO(data)).convert("RGB")
|
||||||
roi = detect.predict(model, np.array(img))
|
mask = detect.predict(model, np.array(img)).convert("L")
|
||||||
roi = roi.resize((img.size), resample=Image.LANCZOS)
|
|
||||||
|
|
||||||
empty = Image.new("RGBA", (img.size), 0)
|
if alpha_matting:
|
||||||
out = Image.composite(img, empty, roi.convert("L"))
|
cutout = alpha_matting_cutout(
|
||||||
|
img,
|
||||||
|
mask,
|
||||||
|
alpha_matting_foreground_threshold,
|
||||||
|
alpha_matting_background_threshold,
|
||||||
|
alpha_matting_erode_structure_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cutout = naive_cutout(img, mask)
|
||||||
|
|
||||||
bio = io.BytesIO()
|
bio = io.BytesIO()
|
||||||
out.save(bio, "PNG")
|
cutout.save(bio, "PNG")
|
||||||
|
|
||||||
return bio.getbuffer()
|
return bio.getbuffer()
|
||||||
|
@ -2,6 +2,7 @@ import argparse
|
|||||||
import glob
|
import glob
|
||||||
import imghdr
|
import imghdr
|
||||||
import os
|
import os
|
||||||
|
from distutils.util import strtobool
|
||||||
|
|
||||||
from ..bg import remove
|
from ..bg import remove
|
||||||
|
|
||||||
@ -18,6 +19,40 @@ def main():
|
|||||||
help="The model name.",
|
help="The model name.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ap.add_argument(
|
||||||
|
"-a",
|
||||||
|
"--alpha-matting",
|
||||||
|
nargs="?",
|
||||||
|
const=True,
|
||||||
|
default=False,
|
||||||
|
type=lambda x: bool(strtobool(x)),
|
||||||
|
help="When true use alpha matting cutout.",
|
||||||
|
)
|
||||||
|
|
||||||
|
ap.add_argument(
|
||||||
|
"-af",
|
||||||
|
"--alpha-matting-foreground-threshold",
|
||||||
|
default=235,
|
||||||
|
type=int,
|
||||||
|
help="The trimap foreground threshold.",
|
||||||
|
)
|
||||||
|
|
||||||
|
ap.add_argument(
|
||||||
|
"-ab",
|
||||||
|
"--alpha-matting-background-threshold",
|
||||||
|
default=15,
|
||||||
|
type=int,
|
||||||
|
help="The trimap background threshold.",
|
||||||
|
)
|
||||||
|
|
||||||
|
ap.add_argument(
|
||||||
|
"-ae",
|
||||||
|
"--alpha-matting-erode-size",
|
||||||
|
default=15,
|
||||||
|
type=int,
|
||||||
|
help="Size of element used for the erosion.",
|
||||||
|
)
|
||||||
|
|
||||||
ap.add_argument(
|
ap.add_argument(
|
||||||
"-p", "--path", nargs="+", help="Path of a file or a folder of files.",
|
"-p", "--path", nargs="+", help="Path of a file or a folder of files.",
|
||||||
)
|
)
|
||||||
@ -60,10 +95,30 @@ def main():
|
|||||||
|
|
||||||
with open(fi, "rb") as input:
|
with open(fi, "rb") as input:
|
||||||
with open(os.path.splitext(fi)[0] + ".out.png", "wb") as output:
|
with open(os.path.splitext(fi)[0] + ".out.png", "wb") as output:
|
||||||
w(output, remove(r(input), args.model))
|
w(
|
||||||
|
output,
|
||||||
|
remove(
|
||||||
|
r(input),
|
||||||
|
model_name=args.model,
|
||||||
|
alpha_matting=args.alpha_matting,
|
||||||
|
alpha_matting_foreground_threshold=args.alpha_matting_foreground_threshold,
|
||||||
|
alpha_matting_background_threshold=args.alpha_matting_background_threshold,
|
||||||
|
alpha_matting_erode_structure_size=args.alpha_matting_erode_size,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
w(args.output, remove(r(args.input), args.model))
|
w(
|
||||||
|
args.output,
|
||||||
|
remove(
|
||||||
|
r(args.input),
|
||||||
|
model_name=args.model,
|
||||||
|
alpha_matting=args.alpha_matting,
|
||||||
|
alpha_matting_foreground_threshold=args.alpha_matting_foreground_threshold,
|
||||||
|
alpha_matting_background_threshold=args.alpha_matting_background_threshold,
|
||||||
|
alpha_matting_erode_structure_size=args.alpha_matting_erode_size,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -11,24 +11,24 @@ from ..bg import remove
|
|||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
|
|
||||||
|
|
||||||
@app.route('/', methods=['GET', 'POST'])
|
@app.route("/", methods=["GET", "POST"])
|
||||||
def index():
|
def index():
|
||||||
file_content = ''
|
file_content = ""
|
||||||
|
|
||||||
if request.method == 'POST':
|
if request.method == "POST":
|
||||||
if 'file' not in request.files:
|
if "file" not in request.files:
|
||||||
return {"error": "missing post form param 'file'"}, 400
|
return {"error": "missing post form param 'file'"}, 400
|
||||||
|
|
||||||
file_content = request.files['file'].read()
|
file_content = request.files["file"].read()
|
||||||
|
|
||||||
if request.method == 'GET':
|
if request.method == "GET":
|
||||||
url = request.args.get("url", type=str)
|
url = request.args.get("url", type=str)
|
||||||
if url is None:
|
if url is None:
|
||||||
return {"error": "missing query param 'url'"}, 400
|
return {"error": "missing query param 'url'"}, 400
|
||||||
|
|
||||||
file_content = urlopen(unquote_plus(url)).read()
|
file_content = urlopen(unquote_plus(url)).read()
|
||||||
|
|
||||||
if file_content == '':
|
if file_content == "":
|
||||||
return {"error": "File content is empty"}, 400
|
return {"error": "File content is empty"}, 400
|
||||||
|
|
||||||
model = request.args.get("model", type=str, default="u2net")
|
model = request.args.get("model", type=str, default="u2net")
|
||||||
@ -36,10 +36,7 @@ def index():
|
|||||||
return {"error": "invalid query param 'model'"}, 400
|
return {"error": "invalid query param 'model'"}, 400
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return send_file(
|
return send_file(BytesIO(remove(file_content, model)), mimetype="image/png",)
|
||||||
BytesIO(remove(file_content, model)),
|
|
||||||
mimetype="image/png",
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
app.logger.exception(e, exc_info=True)
|
app.logger.exception(e, exc_info=True)
|
||||||
return {"error": "oops, something went wrong!"}, 500
|
return {"error": "oops, something went wrong!"}, 500
|
||||||
|
@ -107,7 +107,9 @@ def predict(net, item):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
inputs_test = torch.cuda.FloatTensor(sample["image"].unsqueeze(0).cuda().float())
|
inputs_test = torch.cuda.FloatTensor(
|
||||||
|
sample["image"].unsqueeze(0).cuda().float()
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
inputs_test = torch.FloatTensor(sample["image"].unsqueeze(0).float())
|
inputs_test = torch.FloatTensor(sample["image"].unsqueeze(0).float())
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user