refactoring

This commit is contained in:
Daniel Gatis 2023-04-20 21:39:13 -03:00
parent 1e311331e6
commit 1ca14ce058
25 changed files with 966 additions and 681 deletions

View File

@ -146,6 +146,12 @@ Remove the background applying an alpha matting
rembg i -a path/to/input.png path/to/output.png
```
Passing extras parameters
```
rembg i -m sam -x '{"input_labels": [1], "input_points": [[100,100]]}' path/to/input.png path/to/output.png
```
### rembg `p`
Used when input and output are folders.
@ -266,6 +272,7 @@ The available models are:
- u2net_cloth_seg ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx), [source](https://github.com/levindabhi/cloth-segmentation)): A pre-trained model for Cloths Parsing from human portrait. Here clothes are parsed into 3 category: Upper body, Lower body and Full body.
- silueta ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx), [source](https://github.com/xuebinqin/U-2-Net/issues/295)): Same as u2net but the size is reduced to 43Mb.
- isnet-general-use ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx), [source](https://github.com/xuebinqin/DIS)): A new pre-trained model for general use cases.
- sam ([encoder](https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-encoder-quant.onnx), [decoder](https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-decoder-quant.onnx), [source](https://github.com/facebookresearch/segment-anything)): The Segment Anything Model (SAM).
### Some differences between the models result
@ -278,6 +285,7 @@ The available models are:
<th>u2net_cloth_seg</th>
<th>silueta</th>
<th>isnet-general-use</th>
<th>sam</th>
</tr>
<tr>
<th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/fixtures/car-1.jpg" width="100" /></th>
@ -287,6 +295,7 @@ The available models are:
<th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/car-1.u2net_cloth_seg.png" width="100" /></th>
<th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/car-1.silueta.png" width="100" /></th>
<th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/car-1.isnet-general-use.png" width="100" /></th>
<th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/car-1.sam.png" width="100" /></th>
</tr>
<th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/fixtures/cloth-1.jpg" width="100" /></th>
<th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/cloth-1.u2net.png" width="100" /></th>
@ -295,6 +304,7 @@ The available models are:
<th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/cloth-1.u2net_cloth_seg.png" width="100" /></th>
<th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/cloth-1.silueta.png" width="100" /></th>
<th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/cloth-1.isnet-general-use.png" width="100" /></th>
<th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/cloth-1.sam.png" width="100" /></th>
</tr>
</table>

View File

@ -1,6 +1,6 @@
import io
from enum import Enum
from typing import List, Optional, Tuple, Union
from typing import Any, List, Optional, Tuple, Union
import numpy as np
from cv2 import (
@ -18,9 +18,8 @@ from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
from pymatting.util.util import stack_images
from scipy.ndimage import binary_erosion
from .session_base import BaseSession
from .session_factory import new_session
from .session_sam import SamSession
from .sessions.base import BaseSession
kernel = getStructuringElement(MORPH_ELLIPSE, (3, 3))
@ -120,12 +119,12 @@ def remove(
alpha_matting_foreground_threshold: int = 240,
alpha_matting_background_threshold: int = 10,
alpha_matting_erode_size: int = 10,
session: Optional[Union[BaseSession, SamSession]] = None,
session: Optional[BaseSession] = None,
only_mask: bool = False,
post_process_mask: bool = False,
bgcolor: Optional[Tuple[int, int, int, int]] = None,
input_point: Optional[np.ndarray] = None,
input_label: Optional[np.ndarray] = None,
*args: Optional[Any],
**kwargs: Optional[Any]
) -> Union[bytes, PILImage, np.ndarray]:
if isinstance(data, PILImage):
return_type = ReturnType.PILLOW
@ -140,15 +139,9 @@ def remove(
raise ValueError("Input type {} is not supported.".format(type(data)))
if session is None:
session = new_session("u2net")
if isinstance(session, SamSession):
if input_point is None or input_label is None:
raise ValueError("Input point and label are required for SAM model.")
masks = session.predict_sam(img, input_point, input_label)
else:
masks = session.predict(img)
session = new_session("u2net", *args, **kwargs)
masks = session.predict(img, *args, **kwargs)
cutouts = []
for mask in masks:

View File

@ -1,25 +1,7 @@
import pathlib
import sys
import time
from enum import Enum
from typing import IO, Optional, Tuple, cast
import aiohttp
import click
import filetype
import uvicorn
from asyncer import asyncify
from fastapi import Depends, FastAPI, File, Form, Query
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import Response
from tqdm import tqdm
from watchdog.events import FileSystemEvent, FileSystemEventHandler
from watchdog.observers import Observer
from . import _version
from .bg import remove
from .session_base import BaseSession
from .session_factory import new_session
from .commands import command_functions
@click.group()
@ -28,457 +10,5 @@ def main() -> None:
pass
@main.command(help="for a file as input")
@click.option(
"-m",
"--model",
default="u2net",
type=click.Choice(
[
"u2net",
"u2netp",
"u2net_human_seg",
"u2net_cloth_seg",
"silueta",
"isnet-general-use",
]
),
show_default=True,
show_choices=True,
help="model name",
)
@click.option(
"-a",
"--alpha-matting",
is_flag=True,
show_default=True,
help="use alpha matting",
)
@click.option(
"-af",
"--alpha-matting-foreground-threshold",
default=240,
type=int,
show_default=True,
help="trimap fg threshold",
)
@click.option(
"-ab",
"--alpha-matting-background-threshold",
default=10,
type=int,
show_default=True,
help="trimap bg threshold",
)
@click.option(
"-ae",
"--alpha-matting-erode-size",
default=10,
type=int,
show_default=True,
help="erode size",
)
@click.option(
"-om",
"--only-mask",
is_flag=True,
show_default=True,
help="output only the mask",
)
@click.option(
"-ppm",
"--post-process-mask",
is_flag=True,
show_default=True,
help="post process the mask",
)
@click.option(
"-bgc",
"--bgcolor",
default=None,
type=(int, int, int, int),
nargs=4,
help="Background color (R G B A) to replace the removed background with",
)
@click.argument(
"input", default=(None if sys.stdin.isatty() else "-"), type=click.File("rb")
)
@click.argument(
"output",
default=(None if sys.stdin.isatty() else "-"),
type=click.File("wb", lazy=True),
)
def i(model: str, input: IO, output: IO, **kwargs) -> None:
output.write(remove(input.read(), session=new_session(model), **kwargs))
@main.command(help="for a folder as input")
@click.option(
"-m",
"--model",
default="u2net",
type=click.Choice(
[
"u2net",
"u2netp",
"u2net_human_seg",
"u2net_cloth_seg",
"silueta",
"isnet-general-use",
]
),
show_default=True,
show_choices=True,
help="model name",
)
@click.option(
"-a",
"--alpha-matting",
is_flag=True,
show_default=True,
help="use alpha matting",
)
@click.option(
"-af",
"--alpha-matting-foreground-threshold",
default=240,
type=int,
show_default=True,
help="trimap fg threshold",
)
@click.option(
"-ab",
"--alpha-matting-background-threshold",
default=10,
type=int,
show_default=True,
help="trimap bg threshold",
)
@click.option(
"-ae",
"--alpha-matting-erode-size",
default=10,
type=int,
show_default=True,
help="erode size",
)
@click.option(
"-om",
"--only-mask",
is_flag=True,
show_default=True,
help="output only the mask",
)
@click.option(
"-ppm",
"--post-process-mask",
is_flag=True,
show_default=True,
help="post process the mask",
)
@click.option(
"-w",
"--watch",
default=False,
is_flag=True,
show_default=True,
help="watches a folder for changes",
)
@click.option(
"-bgc",
"--bgcolor",
default=None,
type=(int, int, int, int),
nargs=4,
help="Background color (R G B A) to replace the removed background with",
)
@click.argument(
"input",
type=click.Path(
exists=True,
path_type=pathlib.Path,
file_okay=False,
dir_okay=True,
readable=True,
),
)
@click.argument(
"output",
type=click.Path(
exists=False,
path_type=pathlib.Path,
file_okay=False,
dir_okay=True,
writable=True,
),
)
def p(
model: str, input: pathlib.Path, output: pathlib.Path, watch: bool, **kwargs
) -> None:
session = new_session(model)
def process(each_input: pathlib.Path) -> None:
try:
mimetype = filetype.guess(each_input)
if mimetype is None:
return
if mimetype.mime.find("image") < 0:
return
each_output = (output / each_input.name).with_suffix(".png")
each_output.parents[0].mkdir(parents=True, exist_ok=True)
if not each_output.exists():
each_output.write_bytes(
cast(
bytes,
remove(each_input.read_bytes(), session=session, **kwargs),
)
)
if watch:
print(
f"processed: {each_input.absolute()} -> {each_output.absolute()}"
)
except Exception as e:
print(e)
inputs = list(input.glob("**/*"))
if not watch:
inputs = tqdm(inputs)
for each_input in inputs:
if not each_input.is_dir():
process(each_input)
if watch:
observer = Observer()
class EventHandler(FileSystemEventHandler):
def on_any_event(self, event: FileSystemEvent) -> None:
if not (
event.is_directory or event.event_type in ["deleted", "closed"]
):
process(pathlib.Path(event.src_path))
event_handler = EventHandler()
observer.schedule(event_handler, input, recursive=False)
observer.start()
try:
while True:
time.sleep(1)
finally:
observer.stop()
observer.join()
@main.command(help="for a http server")
@click.option(
"-p",
"--port",
default=5000,
type=int,
show_default=True,
help="port",
)
@click.option(
"-l",
"--log_level",
default="info",
type=str,
show_default=True,
help="log level",
)
@click.option(
"-t",
"--threads",
default=None,
type=int,
show_default=True,
help="number of worker threads",
)
def s(port: int, log_level: str, threads: int) -> None:
sessions: dict[str, BaseSession] = {}
tags_metadata = [
{
"name": "Background Removal",
"description": "Endpoints that perform background removal with different image sources.",
"externalDocs": {
"description": "GitHub Source",
"url": "https://github.com/danielgatis/rembg",
},
},
]
app = FastAPI(
title="Rembg",
description="Rembg is a tool to remove images background. That is it.",
version=_version.get_versions()["version"],
contact={
"name": "Daniel Gatis",
"url": "https://github.com/danielgatis",
"email": "danielgatis@gmail.com",
},
license_info={
"name": "MIT License",
"url": "https://github.com/danielgatis/rembg/blob/main/LICENSE.txt",
},
openapi_tags=tags_metadata,
)
app.add_middleware(
CORSMiddleware,
allow_credentials=True,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
class ModelType(str, Enum):
u2net = "u2net"
u2netp = "u2netp"
u2net_human_seg = "u2net_human_seg"
u2net_cloth_seg = "u2net_cloth_seg"
silueta = "silueta"
isnet_general_use = "isnet-general-use"
class CommonQueryParams:
def __init__(
self,
model: ModelType = Query(
default=ModelType.u2net,
description="Model to use when processing image",
),
a: bool = Query(default=False, description="Enable Alpha Matting"),
af: int = Query(
default=240,
ge=0,
le=255,
description="Alpha Matting (Foreground Threshold)",
),
ab: int = Query(
default=10,
ge=0,
le=255,
description="Alpha Matting (Background Threshold)",
),
ae: int = Query(
default=10, ge=0, description="Alpha Matting (Erode Structure Size)"
),
om: bool = Query(default=False, description="Only Mask"),
ppm: bool = Query(default=False, description="Post Process Mask"),
bgc: Optional[str] = Query(default=None, description="Background Color"),
):
self.model = model
self.a = a
self.af = af
self.ab = ab
self.ae = ae
self.om = om
self.ppm = ppm
self.bgc = (
cast(Tuple[int, int, int, int], tuple(map(int, bgc.split(","))))
if bgc
else None
)
class CommonQueryPostParams:
def __init__(
self,
model: ModelType = Form(
default=ModelType.u2net,
description="Model to use when processing image",
),
a: bool = Form(default=False, description="Enable Alpha Matting"),
af: int = Form(
default=240,
ge=0,
le=255,
description="Alpha Matting (Foreground Threshold)",
),
ab: int = Form(
default=10,
ge=0,
le=255,
description="Alpha Matting (Background Threshold)",
),
ae: int = Form(
default=10, ge=0, description="Alpha Matting (Erode Structure Size)"
),
om: bool = Form(default=False, description="Only Mask"),
ppm: bool = Form(default=False, description="Post Process Mask"),
bgc: Optional[str] = Query(default=None, description="Background Color"),
):
self.model = model
self.a = a
self.af = af
self.ab = ab
self.ae = ae
self.om = om
self.ppm = ppm
self.bgc = (
cast(Tuple[int, int, int, int], tuple(map(int, bgc.split(","))))
if bgc
else None
)
def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
return Response(
remove(
content,
session=sessions.setdefault(
commons.model.value, new_session(commons.model.value)
),
alpha_matting=commons.a,
alpha_matting_foreground_threshold=commons.af,
alpha_matting_background_threshold=commons.ab,
alpha_matting_erode_size=commons.ae,
only_mask=commons.om,
post_process_mask=commons.ppm,
bgcolor=commons.bgc,
),
media_type="image/png",
)
@app.on_event("startup")
def startup():
if threads is not None:
from anyio import CapacityLimiter
from anyio.lowlevel import RunVar
RunVar("_default_thread_limiter").set(CapacityLimiter(threads))
@app.get(
path="/",
tags=["Background Removal"],
summary="Remove from URL",
description="Removes the background from an image obtained by retrieving an URL.",
)
async def get_index(
url: str = Query(
default=..., description="URL of the image that has to be processed."
),
commons: CommonQueryParams = Depends(),
):
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
file = await response.read()
return await asyncify(im_without_bg)(file, commons)
@app.post(
path="/",
tags=["Background Removal"],
summary="Remove from Stream",
description="Removes the background from an image sent within the request itself.",
)
async def post_index(
file: bytes = File(
default=...,
description="Image file (byte stream) that has to be processed.",
),
commons: CommonQueryPostParams = Depends(),
):
return await asyncify(im_without_bg)(file, commons) # type: ignore
uvicorn.run(app, host="0.0.0.0", port=port, log_level=log_level)
for command in command_functions:
main.add_command(command)

View File

@ -0,0 +1,13 @@
from importlib import import_module
from pathlib import Path
from pkgutil import iter_modules
command_functions = []
package_dir = Path(__file__).resolve().parent
for _b, module_name, _p in iter_modules([str(package_dir)]):
module = import_module(f"{__name__}.{module_name}")
for attribute_name in dir(module):
attribute = getattr(module, attribute_name)
if attribute_name.endswith("_command"):
command_functions.append(attribute)

View File

@ -0,0 +1,93 @@
import json
import sys
from typing import IO
import click
from ..bg import remove
from ..session_factory import new_session
from ..sessions import sessions_names
@click.command(
name="i",
help="for a file as input",
)
@click.option(
"-m",
"--model",
default="u2net",
type=click.Choice(sessions_names),
show_default=True,
show_choices=True,
help="model name",
)
@click.option(
"-a",
"--alpha-matting",
is_flag=True,
show_default=True,
help="use alpha matting",
)
@click.option(
"-af",
"--alpha-matting-foreground-threshold",
default=240,
type=int,
show_default=True,
help="trimap fg threshold",
)
@click.option(
"-ab",
"--alpha-matting-background-threshold",
default=10,
type=int,
show_default=True,
help="trimap bg threshold",
)
@click.option(
"-ae",
"--alpha-matting-erode-size",
default=10,
type=int,
show_default=True,
help="erode size",
)
@click.option(
"-om",
"--only-mask",
is_flag=True,
show_default=True,
help="output only the mask",
)
@click.option(
"-ppm",
"--post-process-mask",
is_flag=True,
show_default=True,
help="post process the mask",
)
@click.option(
"-bgc",
"--bgcolor",
default=None,
type=(int, int, int, int),
nargs=4,
help="Background color (R G B A) to replace the removed background with",
)
@click.option("-x", "--extras", type=str)
@click.argument(
"input", default=(None if sys.stdin.isatty() else "-"), type=click.File("rb")
)
@click.argument(
"output",
default=(None if sys.stdin.isatty() else "-"),
type=click.File("wb", lazy=True),
)
def i_command(model: str, extras: str, input: IO, output: IO, **kwargs) -> None:
try:
kwargs.update(json.loads(extras))
except Exception:
pass
output.write(remove(input.read(), session=new_session(model), **kwargs))

181
rembg/commands/p_command.py Normal file
View File

@ -0,0 +1,181 @@
import json
import pathlib
import time
from typing import cast
import click
import filetype
from tqdm import tqdm
from watchdog.events import FileSystemEvent, FileSystemEventHandler
from watchdog.observers import Observer
from ..bg import remove
from ..session_factory import new_session
from ..sessions import sessions_names
@click.command(
name="p",
help="for a folder as input",
)
@click.option(
"-m",
"--model",
default="u2net",
type=click.Choice(sessions_names),
show_default=True,
show_choices=True,
help="model name",
)
@click.option(
"-a",
"--alpha-matting",
is_flag=True,
show_default=True,
help="use alpha matting",
)
@click.option(
"-af",
"--alpha-matting-foreground-threshold",
default=240,
type=int,
show_default=True,
help="trimap fg threshold",
)
@click.option(
"-ab",
"--alpha-matting-background-threshold",
default=10,
type=int,
show_default=True,
help="trimap bg threshold",
)
@click.option(
"-ae",
"--alpha-matting-erode-size",
default=10,
type=int,
show_default=True,
help="erode size",
)
@click.option(
"-om",
"--only-mask",
is_flag=True,
show_default=True,
help="output only the mask",
)
@click.option(
"-ppm",
"--post-process-mask",
is_flag=True,
show_default=True,
help="post process the mask",
)
@click.option(
"-w",
"--watch",
default=False,
is_flag=True,
show_default=True,
help="watches a folder for changes",
)
@click.option(
"-bgc",
"--bgcolor",
default=None,
type=(int, int, int, int),
nargs=4,
help="Background color (R G B A) to replace the removed background with",
)
@click.option("-x", "--extras", type=str)
@click.argument(
"input",
type=click.Path(
exists=True,
path_type=pathlib.Path,
file_okay=False,
dir_okay=True,
readable=True,
),
)
@click.argument(
"output",
type=click.Path(
exists=False,
path_type=pathlib.Path,
file_okay=False,
dir_okay=True,
writable=True,
),
)
def p_command(
model: str,
extras: str,
input: pathlib.Path,
output: pathlib.Path,
watch: bool,
**kwargs,
) -> None:
try:
kwargs.update(json.loads(extras))
except Exception:
pass
session = new_session(model)
def process(each_input: pathlib.Path) -> None:
try:
mimetype = filetype.guess(each_input)
if mimetype is None:
return
if mimetype.mime.find("image") < 0:
return
each_output = (output / each_input.name).with_suffix(".png")
each_output.parents[0].mkdir(parents=True, exist_ok=True)
if not each_output.exists():
each_output.write_bytes(
cast(
bytes,
remove(each_input.read_bytes(), session=session, **kwargs),
)
)
if watch:
print(
f"processed: {each_input.absolute()} -> {each_output.absolute()}"
)
except Exception as e:
print(e)
inputs = list(input.glob("**/*"))
if not watch:
inputs = tqdm(inputs)
for each_input in inputs:
if not each_input.is_dir():
process(each_input)
if watch:
observer = Observer()
class EventHandler(FileSystemEventHandler):
def on_any_event(self, event: FileSystemEvent) -> None:
if not (
event.is_directory or event.event_type in ["deleted", "closed"]
):
process(pathlib.Path(event.src_path))
event_handler = EventHandler()
observer.schedule(event_handler, input, recursive=False)
observer.start()
try:
while True:
time.sleep(1)
finally:
observer.stop()
observer.join()

239
rembg/commands/s_command.py Normal file
View File

@ -0,0 +1,239 @@
import json
from typing import Annotated, Optional, Tuple, cast
import aiohttp
import click
import uvicorn
from asyncer import asyncify
from fastapi import Depends, FastAPI, File, Form, Query
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import Response
from .._version import get_versions
from ..bg import remove
from ..session_factory import new_session
from ..sessions import sessions_names
from ..sessions.base import BaseSession
@click.command(
name="s",
help="for a http server",
)
@click.option(
"-p",
"--port",
default=5000,
type=int,
show_default=True,
help="port",
)
@click.option(
"-l",
"--log_level",
default="info",
type=str,
show_default=True,
help="log level",
)
@click.option(
"-t",
"--threads",
default=None,
type=int,
show_default=True,
help="number of worker threads",
)
def s_command(port: int, log_level: str, threads: int) -> None:
sessions: dict[str, BaseSession] = {}
tags_metadata = [
{
"name": "Background Removal",
"description": "Endpoints that perform background removal with different image sources.",
"externalDocs": {
"description": "GitHub Source",
"url": "https://github.com/danielgatis/rembg",
},
},
]
app = FastAPI(
title="Rembg",
description="Rembg is a tool to remove images background. That is it.",
version=get_versions()["version"],
contact={
"name": "Daniel Gatis",
"url": "https://github.com/danielgatis",
"email": "danielgatis@gmail.com",
},
license_info={
"name": "MIT License",
"url": "https://github.com/danielgatis/rembg/blob/main/LICENSE.txt",
},
openapi_tags=tags_metadata,
)
app.add_middleware(
CORSMiddleware,
allow_credentials=True,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
class CommonQueryParams:
def __init__(
self,
model: Annotated[
str, Query(regex=r"(" + "|".join(sessions_names) + ")")
] = Query(
default="u2net",
description="Model to use when processing image",
),
a: bool = Query(default=False, description="Enable Alpha Matting"),
af: int = Query(
default=240,
ge=0,
le=255,
description="Alpha Matting (Foreground Threshold)",
),
ab: int = Query(
default=10,
ge=0,
le=255,
description="Alpha Matting (Background Threshold)",
),
ae: int = Query(
default=10, ge=0, description="Alpha Matting (Erode Structure Size)"
),
om: bool = Query(default=False, description="Only Mask"),
ppm: bool = Query(default=False, description="Post Process Mask"),
bgc: Optional[str] = Query(default=None, description="Background Color"),
extras: Optional[str] = Query(
default=None, description="Extra parameters as JSON"
),
):
self.model = model
self.a = a
self.af = af
self.ab = ab
self.ae = ae
self.om = om
self.ppm = ppm
self.extras = extras
self.bgc = (
cast(Tuple[int, int, int, int], tuple(map(int, bgc.split(","))))
if bgc
else None
)
class CommonQueryPostParams:
def __init__(
self,
model: Annotated[
str, Form(regex=r"(" + "|".join(sessions_names) + ")")
] = Form(
default="u2net",
description="Model to use when processing image",
),
a: bool = Form(default=False, description="Enable Alpha Matting"),
af: int = Form(
default=240,
ge=0,
le=255,
description="Alpha Matting (Foreground Threshold)",
),
ab: int = Form(
default=10,
ge=0,
le=255,
description="Alpha Matting (Background Threshold)",
),
ae: int = Form(
default=10, ge=0, description="Alpha Matting (Erode Structure Size)"
),
om: bool = Form(default=False, description="Only Mask"),
ppm: bool = Form(default=False, description="Post Process Mask"),
bgc: Optional[str] = Query(default=None, description="Background Color"),
extras: Optional[str] = Query(
default=None, description="Extra parameters as JSON"
),
):
self.model = model
self.a = a
self.af = af
self.ab = ab
self.ae = ae
self.om = om
self.ppm = ppm
self.extras = extras
self.bgc = (
cast(Tuple[int, int, int, int], tuple(map(int, bgc.split(","))))
if bgc
else None
)
def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
kwargs = dict()
try:
kwargs.update(json.loads(commons.extras))
except Exception:
pass
return Response(
remove(
content,
session=sessions.setdefault(commons.model, new_session(commons.model)),
alpha_matting=commons.a,
alpha_matting_foreground_threshold=commons.af,
alpha_matting_background_threshold=commons.ab,
alpha_matting_erode_size=commons.ae,
only_mask=commons.om,
post_process_mask=commons.ppm,
bgcolor=commons.bgc,
**kwargs
),
media_type="image/png",
)
@app.on_event("startup")
def startup():
if threads is not None:
from anyio import CapacityLimiter
from anyio.lowlevel import RunVar
RunVar("_default_thread_limiter").set(CapacityLimiter(threads))
@app.get(
path="/",
tags=["Background Removal"],
summary="Remove from URL",
description="Removes the background from an image obtained by retrieving an URL.",
)
async def get_index(
url: str = Query(
default=..., description="URL of the image that has to be processed."
),
commons: CommonQueryParams = Depends(),
):
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
file = await response.read()
return await asyncify(im_without_bg)(file, commons)
@app.post(
path="/",
tags=["Background Removal"],
summary="Remove from Stream",
description="Removes the background from an image sent within the request itself.",
)
async def post_index(
file: bytes = File(
default=...,
description="Image file (byte stream) that has to be processed.",
),
commons: CommonQueryPostParams = Depends(),
):
return await asyncify(im_without_bg)(file, commons) # type: ignore
uvicorn.run(app, host="0.0.0.0", port=port, log_level=log_level)

View File

@ -1,28 +0,0 @@
from typing import List
import numpy as np
from PIL import Image
from PIL.Image import Image as PILImage
from .session_base import BaseSession
class DisSession(BaseSession):
def predict(self, img: PILImage) -> List[PILImage]:
ort_outs = self.inner_session.run(
None,
self.normalize(img, (0.485, 0.456, 0.406), (1.0, 1.0, 1.0), (1024, 1024)),
)
pred = ort_outs[0][:, 0, :, :]
ma = np.max(pred)
mi = np.min(pred)
pred = (pred - mi) / (ma - mi)
pred = np.squeeze(pred)
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
mask = mask.resize(img.size, Image.LANCZOS)
return [mask]

View File

@ -1,114 +1,24 @@
import hashlib
import os
import sys
from contextlib import redirect_stdout
from pathlib import Path
from typing import Type
import onnxruntime as ort
import pooch
from .session_base import BaseSession
from .session_cloth import ClothSession
from .session_dis import DisSession
from .session_sam import SamSession
from .session_simple import SimpleSession
from .sessions import sessions_class
from .sessions.base import BaseSession
from .sessions.u2net import U2netSession
def download_model(url: str, md5: str, fname: str, path: Path):
pooch.retrieve(
url,
f"md5:{md5}",
fname=fname,
path=path,
progressbar=True,
)
def new_session(model_name: str = "u2net", *args, **kwargs) -> BaseSession:
session_class: Type[BaseSession] = U2netSession
def new_session(model_name: str = "u2net") -> BaseSession:
# Define the model path
u2net_home = os.getenv(
"U2NET_HOME", os.path.join(os.getenv("XDG_DATA_HOME", "~"), ".u2net")
)
fname = f"{model_name}.onnx"
path = Path(u2net_home).expanduser()
full_path = Path(u2net_home).expanduser() / fname
session_class: Type[BaseSession]
md5 = "60024c5c889badc19c04ad937298a77b"
url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx"
session_class = SimpleSession
if model_name == "u2netp":
md5 = "8e83ca70e441ab06c318d82300c84806"
url = (
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx"
)
session_class = SimpleSession
elif model_name == "u2net_human_seg":
md5 = "c09ddc2e0104f800e3e1bb4652583d1f"
url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx"
session_class = SimpleSession
elif model_name == "u2net_cloth_seg":
md5 = "2434d1f3cb744e0e49386c906e5a08bb"
url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx"
session_class = ClothSession
elif model_name == "silueta":
md5 = "55e59e0d8062d2f5d013f4725ee84782"
url = (
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx"
)
session_class = SimpleSession
elif model_name == "isnet-general-use":
md5 = "fc16ebd8b0c10d971d3513d564d01e29"
url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx"
session_class = DisSession
elif model_name == "sam":
path = Path(u2net_home).expanduser()
fname_encoder = f"{model_name}_encoder.onnx"
encoder_md5 = "13d97c5c79ab13ef86d67cbde5f1b250"
encoder_url = "https://github.com/Flippchen/rembg/releases/download/test/vit_b-encoder-quant.onnx"
fname_decoder = f"{model_name}_decoder.onnx"
decoder_md5 = "fa3d1c36a3187d3de1c8deebf33dd127"
decoder_url = "https://github.com/Flippchen/rembg/releases/download/test/vit_b-decoder-quant.onnx"
download_model(encoder_url, encoder_md5, fname_encoder, path)
download_model(decoder_url, decoder_md5, fname_decoder, path)
sess_opts = ort.SessionOptions()
if "OMP_NUM_THREADS" in os.environ:
sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
return SamSession(
model_name,
ort.InferenceSession(
str(path / fname_encoder),
providers=ort.get_available_providers(),
sess_options=sess_opts,
),
ort.InferenceSession(
str(path / fname_decoder),
providers=ort.get_available_providers(),
sess_options=sess_opts,
),
)
download_model(url, md5, fname, path)
for sc in sessions_class:
if sc.name() == model_name:
session_class = sc
break
sess_opts = ort.SessionOptions()
if "OMP_NUM_THREADS" in os.environ:
sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
return session_class(
model_name,
ort.InferenceSession(
str(full_path),
providers=ort.get_available_providers(),
sess_options=sess_opts,
),
)
return session_class(model_name, sess_opts, *args, **kwargs)

View File

@ -1,30 +0,0 @@
from typing import List
import numpy as np
from PIL import Image
from PIL.Image import Image as PILImage
from .session_base import BaseSession
class SimpleSession(BaseSession):
def predict(self, img: PILImage) -> List[PILImage]:
ort_outs = self.inner_session.run(
None,
self.normalize(
img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)
),
)
pred = ort_outs[0][:, 0, :, :]
ma = np.max(pred)
mi = np.min(pred)
pred = (pred - mi) / (ma - mi)
pred = np.squeeze(pred)
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
mask = mask.resize(img.size, Image.LANCZOS)
return [mask]

View File

@ -0,0 +1,22 @@
from importlib import import_module
from inspect import isclass
from pathlib import Path
from pkgutil import iter_modules
from .base import BaseSession
sessions_class = []
sessions_names = []
package_dir = Path(__file__).resolve().parent
for _b, module_name, _p in iter_modules([str(package_dir)]):
module = import_module(f"{__name__}.{module_name}")
for attribute_name in dir(module):
attribute = getattr(module, attribute_name)
if (
isclass(attribute)
and issubclass(attribute, BaseSession)
and attribute != BaseSession
):
sessions_class.append(attribute)
sessions_names.append(attribute.name())

View File

@ -1,3 +1,4 @@
import os
from typing import Dict, List, Tuple
import numpy as np
@ -7,9 +8,13 @@ from PIL.Image import Image as PILImage
class BaseSession:
def __init__(self, model_name: str, inner_session: ort.InferenceSession):
def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
self.model_name = model_name
self.inner_session = inner_session
self.inner_session = ort.InferenceSession(
str(self.__class__.download_models()),
providers=ort.get_available_providers(),
sess_options=sess_opts,
)
def normalize(
self,
@ -17,6 +22,8 @@ class BaseSession:
mean: Tuple[float, float, float],
std: Tuple[float, float, float],
size: Tuple[int, int],
*args,
**kwargs
) -> Dict[str, np.ndarray]:
im = img.convert("RGB").resize(size, Image.LANCZOS)
@ -36,5 +43,21 @@ class BaseSession:
.astype(np.float32)
}
def predict(self, img: PILImage) -> List[PILImage]:
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
raise NotImplementedError
@classmethod
def u2net_home(cls, *args, **kwargs):
return os.path.expanduser(
os.getenv(
"U2NET_HOME", os.path.join(os.getenv("XDG_DATA_HOME", "~"), ".u2net")
)
)
@classmethod
def download_models(cls, *args, **kwargs):
raise NotImplementedError
@classmethod
def name(cls, *args, **kwargs):
raise NotImplementedError

47
rembg/sessions/dis.py Normal file
View File

@ -0,0 +1,47 @@
import os
from typing import List
import numpy as np
import pooch
from PIL import Image
from PIL.Image import Image as PILImage
from .base import BaseSession
class DisSession(BaseSession):
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
ort_outs = self.inner_session.run(
None,
self.normalize(img, (0.485, 0.456, 0.406), (1.0, 1.0, 1.0), (1024, 1024)),
)
pred = ort_outs[0][:, 0, :, :]
ma = np.max(pred)
mi = np.min(pred)
pred = (pred - mi) / (ma - mi)
pred = np.squeeze(pred)
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
mask = mask.resize(img.size, Image.LANCZOS)
return [mask]
@classmethod
def download_models(cls, *args, **kwargs):
fname = f"{cls.name()}.onnx"
pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx",
f"md5:fc16ebd8b0c10d971d3513d564d01e29",
fname=fname,
path=cls.u2net_home(),
progressbar=True,
)
return os.path.join(cls.u2net_home(), fname)
@classmethod
def name(cls, *args, **kwargs):
return "isnet-general-use"

View File

@ -1,11 +1,13 @@
import os
from typing import List
import numpy as np
import onnxruntime as ort
import pooch
from PIL import Image
from PIL.Image import Image as PILImage
from .session_base import BaseSession
from .base import BaseSession
def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int):
@ -47,14 +49,19 @@ def pad_to_square(img: np.ndarray, size=1024):
class SamSession(BaseSession):
def __init__(
self,
model_name: str,
encoder: ort.InferenceSession,
decoder: ort.InferenceSession,
):
super().__init__(model_name, encoder)
self.decoder = decoder
def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
self.model_name = model_name
paths = self.__class__.download_models()
self.encoder = ort.InferenceSession(
str(paths[0]),
providers=ort.get_available_providers(),
sess_options=sess_opts,
)
self.decoder = ort.InferenceSession(
str(paths[1]),
providers=ort.get_available_providers(),
sess_options=sess_opts,
)
def normalize(
self,
@ -62,17 +69,19 @@ class SamSession(BaseSession):
mean=(123.675, 116.28, 103.53),
std=(58.395, 57.12, 57.375),
size=(1024, 1024),
*args,
**kwargs,
):
pixel_mean = np.array([*mean]).reshape(1, 1, -1)
pixel_std = np.array([*std]).reshape(1, 1, -1)
x = (img - pixel_mean) / pixel_std
return x
def predict_sam(
def predict(
self,
img: PILImage,
input_point: np.ndarray,
input_label: np.ndarray,
*args,
**kwargs,
) -> List[PILImage]:
# Preprocess image
image = resize_longes_side(img)
@ -80,17 +89,25 @@ class SamSession(BaseSession):
image = self.normalize(image)
image = pad_to_square(image)
input_labels = kwargs.get("input_labels")
input_points = kwargs.get("input_points")
if input_labels is None:
raise ValueError("input_labels is required")
if input_points is None:
raise ValueError("input_points is required")
# Transpose
image = image.transpose(2, 0, 1)[None, :, :, :]
# Run encoder (Image embedding)
encoded = self.inner_session.run(None, {"x": image})
encoded = self.encoder.run(None, {"x": image})
image_embedding = encoded[0]
# Add a batch index, concatenate a padding point, and transform.
onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[
onnx_coord = np.concatenate([input_points, np.array([[0.0, 0.0]])], axis=0)[
None, :, :
]
onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[
onnx_label = np.concatenate([input_labels, np.array([-1])], axis=0)[
None, :
].astype(np.float32)
onnx_coord = apply_coords(onnx_coord, img.size[::1], 1024).astype(np.float32)
@ -116,3 +133,33 @@ class SamSession(BaseSession):
]
return masks
@classmethod
def download_models(cls, *args, **kwargs):
fname_encoder = f"{cls.name()}_encoder.onnx"
fname_decoder = f"{cls.name()}_decoder.onnx"
pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-encoder-quant.onnx",
f"md5:13d97c5c79ab13ef86d67cbde5f1b250",
fname=fname_encoder,
path=cls.u2net_home(),
progressbar=True,
)
pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-decoder-quant.onnx",
f"md5:fa3d1c36a3187d3de1c8deebf33dd127",
fname=fname_decoder,
path=cls.u2net_home(),
progressbar=True,
)
return (
os.path.join(cls.u2net_home(), fname_encoder),
os.path.join(cls.u2net_home(), fname_decoder),
)
@classmethod
def name(cls, *args, **kwargs):
return "sam"

49
rembg/sessions/silueta.py Normal file
View File

@ -0,0 +1,49 @@
import os
from typing import List
import numpy as np
import pooch
from PIL import Image
from PIL.Image import Image as PILImage
from .base import BaseSession
class SiluetaSession(BaseSession):
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
ort_outs = self.inner_session.run(
None,
self.normalize(
img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)
),
)
pred = ort_outs[0][:, 0, :, :]
ma = np.max(pred)
mi = np.min(pred)
pred = (pred - mi) / (ma - mi)
pred = np.squeeze(pred)
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
mask = mask.resize(img.size, Image.LANCZOS)
return [mask]
@classmethod
def download_models(cls, *args, **kwargs):
fname = f"{cls.name()}.onnx"
pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx",
f"md5:55e59e0d8062d2f5d013f4725ee84782",
fname=fname,
path=cls.u2net_home(),
progressbar=True,
)
return os.path.join(cls.u2net_home(), fname)
@classmethod
def name(cls, *args, **kwargs):
return "silueta"

49
rembg/sessions/u2net.py Normal file
View File

@ -0,0 +1,49 @@
import os
from typing import List
import numpy as np
import pooch
from PIL import Image
from PIL.Image import Image as PILImage
from .base import BaseSession
class U2netSession(BaseSession):
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
ort_outs = self.inner_session.run(
None,
self.normalize(
img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)
),
)
pred = ort_outs[0][:, 0, :, :]
ma = np.max(pred)
mi = np.min(pred)
pred = (pred - mi) / (ma - mi)
pred = np.squeeze(pred)
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
mask = mask.resize(img.size, Image.LANCZOS)
return [mask]
@classmethod
def download_models(cls, *args, **kwargs):
fname = f"{cls.name()}.onnx"
pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx",
f"md5:60024c5c889badc19c04ad937298a77b",
fname=fname,
path=cls.u2net_home(),
progressbar=True,
)
return os.path.join(cls.u2net_home(), fname)
@classmethod
def name(cls, *args, **kwargs):
return "u2net"

View File

@ -1,11 +1,13 @@
import os
from typing import List
import numpy as np
import pooch
from PIL import Image
from PIL.Image import Image as PILImage
from scipy.special import log_softmax
from .session_base import BaseSession
from .base import BaseSession
pallete1 = [
0,
@ -53,8 +55,8 @@ pallete3 = [
]
class ClothSession(BaseSession):
def predict(self, img: PILImage) -> List[PILImage]:
class Unet2ClothSession(BaseSession):
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
ort_outs = self.inner_session.run(
None,
self.normalize(
@ -89,3 +91,20 @@ class ClothSession(BaseSession):
masks.append(mask3)
return masks
@classmethod
def download_models(cls, *args, **kwargs):
fname = f"{cls.name()}.onnx"
pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx",
f"md5:2434d1f3cb744e0e49386c906e5a08bb",
fname=fname,
path=cls.u2net_home(),
progressbar=True,
)
return os.path.join(cls.u2net_home(), fname)
@classmethod
def name(cls, *args, **kwargs):
return "u2net_cloth_seg"

View File

@ -0,0 +1,49 @@
import os
from typing import List
import numpy as np
import pooch
from PIL import Image
from PIL.Image import Image as PILImage
from .base import BaseSession
class U2netHumanSegSession(BaseSession):
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
ort_outs = self.inner_session.run(
None,
self.normalize(
img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)
),
)
pred = ort_outs[0][:, 0, :, :]
ma = np.max(pred)
mi = np.min(pred)
pred = (pred - mi) / (ma - mi)
pred = np.squeeze(pred)
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
mask = mask.resize(img.size, Image.LANCZOS)
return [mask]
@classmethod
def download_models(cls, *args, **kwargs):
fname = f"{cls.name()}.onnx"
pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx",
f"md5:c09ddc2e0104f800e3e1bb4652583d1f",
fname=fname,
path=cls.u2net_home(),
progressbar=True,
)
return os.path.join(cls.u2net_home(), fname)
@classmethod
def name(cls, *args, **kwargs):
return "u2net_human_seg"

49
rembg/sessions/u2netp.py Normal file
View File

@ -0,0 +1,49 @@
import os
from typing import List
import numpy as np
import pooch
from PIL import Image
from PIL.Image import Image as PILImage
from .base import BaseSession
class U2netpSession(BaseSession):
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
ort_outs = self.inner_session.run(
None,
self.normalize(
img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)
),
)
pred = ort_outs[0][:, 0, :, :]
ma = np.max(pred)
mi = np.min(pred)
pred = (pred - mi) / (ma - mi)
pred = np.squeeze(pred)
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
mask = mask.resize(img.size, Image.LANCZOS)
return [mask]
@classmethod
def download_models(cls, *args, **kwargs):
fname = f"{cls.name()}.onnx"
pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx",
f"md5:8e83ca70e441ab06c318d82300c84806",
fname=fname,
path=cls.u2net_home(),
progressbar=True,
)
return os.path.join(cls.u2net_home(), fname)
@classmethod
def name(cls, *args, **kwargs):
return "u2netp"

View File

@ -1 +1 @@
onnxruntime-gpu==1.13.1
onnxruntime-gpu==1.14.1

View File

@ -1,14 +1,14 @@
aiohttp==3.8.1
asyncer==0.0.2
click==8.1.3
fastapi==0.87.0
fastapi==0.92.0
filetype==1.2.0
pooch==1.6.0
imagehash==4.3.1
numpy==1.23.5
onnxruntime==1.14.1
opencv-python-headless==4.6.0.66
pillow==9.3.0
pooch==1.6.0
pymatting==1.1.8
python-multipart==0.0.5
scikit-image==0.19.3

View File

@ -42,12 +42,12 @@ setup(
"click>=8.1.3",
"fastapi>=0.92.0",
"filetype>=1.2.0",
"pooch>=1.6.0",
"imagehash>=4.3.1",
"numpy>=1.23.5",
"onnxruntime>=1.13.1",
"onnxruntime>=1.14.1",
"opencv-python-headless>=4.6.0.66",
"pillow>=9.3.0",
"pooch>=1.6.0",
"pymatting>=1.1.8",
"python-multipart>=0.0.5",
"scikit-image>=0.19.3",
@ -62,7 +62,7 @@ setup(
],
},
extras_require={
"gpu": ["onnxruntime-gpu>=1.13.1"],
"gpu": ["onnxruntime-gpu>=1.14.1"],
},
version=versioneer.get_version(),
cmdclass=versioneer.get_cmdclass(),

BIN
tests/results/car-1.sam.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 78 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 104 KiB

View File

@ -4,28 +4,48 @@ from pathlib import Path
from imagehash import phash as hash_img
from PIL import Image
from rembg import remove
from rembg import new_session
from rembg import new_session, remove
here = Path(__file__).parent.resolve()
def test_remove():
for model in ["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg", "silueta", "isnet-general-use"]:
kwargs = {
"sam": {
"car-1" : {
"input_points": [[250, 200]],
"input_labels": [1],
},
"cloth-1" : {
"input_points": [[370, 495]],
"input_labels": [1],
}
}
}
for model in [
"u2net",
"u2netp",
"u2net_human_seg",
"u2net_cloth_seg",
"silueta",
"isnet-general-use",
"sam"
]:
for picture in ["car-1", "cloth-1"]:
image_path = Path(here / "fixtures" / f"{picture}.jpg")
expected_path = Path(here / "results" / f"{picture}.{model}.png")
image_path = Path(here / "fixtures" / f"{picture}.jpg")
image = image_path.read_bytes()
expected = expected_path.read_bytes()
actual = remove(image, session=new_session(model))
actual = remove(image, session=new_session(model), **kwargs.get(model, {}).get(picture, {}))
actual_hash = hash_img(Image.open(BytesIO(actual)))
expected_path = Path(here / "results" / f"{picture}.{model}.png")
# Uncomment to update the expected results
# f = open(expected_path, "ab")
# f.write(actual)
# f.close()
actual_hash = hash_img(Image.open(BytesIO(actual)))
expected = expected_path.read_bytes()
expected_hash = hash_img(Image.open(BytesIO(expected)))
print(f"image_path: {image_path}")