add base_size arg

This commit is contained in:
Daniel Gatis 2021-01-22 13:56:19 -03:00
parent 3558a53d1f
commit 24711986b7
4 changed files with 20 additions and 6 deletions

View File

@ -11,7 +11,7 @@ with open("requirements.txt") as f:
setup(
name="rembg",
version="1.0.18",
version="1.0.19",
description="Remove image background",
long_description=long_description,
long_description_content_type="text/markdown",

View File

@ -17,11 +17,11 @@ def alpha_matting_cutout(
foreground_threshold,
background_threshold,
erode_structure_size,
base_size,
):
base_size = (1000, 1000)
size = img.size
img.thumbnail(base_size, Image.LANCZOS)
img.thumbnail((base_size, base_size), Image.LANCZOS)
mask = mask.resize(img.size, Image.LANCZOS)
img = np.asarray(img)
@ -83,6 +83,7 @@ def remove(
alpha_matting_foreground_threshold=240,
alpha_matting_background_threshold=10,
alpha_matting_erode_structure_size=10,
alpha_matting_base_size=1000,
):
model = get_model(model_name)
img = Image.open(io.BytesIO(data)).convert("RGB")
@ -95,6 +96,7 @@ def remove(
alpha_matting_foreground_threshold,
alpha_matting_background_threshold,
alpha_matting_erode_structure_size,
alpha_matting_base_size,
)
else:
cutout = naive_cutout(img, mask)

View File

@ -55,6 +55,14 @@ def main():
help="Size of element used for the erosion.",
)
ap.add_argument(
"-az",
"--alpha-matting-base-size",
default=1000,
type=int,
help="The image base size.",
)
ap.add_argument(
"-p",
"--path",
@ -113,6 +121,7 @@ def main():
alpha_matting_foreground_threshold=args.alpha_matting_foreground_threshold,
alpha_matting_background_threshold=args.alpha_matting_background_threshold,
alpha_matting_erode_structure_size=args.alpha_matting_erode_size,
alpha_matting_base_size=args.alpha_matting_base_size,
),
)
@ -126,6 +135,7 @@ def main():
alpha_matting_foreground_threshold=args.alpha_matting_foreground_threshold,
alpha_matting_background_threshold=args.alpha_matting_background_threshold,
alpha_matting_erode_structure_size=args.alpha_matting_erode_size,
alpha_matting_base_size=args.alpha_matting_base_size,
),
)

View File

@ -35,7 +35,8 @@ def index():
af = request.values.get("af", type=int, default=240)
ab = request.values.get("ab", type=int, default=10)
ae = request.values.get("ae", type=int, default=10)
az = request.values.get("az", type=int, default=1000)
model = request.args.get("model", type=str, default="u2net")
if model not in ("u2net", "u2netp"):
return {"error": "invalid query param 'model'"}, 400
@ -44,12 +45,13 @@ def index():
return send_file(
BytesIO(
remove(
file_content,
file_content,
model_name=model,
alpha_matting=alpha_matting,
alpha_matting_foreground_threshold=af,
alpha_matting_background_threshold=ab,
alpha_matting_erode_structure_size=ae
alpha_matting_erode_structure_size=ae,
alpha_matting_base_size=az,
)
),
mimetype="image/png",