fix sam session (#531)

This commit is contained in:
Daniel Gatis 2023-10-25 23:25:12 -03:00 committed by GitHub
parent 342fc54d3a
commit 47701001ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 189 additions and 72 deletions

View File

@ -159,10 +159,14 @@ rembg i -a path/to/input.png path/to/output.png
Passing extras parameters Passing extras parameters
``` ```
rembg i -m sam -x '{"input_labels": [1], "input_points": [[100,100]]}' path/to/input.png path/to/output.png SAM example
rembg i -m sam -x '{ "sam_prompt": [{"type": "point", "data": [724, 740], "label": 1}] }' examples/plants-1.jpg examples/plants-1.out.png
``` ```
``` ```
Custom model example
rembg i -m u2net_custom -x '{"model_path": "~/.u2net/u2net.onnx"}' path/to/input.png path/to/output.png rembg i -m u2net_custom -x '{"model_path": "~/.u2net/u2net.onnx"}' path/to/input.png path/to/output.png
``` ```

BIN
examples/plants-1.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 MiB

BIN
examples/plants-1.out.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 29 KiB

View File

@ -1,9 +1,12 @@
import os import os
from copy import deepcopy
from typing import List from typing import List
import cv2
import numpy as np import numpy as np
import onnxruntime as ort import onnxruntime as ort
import pooch import pooch
from jsonschema import validate
from PIL import Image from PIL import Image
from PIL.Image import Image as PILImage from PIL.Image import Image as PILImage
@ -15,37 +18,58 @@ def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int):
newh, neww = oldh * scale, oldw * scale newh, neww = oldh * scale, oldw * scale
neww = int(neww + 0.5) neww = int(neww + 0.5)
newh = int(newh + 0.5) newh = int(newh + 0.5)
return (newh, neww) return (newh, neww)
def apply_coords(coords: np.ndarray, original_size, target_length) -> np.ndarray: def apply_coords(coords: np.ndarray, original_size, target_length):
old_h, old_w = original_size old_h, old_w = original_size
new_h, new_w = get_preprocess_shape( new_h, new_w = get_preprocess_shape(
original_size[0], original_size[1], target_length original_size[0], original_size[1], target_length
) )
coords = coords.copy().astype(float)
coords = deepcopy(coords).astype(float)
coords[..., 0] = coords[..., 0] * (new_w / old_w) coords[..., 0] = coords[..., 0] * (new_w / old_w)
coords[..., 1] = coords[..., 1] * (new_h / old_h) coords[..., 1] = coords[..., 1] * (new_h / old_h)
return coords return coords
def resize_longes_side(img: PILImage, size=1024): def get_input_points(prompt):
w, h = img.size points = []
if h > w: labels = []
new_h, new_w = size, int(w * size / h)
else:
new_h, new_w = int(h * size / w), size
return img.resize((new_w, new_h)) for mark in prompt:
if mark["type"] == "point":
points.append(mark["data"])
labels.append(mark["label"])
elif mark["type"] == "rectangle":
points.append([mark["data"][0], mark["data"][1]])
points.append([mark["data"][2], mark["data"][3]])
labels.append(2)
labels.append(3)
points, labels = np.array(points), np.array(labels)
return points, labels
def pad_to_square(img: np.ndarray, size=1024): def transform_masks(masks, original_size, transform_matrix):
h, w = img.shape[:2] output_masks = []
padh = size - h
padw = size - w for batch in range(masks.shape[0]):
img = np.pad(img, ((0, padh), (0, padw), (0, 0)), mode="constant") batch_masks = []
img = img.astype(np.float32) for mask_id in range(masks.shape[1]):
return img mask = masks[batch, mask_id]
mask = cv2.warpAffine(
mask,
transform_matrix[:2],
(original_size[1], original_size[0]),
flags=cv2.INTER_LINEAR,
)
batch_masks.append(mask)
output_masks.append(batch_masks)
return np.array(output_masks)
class SamSession(BaseSession): class SamSession(BaseSession):
@ -70,7 +94,7 @@ class SamSession(BaseSession):
**kwargs: Arbitrary keyword arguments. **kwargs: Arbitrary keyword arguments.
""" """
self.model_name = model_name self.model_name = model_name
paths = self.__class__.download_models() paths = self.__class__.download_models(*args, **kwargs)
self.encoder = ort.InferenceSession( self.encoder = ort.InferenceSession(
str(paths[0]), str(paths[0]),
providers=ort.get_available_providers(), providers=ort.get_available_providers(),
@ -85,9 +109,9 @@ class SamSession(BaseSession):
def normalize( def normalize(
self, self,
img: np.ndarray, img: np.ndarray,
mean=(123.675, 116.28, 103.53), mean=(),
std=(58.395, 57.12, 57.375), std=(),
size=(1024, 1024), size=(),
*args, *args,
**kwargs, **kwargs,
): ):
@ -96,19 +120,16 @@ class SamSession(BaseSession):
Args: Args:
img (np.ndarray): The input image. img (np.ndarray): The input image.
mean (tuple, optional): The mean values for normalization. Defaults to (123.675, 116.28, 103.53). mean (tuple, optional): The mean values for normalization. Defaults to ().
std (tuple, optional): The standard deviation values for normalization. Defaults to (58.395, 57.12, 57.375). std (tuple, optional): The standard deviation values for normalization. Defaults to ().
size (tuple, optional): The target size of the image. Defaults to (1024, 1024). size (tuple, optional): The target size of the image. Defaults to ().
*args: Variable length argument list. *args: Variable length argument list.
**kwargs: Arbitrary keyword arguments. **kwargs: Arbitrary keyword arguments.
Returns: Returns:
np.ndarray: The normalized image. np.ndarray: The normalized image.
""" """
pixel_mean = np.array([*mean]).reshape(1, 1, -1) return img
pixel_std = np.array([*std]).reshape(1, 1, -1)
x = (img - pixel_mean) / pixel_std
return x
def predict( def predict(
self, self,
@ -129,36 +150,89 @@ class SamSession(BaseSession):
Returns: Returns:
List[PILImage]: A list of masks generated by the decoder. List[PILImage]: A list of masks generated by the decoder.
""" """
# Preprocess image prompt = kwargs.get("sam_prompt", "{}")
image = resize_longes_side(img) schema = {
image = np.array(image) "type": "array",
image = self.normalize(image) "items": {
image = pad_to_square(image) "type": "object",
"properties": {
"type": {"type": "string"},
"label": {"type": "integer"},
"data": {
"type": "array",
"items": {"type": "number"},
},
},
},
}
input_labels = kwargs.get("input_labels") validate(instance=prompt, schema=schema)
input_points = kwargs.get("input_points")
if input_labels is None: target_size = 1024
raise ValueError("input_labels is required") input_size = (684, 1024)
if input_points is None: encoder_input_name = self.encoder.get_inputs()[0].name
raise ValueError("input_points is required")
# Transpose img = img.convert("RGB")
image = image.transpose(2, 0, 1)[None, :, :, :] cv_image = np.array(img)
# Run encoder (Image embedding) original_size = cv_image.shape[:2]
encoded = self.encoder.run(None, {"x": image})
image_embedding = encoded[0]
# Add a batch index, concatenate a padding point, and transform. scale_x = input_size[1] / cv_image.shape[1]
scale_y = input_size[0] / cv_image.shape[0]
scale = min(scale_x, scale_y)
transform_matrix = np.array(
[
[scale, 0, 0],
[0, scale, 0],
[0, 0, 1],
]
)
cv_image = cv2.warpAffine(
cv_image,
transform_matrix[:2],
(input_size[1], input_size[0]),
flags=cv2.INTER_LINEAR,
)
## encoder
encoder_inputs = {
encoder_input_name: cv_image.astype(np.float32),
}
encoder_output = self.encoder.run(None, encoder_inputs)
image_embedding = encoder_output[0]
embedding = {
"image_embedding": image_embedding,
"original_size": original_size,
"transform_matrix": transform_matrix,
}
## decoder
input_points, input_labels = get_input_points(prompt)
onnx_coord = np.concatenate([input_points, np.array([[0.0, 0.0]])], axis=0)[ onnx_coord = np.concatenate([input_points, np.array([[0.0, 0.0]])], axis=0)[
None, :, : None, :, :
] ]
onnx_label = np.concatenate([input_labels, np.array([-1])], axis=0)[ onnx_label = np.concatenate([input_labels, np.array([-1])], axis=0)[
None, : None, :
].astype(np.float32) ].astype(np.float32)
onnx_coord = apply_coords(onnx_coord, img.size[::1], 1024).astype(np.float32) onnx_coord = apply_coords(onnx_coord, input_size, target_size).astype(
np.float32
)
onnx_coord = np.concatenate(
[
onnx_coord,
np.ones((1, onnx_coord.shape[1], 1), dtype=np.float32),
],
axis=2,
)
onnx_coord = np.matmul(onnx_coord, transform_matrix.T)
onnx_coord = onnx_coord[:, :, :2].astype(np.float32)
# Create an empty mask input and an indicator for no mask.
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32) onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
onnx_has_mask_input = np.zeros(1, dtype=np.float32) onnx_has_mask_input = np.zeros(1, dtype=np.float32)
@ -168,17 +242,19 @@ class SamSession(BaseSession):
"point_labels": onnx_label, "point_labels": onnx_label,
"mask_input": onnx_mask_input, "mask_input": onnx_mask_input,
"has_mask_input": onnx_has_mask_input, "has_mask_input": onnx_has_mask_input,
"orig_im_size": np.array(img.size[::-1], dtype=np.float32), "orig_im_size": np.array(input_size, dtype=np.float32),
} }
masks, _, low_res_logits = self.decoder.run(None, decoder_inputs) masks, _, _ = self.decoder.run(None, decoder_inputs)
masks = masks > 0.0 inv_transform_matrix = np.linalg.inv(transform_matrix)
masks = [ masks = transform_masks(masks, original_size, inv_transform_matrix)
Image.fromarray((masks[i, 0] * 255).astype(np.uint8))
for i in range(masks.shape[0])
]
return masks mask = np.zeros((masks.shape[2], masks.shape[3], 3), dtype=np.uint8)
for m in masks[0, :, :, :]:
mask[m > 0.0] = [255, 255, 255]
mask = Image.fromarray(mask).convert("L")
return [mask]
@classmethod @classmethod
def download_models(cls, *args, **kwargs): def download_models(cls, *args, **kwargs):
@ -195,29 +271,64 @@ class SamSession(BaseSession):
Returns: Returns:
tuple: A tuple containing the file paths of the downloaded encoder and decoder models. tuple: A tuple containing the file paths of the downloaded encoder and decoder models.
""" """
fname_encoder = f"{cls.name(*args, **kwargs)}_encoder.onnx" model_name = kwargs.get("sam_model", "sam_vit_b_01ec64")
fname_decoder = f"{cls.name(*args, **kwargs)}_decoder.onnx" quant = kwargs.get("sam_quant", False)
fname_encoder = f"{model_name}.encoder.onnx"
fname_decoder = f"{model_name}.decoder.onnx"
if quant:
fname_encoder = f"{model_name}.encoder.quant.onnx"
fname_decoder = f"{model_name}.decoder.quant.onnx"
pooch.retrieve( pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-encoder-quant.onnx", f"https://github.com/danielgatis/rembg/releases/download/v0.0.0/{fname_encoder}",
None None,
if cls.checksum_disabled(*args, **kwargs)
else "md5:13d97c5c79ab13ef86d67cbde5f1b250",
fname=fname_encoder, fname=fname_encoder,
path=cls.u2net_home(*args, **kwargs), path=cls.u2net_home(*args, **kwargs),
progressbar=True, progressbar=True,
) )
pooch.retrieve( pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-decoder-quant.onnx", f"https://github.com/danielgatis/rembg/releases/download/v0.0.0/{fname_decoder}",
None None,
if cls.checksum_disabled(*args, **kwargs)
else "md5:fa3d1c36a3187d3de1c8deebf33dd127",
fname=fname_decoder, fname=fname_decoder,
path=cls.u2net_home(*args, **kwargs), path=cls.u2net_home(*args, **kwargs),
progressbar=True, progressbar=True,
) )
if fname_encoder == "sam_vit_h_4b8939.encoder.onnx" and not os.path.exists(
os.path.join(
cls.u2net_home(*args, **kwargs), "sam_vit_h_4b8939.encoder_data.bin"
)
):
content = bytearray()
for i in range(1, 4):
pooch.retrieve(
f"https://github.com/danielgatis/rembg/releases/download/v0.0.0/sam_vit_h_4b8939.encoder_data.{i}.bin",
None,
fname=f"sam_vit_h_4b8939.encoder_data.{i}.bin",
path=cls.u2net_home(*args, **kwargs),
progressbar=True,
)
fbin = os.path.join(
cls.u2net_home(*args, **kwargs),
f"sam_vit_h_4b8939.encoder_data.{i}.bin",
)
content.extend(open(fbin, "rb").read())
os.remove(fbin)
with open(
os.path.join(
cls.u2net_home(*args, **kwargs),
"sam_vit_h_4b8939.encoder_data.bin",
),
"wb",
) as fp:
fp.write(content)
return ( return (
os.path.join(cls.u2net_home(*args, **kwargs), fname_encoder), os.path.join(cls.u2net_home(*args, **kwargs), fname_encoder),
os.path.join(cls.u2net_home(*args, **kwargs), fname_decoder), os.path.join(cls.u2net_home(*args, **kwargs), fname_decoder),

View File

@ -12,6 +12,7 @@ here = pathlib.Path(__file__).parent.resolve()
long_description = (here / "README.md").read_text(encoding="utf-8") long_description = (here / "README.md").read_text(encoding="utf-8")
install_requires = [ install_requires = [
"jsonschema",
"numpy", "numpy",
"onnxruntime", "onnxruntime",
"opencv-python-headless", "opencv-python-headless",

BIN
tests/fixtures/plants-1.jpg vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 24 KiB

After

Width:  |  Height:  |  Size: 21 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 78 KiB

After

Width:  |  Height:  |  Size: 70 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 104 KiB

After

Width:  |  Height:  |  Size: 165 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 510 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 698 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 29 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 608 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 572 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 17 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 341 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 479 KiB

View File

@ -12,18 +12,19 @@ def test_remove():
kwargs = { kwargs = {
"sam": { "sam": {
"anime-girl-1" : { "anime-girl-1" : {
"input_points": [[400, 165]], "sam_prompt" :[{"type": "point", "data": [400, 165], "label": 1}],
"input_labels": [1],
}, },
"car-1" : { "car-1" : {
"input_points": [[250, 200]], "sam_prompt" :[{"type": "point", "data": [250, 200], "label": 1}],
"input_labels": [1],
}, },
"cloth-1" : { "cloth-1" : {
"input_points": [[370, 495]], "sam_prompt" :[{"type": "point", "data": [370, 495], "label": 1}],
"input_labels": [1], },
"plants-1" : {
"sam_prompt" :[{"type": "point", "data": [724, 740], "label": 1}],
}, },
} }
} }
@ -38,7 +39,7 @@ def test_remove():
"isnet-anime", "isnet-anime",
"sam" "sam"
]: ]:
for picture in ["anime-girl-1", "car-1", "cloth-1"]: for picture in ["anime-girl-1", "car-1", "cloth-1", "plants-1"]:
image_path = Path(here / "fixtures" / f"{picture}.jpg") image_path = Path(here / "fixtures" / f"{picture}.jpg")
image = image_path.read_bytes() image = image_path.read_bytes()