fix sam session (#531)
@ -159,10 +159,14 @@ rembg i -a path/to/input.png path/to/output.png
|
|||||||
Passing extras parameters
|
Passing extras parameters
|
||||||
|
|
||||||
```
|
```
|
||||||
rembg i -m sam -x '{"input_labels": [1], "input_points": [[100,100]]}' path/to/input.png path/to/output.png
|
SAM example
|
||||||
|
|
||||||
|
rembg i -m sam -x '{ "sam_prompt": [{"type": "point", "data": [724, 740], "label": 1}] }' examples/plants-1.jpg examples/plants-1.out.png
|
||||||
```
|
```
|
||||||
|
|
||||||
```
|
```
|
||||||
|
Custom model example
|
||||||
|
|
||||||
rembg i -m u2net_custom -x '{"model_path": "~/.u2net/u2net.onnx"}' path/to/input.png path/to/output.png
|
rembg i -m u2net_custom -x '{"model_path": "~/.u2net/u2net.onnx"}' path/to/input.png path/to/output.png
|
||||||
```
|
```
|
||||||
|
|
||||||
|
BIN
examples/plants-1.jpg
Normal file
After Width: | Height: | Size: 1.4 MiB |
BIN
examples/plants-1.out.png
Normal file
After Width: | Height: | Size: 29 KiB |
@ -1,9 +1,12 @@
|
|||||||
import os
|
import os
|
||||||
|
from copy import deepcopy
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
import pooch
|
import pooch
|
||||||
|
from jsonschema import validate
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from PIL.Image import Image as PILImage
|
from PIL.Image import Image as PILImage
|
||||||
|
|
||||||
@ -15,37 +18,58 @@ def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int):
|
|||||||
newh, neww = oldh * scale, oldw * scale
|
newh, neww = oldh * scale, oldw * scale
|
||||||
neww = int(neww + 0.5)
|
neww = int(neww + 0.5)
|
||||||
newh = int(newh + 0.5)
|
newh = int(newh + 0.5)
|
||||||
|
|
||||||
return (newh, neww)
|
return (newh, neww)
|
||||||
|
|
||||||
|
|
||||||
def apply_coords(coords: np.ndarray, original_size, target_length) -> np.ndarray:
|
def apply_coords(coords: np.ndarray, original_size, target_length):
|
||||||
old_h, old_w = original_size
|
old_h, old_w = original_size
|
||||||
new_h, new_w = get_preprocess_shape(
|
new_h, new_w = get_preprocess_shape(
|
||||||
original_size[0], original_size[1], target_length
|
original_size[0], original_size[1], target_length
|
||||||
)
|
)
|
||||||
coords = coords.copy().astype(float)
|
|
||||||
|
coords = deepcopy(coords).astype(float)
|
||||||
coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
||||||
coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
||||||
|
|
||||||
return coords
|
return coords
|
||||||
|
|
||||||
|
|
||||||
def resize_longes_side(img: PILImage, size=1024):
|
def get_input_points(prompt):
|
||||||
w, h = img.size
|
points = []
|
||||||
if h > w:
|
labels = []
|
||||||
new_h, new_w = size, int(w * size / h)
|
|
||||||
else:
|
|
||||||
new_h, new_w = int(h * size / w), size
|
|
||||||
|
|
||||||
return img.resize((new_w, new_h))
|
for mark in prompt:
|
||||||
|
if mark["type"] == "point":
|
||||||
|
points.append(mark["data"])
|
||||||
|
labels.append(mark["label"])
|
||||||
|
elif mark["type"] == "rectangle":
|
||||||
|
points.append([mark["data"][0], mark["data"][1]])
|
||||||
|
points.append([mark["data"][2], mark["data"][3]])
|
||||||
|
labels.append(2)
|
||||||
|
labels.append(3)
|
||||||
|
|
||||||
|
points, labels = np.array(points), np.array(labels)
|
||||||
|
return points, labels
|
||||||
|
|
||||||
|
|
||||||
def pad_to_square(img: np.ndarray, size=1024):
|
def transform_masks(masks, original_size, transform_matrix):
|
||||||
h, w = img.shape[:2]
|
output_masks = []
|
||||||
padh = size - h
|
|
||||||
padw = size - w
|
for batch in range(masks.shape[0]):
|
||||||
img = np.pad(img, ((0, padh), (0, padw), (0, 0)), mode="constant")
|
batch_masks = []
|
||||||
img = img.astype(np.float32)
|
for mask_id in range(masks.shape[1]):
|
||||||
return img
|
mask = masks[batch, mask_id]
|
||||||
|
mask = cv2.warpAffine(
|
||||||
|
mask,
|
||||||
|
transform_matrix[:2],
|
||||||
|
(original_size[1], original_size[0]),
|
||||||
|
flags=cv2.INTER_LINEAR,
|
||||||
|
)
|
||||||
|
batch_masks.append(mask)
|
||||||
|
output_masks.append(batch_masks)
|
||||||
|
|
||||||
|
return np.array(output_masks)
|
||||||
|
|
||||||
|
|
||||||
class SamSession(BaseSession):
|
class SamSession(BaseSession):
|
||||||
@ -70,7 +94,7 @@ class SamSession(BaseSession):
|
|||||||
**kwargs: Arbitrary keyword arguments.
|
**kwargs: Arbitrary keyword arguments.
|
||||||
"""
|
"""
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
paths = self.__class__.download_models()
|
paths = self.__class__.download_models(*args, **kwargs)
|
||||||
self.encoder = ort.InferenceSession(
|
self.encoder = ort.InferenceSession(
|
||||||
str(paths[0]),
|
str(paths[0]),
|
||||||
providers=ort.get_available_providers(),
|
providers=ort.get_available_providers(),
|
||||||
@ -85,9 +109,9 @@ class SamSession(BaseSession):
|
|||||||
def normalize(
|
def normalize(
|
||||||
self,
|
self,
|
||||||
img: np.ndarray,
|
img: np.ndarray,
|
||||||
mean=(123.675, 116.28, 103.53),
|
mean=(),
|
||||||
std=(58.395, 57.12, 57.375),
|
std=(),
|
||||||
size=(1024, 1024),
|
size=(),
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@ -96,19 +120,16 @@ class SamSession(BaseSession):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
img (np.ndarray): The input image.
|
img (np.ndarray): The input image.
|
||||||
mean (tuple, optional): The mean values for normalization. Defaults to (123.675, 116.28, 103.53).
|
mean (tuple, optional): The mean values for normalization. Defaults to ().
|
||||||
std (tuple, optional): The standard deviation values for normalization. Defaults to (58.395, 57.12, 57.375).
|
std (tuple, optional): The standard deviation values for normalization. Defaults to ().
|
||||||
size (tuple, optional): The target size of the image. Defaults to (1024, 1024).
|
size (tuple, optional): The target size of the image. Defaults to ().
|
||||||
*args: Variable length argument list.
|
*args: Variable length argument list.
|
||||||
**kwargs: Arbitrary keyword arguments.
|
**kwargs: Arbitrary keyword arguments.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
np.ndarray: The normalized image.
|
np.ndarray: The normalized image.
|
||||||
"""
|
"""
|
||||||
pixel_mean = np.array([*mean]).reshape(1, 1, -1)
|
return img
|
||||||
pixel_std = np.array([*std]).reshape(1, 1, -1)
|
|
||||||
x = (img - pixel_mean) / pixel_std
|
|
||||||
return x
|
|
||||||
|
|
||||||
def predict(
|
def predict(
|
||||||
self,
|
self,
|
||||||
@ -129,36 +150,89 @@ class SamSession(BaseSession):
|
|||||||
Returns:
|
Returns:
|
||||||
List[PILImage]: A list of masks generated by the decoder.
|
List[PILImage]: A list of masks generated by the decoder.
|
||||||
"""
|
"""
|
||||||
# Preprocess image
|
prompt = kwargs.get("sam_prompt", "{}")
|
||||||
image = resize_longes_side(img)
|
schema = {
|
||||||
image = np.array(image)
|
"type": "array",
|
||||||
image = self.normalize(image)
|
"items": {
|
||||||
image = pad_to_square(image)
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"type": {"type": "string"},
|
||||||
|
"label": {"type": "integer"},
|
||||||
|
"data": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "number"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
input_labels = kwargs.get("input_labels")
|
validate(instance=prompt, schema=schema)
|
||||||
input_points = kwargs.get("input_points")
|
|
||||||
|
|
||||||
if input_labels is None:
|
target_size = 1024
|
||||||
raise ValueError("input_labels is required")
|
input_size = (684, 1024)
|
||||||
if input_points is None:
|
encoder_input_name = self.encoder.get_inputs()[0].name
|
||||||
raise ValueError("input_points is required")
|
|
||||||
|
|
||||||
# Transpose
|
img = img.convert("RGB")
|
||||||
image = image.transpose(2, 0, 1)[None, :, :, :]
|
cv_image = np.array(img)
|
||||||
# Run encoder (Image embedding)
|
original_size = cv_image.shape[:2]
|
||||||
encoded = self.encoder.run(None, {"x": image})
|
|
||||||
image_embedding = encoded[0]
|
|
||||||
|
|
||||||
# Add a batch index, concatenate a padding point, and transform.
|
scale_x = input_size[1] / cv_image.shape[1]
|
||||||
|
scale_y = input_size[0] / cv_image.shape[0]
|
||||||
|
scale = min(scale_x, scale_y)
|
||||||
|
|
||||||
|
transform_matrix = np.array(
|
||||||
|
[
|
||||||
|
[scale, 0, 0],
|
||||||
|
[0, scale, 0],
|
||||||
|
[0, 0, 1],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
cv_image = cv2.warpAffine(
|
||||||
|
cv_image,
|
||||||
|
transform_matrix[:2],
|
||||||
|
(input_size[1], input_size[0]),
|
||||||
|
flags=cv2.INTER_LINEAR,
|
||||||
|
)
|
||||||
|
|
||||||
|
## encoder
|
||||||
|
|
||||||
|
encoder_inputs = {
|
||||||
|
encoder_input_name: cv_image.astype(np.float32),
|
||||||
|
}
|
||||||
|
|
||||||
|
encoder_output = self.encoder.run(None, encoder_inputs)
|
||||||
|
image_embedding = encoder_output[0]
|
||||||
|
|
||||||
|
embedding = {
|
||||||
|
"image_embedding": image_embedding,
|
||||||
|
"original_size": original_size,
|
||||||
|
"transform_matrix": transform_matrix,
|
||||||
|
}
|
||||||
|
|
||||||
|
## decoder
|
||||||
|
|
||||||
|
input_points, input_labels = get_input_points(prompt)
|
||||||
onnx_coord = np.concatenate([input_points, np.array([[0.0, 0.0]])], axis=0)[
|
onnx_coord = np.concatenate([input_points, np.array([[0.0, 0.0]])], axis=0)[
|
||||||
None, :, :
|
None, :, :
|
||||||
]
|
]
|
||||||
onnx_label = np.concatenate([input_labels, np.array([-1])], axis=0)[
|
onnx_label = np.concatenate([input_labels, np.array([-1])], axis=0)[
|
||||||
None, :
|
None, :
|
||||||
].astype(np.float32)
|
].astype(np.float32)
|
||||||
onnx_coord = apply_coords(onnx_coord, img.size[::1], 1024).astype(np.float32)
|
onnx_coord = apply_coords(onnx_coord, input_size, target_size).astype(
|
||||||
|
np.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
onnx_coord = np.concatenate(
|
||||||
|
[
|
||||||
|
onnx_coord,
|
||||||
|
np.ones((1, onnx_coord.shape[1], 1), dtype=np.float32),
|
||||||
|
],
|
||||||
|
axis=2,
|
||||||
|
)
|
||||||
|
onnx_coord = np.matmul(onnx_coord, transform_matrix.T)
|
||||||
|
onnx_coord = onnx_coord[:, :, :2].astype(np.float32)
|
||||||
|
|
||||||
# Create an empty mask input and an indicator for no mask.
|
|
||||||
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
|
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
|
||||||
onnx_has_mask_input = np.zeros(1, dtype=np.float32)
|
onnx_has_mask_input = np.zeros(1, dtype=np.float32)
|
||||||
|
|
||||||
@ -168,17 +242,19 @@ class SamSession(BaseSession):
|
|||||||
"point_labels": onnx_label,
|
"point_labels": onnx_label,
|
||||||
"mask_input": onnx_mask_input,
|
"mask_input": onnx_mask_input,
|
||||||
"has_mask_input": onnx_has_mask_input,
|
"has_mask_input": onnx_has_mask_input,
|
||||||
"orig_im_size": np.array(img.size[::-1], dtype=np.float32),
|
"orig_im_size": np.array(input_size, dtype=np.float32),
|
||||||
}
|
}
|
||||||
|
|
||||||
masks, _, low_res_logits = self.decoder.run(None, decoder_inputs)
|
masks, _, _ = self.decoder.run(None, decoder_inputs)
|
||||||
masks = masks > 0.0
|
inv_transform_matrix = np.linalg.inv(transform_matrix)
|
||||||
masks = [
|
masks = transform_masks(masks, original_size, inv_transform_matrix)
|
||||||
Image.fromarray((masks[i, 0] * 255).astype(np.uint8))
|
|
||||||
for i in range(masks.shape[0])
|
|
||||||
]
|
|
||||||
|
|
||||||
return masks
|
mask = np.zeros((masks.shape[2], masks.shape[3], 3), dtype=np.uint8)
|
||||||
|
for m in masks[0, :, :, :]:
|
||||||
|
mask[m > 0.0] = [255, 255, 255]
|
||||||
|
|
||||||
|
mask = Image.fromarray(mask).convert("L")
|
||||||
|
return [mask]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def download_models(cls, *args, **kwargs):
|
def download_models(cls, *args, **kwargs):
|
||||||
@ -195,29 +271,64 @@ class SamSession(BaseSession):
|
|||||||
Returns:
|
Returns:
|
||||||
tuple: A tuple containing the file paths of the downloaded encoder and decoder models.
|
tuple: A tuple containing the file paths of the downloaded encoder and decoder models.
|
||||||
"""
|
"""
|
||||||
fname_encoder = f"{cls.name(*args, **kwargs)}_encoder.onnx"
|
model_name = kwargs.get("sam_model", "sam_vit_b_01ec64")
|
||||||
fname_decoder = f"{cls.name(*args, **kwargs)}_decoder.onnx"
|
quant = kwargs.get("sam_quant", False)
|
||||||
|
|
||||||
|
fname_encoder = f"{model_name}.encoder.onnx"
|
||||||
|
fname_decoder = f"{model_name}.decoder.onnx"
|
||||||
|
|
||||||
|
if quant:
|
||||||
|
fname_encoder = f"{model_name}.encoder.quant.onnx"
|
||||||
|
fname_decoder = f"{model_name}.decoder.quant.onnx"
|
||||||
|
|
||||||
pooch.retrieve(
|
pooch.retrieve(
|
||||||
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-encoder-quant.onnx",
|
f"https://github.com/danielgatis/rembg/releases/download/v0.0.0/{fname_encoder}",
|
||||||
None
|
None,
|
||||||
if cls.checksum_disabled(*args, **kwargs)
|
|
||||||
else "md5:13d97c5c79ab13ef86d67cbde5f1b250",
|
|
||||||
fname=fname_encoder,
|
fname=fname_encoder,
|
||||||
path=cls.u2net_home(*args, **kwargs),
|
path=cls.u2net_home(*args, **kwargs),
|
||||||
progressbar=True,
|
progressbar=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
pooch.retrieve(
|
pooch.retrieve(
|
||||||
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-decoder-quant.onnx",
|
f"https://github.com/danielgatis/rembg/releases/download/v0.0.0/{fname_decoder}",
|
||||||
None
|
None,
|
||||||
if cls.checksum_disabled(*args, **kwargs)
|
|
||||||
else "md5:fa3d1c36a3187d3de1c8deebf33dd127",
|
|
||||||
fname=fname_decoder,
|
fname=fname_decoder,
|
||||||
path=cls.u2net_home(*args, **kwargs),
|
path=cls.u2net_home(*args, **kwargs),
|
||||||
progressbar=True,
|
progressbar=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if fname_encoder == "sam_vit_h_4b8939.encoder.onnx" and not os.path.exists(
|
||||||
|
os.path.join(
|
||||||
|
cls.u2net_home(*args, **kwargs), "sam_vit_h_4b8939.encoder_data.bin"
|
||||||
|
)
|
||||||
|
):
|
||||||
|
content = bytearray()
|
||||||
|
|
||||||
|
for i in range(1, 4):
|
||||||
|
pooch.retrieve(
|
||||||
|
f"https://github.com/danielgatis/rembg/releases/download/v0.0.0/sam_vit_h_4b8939.encoder_data.{i}.bin",
|
||||||
|
None,
|
||||||
|
fname=f"sam_vit_h_4b8939.encoder_data.{i}.bin",
|
||||||
|
path=cls.u2net_home(*args, **kwargs),
|
||||||
|
progressbar=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
fbin = os.path.join(
|
||||||
|
cls.u2net_home(*args, **kwargs),
|
||||||
|
f"sam_vit_h_4b8939.encoder_data.{i}.bin",
|
||||||
|
)
|
||||||
|
content.extend(open(fbin, "rb").read())
|
||||||
|
os.remove(fbin)
|
||||||
|
|
||||||
|
with open(
|
||||||
|
os.path.join(
|
||||||
|
cls.u2net_home(*args, **kwargs),
|
||||||
|
"sam_vit_h_4b8939.encoder_data.bin",
|
||||||
|
),
|
||||||
|
"wb",
|
||||||
|
) as fp:
|
||||||
|
fp.write(content)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
os.path.join(cls.u2net_home(*args, **kwargs), fname_encoder),
|
os.path.join(cls.u2net_home(*args, **kwargs), fname_encoder),
|
||||||
os.path.join(cls.u2net_home(*args, **kwargs), fname_decoder),
|
os.path.join(cls.u2net_home(*args, **kwargs), fname_decoder),
|
||||||
|
1
setup.py
@ -12,6 +12,7 @@ here = pathlib.Path(__file__).parent.resolve()
|
|||||||
long_description = (here / "README.md").read_text(encoding="utf-8")
|
long_description = (here / "README.md").read_text(encoding="utf-8")
|
||||||
|
|
||||||
install_requires = [
|
install_requires = [
|
||||||
|
"jsonschema",
|
||||||
"numpy",
|
"numpy",
|
||||||
"onnxruntime",
|
"onnxruntime",
|
||||||
"opencv-python-headless",
|
"opencv-python-headless",
|
||||||
|
BIN
tests/fixtures/plants-1.jpg
vendored
Normal file
After Width: | Height: | Size: 1.4 MiB |
Before Width: | Height: | Size: 24 KiB After Width: | Height: | Size: 21 KiB |
Before Width: | Height: | Size: 78 KiB After Width: | Height: | Size: 70 KiB |
Before Width: | Height: | Size: 104 KiB After Width: | Height: | Size: 165 KiB |
BIN
tests/results/plants-1.isnet-anime.png
Normal file
After Width: | Height: | Size: 510 KiB |
BIN
tests/results/plants-1.isnet-general-use.png
Normal file
After Width: | Height: | Size: 698 KiB |
BIN
tests/results/plants-1.sam.png
Normal file
After Width: | Height: | Size: 29 KiB |
BIN
tests/results/plants-1.silueta.png
Normal file
After Width: | Height: | Size: 608 KiB |
BIN
tests/results/plants-1.u2net.png
Normal file
After Width: | Height: | Size: 572 KiB |
BIN
tests/results/plants-1.u2net_cloth_seg.png
Normal file
After Width: | Height: | Size: 17 KiB |
BIN
tests/results/plants-1.u2net_human_seg.png
Normal file
After Width: | Height: | Size: 341 KiB |
BIN
tests/results/plants-1.u2netp.png
Normal file
After Width: | Height: | Size: 479 KiB |
@ -12,18 +12,19 @@ def test_remove():
|
|||||||
kwargs = {
|
kwargs = {
|
||||||
"sam": {
|
"sam": {
|
||||||
"anime-girl-1" : {
|
"anime-girl-1" : {
|
||||||
"input_points": [[400, 165]],
|
"sam_prompt" :[{"type": "point", "data": [400, 165], "label": 1}],
|
||||||
"input_labels": [1],
|
|
||||||
},
|
},
|
||||||
|
|
||||||
"car-1" : {
|
"car-1" : {
|
||||||
"input_points": [[250, 200]],
|
"sam_prompt" :[{"type": "point", "data": [250, 200], "label": 1}],
|
||||||
"input_labels": [1],
|
|
||||||
},
|
},
|
||||||
|
|
||||||
"cloth-1" : {
|
"cloth-1" : {
|
||||||
"input_points": [[370, 495]],
|
"sam_prompt" :[{"type": "point", "data": [370, 495], "label": 1}],
|
||||||
"input_labels": [1],
|
},
|
||||||
|
|
||||||
|
"plants-1" : {
|
||||||
|
"sam_prompt" :[{"type": "point", "data": [724, 740], "label": 1}],
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -38,7 +39,7 @@ def test_remove():
|
|||||||
"isnet-anime",
|
"isnet-anime",
|
||||||
"sam"
|
"sam"
|
||||||
]:
|
]:
|
||||||
for picture in ["anime-girl-1", "car-1", "cloth-1"]:
|
for picture in ["anime-girl-1", "car-1", "cloth-1", "plants-1"]:
|
||||||
image_path = Path(here / "fixtures" / f"{picture}.jpg")
|
image_path = Path(here / "fixtures" / f"{picture}.jpg")
|
||||||
image = image_path.read_bytes()
|
image = image_path.read_bytes()
|
||||||
|
|
||||||
|