mirror of
https://git.mirrors.martin98.com/https://github.com/danielgatis/rembg
synced 2025-08-15 17:55:55 +08:00
add mask option and better gpu install
This commit is contained in:
parent
521e49a18a
commit
722e23cc8d
@ -11,7 +11,7 @@ WORKDIR /rembg
|
|||||||
|
|
||||||
COPY . .
|
COPY . .
|
||||||
|
|
||||||
RUN GPU=1 pip3 install .
|
RUN ["pip3", "install", ".[gpu]"]
|
||||||
|
|
||||||
ENTRYPOINT ["rembg"]
|
ENTRYPOINT ["rembg"]
|
||||||
CMD []
|
CMD []
|
||||||
|
@ -45,7 +45,7 @@ pip install rembg
|
|||||||
|
|
||||||
GPU support:
|
GPU support:
|
||||||
```bash
|
```bash
|
||||||
GPU=1 pip install rembg
|
pip install rembg[gpu]
|
||||||
```
|
```
|
||||||
|
|
||||||
### Usage as a cli
|
### Usage as a cli
|
||||||
|
29
rembg/bg.py
29
rembg/bg.py
@ -18,13 +18,7 @@ def alpha_matting_cutout(
|
|||||||
foreground_threshold: int,
|
foreground_threshold: int,
|
||||||
background_threshold: int,
|
background_threshold: int,
|
||||||
erode_structure_size: int,
|
erode_structure_size: int,
|
||||||
base_size: int,
|
|
||||||
) -> Image:
|
) -> Image:
|
||||||
size = img.size
|
|
||||||
|
|
||||||
img.thumbnail((base_size, base_size), Image.LANCZOS)
|
|
||||||
mask = mask.resize(img.size, Image.LANCZOS)
|
|
||||||
|
|
||||||
img = np.asarray(img)
|
img = np.asarray(img)
|
||||||
mask = np.asarray(mask)
|
mask = np.asarray(mask)
|
||||||
|
|
||||||
@ -60,45 +54,37 @@ def alpha_matting_cutout(
|
|||||||
|
|
||||||
cutout = np.clip(cutout * 255, 0, 255).astype(np.uint8)
|
cutout = np.clip(cutout * 255, 0, 255).astype(np.uint8)
|
||||||
cutout = Image.fromarray(cutout)
|
cutout = Image.fromarray(cutout)
|
||||||
cutout = cutout.resize(size, Image.LANCZOS)
|
|
||||||
|
|
||||||
return cutout
|
return cutout
|
||||||
|
|
||||||
|
|
||||||
def naive_cutout(img: Image, mask: Image) -> Image:
|
def naive_cutout(img: Image, mask: Image) -> Image:
|
||||||
empty = Image.new("RGBA", (img.size), 0)
|
empty = Image.new("RGBA", (img.size), 0)
|
||||||
cutout = Image.composite(img, empty, mask.resize(img.size, Image.LANCZOS))
|
cutout = Image.composite(img, empty, mask)
|
||||||
return cutout
|
return cutout
|
||||||
|
|
||||||
|
|
||||||
def resize_image(img: Image, width: Optional[int], height: Optional[int]) -> Image:
|
|
||||||
original_width, original_height = img.size
|
|
||||||
width = original_width if width is None else width
|
|
||||||
height = original_height if height is None else height
|
|
||||||
return img.resize((width, height))
|
|
||||||
|
|
||||||
|
|
||||||
def remove(
|
def remove(
|
||||||
data: bytes,
|
data: bytes,
|
||||||
alpha_matting: bool = False,
|
alpha_matting: bool = False,
|
||||||
alpha_matting_foreground_threshold: int = 240,
|
alpha_matting_foreground_threshold: int = 240,
|
||||||
alpha_matting_background_threshold: int = 10,
|
alpha_matting_background_threshold: int = 10,
|
||||||
alpha_matting_erode_size: int = 10,
|
alpha_matting_erode_size: int = 10,
|
||||||
alpha_matting_base_size: int = 1000,
|
|
||||||
session: Optional[ort.InferenceSession] = None,
|
session: Optional[ort.InferenceSession] = None,
|
||||||
width: Optional[int] = None,
|
only_mask: bool = False,
|
||||||
height: Optional[int] = None,
|
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
img = Image.open(io.BytesIO(data)).convert("RGB")
|
img = Image.open(io.BytesIO(data)).convert("RGB")
|
||||||
if width is not None or height is not None:
|
|
||||||
img = resize_image(img, width, height)
|
|
||||||
|
|
||||||
if session is None:
|
if session is None:
|
||||||
session = ort_session("u2net")
|
session = ort_session("u2net")
|
||||||
|
|
||||||
mask = predict(session, np.array(img)).convert("L")
|
mask = predict(session, np.array(img)).convert("L")
|
||||||
|
mask = mask.resize(img.size, Image.LANCZOS)
|
||||||
|
|
||||||
if alpha_matting:
|
if only_mask:
|
||||||
|
cutout = mask
|
||||||
|
|
||||||
|
elif alpha_matting:
|
||||||
try:
|
try:
|
||||||
cutout = alpha_matting_cutout(
|
cutout = alpha_matting_cutout(
|
||||||
img,
|
img,
|
||||||
@ -106,7 +92,6 @@ def remove(
|
|||||||
alpha_matting_foreground_threshold,
|
alpha_matting_foreground_threshold,
|
||||||
alpha_matting_background_threshold,
|
alpha_matting_background_threshold,
|
||||||
alpha_matting_erode_size,
|
alpha_matting_erode_size,
|
||||||
alpha_matting_base_size,
|
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
cutout = naive_cutout(img, mask)
|
cutout = naive_cutout(img, mask)
|
||||||
|
62
rembg/cli.py
62
rembg/cli.py
@ -66,28 +66,11 @@ def main():
|
|||||||
help="erode size",
|
help="erode size",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"-az",
|
"-om",
|
||||||
"--alpha-matting-base-size",
|
"--only-mask",
|
||||||
default=1000,
|
is_flag=True,
|
||||||
type=int,
|
|
||||||
show_default=True,
|
show_default=True,
|
||||||
help="image base size",
|
help="output only the mask",
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"-w",
|
|
||||||
"--width",
|
|
||||||
default=None,
|
|
||||||
type=int,
|
|
||||||
show_default=True,
|
|
||||||
help="output image size",
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"-h",
|
|
||||||
"--height",
|
|
||||||
default=None,
|
|
||||||
type=int,
|
|
||||||
show_default=True,
|
|
||||||
help="output image size",
|
|
||||||
)
|
)
|
||||||
@click.argument(
|
@click.argument(
|
||||||
"input", default=(None if sys.stdin.isatty() else "-"), type=click.File("rb")
|
"input", default=(None if sys.stdin.isatty() else "-"), type=click.File("rb")
|
||||||
@ -143,28 +126,11 @@ def i(model: str, input: IO, output: IO, **kwargs):
|
|||||||
help="erode size",
|
help="erode size",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"-az",
|
"-om",
|
||||||
"--alpha-matting-base-size",
|
"--only-mask",
|
||||||
default=1000,
|
is_flag=True,
|
||||||
type=int,
|
|
||||||
show_default=True,
|
show_default=True,
|
||||||
help="image base size",
|
help="output only the mask",
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"-w",
|
|
||||||
"--width",
|
|
||||||
default=None,
|
|
||||||
type=int,
|
|
||||||
show_default=True,
|
|
||||||
help="output image size",
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"-h",
|
|
||||||
"--height",
|
|
||||||
default=None,
|
|
||||||
type=int,
|
|
||||||
show_default=True,
|
|
||||||
help="output image size",
|
|
||||||
)
|
)
|
||||||
@click.argument(
|
@click.argument(
|
||||||
"input",
|
"input",
|
||||||
@ -240,18 +206,14 @@ def s(port: int, log_level: str):
|
|||||||
af: int = Query(240, ge=0),
|
af: int = Query(240, ge=0),
|
||||||
ab: int = Query(10, ge=0),
|
ab: int = Query(10, ge=0),
|
||||||
ae: int = Query(10, ge=0),
|
ae: int = Query(10, ge=0),
|
||||||
az: int = Query(1000, ge=0),
|
om: bool = Query(False),
|
||||||
width: Optional[int] = Query(None, gt=0),
|
|
||||||
height: Optional[int] = Query(None, gt=0),
|
|
||||||
):
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.width = width
|
|
||||||
self.height = height
|
|
||||||
self.a = a
|
self.a = a
|
||||||
self.af = af
|
self.af = af
|
||||||
self.ab = ab
|
self.ab = ab
|
||||||
self.ae = ae
|
self.ae = ae
|
||||||
self.az = az
|
self.om = om
|
||||||
|
|
||||||
def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
|
def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
|
||||||
return Response(
|
return Response(
|
||||||
@ -260,13 +222,11 @@ def s(port: int, log_level: str):
|
|||||||
session=sessions.setdefault(
|
session=sessions.setdefault(
|
||||||
commons.model.value, ort_session(commons.model.value)
|
commons.model.value, ort_session(commons.model.value)
|
||||||
),
|
),
|
||||||
width=commons.width,
|
|
||||||
height=commons.height,
|
|
||||||
alpha_matting=commons.a,
|
alpha_matting=commons.a,
|
||||||
alpha_matting_foreground_threshold=commons.af,
|
alpha_matting_foreground_threshold=commons.af,
|
||||||
alpha_matting_background_threshold=commons.ab,
|
alpha_matting_background_threshold=commons.ab,
|
||||||
alpha_matting_erode_size=commons.ae,
|
alpha_matting_erode_size=commons.ae,
|
||||||
alpha_matting_base_size=commons.az,
|
only_mask=commons.om,
|
||||||
),
|
),
|
||||||
media_type="image/png",
|
media_type="image/png",
|
||||||
)
|
)
|
||||||
|
@ -1 +0,0 @@
|
|||||||
onnxruntime==1.10.0
|
|
@ -5,6 +5,7 @@ fastapi==0.72.0
|
|||||||
filetype==1.0.9
|
filetype==1.0.9
|
||||||
gdown==4.2.0
|
gdown==4.2.0
|
||||||
numpy==1.21.5
|
numpy==1.21.5
|
||||||
|
onnxruntime==1.10.0
|
||||||
pillow==9.0.0
|
pillow==9.0.0
|
||||||
pymatting==1.1.5
|
pymatting==1.1.5
|
||||||
python-multipart==0.0.5
|
python-multipart==0.0.5
|
||||||
|
11
setup.py
11
setup.py
@ -14,12 +14,8 @@ long_description = (here / "README.md").read_text(encoding="utf-8")
|
|||||||
with open("requirements.txt") as f:
|
with open("requirements.txt") as f:
|
||||||
requireds = f.read().splitlines()
|
requireds = f.read().splitlines()
|
||||||
|
|
||||||
if os.getenv("GPU") is None:
|
with open("requirements-gpu.txt") as f:
|
||||||
with open("requirements-cpu.txt") as f:
|
gpu_requireds = f.read().splitlines()
|
||||||
requireds += f.read().splitlines()
|
|
||||||
else:
|
|
||||||
with open("requirements-gpu.txt") as f:
|
|
||||||
requireds += f.read().splitlines()
|
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="rembg",
|
name="rembg",
|
||||||
@ -42,6 +38,9 @@ setup(
|
|||||||
"rembg=rembg.cli:main",
|
"rembg=rembg.cli:main",
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
|
extras_require={
|
||||||
|
'gpu': gpu_requireds,
|
||||||
|
},
|
||||||
version=versioneer.get_version(),
|
version=versioneer.get_version(),
|
||||||
cmdclass=versioneer.get_cmdclass(),
|
cmdclass=versioneer.get_cmdclass(),
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user