mirror of
https://git.mirrors.martin98.com/https://github.com/danielgatis/rembg
synced 2025-08-16 14:05:53 +08:00
fix linters
This commit is contained in:
parent
df3744b613
commit
a30359986a
12
rembg/bg.py
12
rembg/bg.py
@ -55,11 +55,11 @@ def alpha_matting_cutout(
|
|||||||
if img.mode == "RGBA" or img.mode == "CMYK":
|
if img.mode == "RGBA" or img.mode == "CMYK":
|
||||||
img = img.convert("RGB")
|
img = img.convert("RGB")
|
||||||
|
|
||||||
img = np.asarray(img)
|
img_array = np.asarray(img)
|
||||||
mask = np.asarray(mask)
|
mask_array = np.asarray(mask)
|
||||||
|
|
||||||
is_foreground = mask > foreground_threshold
|
is_foreground = mask_array > foreground_threshold
|
||||||
is_background = mask < background_threshold
|
is_background = mask_array < background_threshold
|
||||||
|
|
||||||
structure = None
|
structure = None
|
||||||
if erode_structure_size > 0:
|
if erode_structure_size > 0:
|
||||||
@ -70,11 +70,11 @@ def alpha_matting_cutout(
|
|||||||
is_foreground = binary_erosion(is_foreground, structure=structure)
|
is_foreground = binary_erosion(is_foreground, structure=structure)
|
||||||
is_background = binary_erosion(is_background, structure=structure, border_value=1)
|
is_background = binary_erosion(is_background, structure=structure, border_value=1)
|
||||||
|
|
||||||
trimap = np.full(mask.shape, dtype=np.uint8, fill_value=128)
|
trimap = np.full(mask_array.shape, dtype=np.uint8, fill_value=128)
|
||||||
trimap[is_foreground] = 255
|
trimap[is_foreground] = 255
|
||||||
trimap[is_background] = 0
|
trimap[is_background] = 0
|
||||||
|
|
||||||
img_normalized = img / 255.0
|
img_normalized = img_array / 255.0
|
||||||
trimap_normalized = trimap / 255.0
|
trimap_normalized = trimap / 255.0
|
||||||
|
|
||||||
alpha = estimate_alpha_cf(img_normalized, trimap_normalized)
|
alpha = estimate_alpha_cf(img_normalized, trimap_normalized)
|
||||||
|
@ -6,7 +6,7 @@ import sys
|
|||||||
from typing import IO
|
from typing import IO
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from PIL import Image
|
from PIL.Image import Image as PILImage
|
||||||
|
|
||||||
from ..bg import remove
|
from ..bg import remove
|
||||||
from ..session_factory import new_session
|
from ..session_factory import new_session
|
||||||
@ -134,7 +134,7 @@ def b_command(
|
|||||||
if not os.path.isdir(output_dir):
|
if not os.path.isdir(output_dir):
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
def img_to_byte_array(img: Image) -> bytes:
|
def img_to_byte_array(img: PILImage) -> bytes:
|
||||||
buff = io.BytesIO()
|
buff = io.BytesIO()
|
||||||
img.save(buff, format="PNG")
|
img.save(buff, format="PNG")
|
||||||
return buff.getvalue()
|
return buff.getvalue()
|
||||||
|
@ -186,9 +186,9 @@ def p_command(
|
|||||||
|
|
||||||
inputs = list(input.glob("**/*"))
|
inputs = list(input.glob("**/*"))
|
||||||
if not watch:
|
if not watch:
|
||||||
inputs = tqdm(inputs)
|
inputs_tqdm = tqdm(inputs)
|
||||||
|
|
||||||
for each_input in inputs:
|
for each_input in inputs_tqdm:
|
||||||
if not each_input.is_dir():
|
if not each_input.is_dir():
|
||||||
process(each_input)
|
process(each_input)
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import List
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -87,8 +87,9 @@ class SamSession(BaseSession):
|
|||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
sess_opts: ort.SessionOptions,
|
sess_opts: ort.SessionOptions,
|
||||||
|
providers=None,
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize a new SamSession with the given model name and session options.
|
Initialize a new SamSession with the given model name and session options.
|
||||||
@ -101,52 +102,27 @@ class SamSession(BaseSession):
|
|||||||
"""
|
"""
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
self.providers = []
|
valid_providers = []
|
||||||
|
available_providers = ort.get_available_providers()
|
||||||
|
|
||||||
_providers = ort.get_available_providers()
|
for provider in (providers or []):
|
||||||
for provider in kwargs.get("providers", []):
|
if provider in available_providers:
|
||||||
if provider in _providers:
|
valid_providers.append(provider)
|
||||||
self.providers.append(provider)
|
|
||||||
else:
|
else:
|
||||||
self.providers.extend(_providers)
|
valid_providers.extend(available_providers)
|
||||||
|
|
||||||
paths = self.__class__.download_models(*args, **kwargs)
|
paths = self.__class__.download_models(*args, **kwargs)
|
||||||
self.encoder = ort.InferenceSession(
|
self.encoder = ort.InferenceSession(
|
||||||
str(paths[0]),
|
str(paths[0]),
|
||||||
providers=self.providers,
|
providers=valid_providers,
|
||||||
sess_options=sess_opts,
|
sess_options=sess_opts,
|
||||||
)
|
)
|
||||||
self.decoder = ort.InferenceSession(
|
self.decoder = ort.InferenceSession(
|
||||||
str(paths[1]),
|
str(paths[1]),
|
||||||
providers=self.providers,
|
providers=valid_providers,
|
||||||
sess_options=sess_opts,
|
sess_options=sess_opts,
|
||||||
)
|
)
|
||||||
|
|
||||||
def normalize(
|
|
||||||
self,
|
|
||||||
img: np.ndarray,
|
|
||||||
mean=(),
|
|
||||||
std=(),
|
|
||||||
size=(),
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Normalize the input image by subtracting the mean and dividing by the standard deviation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
img (np.ndarray): The input image.
|
|
||||||
mean (tuple, optional): The mean values for normalization. Defaults to ().
|
|
||||||
std (tuple, optional): The standard deviation values for normalization. Defaults to ().
|
|
||||||
size (tuple, optional): The target size of the image. Defaults to ().
|
|
||||||
*args: Variable length argument list.
|
|
||||||
**kwargs: Arbitrary keyword arguments.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: The normalized image.
|
|
||||||
"""
|
|
||||||
return img
|
|
||||||
|
|
||||||
def predict(
|
def predict(
|
||||||
self,
|
self,
|
||||||
img: PILImage,
|
img: PILImage,
|
||||||
@ -269,8 +245,7 @@ class SamSession(BaseSession):
|
|||||||
for m in masks[0, :, :, :]:
|
for m in masks[0, :, :, :]:
|
||||||
mask[m > 0.0] = [255, 255, 255]
|
mask[m > 0.0] = [255, 255, 255]
|
||||||
|
|
||||||
mask = Image.fromarray(mask).convert("L")
|
return [Image.fromarray(mask).convert("L")]
|
||||||
return [mask]
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def download_models(cls, *args, **kwargs):
|
def download_models(cls, *args, **kwargs):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user