added input for remove function

This commit is contained in:
Flippchen 2023-04-20 12:39:57 +02:00
parent 3bdc06dff6
commit d7828b0369
3 changed files with 12 additions and 4 deletions

View File

@ -123,6 +123,8 @@ def remove(
only_mask: bool = False,
post_process_mask: bool = False,
bgcolor: Optional[Tuple[int, int, int, int]] = None,
input_point: Optional[np.ndarray] = None,
input_label: Optional[np.ndarray] = None,
) -> Union[bytes, PILImage, np.ndarray]:
if isinstance(data, PILImage):
return_type = ReturnType.PILLOW
@ -139,7 +141,13 @@ def remove(
if session is None:
session = new_session("u2net")
masks = session.predict(img)
if session.model_name == "sam":
if input_point is None or input_label is None:
raise ValueError("Input point and label are required for SAM model.")
masks = session.predict(img, input_point, input_label)
else:
masks = session.predict(img)
cutouts = []
for mask in masks:

View File

@ -64,7 +64,7 @@ def new_session(model_name: str = "u2net") -> BaseSession:
md5 = "fc16ebd8b0c10d971d3513d564d01e29"
url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx"
session_class = DisSession
elif model_name == "SAM":
elif model_name == "sam":
path = Path(u2net_home).expanduser()
fname_encoder = f"{model_name}_encoder.onnx"

View File

@ -71,8 +71,8 @@ class SamSession(BaseSession):
def predict(
self,
img: PILImage,
input_point=np.array([[500, 375]]),
input_label=np.array([1]),
input_point: np.ndarray,
input_label: np.ndarray,
) -> List[PILImage]:
# Preprocess image
image = resize_longes_side(img)