mirror of
https://git.mirrors.martin98.com/https://github.com/danielgatis/rembg
synced 2025-08-06 08:26:01 +08:00
add color bg param
This commit is contained in:
parent
a9e1f08036
commit
54cf4f8c11
14
rembg/bg.py
14
rembg/bg.py
@ -1,6 +1,6 @@
|
||||
import io
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from cv2 import (
|
||||
@ -105,9 +105,9 @@ def post_process(mask: np.ndarray) -> np.ndarray:
|
||||
return mask
|
||||
|
||||
|
||||
def apply_background_color(img: PILImage, color: List[int]) -> PILImage:
|
||||
r, g, b = color
|
||||
colored_image = Image.new("RGBA", img.size, (r, g, b, 255))
|
||||
def apply_background_color(img: PILImage, color: Tuple[int, int, int, int]) -> PILImage:
|
||||
r, g, b, a = color
|
||||
colored_image = Image.new("RGBA", img.size, (r, g, b, a))
|
||||
colored_image.paste(img, mask=img)
|
||||
|
||||
return colored_image
|
||||
@ -122,7 +122,7 @@ def remove(
|
||||
session: Optional[BaseSession] = None,
|
||||
only_mask: bool = False,
|
||||
post_process_mask: bool = False,
|
||||
color: Optional[List[int]] = None,
|
||||
bgcolor: Optional[Tuple[int, int, int, int]] = None,
|
||||
) -> Union[bytes, PILImage, np.ndarray]:
|
||||
if isinstance(data, PILImage):
|
||||
return_type = ReturnType.PILLOW
|
||||
@ -170,8 +170,8 @@ def remove(
|
||||
if len(cutouts) > 0:
|
||||
cutout = get_concat_v_multi(cutouts)
|
||||
|
||||
if color is not None:
|
||||
cutout = apply_background_color(cutout, color)
|
||||
if bgcolor is not None and not only_mask:
|
||||
cutout = apply_background_color(cutout, bgcolor)
|
||||
|
||||
if ReturnType.PILLOW == return_type:
|
||||
return cutout
|
||||
|
28
rembg/cli.py
28
rembg/cli.py
@ -2,7 +2,7 @@ import pathlib
|
||||
import sys
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import IO, cast
|
||||
from typing import IO, Optional, cast
|
||||
|
||||
import aiohttp
|
||||
import click
|
||||
@ -93,12 +93,12 @@ def main() -> None:
|
||||
help="post process the mask",
|
||||
)
|
||||
@click.option(
|
||||
"-c",
|
||||
"--color",
|
||||
"-bgc",
|
||||
"--bgcolor",
|
||||
default=None,
|
||||
nargs=3,
|
||||
type=int,
|
||||
help="Background color (R G B) to replace the removed background with",
|
||||
type=(int, int, int, int),
|
||||
nargs=4,
|
||||
help="Background color (R G B A) to replace the removed background with",
|
||||
)
|
||||
@click.argument(
|
||||
"input", default=(None if sys.stdin.isatty() else "-"), type=click.File("rb")
|
||||
@ -185,13 +185,12 @@ def i(model: str, input: IO, output: IO, **kwargs) -> None:
|
||||
help="watches a folder for changes",
|
||||
)
|
||||
@click.option(
|
||||
"-c",
|
||||
"--color",
|
||||
"-bgc",
|
||||
"--bgcolor",
|
||||
default=None,
|
||||
type=(int, int, int),
|
||||
nargs=3,
|
||||
metavar="R G B",
|
||||
help="background color (RGB) to replace removed areas",
|
||||
type=(int, int, int, int),
|
||||
nargs=4,
|
||||
help="Background color (R G B A) to replace the removed background with",
|
||||
)
|
||||
@click.argument(
|
||||
"input",
|
||||
@ -369,6 +368,7 @@ def s(port: int, log_level: str, threads: int) -> None:
|
||||
),
|
||||
om: bool = Query(default=False, description="Only Mask"),
|
||||
ppm: bool = Query(default=False, description="Post Process Mask"),
|
||||
bgc: Optional[str] = Query(default=None, description="Background Color"),
|
||||
):
|
||||
self.model = model
|
||||
self.a = a
|
||||
@ -377,6 +377,7 @@ def s(port: int, log_level: str, threads: int) -> None:
|
||||
self.ae = ae
|
||||
self.om = om
|
||||
self.ppm = ppm
|
||||
self.bgc = map(int, bgc.split(",")) if bgc else None
|
||||
|
||||
class CommonQueryPostParams:
|
||||
def __init__(
|
||||
@ -403,6 +404,7 @@ def s(port: int, log_level: str, threads: int) -> None:
|
||||
),
|
||||
om: bool = Form(default=False, description="Only Mask"),
|
||||
ppm: bool = Form(default=False, description="Post Process Mask"),
|
||||
bgc: Optional[str] = Query(default=None, description="Background Color"),
|
||||
):
|
||||
self.model = model
|
||||
self.a = a
|
||||
@ -411,6 +413,7 @@ def s(port: int, log_level: str, threads: int) -> None:
|
||||
self.ae = ae
|
||||
self.om = om
|
||||
self.ppm = ppm
|
||||
self.bgc = map(int, bgc.split(",")) if bgc else None
|
||||
|
||||
def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
|
||||
return Response(
|
||||
@ -425,6 +428,7 @@ def s(port: int, log_level: str, threads: int) -> None:
|
||||
alpha_matting_erode_size=commons.ae,
|
||||
only_mask=commons.om,
|
||||
post_process_mask=commons.ppm,
|
||||
bgcolor=commons.bgc,
|
||||
),
|
||||
media_type="image/png",
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user