mirror of
https://git.mirrors.martin98.com/https://github.com/danielgatis/rembg
synced 2025-08-05 18:50:36 +08:00
fix pylint
This commit is contained in:
parent
d7828b0369
commit
394ab21ab9
@ -20,6 +20,7 @@ from scipy.ndimage import binary_erosion
|
||||
|
||||
from .session_base import BaseSession
|
||||
from .session_factory import new_session
|
||||
from .session_sam import SamSession
|
||||
|
||||
kernel = getStructuringElement(MORPH_ELLIPSE, (3, 3))
|
||||
|
||||
@ -119,7 +120,7 @@ def remove(
|
||||
alpha_matting_foreground_threshold: int = 240,
|
||||
alpha_matting_background_threshold: int = 10,
|
||||
alpha_matting_erode_size: int = 10,
|
||||
session: Optional[BaseSession] = None,
|
||||
session: Optional[Union[BaseSession, SamSession]] = None,
|
||||
only_mask: bool = False,
|
||||
post_process_mask: bool = False,
|
||||
bgcolor: Optional[Tuple[int, int, int, int]] = None,
|
||||
@ -141,10 +142,10 @@ def remove(
|
||||
if session is None:
|
||||
session = new_session("u2net")
|
||||
|
||||
if session.model_name == "sam":
|
||||
if isinstance(session, SamSession):
|
||||
if input_point is None or input_label is None:
|
||||
raise ValueError("Input point and label are required for SAM model.")
|
||||
masks = session.predict(img, input_point, input_label)
|
||||
masks = session.predict_sam(img, input_point, input_label)
|
||||
else:
|
||||
masks = session.predict(img)
|
||||
|
||||
|
@ -68,7 +68,7 @@ class SamSession(BaseSession):
|
||||
x = (img - pixel_mean) / pixel_std
|
||||
return x
|
||||
|
||||
def predict(
|
||||
def predict_sam(
|
||||
self,
|
||||
img: PILImage,
|
||||
input_point: np.ndarray,
|
||||
|
Loading…
x
Reference in New Issue
Block a user