mirror of
https://git.mirrors.martin98.com/https://github.com/danielgatis/rembg
synced 2025-08-05 16:50:38 +08:00
added input for remove function
This commit is contained in:
parent
3bdc06dff6
commit
d7828b0369
10
rembg/bg.py
10
rembg/bg.py
@ -123,6 +123,8 @@ def remove(
|
||||
only_mask: bool = False,
|
||||
post_process_mask: bool = False,
|
||||
bgcolor: Optional[Tuple[int, int, int, int]] = None,
|
||||
input_point: Optional[np.ndarray] = None,
|
||||
input_label: Optional[np.ndarray] = None,
|
||||
) -> Union[bytes, PILImage, np.ndarray]:
|
||||
if isinstance(data, PILImage):
|
||||
return_type = ReturnType.PILLOW
|
||||
@ -139,7 +141,13 @@ def remove(
|
||||
if session is None:
|
||||
session = new_session("u2net")
|
||||
|
||||
masks = session.predict(img)
|
||||
if session.model_name == "sam":
|
||||
if input_point is None or input_label is None:
|
||||
raise ValueError("Input point and label are required for SAM model.")
|
||||
masks = session.predict(img, input_point, input_label)
|
||||
else:
|
||||
masks = session.predict(img)
|
||||
|
||||
cutouts = []
|
||||
|
||||
for mask in masks:
|
||||
|
@ -64,7 +64,7 @@ def new_session(model_name: str = "u2net") -> BaseSession:
|
||||
md5 = "fc16ebd8b0c10d971d3513d564d01e29"
|
||||
url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx"
|
||||
session_class = DisSession
|
||||
elif model_name == "SAM":
|
||||
elif model_name == "sam":
|
||||
path = Path(u2net_home).expanduser()
|
||||
|
||||
fname_encoder = f"{model_name}_encoder.onnx"
|
||||
|
@ -71,8 +71,8 @@ class SamSession(BaseSession):
|
||||
def predict(
|
||||
self,
|
||||
img: PILImage,
|
||||
input_point=np.array([[500, 375]]),
|
||||
input_label=np.array([1]),
|
||||
input_point: np.ndarray,
|
||||
input_label: np.ndarray,
|
||||
) -> List[PILImage]:
|
||||
# Preprocess image
|
||||
image = resize_longes_side(img)
|
||||
|
Loading…
x
Reference in New Issue
Block a user