mirror of
https://git.mirrors.martin98.com/https://github.com/danielgatis/rembg
synced 2025-08-16 15:35:54 +08:00
fix linters
This commit is contained in:
parent
df3744b613
commit
a30359986a
12
rembg/bg.py
12
rembg/bg.py
@ -55,11 +55,11 @@ def alpha_matting_cutout(
|
||||
if img.mode == "RGBA" or img.mode == "CMYK":
|
||||
img = img.convert("RGB")
|
||||
|
||||
img = np.asarray(img)
|
||||
mask = np.asarray(mask)
|
||||
img_array = np.asarray(img)
|
||||
mask_array = np.asarray(mask)
|
||||
|
||||
is_foreground = mask > foreground_threshold
|
||||
is_background = mask < background_threshold
|
||||
is_foreground = mask_array > foreground_threshold
|
||||
is_background = mask_array < background_threshold
|
||||
|
||||
structure = None
|
||||
if erode_structure_size > 0:
|
||||
@ -70,11 +70,11 @@ def alpha_matting_cutout(
|
||||
is_foreground = binary_erosion(is_foreground, structure=structure)
|
||||
is_background = binary_erosion(is_background, structure=structure, border_value=1)
|
||||
|
||||
trimap = np.full(mask.shape, dtype=np.uint8, fill_value=128)
|
||||
trimap = np.full(mask_array.shape, dtype=np.uint8, fill_value=128)
|
||||
trimap[is_foreground] = 255
|
||||
trimap[is_background] = 0
|
||||
|
||||
img_normalized = img / 255.0
|
||||
img_normalized = img_array / 255.0
|
||||
trimap_normalized = trimap / 255.0
|
||||
|
||||
alpha = estimate_alpha_cf(img_normalized, trimap_normalized)
|
||||
|
@ -6,7 +6,7 @@ import sys
|
||||
from typing import IO
|
||||
|
||||
import click
|
||||
from PIL import Image
|
||||
from PIL.Image import Image as PILImage
|
||||
|
||||
from ..bg import remove
|
||||
from ..session_factory import new_session
|
||||
@ -134,7 +134,7 @@ def b_command(
|
||||
if not os.path.isdir(output_dir):
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
def img_to_byte_array(img: Image) -> bytes:
|
||||
def img_to_byte_array(img: PILImage) -> bytes:
|
||||
buff = io.BytesIO()
|
||||
img.save(buff, format="PNG")
|
||||
return buff.getvalue()
|
||||
|
@ -186,9 +186,9 @@ def p_command(
|
||||
|
||||
inputs = list(input.glob("**/*"))
|
||||
if not watch:
|
||||
inputs = tqdm(inputs)
|
||||
inputs_tqdm = tqdm(inputs)
|
||||
|
||||
for each_input in inputs:
|
||||
for each_input in inputs_tqdm:
|
||||
if not each_input.is_dir():
|
||||
process(each_input)
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import List
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
@ -87,8 +87,9 @@ class SamSession(BaseSession):
|
||||
self,
|
||||
model_name: str,
|
||||
sess_opts: ort.SessionOptions,
|
||||
providers=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Initialize a new SamSession with the given model name and session options.
|
||||
@ -101,52 +102,27 @@ class SamSession(BaseSession):
|
||||
"""
|
||||
self.model_name = model_name
|
||||
|
||||
self.providers = []
|
||||
valid_providers = []
|
||||
available_providers = ort.get_available_providers()
|
||||
|
||||
_providers = ort.get_available_providers()
|
||||
for provider in kwargs.get("providers", []):
|
||||
if provider in _providers:
|
||||
self.providers.append(provider)
|
||||
for provider in (providers or []):
|
||||
if provider in available_providers:
|
||||
valid_providers.append(provider)
|
||||
else:
|
||||
self.providers.extend(_providers)
|
||||
valid_providers.extend(available_providers)
|
||||
|
||||
paths = self.__class__.download_models(*args, **kwargs)
|
||||
self.encoder = ort.InferenceSession(
|
||||
str(paths[0]),
|
||||
providers=self.providers,
|
||||
providers=valid_providers,
|
||||
sess_options=sess_opts,
|
||||
)
|
||||
self.decoder = ort.InferenceSession(
|
||||
str(paths[1]),
|
||||
providers=self.providers,
|
||||
providers=valid_providers,
|
||||
sess_options=sess_opts,
|
||||
)
|
||||
|
||||
def normalize(
|
||||
self,
|
||||
img: np.ndarray,
|
||||
mean=(),
|
||||
std=(),
|
||||
size=(),
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Normalize the input image by subtracting the mean and dividing by the standard deviation.
|
||||
|
||||
Args:
|
||||
img (np.ndarray): The input image.
|
||||
mean (tuple, optional): The mean values for normalization. Defaults to ().
|
||||
std (tuple, optional): The standard deviation values for normalization. Defaults to ().
|
||||
size (tuple, optional): The target size of the image. Defaults to ().
|
||||
*args: Variable length argument list.
|
||||
**kwargs: Arbitrary keyword arguments.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The normalized image.
|
||||
"""
|
||||
return img
|
||||
|
||||
def predict(
|
||||
self,
|
||||
img: PILImage,
|
||||
@ -269,8 +245,7 @@ class SamSession(BaseSession):
|
||||
for m in masks[0, :, :, :]:
|
||||
mask[m > 0.0] = [255, 255, 255]
|
||||
|
||||
mask = Image.fromarray(mask).convert("L")
|
||||
return [mask]
|
||||
return [Image.fromarray(mask).convert("L")]
|
||||
|
||||
@classmethod
|
||||
def download_models(cls, *args, **kwargs):
|
||||
|
Loading…
x
Reference in New Issue
Block a user