fix linters

This commit is contained in:
Daniel Gatis 2024-05-23 14:58:41 -03:00
parent df3744b613
commit a30359986a
4 changed files with 22 additions and 47 deletions

View File

@ -55,11 +55,11 @@ def alpha_matting_cutout(
if img.mode == "RGBA" or img.mode == "CMYK":
img = img.convert("RGB")
img = np.asarray(img)
mask = np.asarray(mask)
img_array = np.asarray(img)
mask_array = np.asarray(mask)
is_foreground = mask > foreground_threshold
is_background = mask < background_threshold
is_foreground = mask_array > foreground_threshold
is_background = mask_array < background_threshold
structure = None
if erode_structure_size > 0:
@ -70,11 +70,11 @@ def alpha_matting_cutout(
is_foreground = binary_erosion(is_foreground, structure=structure)
is_background = binary_erosion(is_background, structure=structure, border_value=1)
trimap = np.full(mask.shape, dtype=np.uint8, fill_value=128)
trimap = np.full(mask_array.shape, dtype=np.uint8, fill_value=128)
trimap[is_foreground] = 255
trimap[is_background] = 0
img_normalized = img / 255.0
img_normalized = img_array / 255.0
trimap_normalized = trimap / 255.0
alpha = estimate_alpha_cf(img_normalized, trimap_normalized)

View File

@ -6,7 +6,7 @@ import sys
from typing import IO
import click
from PIL import Image
from PIL.Image import Image as PILImage
from ..bg import remove
from ..session_factory import new_session
@ -134,7 +134,7 @@ def b_command(
if not os.path.isdir(output_dir):
os.makedirs(output_dir, exist_ok=True)
def img_to_byte_array(img: Image) -> bytes:
def img_to_byte_array(img: PILImage) -> bytes:
buff = io.BytesIO()
img.save(buff, format="PNG")
return buff.getvalue()

View File

@ -186,9 +186,9 @@ def p_command(
inputs = list(input.glob("**/*"))
if not watch:
inputs = tqdm(inputs)
inputs_tqdm = tqdm(inputs)
for each_input in inputs:
for each_input in inputs_tqdm:
if not each_input.is_dir():
process(each_input)

View File

@ -1,6 +1,6 @@
import os
from copy import deepcopy
from typing import List
from typing import Dict, List, Tuple
import cv2
import numpy as np
@ -87,8 +87,9 @@ class SamSession(BaseSession):
self,
model_name: str,
sess_opts: ort.SessionOptions,
providers=None,
*args,
**kwargs,
**kwargs
):
"""
Initialize a new SamSession with the given model name and session options.
@ -101,52 +102,27 @@ class SamSession(BaseSession):
"""
self.model_name = model_name
self.providers = []
valid_providers = []
available_providers = ort.get_available_providers()
_providers = ort.get_available_providers()
for provider in kwargs.get("providers", []):
if provider in _providers:
self.providers.append(provider)
for provider in (providers or []):
if provider in available_providers:
valid_providers.append(provider)
else:
self.providers.extend(_providers)
valid_providers.extend(available_providers)
paths = self.__class__.download_models(*args, **kwargs)
self.encoder = ort.InferenceSession(
str(paths[0]),
providers=self.providers,
providers=valid_providers,
sess_options=sess_opts,
)
self.decoder = ort.InferenceSession(
str(paths[1]),
providers=self.providers,
providers=valid_providers,
sess_options=sess_opts,
)
def normalize(
self,
img: np.ndarray,
mean=(),
std=(),
size=(),
*args,
**kwargs,
):
"""
Normalize the input image by subtracting the mean and dividing by the standard deviation.
Args:
img (np.ndarray): The input image.
mean (tuple, optional): The mean values for normalization. Defaults to ().
std (tuple, optional): The standard deviation values for normalization. Defaults to ().
size (tuple, optional): The target size of the image. Defaults to ().
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
np.ndarray: The normalized image.
"""
return img
def predict(
self,
img: PILImage,
@ -269,8 +245,7 @@ class SamSession(BaseSession):
for m in masks[0, :, :, :]:
mask[m > 0.0] = [255, 255, 255]
mask = Image.fromarray(mask).convert("L")
return [mask]
return [Image.fromarray(mask).convert("L")]
@classmethod
def download_models(cls, *args, **kwargs):