fix pylint

This commit is contained in:
Flippchen 2023-04-20 12:47:14 +02:00
parent d7828b0369
commit 394ab21ab9
2 changed files with 5 additions and 4 deletions

View File

@ -20,6 +20,7 @@ from scipy.ndimage import binary_erosion
from .session_base import BaseSession
from .session_factory import new_session
from .session_sam import SamSession
kernel = getStructuringElement(MORPH_ELLIPSE, (3, 3))
@ -119,7 +120,7 @@ def remove(
alpha_matting_foreground_threshold: int = 240,
alpha_matting_background_threshold: int = 10,
alpha_matting_erode_size: int = 10,
session: Optional[BaseSession] = None,
session: Optional[Union[BaseSession, SamSession]] = None,
only_mask: bool = False,
post_process_mask: bool = False,
bgcolor: Optional[Tuple[int, int, int, int]] = None,
@ -141,10 +142,10 @@ def remove(
if session is None:
session = new_session("u2net")
if session.model_name == "sam":
if isinstance(session, SamSession):
if input_point is None or input_label is None:
raise ValueError("Input point and label are required for SAM model.")
masks = session.predict(img, input_point, input_label)
masks = session.predict_sam(img, input_point, input_label)
else:
masks = session.predict(img)

View File

@ -68,7 +68,7 @@ class SamSession(BaseSession):
x = (img - pixel_mean) / pixel_std
return x
def predict(
def predict_sam(
self,
img: PILImage,
input_point: np.ndarray,