mirror of
https://git.mirrors.martin98.com/https://github.com/danielgatis/rembg
synced 2025-08-14 20:26:01 +08:00
add mask option and better gpu install
This commit is contained in:
parent
521e49a18a
commit
722e23cc8d
@ -11,7 +11,7 @@ WORKDIR /rembg
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN GPU=1 pip3 install .
|
||||
RUN ["pip3", "install", ".[gpu]"]
|
||||
|
||||
ENTRYPOINT ["rembg"]
|
||||
CMD []
|
||||
|
@ -45,7 +45,7 @@ pip install rembg
|
||||
|
||||
GPU support:
|
||||
```bash
|
||||
GPU=1 pip install rembg
|
||||
pip install rembg[gpu]
|
||||
```
|
||||
|
||||
### Usage as a cli
|
||||
|
29
rembg/bg.py
29
rembg/bg.py
@ -18,13 +18,7 @@ def alpha_matting_cutout(
|
||||
foreground_threshold: int,
|
||||
background_threshold: int,
|
||||
erode_structure_size: int,
|
||||
base_size: int,
|
||||
) -> Image:
|
||||
size = img.size
|
||||
|
||||
img.thumbnail((base_size, base_size), Image.LANCZOS)
|
||||
mask = mask.resize(img.size, Image.LANCZOS)
|
||||
|
||||
img = np.asarray(img)
|
||||
mask = np.asarray(mask)
|
||||
|
||||
@ -60,45 +54,37 @@ def alpha_matting_cutout(
|
||||
|
||||
cutout = np.clip(cutout * 255, 0, 255).astype(np.uint8)
|
||||
cutout = Image.fromarray(cutout)
|
||||
cutout = cutout.resize(size, Image.LANCZOS)
|
||||
|
||||
return cutout
|
||||
|
||||
|
||||
def naive_cutout(img: Image, mask: Image) -> Image:
|
||||
empty = Image.new("RGBA", (img.size), 0)
|
||||
cutout = Image.composite(img, empty, mask.resize(img.size, Image.LANCZOS))
|
||||
cutout = Image.composite(img, empty, mask)
|
||||
return cutout
|
||||
|
||||
|
||||
def resize_image(img: Image, width: Optional[int], height: Optional[int]) -> Image:
|
||||
original_width, original_height = img.size
|
||||
width = original_width if width is None else width
|
||||
height = original_height if height is None else height
|
||||
return img.resize((width, height))
|
||||
|
||||
|
||||
def remove(
|
||||
data: bytes,
|
||||
alpha_matting: bool = False,
|
||||
alpha_matting_foreground_threshold: int = 240,
|
||||
alpha_matting_background_threshold: int = 10,
|
||||
alpha_matting_erode_size: int = 10,
|
||||
alpha_matting_base_size: int = 1000,
|
||||
session: Optional[ort.InferenceSession] = None,
|
||||
width: Optional[int] = None,
|
||||
height: Optional[int] = None,
|
||||
only_mask: bool = False,
|
||||
) -> bytes:
|
||||
img = Image.open(io.BytesIO(data)).convert("RGB")
|
||||
if width is not None or height is not None:
|
||||
img = resize_image(img, width, height)
|
||||
|
||||
if session is None:
|
||||
session = ort_session("u2net")
|
||||
|
||||
mask = predict(session, np.array(img)).convert("L")
|
||||
mask = mask.resize(img.size, Image.LANCZOS)
|
||||
|
||||
if alpha_matting:
|
||||
if only_mask:
|
||||
cutout = mask
|
||||
|
||||
elif alpha_matting:
|
||||
try:
|
||||
cutout = alpha_matting_cutout(
|
||||
img,
|
||||
@ -106,7 +92,6 @@ def remove(
|
||||
alpha_matting_foreground_threshold,
|
||||
alpha_matting_background_threshold,
|
||||
alpha_matting_erode_size,
|
||||
alpha_matting_base_size,
|
||||
)
|
||||
except Exception:
|
||||
cutout = naive_cutout(img, mask)
|
||||
|
62
rembg/cli.py
62
rembg/cli.py
@ -66,28 +66,11 @@ def main():
|
||||
help="erode size",
|
||||
)
|
||||
@click.option(
|
||||
"-az",
|
||||
"--alpha-matting-base-size",
|
||||
default=1000,
|
||||
type=int,
|
||||
"-om",
|
||||
"--only-mask",
|
||||
is_flag=True,
|
||||
show_default=True,
|
||||
help="image base size",
|
||||
)
|
||||
@click.option(
|
||||
"-w",
|
||||
"--width",
|
||||
default=None,
|
||||
type=int,
|
||||
show_default=True,
|
||||
help="output image size",
|
||||
)
|
||||
@click.option(
|
||||
"-h",
|
||||
"--height",
|
||||
default=None,
|
||||
type=int,
|
||||
show_default=True,
|
||||
help="output image size",
|
||||
help="output only the mask",
|
||||
)
|
||||
@click.argument(
|
||||
"input", default=(None if sys.stdin.isatty() else "-"), type=click.File("rb")
|
||||
@ -143,28 +126,11 @@ def i(model: str, input: IO, output: IO, **kwargs):
|
||||
help="erode size",
|
||||
)
|
||||
@click.option(
|
||||
"-az",
|
||||
"--alpha-matting-base-size",
|
||||
default=1000,
|
||||
type=int,
|
||||
"-om",
|
||||
"--only-mask",
|
||||
is_flag=True,
|
||||
show_default=True,
|
||||
help="image base size",
|
||||
)
|
||||
@click.option(
|
||||
"-w",
|
||||
"--width",
|
||||
default=None,
|
||||
type=int,
|
||||
show_default=True,
|
||||
help="output image size",
|
||||
)
|
||||
@click.option(
|
||||
"-h",
|
||||
"--height",
|
||||
default=None,
|
||||
type=int,
|
||||
show_default=True,
|
||||
help="output image size",
|
||||
help="output only the mask",
|
||||
)
|
||||
@click.argument(
|
||||
"input",
|
||||
@ -240,18 +206,14 @@ def s(port: int, log_level: str):
|
||||
af: int = Query(240, ge=0),
|
||||
ab: int = Query(10, ge=0),
|
||||
ae: int = Query(10, ge=0),
|
||||
az: int = Query(1000, ge=0),
|
||||
width: Optional[int] = Query(None, gt=0),
|
||||
height: Optional[int] = Query(None, gt=0),
|
||||
om: bool = Query(False),
|
||||
):
|
||||
self.model = model
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.a = a
|
||||
self.af = af
|
||||
self.ab = ab
|
||||
self.ae = ae
|
||||
self.az = az
|
||||
self.om = om
|
||||
|
||||
def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
|
||||
return Response(
|
||||
@ -260,13 +222,11 @@ def s(port: int, log_level: str):
|
||||
session=sessions.setdefault(
|
||||
commons.model.value, ort_session(commons.model.value)
|
||||
),
|
||||
width=commons.width,
|
||||
height=commons.height,
|
||||
alpha_matting=commons.a,
|
||||
alpha_matting_foreground_threshold=commons.af,
|
||||
alpha_matting_background_threshold=commons.ab,
|
||||
alpha_matting_erode_size=commons.ae,
|
||||
alpha_matting_base_size=commons.az,
|
||||
only_mask=commons.om,
|
||||
),
|
||||
media_type="image/png",
|
||||
)
|
||||
|
@ -1 +0,0 @@
|
||||
onnxruntime==1.10.0
|
@ -5,6 +5,7 @@ fastapi==0.72.0
|
||||
filetype==1.0.9
|
||||
gdown==4.2.0
|
||||
numpy==1.21.5
|
||||
onnxruntime==1.10.0
|
||||
pillow==9.0.0
|
||||
pymatting==1.1.5
|
||||
python-multipart==0.0.5
|
||||
|
11
setup.py
11
setup.py
@ -14,12 +14,8 @@ long_description = (here / "README.md").read_text(encoding="utf-8")
|
||||
with open("requirements.txt") as f:
|
||||
requireds = f.read().splitlines()
|
||||
|
||||
if os.getenv("GPU") is None:
|
||||
with open("requirements-cpu.txt") as f:
|
||||
requireds += f.read().splitlines()
|
||||
else:
|
||||
with open("requirements-gpu.txt") as f:
|
||||
requireds += f.read().splitlines()
|
||||
with open("requirements-gpu.txt") as f:
|
||||
gpu_requireds = f.read().splitlines()
|
||||
|
||||
setup(
|
||||
name="rembg",
|
||||
@ -42,6 +38,9 @@ setup(
|
||||
"rembg=rembg.cli:main",
|
||||
],
|
||||
},
|
||||
extras_require={
|
||||
'gpu': gpu_requireds,
|
||||
},
|
||||
version=versioneer.get_version(),
|
||||
cmdclass=versioneer.get_cmdclass(),
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user