fix linters

This commit is contained in:
Daniel Gatis 2024-05-23 14:58:41 -03:00
parent df3744b613
commit a30359986a
4 changed files with 22 additions and 47 deletions

View File

@ -55,11 +55,11 @@ def alpha_matting_cutout(
if img.mode == "RGBA" or img.mode == "CMYK": if img.mode == "RGBA" or img.mode == "CMYK":
img = img.convert("RGB") img = img.convert("RGB")
img = np.asarray(img) img_array = np.asarray(img)
mask = np.asarray(mask) mask_array = np.asarray(mask)
is_foreground = mask > foreground_threshold is_foreground = mask_array > foreground_threshold
is_background = mask < background_threshold is_background = mask_array < background_threshold
structure = None structure = None
if erode_structure_size > 0: if erode_structure_size > 0:
@ -70,11 +70,11 @@ def alpha_matting_cutout(
is_foreground = binary_erosion(is_foreground, structure=structure) is_foreground = binary_erosion(is_foreground, structure=structure)
is_background = binary_erosion(is_background, structure=structure, border_value=1) is_background = binary_erosion(is_background, structure=structure, border_value=1)
trimap = np.full(mask.shape, dtype=np.uint8, fill_value=128) trimap = np.full(mask_array.shape, dtype=np.uint8, fill_value=128)
trimap[is_foreground] = 255 trimap[is_foreground] = 255
trimap[is_background] = 0 trimap[is_background] = 0
img_normalized = img / 255.0 img_normalized = img_array / 255.0
trimap_normalized = trimap / 255.0 trimap_normalized = trimap / 255.0
alpha = estimate_alpha_cf(img_normalized, trimap_normalized) alpha = estimate_alpha_cf(img_normalized, trimap_normalized)

View File

@ -6,7 +6,7 @@ import sys
from typing import IO from typing import IO
import click import click
from PIL import Image from PIL.Image import Image as PILImage
from ..bg import remove from ..bg import remove
from ..session_factory import new_session from ..session_factory import new_session
@ -134,7 +134,7 @@ def b_command(
if not os.path.isdir(output_dir): if not os.path.isdir(output_dir):
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
def img_to_byte_array(img: Image) -> bytes: def img_to_byte_array(img: PILImage) -> bytes:
buff = io.BytesIO() buff = io.BytesIO()
img.save(buff, format="PNG") img.save(buff, format="PNG")
return buff.getvalue() return buff.getvalue()

View File

@ -186,9 +186,9 @@ def p_command(
inputs = list(input.glob("**/*")) inputs = list(input.glob("**/*"))
if not watch: if not watch:
inputs = tqdm(inputs) inputs_tqdm = tqdm(inputs)
for each_input in inputs: for each_input in inputs_tqdm:
if not each_input.is_dir(): if not each_input.is_dir():
process(each_input) process(each_input)

View File

@ -1,6 +1,6 @@
import os import os
from copy import deepcopy from copy import deepcopy
from typing import List from typing import Dict, List, Tuple
import cv2 import cv2
import numpy as np import numpy as np
@ -87,8 +87,9 @@ class SamSession(BaseSession):
self, self,
model_name: str, model_name: str,
sess_opts: ort.SessionOptions, sess_opts: ort.SessionOptions,
providers=None,
*args, *args,
**kwargs, **kwargs
): ):
""" """
Initialize a new SamSession with the given model name and session options. Initialize a new SamSession with the given model name and session options.
@ -101,52 +102,27 @@ class SamSession(BaseSession):
""" """
self.model_name = model_name self.model_name = model_name
self.providers = [] valid_providers = []
available_providers = ort.get_available_providers()
_providers = ort.get_available_providers() for provider in (providers or []):
for provider in kwargs.get("providers", []): if provider in available_providers:
if provider in _providers: valid_providers.append(provider)
self.providers.append(provider)
else: else:
self.providers.extend(_providers) valid_providers.extend(available_providers)
paths = self.__class__.download_models(*args, **kwargs) paths = self.__class__.download_models(*args, **kwargs)
self.encoder = ort.InferenceSession( self.encoder = ort.InferenceSession(
str(paths[0]), str(paths[0]),
providers=self.providers, providers=valid_providers,
sess_options=sess_opts, sess_options=sess_opts,
) )
self.decoder = ort.InferenceSession( self.decoder = ort.InferenceSession(
str(paths[1]), str(paths[1]),
providers=self.providers, providers=valid_providers,
sess_options=sess_opts, sess_options=sess_opts,
) )
def normalize(
self,
img: np.ndarray,
mean=(),
std=(),
size=(),
*args,
**kwargs,
):
"""
Normalize the input image by subtracting the mean and dividing by the standard deviation.
Args:
img (np.ndarray): The input image.
mean (tuple, optional): The mean values for normalization. Defaults to ().
std (tuple, optional): The standard deviation values for normalization. Defaults to ().
size (tuple, optional): The target size of the image. Defaults to ().
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
np.ndarray: The normalized image.
"""
return img
def predict( def predict(
self, self,
img: PILImage, img: PILImage,
@ -269,8 +245,7 @@ class SamSession(BaseSession):
for m in masks[0, :, :, :]: for m in masks[0, :, :, :]:
mask[m > 0.0] = [255, 255, 255] mask[m > 0.0] = [255, 255, 255]
mask = Image.fromarray(mask).convert("L") return [Image.fromarray(mask).convert("L")]
return [mask]
@classmethod @classmethod
def download_models(cls, *args, **kwargs): def download_models(cls, *args, **kwargs):