mirror of
https://git.mirrors.martin98.com/https://github.com/danielgatis/rembg
synced 2025-07-09 07:11:51 +08:00
65 lines
1.8 KiB
Python
65 lines
1.8 KiB
Python
from io import BytesIO
|
|
from pathlib import Path
|
|
|
|
from imagehash import phash as hash_img
|
|
from PIL import Image
|
|
|
|
from rembg import new_session, remove
|
|
|
|
here = Path(__file__).parent.resolve()
|
|
|
|
def test_remove():
|
|
kwargs = {
|
|
"sam": {
|
|
"anime-girl-1" : {
|
|
"input_points": [[400, 165]],
|
|
"input_labels": [1],
|
|
},
|
|
|
|
"car-1" : {
|
|
"input_points": [[250, 200]],
|
|
"input_labels": [1],
|
|
},
|
|
|
|
"cloth-1" : {
|
|
"input_points": [[370, 495]],
|
|
"input_labels": [1],
|
|
},
|
|
}
|
|
}
|
|
|
|
for model in [
|
|
"u2net",
|
|
"u2netp",
|
|
"u2net_human_seg",
|
|
"u2net_cloth_seg",
|
|
"silueta",
|
|
"isnet-general-use",
|
|
"isnet-anime",
|
|
"sam"
|
|
]:
|
|
for picture in ["anime-girl-1", "car-1", "cloth-1"]:
|
|
image_path = Path(here / "fixtures" / f"{picture}.jpg")
|
|
image = image_path.read_bytes()
|
|
|
|
actual = remove(image, session=new_session(model), **kwargs.get(model, {}).get(picture, {}))
|
|
actual_hash = hash_img(Image.open(BytesIO(actual)))
|
|
|
|
expected_path = Path(here / "results" / f"{picture}.{model}.png")
|
|
# Uncomment to update the expected results
|
|
# f = open(expected_path, "ab")
|
|
# f.write(actual)
|
|
# f.close()
|
|
|
|
expected = expected_path.read_bytes()
|
|
expected_hash = hash_img(Image.open(BytesIO(expected)))
|
|
|
|
print(f"image_path: {image_path}")
|
|
print(f"expected_path: {expected_path}")
|
|
print(f"actual_hash: {actual_hash}")
|
|
print(f"expected_hash: {expected_hash}")
|
|
print(f"actual_hash == expected_hash: {actual_hash == expected_hash}")
|
|
print("---\n")
|
|
|
|
assert actual_hash == expected_hash
|