fix sam session (#531)

This commit is contained in:
Daniel Gatis 2023-10-25 23:25:12 -03:00 committed by GitHub
parent 342fc54d3a
commit 47701001ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 189 additions and 72 deletions

View File

@ -159,10 +159,14 @@ rembg i -a path/to/input.png path/to/output.png
Passing extras parameters
```
rembg i -m sam -x '{"input_labels": [1], "input_points": [[100,100]]}' path/to/input.png path/to/output.png
SAM example
rembg i -m sam -x '{ "sam_prompt": [{"type": "point", "data": [724, 740], "label": 1}] }' examples/plants-1.jpg examples/plants-1.out.png
```
```
Custom model example
rembg i -m u2net_custom -x '{"model_path": "~/.u2net/u2net.onnx"}' path/to/input.png path/to/output.png
```

BIN
examples/plants-1.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 MiB

BIN
examples/plants-1.out.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 29 KiB

View File

@ -1,9 +1,12 @@
import os
from copy import deepcopy
from typing import List
import cv2
import numpy as np
import onnxruntime as ort
import pooch
from jsonschema import validate
from PIL import Image
from PIL.Image import Image as PILImage
@ -15,37 +18,58 @@ def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int):
newh, neww = oldh * scale, oldw * scale
neww = int(neww + 0.5)
newh = int(newh + 0.5)
return (newh, neww)
def apply_coords(coords: np.ndarray, original_size, target_length) -> np.ndarray:
def apply_coords(coords: np.ndarray, original_size, target_length):
old_h, old_w = original_size
new_h, new_w = get_preprocess_shape(
original_size[0], original_size[1], target_length
)
coords = coords.copy().astype(float)
coords = deepcopy(coords).astype(float)
coords[..., 0] = coords[..., 0] * (new_w / old_w)
coords[..., 1] = coords[..., 1] * (new_h / old_h)
return coords
def resize_longes_side(img: PILImage, size=1024):
w, h = img.size
if h > w:
new_h, new_w = size, int(w * size / h)
else:
new_h, new_w = int(h * size / w), size
def get_input_points(prompt):
points = []
labels = []
return img.resize((new_w, new_h))
for mark in prompt:
if mark["type"] == "point":
points.append(mark["data"])
labels.append(mark["label"])
elif mark["type"] == "rectangle":
points.append([mark["data"][0], mark["data"][1]])
points.append([mark["data"][2], mark["data"][3]])
labels.append(2)
labels.append(3)
points, labels = np.array(points), np.array(labels)
return points, labels
def pad_to_square(img: np.ndarray, size=1024):
h, w = img.shape[:2]
padh = size - h
padw = size - w
img = np.pad(img, ((0, padh), (0, padw), (0, 0)), mode="constant")
img = img.astype(np.float32)
return img
def transform_masks(masks, original_size, transform_matrix):
output_masks = []
for batch in range(masks.shape[0]):
batch_masks = []
for mask_id in range(masks.shape[1]):
mask = masks[batch, mask_id]
mask = cv2.warpAffine(
mask,
transform_matrix[:2],
(original_size[1], original_size[0]),
flags=cv2.INTER_LINEAR,
)
batch_masks.append(mask)
output_masks.append(batch_masks)
return np.array(output_masks)
class SamSession(BaseSession):
@ -70,7 +94,7 @@ class SamSession(BaseSession):
**kwargs: Arbitrary keyword arguments.
"""
self.model_name = model_name
paths = self.__class__.download_models()
paths = self.__class__.download_models(*args, **kwargs)
self.encoder = ort.InferenceSession(
str(paths[0]),
providers=ort.get_available_providers(),
@ -85,9 +109,9 @@ class SamSession(BaseSession):
def normalize(
self,
img: np.ndarray,
mean=(123.675, 116.28, 103.53),
std=(58.395, 57.12, 57.375),
size=(1024, 1024),
mean=(),
std=(),
size=(),
*args,
**kwargs,
):
@ -96,19 +120,16 @@ class SamSession(BaseSession):
Args:
img (np.ndarray): The input image.
mean (tuple, optional): The mean values for normalization. Defaults to (123.675, 116.28, 103.53).
std (tuple, optional): The standard deviation values for normalization. Defaults to (58.395, 57.12, 57.375).
size (tuple, optional): The target size of the image. Defaults to (1024, 1024).
mean (tuple, optional): The mean values for normalization. Defaults to ().
std (tuple, optional): The standard deviation values for normalization. Defaults to ().
size (tuple, optional): The target size of the image. Defaults to ().
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
np.ndarray: The normalized image.
"""
pixel_mean = np.array([*mean]).reshape(1, 1, -1)
pixel_std = np.array([*std]).reshape(1, 1, -1)
x = (img - pixel_mean) / pixel_std
return x
return img
def predict(
self,
@ -129,36 +150,89 @@ class SamSession(BaseSession):
Returns:
List[PILImage]: A list of masks generated by the decoder.
"""
# Preprocess image
image = resize_longes_side(img)
image = np.array(image)
image = self.normalize(image)
image = pad_to_square(image)
prompt = kwargs.get("sam_prompt", "{}")
schema = {
"type": "array",
"items": {
"type": "object",
"properties": {
"type": {"type": "string"},
"label": {"type": "integer"},
"data": {
"type": "array",
"items": {"type": "number"},
},
},
},
}
input_labels = kwargs.get("input_labels")
input_points = kwargs.get("input_points")
validate(instance=prompt, schema=schema)
if input_labels is None:
raise ValueError("input_labels is required")
if input_points is None:
raise ValueError("input_points is required")
target_size = 1024
input_size = (684, 1024)
encoder_input_name = self.encoder.get_inputs()[0].name
# Transpose
image = image.transpose(2, 0, 1)[None, :, :, :]
# Run encoder (Image embedding)
encoded = self.encoder.run(None, {"x": image})
image_embedding = encoded[0]
img = img.convert("RGB")
cv_image = np.array(img)
original_size = cv_image.shape[:2]
# Add a batch index, concatenate a padding point, and transform.
scale_x = input_size[1] / cv_image.shape[1]
scale_y = input_size[0] / cv_image.shape[0]
scale = min(scale_x, scale_y)
transform_matrix = np.array(
[
[scale, 0, 0],
[0, scale, 0],
[0, 0, 1],
]
)
cv_image = cv2.warpAffine(
cv_image,
transform_matrix[:2],
(input_size[1], input_size[0]),
flags=cv2.INTER_LINEAR,
)
## encoder
encoder_inputs = {
encoder_input_name: cv_image.astype(np.float32),
}
encoder_output = self.encoder.run(None, encoder_inputs)
image_embedding = encoder_output[0]
embedding = {
"image_embedding": image_embedding,
"original_size": original_size,
"transform_matrix": transform_matrix,
}
## decoder
input_points, input_labels = get_input_points(prompt)
onnx_coord = np.concatenate([input_points, np.array([[0.0, 0.0]])], axis=0)[
None, :, :
]
onnx_label = np.concatenate([input_labels, np.array([-1])], axis=0)[
None, :
].astype(np.float32)
onnx_coord = apply_coords(onnx_coord, img.size[::1], 1024).astype(np.float32)
onnx_coord = apply_coords(onnx_coord, input_size, target_size).astype(
np.float32
)
onnx_coord = np.concatenate(
[
onnx_coord,
np.ones((1, onnx_coord.shape[1], 1), dtype=np.float32),
],
axis=2,
)
onnx_coord = np.matmul(onnx_coord, transform_matrix.T)
onnx_coord = onnx_coord[:, :, :2].astype(np.float32)
# Create an empty mask input and an indicator for no mask.
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
onnx_has_mask_input = np.zeros(1, dtype=np.float32)
@ -168,17 +242,19 @@ class SamSession(BaseSession):
"point_labels": onnx_label,
"mask_input": onnx_mask_input,
"has_mask_input": onnx_has_mask_input,
"orig_im_size": np.array(img.size[::-1], dtype=np.float32),
"orig_im_size": np.array(input_size, dtype=np.float32),
}
masks, _, low_res_logits = self.decoder.run(None, decoder_inputs)
masks = masks > 0.0
masks = [
Image.fromarray((masks[i, 0] * 255).astype(np.uint8))
for i in range(masks.shape[0])
]
masks, _, _ = self.decoder.run(None, decoder_inputs)
inv_transform_matrix = np.linalg.inv(transform_matrix)
masks = transform_masks(masks, original_size, inv_transform_matrix)
return masks
mask = np.zeros((masks.shape[2], masks.shape[3], 3), dtype=np.uint8)
for m in masks[0, :, :, :]:
mask[m > 0.0] = [255, 255, 255]
mask = Image.fromarray(mask).convert("L")
return [mask]
@classmethod
def download_models(cls, *args, **kwargs):
@ -195,29 +271,64 @@ class SamSession(BaseSession):
Returns:
tuple: A tuple containing the file paths of the downloaded encoder and decoder models.
"""
fname_encoder = f"{cls.name(*args, **kwargs)}_encoder.onnx"
fname_decoder = f"{cls.name(*args, **kwargs)}_decoder.onnx"
model_name = kwargs.get("sam_model", "sam_vit_b_01ec64")
quant = kwargs.get("sam_quant", False)
fname_encoder = f"{model_name}.encoder.onnx"
fname_decoder = f"{model_name}.decoder.onnx"
if quant:
fname_encoder = f"{model_name}.encoder.quant.onnx"
fname_decoder = f"{model_name}.decoder.quant.onnx"
pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-encoder-quant.onnx",
None
if cls.checksum_disabled(*args, **kwargs)
else "md5:13d97c5c79ab13ef86d67cbde5f1b250",
f"https://github.com/danielgatis/rembg/releases/download/v0.0.0/{fname_encoder}",
None,
fname=fname_encoder,
path=cls.u2net_home(*args, **kwargs),
progressbar=True,
)
pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-decoder-quant.onnx",
None
if cls.checksum_disabled(*args, **kwargs)
else "md5:fa3d1c36a3187d3de1c8deebf33dd127",
f"https://github.com/danielgatis/rembg/releases/download/v0.0.0/{fname_decoder}",
None,
fname=fname_decoder,
path=cls.u2net_home(*args, **kwargs),
progressbar=True,
)
if fname_encoder == "sam_vit_h_4b8939.encoder.onnx" and not os.path.exists(
os.path.join(
cls.u2net_home(*args, **kwargs), "sam_vit_h_4b8939.encoder_data.bin"
)
):
content = bytearray()
for i in range(1, 4):
pooch.retrieve(
f"https://github.com/danielgatis/rembg/releases/download/v0.0.0/sam_vit_h_4b8939.encoder_data.{i}.bin",
None,
fname=f"sam_vit_h_4b8939.encoder_data.{i}.bin",
path=cls.u2net_home(*args, **kwargs),
progressbar=True,
)
fbin = os.path.join(
cls.u2net_home(*args, **kwargs),
f"sam_vit_h_4b8939.encoder_data.{i}.bin",
)
content.extend(open(fbin, "rb").read())
os.remove(fbin)
with open(
os.path.join(
cls.u2net_home(*args, **kwargs),
"sam_vit_h_4b8939.encoder_data.bin",
),
"wb",
) as fp:
fp.write(content)
return (
os.path.join(cls.u2net_home(*args, **kwargs), fname_encoder),
os.path.join(cls.u2net_home(*args, **kwargs), fname_decoder),

View File

@ -12,6 +12,7 @@ here = pathlib.Path(__file__).parent.resolve()
long_description = (here / "README.md").read_text(encoding="utf-8")
install_requires = [
"jsonschema",
"numpy",
"onnxruntime",
"opencv-python-headless",

BIN
tests/fixtures/plants-1.jpg vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 24 KiB

After

Width:  |  Height:  |  Size: 21 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 78 KiB

After

Width:  |  Height:  |  Size: 70 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 104 KiB

After

Width:  |  Height:  |  Size: 165 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 510 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 698 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 29 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 608 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 572 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 17 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 341 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 479 KiB

View File

@ -12,18 +12,19 @@ def test_remove():
kwargs = {
"sam": {
"anime-girl-1" : {
"input_points": [[400, 165]],
"input_labels": [1],
"sam_prompt" :[{"type": "point", "data": [400, 165], "label": 1}],
},
"car-1" : {
"input_points": [[250, 200]],
"input_labels": [1],
"sam_prompt" :[{"type": "point", "data": [250, 200], "label": 1}],
},
"cloth-1" : {
"input_points": [[370, 495]],
"input_labels": [1],
"sam_prompt" :[{"type": "point", "data": [370, 495], "label": 1}],
},
"plants-1" : {
"sam_prompt" :[{"type": "point", "data": [724, 740], "label": 1}],
},
}
}
@ -38,7 +39,7 @@ def test_remove():
"isnet-anime",
"sam"
]:
for picture in ["anime-girl-1", "car-1", "cloth-1"]:
for picture in ["anime-girl-1", "car-1", "cloth-1", "plants-1"]:
image_path = Path(here / "fixtures" / f"{picture}.jpg")
image = image_path.read_bytes()