diff --git a/README.md b/README.md index c89c276..81039a3 100644 --- a/README.md +++ b/README.md @@ -146,6 +146,12 @@ Remove the background applying an alpha matting rembg i -a path/to/input.png path/to/output.png ``` +Passing extras parameters + +``` +rembg i -m sam -x '{"input_labels": [1], "input_points": [[100,100]]}' path/to/input.png path/to/output.png +``` + ### rembg `p` Used when input and output are folders. @@ -266,6 +272,7 @@ The available models are: - u2net_cloth_seg ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx), [source](https://github.com/levindabhi/cloth-segmentation)): A pre-trained model for Cloths Parsing from human portrait. Here clothes are parsed into 3 category: Upper body, Lower body and Full body. - silueta ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx), [source](https://github.com/xuebinqin/U-2-Net/issues/295)): Same as u2net but the size is reduced to 43Mb. - isnet-general-use ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx), [source](https://github.com/xuebinqin/DIS)): A new pre-trained model for general use cases. +- sam ([encoder](https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-encoder-quant.onnx), [decoder](https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-decoder-quant.onnx), [source](https://github.com/facebookresearch/segment-anything)): The Segment Anything Model (SAM). ### Some differences between the models result @@ -278,6 +285,7 @@ The available models are: u2net_cloth_seg silueta isnet-general-use + sam @@ -287,6 +295,7 @@ The available models are: + @@ -295,6 +304,7 @@ The available models are: + diff --git a/rembg/bg.py b/rembg/bg.py index ab342cb..eed42c5 100644 --- a/rembg/bg.py +++ b/rembg/bg.py @@ -1,6 +1,6 @@ import io from enum import Enum -from typing import List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union import numpy as np from cv2 import ( @@ -18,9 +18,8 @@ from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml from pymatting.util.util import stack_images from scipy.ndimage import binary_erosion -from .session_base import BaseSession from .session_factory import new_session -from .session_sam import SamSession +from .sessions.base import BaseSession kernel = getStructuringElement(MORPH_ELLIPSE, (3, 3)) @@ -120,12 +119,12 @@ def remove( alpha_matting_foreground_threshold: int = 240, alpha_matting_background_threshold: int = 10, alpha_matting_erode_size: int = 10, - session: Optional[Union[BaseSession, SamSession]] = None, + session: Optional[BaseSession] = None, only_mask: bool = False, post_process_mask: bool = False, bgcolor: Optional[Tuple[int, int, int, int]] = None, - input_point: Optional[np.ndarray] = None, - input_label: Optional[np.ndarray] = None, + *args: Optional[Any], + **kwargs: Optional[Any] ) -> Union[bytes, PILImage, np.ndarray]: if isinstance(data, PILImage): return_type = ReturnType.PILLOW @@ -140,15 +139,9 @@ def remove( raise ValueError("Input type {} is not supported.".format(type(data))) if session is None: - session = new_session("u2net") - - if isinstance(session, SamSession): - if input_point is None or input_label is None: - raise ValueError("Input point and label are required for SAM model.") - masks = session.predict_sam(img, input_point, input_label) - else: - masks = session.predict(img) + session = new_session("u2net", *args, **kwargs) + masks = session.predict(img, *args, **kwargs) cutouts = [] for mask in masks: diff --git a/rembg/cli.py b/rembg/cli.py index cfb1760..bd3ac26 100644 --- a/rembg/cli.py +++ b/rembg/cli.py @@ -1,25 +1,7 @@ -import pathlib -import sys -import time -from enum import Enum -from typing import IO, Optional, Tuple, cast - -import aiohttp import click -import filetype -import uvicorn -from asyncer import asyncify -from fastapi import Depends, FastAPI, File, Form, Query -from fastapi.middleware.cors import CORSMiddleware -from starlette.responses import Response -from tqdm import tqdm -from watchdog.events import FileSystemEvent, FileSystemEventHandler -from watchdog.observers import Observer from . import _version -from .bg import remove -from .session_base import BaseSession -from .session_factory import new_session +from .commands import command_functions @click.group() @@ -28,457 +10,5 @@ def main() -> None: pass -@main.command(help="for a file as input") -@click.option( - "-m", - "--model", - default="u2net", - type=click.Choice( - [ - "u2net", - "u2netp", - "u2net_human_seg", - "u2net_cloth_seg", - "silueta", - "isnet-general-use", - ] - ), - show_default=True, - show_choices=True, - help="model name", -) -@click.option( - "-a", - "--alpha-matting", - is_flag=True, - show_default=True, - help="use alpha matting", -) -@click.option( - "-af", - "--alpha-matting-foreground-threshold", - default=240, - type=int, - show_default=True, - help="trimap fg threshold", -) -@click.option( - "-ab", - "--alpha-matting-background-threshold", - default=10, - type=int, - show_default=True, - help="trimap bg threshold", -) -@click.option( - "-ae", - "--alpha-matting-erode-size", - default=10, - type=int, - show_default=True, - help="erode size", -) -@click.option( - "-om", - "--only-mask", - is_flag=True, - show_default=True, - help="output only the mask", -) -@click.option( - "-ppm", - "--post-process-mask", - is_flag=True, - show_default=True, - help="post process the mask", -) -@click.option( - "-bgc", - "--bgcolor", - default=None, - type=(int, int, int, int), - nargs=4, - help="Background color (R G B A) to replace the removed background with", -) -@click.argument( - "input", default=(None if sys.stdin.isatty() else "-"), type=click.File("rb") -) -@click.argument( - "output", - default=(None if sys.stdin.isatty() else "-"), - type=click.File("wb", lazy=True), -) -def i(model: str, input: IO, output: IO, **kwargs) -> None: - output.write(remove(input.read(), session=new_session(model), **kwargs)) - - -@main.command(help="for a folder as input") -@click.option( - "-m", - "--model", - default="u2net", - type=click.Choice( - [ - "u2net", - "u2netp", - "u2net_human_seg", - "u2net_cloth_seg", - "silueta", - "isnet-general-use", - ] - ), - show_default=True, - show_choices=True, - help="model name", -) -@click.option( - "-a", - "--alpha-matting", - is_flag=True, - show_default=True, - help="use alpha matting", -) -@click.option( - "-af", - "--alpha-matting-foreground-threshold", - default=240, - type=int, - show_default=True, - help="trimap fg threshold", -) -@click.option( - "-ab", - "--alpha-matting-background-threshold", - default=10, - type=int, - show_default=True, - help="trimap bg threshold", -) -@click.option( - "-ae", - "--alpha-matting-erode-size", - default=10, - type=int, - show_default=True, - help="erode size", -) -@click.option( - "-om", - "--only-mask", - is_flag=True, - show_default=True, - help="output only the mask", -) -@click.option( - "-ppm", - "--post-process-mask", - is_flag=True, - show_default=True, - help="post process the mask", -) -@click.option( - "-w", - "--watch", - default=False, - is_flag=True, - show_default=True, - help="watches a folder for changes", -) -@click.option( - "-bgc", - "--bgcolor", - default=None, - type=(int, int, int, int), - nargs=4, - help="Background color (R G B A) to replace the removed background with", -) -@click.argument( - "input", - type=click.Path( - exists=True, - path_type=pathlib.Path, - file_okay=False, - dir_okay=True, - readable=True, - ), -) -@click.argument( - "output", - type=click.Path( - exists=False, - path_type=pathlib.Path, - file_okay=False, - dir_okay=True, - writable=True, - ), -) -def p( - model: str, input: pathlib.Path, output: pathlib.Path, watch: bool, **kwargs -) -> None: - session = new_session(model) - - def process(each_input: pathlib.Path) -> None: - try: - mimetype = filetype.guess(each_input) - if mimetype is None: - return - if mimetype.mime.find("image") < 0: - return - - each_output = (output / each_input.name).with_suffix(".png") - each_output.parents[0].mkdir(parents=True, exist_ok=True) - - if not each_output.exists(): - each_output.write_bytes( - cast( - bytes, - remove(each_input.read_bytes(), session=session, **kwargs), - ) - ) - - if watch: - print( - f"processed: {each_input.absolute()} -> {each_output.absolute()}" - ) - except Exception as e: - print(e) - - inputs = list(input.glob("**/*")) - if not watch: - inputs = tqdm(inputs) - - for each_input in inputs: - if not each_input.is_dir(): - process(each_input) - - if watch: - observer = Observer() - - class EventHandler(FileSystemEventHandler): - def on_any_event(self, event: FileSystemEvent) -> None: - if not ( - event.is_directory or event.event_type in ["deleted", "closed"] - ): - process(pathlib.Path(event.src_path)) - - event_handler = EventHandler() - observer.schedule(event_handler, input, recursive=False) - observer.start() - - try: - while True: - time.sleep(1) - - finally: - observer.stop() - observer.join() - - -@main.command(help="for a http server") -@click.option( - "-p", - "--port", - default=5000, - type=int, - show_default=True, - help="port", -) -@click.option( - "-l", - "--log_level", - default="info", - type=str, - show_default=True, - help="log level", -) -@click.option( - "-t", - "--threads", - default=None, - type=int, - show_default=True, - help="number of worker threads", -) -def s(port: int, log_level: str, threads: int) -> None: - sessions: dict[str, BaseSession] = {} - tags_metadata = [ - { - "name": "Background Removal", - "description": "Endpoints that perform background removal with different image sources.", - "externalDocs": { - "description": "GitHub Source", - "url": "https://github.com/danielgatis/rembg", - }, - }, - ] - app = FastAPI( - title="Rembg", - description="Rembg is a tool to remove images background. That is it.", - version=_version.get_versions()["version"], - contact={ - "name": "Daniel Gatis", - "url": "https://github.com/danielgatis", - "email": "danielgatis@gmail.com", - }, - license_info={ - "name": "MIT License", - "url": "https://github.com/danielgatis/rembg/blob/main/LICENSE.txt", - }, - openapi_tags=tags_metadata, - ) - - app.add_middleware( - CORSMiddleware, - allow_credentials=True, - allow_origins=["*"], - allow_methods=["*"], - allow_headers=["*"], - ) - - class ModelType(str, Enum): - u2net = "u2net" - u2netp = "u2netp" - u2net_human_seg = "u2net_human_seg" - u2net_cloth_seg = "u2net_cloth_seg" - silueta = "silueta" - isnet_general_use = "isnet-general-use" - - class CommonQueryParams: - def __init__( - self, - model: ModelType = Query( - default=ModelType.u2net, - description="Model to use when processing image", - ), - a: bool = Query(default=False, description="Enable Alpha Matting"), - af: int = Query( - default=240, - ge=0, - le=255, - description="Alpha Matting (Foreground Threshold)", - ), - ab: int = Query( - default=10, - ge=0, - le=255, - description="Alpha Matting (Background Threshold)", - ), - ae: int = Query( - default=10, ge=0, description="Alpha Matting (Erode Structure Size)" - ), - om: bool = Query(default=False, description="Only Mask"), - ppm: bool = Query(default=False, description="Post Process Mask"), - bgc: Optional[str] = Query(default=None, description="Background Color"), - ): - self.model = model - self.a = a - self.af = af - self.ab = ab - self.ae = ae - self.om = om - self.ppm = ppm - self.bgc = ( - cast(Tuple[int, int, int, int], tuple(map(int, bgc.split(",")))) - if bgc - else None - ) - - class CommonQueryPostParams: - def __init__( - self, - model: ModelType = Form( - default=ModelType.u2net, - description="Model to use when processing image", - ), - a: bool = Form(default=False, description="Enable Alpha Matting"), - af: int = Form( - default=240, - ge=0, - le=255, - description="Alpha Matting (Foreground Threshold)", - ), - ab: int = Form( - default=10, - ge=0, - le=255, - description="Alpha Matting (Background Threshold)", - ), - ae: int = Form( - default=10, ge=0, description="Alpha Matting (Erode Structure Size)" - ), - om: bool = Form(default=False, description="Only Mask"), - ppm: bool = Form(default=False, description="Post Process Mask"), - bgc: Optional[str] = Query(default=None, description="Background Color"), - ): - self.model = model - self.a = a - self.af = af - self.ab = ab - self.ae = ae - self.om = om - self.ppm = ppm - self.bgc = ( - cast(Tuple[int, int, int, int], tuple(map(int, bgc.split(",")))) - if bgc - else None - ) - - def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response: - return Response( - remove( - content, - session=sessions.setdefault( - commons.model.value, new_session(commons.model.value) - ), - alpha_matting=commons.a, - alpha_matting_foreground_threshold=commons.af, - alpha_matting_background_threshold=commons.ab, - alpha_matting_erode_size=commons.ae, - only_mask=commons.om, - post_process_mask=commons.ppm, - bgcolor=commons.bgc, - ), - media_type="image/png", - ) - - @app.on_event("startup") - def startup(): - if threads is not None: - from anyio import CapacityLimiter - from anyio.lowlevel import RunVar - - RunVar("_default_thread_limiter").set(CapacityLimiter(threads)) - - @app.get( - path="/", - tags=["Background Removal"], - summary="Remove from URL", - description="Removes the background from an image obtained by retrieving an URL.", - ) - async def get_index( - url: str = Query( - default=..., description="URL of the image that has to be processed." - ), - commons: CommonQueryParams = Depends(), - ): - async with aiohttp.ClientSession() as session: - async with session.get(url) as response: - file = await response.read() - return await asyncify(im_without_bg)(file, commons) - - @app.post( - path="/", - tags=["Background Removal"], - summary="Remove from Stream", - description="Removes the background from an image sent within the request itself.", - ) - async def post_index( - file: bytes = File( - default=..., - description="Image file (byte stream) that has to be processed.", - ), - commons: CommonQueryPostParams = Depends(), - ): - return await asyncify(im_without_bg)(file, commons) # type: ignore - - uvicorn.run(app, host="0.0.0.0", port=port, log_level=log_level) +for command in command_functions: + main.add_command(command) diff --git a/rembg/commands/__init__.py b/rembg/commands/__init__.py new file mode 100644 index 0000000..64f8993 --- /dev/null +++ b/rembg/commands/__init__.py @@ -0,0 +1,13 @@ +from importlib import import_module +from pathlib import Path +from pkgutil import iter_modules + +command_functions = [] + +package_dir = Path(__file__).resolve().parent +for _b, module_name, _p in iter_modules([str(package_dir)]): + module = import_module(f"{__name__}.{module_name}") + for attribute_name in dir(module): + attribute = getattr(module, attribute_name) + if attribute_name.endswith("_command"): + command_functions.append(attribute) diff --git a/rembg/commands/i_command.py b/rembg/commands/i_command.py new file mode 100644 index 0000000..d65313c --- /dev/null +++ b/rembg/commands/i_command.py @@ -0,0 +1,93 @@ +import json +import sys +from typing import IO + +import click + +from ..bg import remove +from ..session_factory import new_session +from ..sessions import sessions_names + + +@click.command( + name="i", + help="for a file as input", +) +@click.option( + "-m", + "--model", + default="u2net", + type=click.Choice(sessions_names), + show_default=True, + show_choices=True, + help="model name", +) +@click.option( + "-a", + "--alpha-matting", + is_flag=True, + show_default=True, + help="use alpha matting", +) +@click.option( + "-af", + "--alpha-matting-foreground-threshold", + default=240, + type=int, + show_default=True, + help="trimap fg threshold", +) +@click.option( + "-ab", + "--alpha-matting-background-threshold", + default=10, + type=int, + show_default=True, + help="trimap bg threshold", +) +@click.option( + "-ae", + "--alpha-matting-erode-size", + default=10, + type=int, + show_default=True, + help="erode size", +) +@click.option( + "-om", + "--only-mask", + is_flag=True, + show_default=True, + help="output only the mask", +) +@click.option( + "-ppm", + "--post-process-mask", + is_flag=True, + show_default=True, + help="post process the mask", +) +@click.option( + "-bgc", + "--bgcolor", + default=None, + type=(int, int, int, int), + nargs=4, + help="Background color (R G B A) to replace the removed background with", +) +@click.option("-x", "--extras", type=str) +@click.argument( + "input", default=(None if sys.stdin.isatty() else "-"), type=click.File("rb") +) +@click.argument( + "output", + default=(None if sys.stdin.isatty() else "-"), + type=click.File("wb", lazy=True), +) +def i_command(model: str, extras: str, input: IO, output: IO, **kwargs) -> None: + try: + kwargs.update(json.loads(extras)) + except Exception: + pass + + output.write(remove(input.read(), session=new_session(model), **kwargs)) diff --git a/rembg/commands/p_command.py b/rembg/commands/p_command.py new file mode 100644 index 0000000..2163bfb --- /dev/null +++ b/rembg/commands/p_command.py @@ -0,0 +1,181 @@ +import json +import pathlib +import time +from typing import cast + +import click +import filetype +from tqdm import tqdm +from watchdog.events import FileSystemEvent, FileSystemEventHandler +from watchdog.observers import Observer + +from ..bg import remove +from ..session_factory import new_session +from ..sessions import sessions_names + + +@click.command( + name="p", + help="for a folder as input", +) +@click.option( + "-m", + "--model", + default="u2net", + type=click.Choice(sessions_names), + show_default=True, + show_choices=True, + help="model name", +) +@click.option( + "-a", + "--alpha-matting", + is_flag=True, + show_default=True, + help="use alpha matting", +) +@click.option( + "-af", + "--alpha-matting-foreground-threshold", + default=240, + type=int, + show_default=True, + help="trimap fg threshold", +) +@click.option( + "-ab", + "--alpha-matting-background-threshold", + default=10, + type=int, + show_default=True, + help="trimap bg threshold", +) +@click.option( + "-ae", + "--alpha-matting-erode-size", + default=10, + type=int, + show_default=True, + help="erode size", +) +@click.option( + "-om", + "--only-mask", + is_flag=True, + show_default=True, + help="output only the mask", +) +@click.option( + "-ppm", + "--post-process-mask", + is_flag=True, + show_default=True, + help="post process the mask", +) +@click.option( + "-w", + "--watch", + default=False, + is_flag=True, + show_default=True, + help="watches a folder for changes", +) +@click.option( + "-bgc", + "--bgcolor", + default=None, + type=(int, int, int, int), + nargs=4, + help="Background color (R G B A) to replace the removed background with", +) +@click.option("-x", "--extras", type=str) +@click.argument( + "input", + type=click.Path( + exists=True, + path_type=pathlib.Path, + file_okay=False, + dir_okay=True, + readable=True, + ), +) +@click.argument( + "output", + type=click.Path( + exists=False, + path_type=pathlib.Path, + file_okay=False, + dir_okay=True, + writable=True, + ), +) +def p_command( + model: str, + extras: str, + input: pathlib.Path, + output: pathlib.Path, + watch: bool, + **kwargs, +) -> None: + try: + kwargs.update(json.loads(extras)) + except Exception: + pass + + session = new_session(model) + + def process(each_input: pathlib.Path) -> None: + try: + mimetype = filetype.guess(each_input) + if mimetype is None: + return + if mimetype.mime.find("image") < 0: + return + + each_output = (output / each_input.name).with_suffix(".png") + each_output.parents[0].mkdir(parents=True, exist_ok=True) + + if not each_output.exists(): + each_output.write_bytes( + cast( + bytes, + remove(each_input.read_bytes(), session=session, **kwargs), + ) + ) + + if watch: + print( + f"processed: {each_input.absolute()} -> {each_output.absolute()}" + ) + except Exception as e: + print(e) + + inputs = list(input.glob("**/*")) + if not watch: + inputs = tqdm(inputs) + + for each_input in inputs: + if not each_input.is_dir(): + process(each_input) + + if watch: + observer = Observer() + + class EventHandler(FileSystemEventHandler): + def on_any_event(self, event: FileSystemEvent) -> None: + if not ( + event.is_directory or event.event_type in ["deleted", "closed"] + ): + process(pathlib.Path(event.src_path)) + + event_handler = EventHandler() + observer.schedule(event_handler, input, recursive=False) + observer.start() + + try: + while True: + time.sleep(1) + + finally: + observer.stop() + observer.join() diff --git a/rembg/commands/s_command.py b/rembg/commands/s_command.py new file mode 100644 index 0000000..ac015b8 --- /dev/null +++ b/rembg/commands/s_command.py @@ -0,0 +1,239 @@ +import json +from typing import Annotated, Optional, Tuple, cast + +import aiohttp +import click +import uvicorn +from asyncer import asyncify +from fastapi import Depends, FastAPI, File, Form, Query +from fastapi.middleware.cors import CORSMiddleware +from starlette.responses import Response + +from .._version import get_versions +from ..bg import remove +from ..session_factory import new_session +from ..sessions import sessions_names +from ..sessions.base import BaseSession + + +@click.command( + name="s", + help="for a http server", +) +@click.option( + "-p", + "--port", + default=5000, + type=int, + show_default=True, + help="port", +) +@click.option( + "-l", + "--log_level", + default="info", + type=str, + show_default=True, + help="log level", +) +@click.option( + "-t", + "--threads", + default=None, + type=int, + show_default=True, + help="number of worker threads", +) +def s_command(port: int, log_level: str, threads: int) -> None: + sessions: dict[str, BaseSession] = {} + tags_metadata = [ + { + "name": "Background Removal", + "description": "Endpoints that perform background removal with different image sources.", + "externalDocs": { + "description": "GitHub Source", + "url": "https://github.com/danielgatis/rembg", + }, + }, + ] + app = FastAPI( + title="Rembg", + description="Rembg is a tool to remove images background. That is it.", + version=get_versions()["version"], + contact={ + "name": "Daniel Gatis", + "url": "https://github.com/danielgatis", + "email": "danielgatis@gmail.com", + }, + license_info={ + "name": "MIT License", + "url": "https://github.com/danielgatis/rembg/blob/main/LICENSE.txt", + }, + openapi_tags=tags_metadata, + ) + + app.add_middleware( + CORSMiddleware, + allow_credentials=True, + allow_origins=["*"], + allow_methods=["*"], + allow_headers=["*"], + ) + + class CommonQueryParams: + def __init__( + self, + model: Annotated[ + str, Query(regex=r"(" + "|".join(sessions_names) + ")") + ] = Query( + default="u2net", + description="Model to use when processing image", + ), + a: bool = Query(default=False, description="Enable Alpha Matting"), + af: int = Query( + default=240, + ge=0, + le=255, + description="Alpha Matting (Foreground Threshold)", + ), + ab: int = Query( + default=10, + ge=0, + le=255, + description="Alpha Matting (Background Threshold)", + ), + ae: int = Query( + default=10, ge=0, description="Alpha Matting (Erode Structure Size)" + ), + om: bool = Query(default=False, description="Only Mask"), + ppm: bool = Query(default=False, description="Post Process Mask"), + bgc: Optional[str] = Query(default=None, description="Background Color"), + extras: Optional[str] = Query( + default=None, description="Extra parameters as JSON" + ), + ): + self.model = model + self.a = a + self.af = af + self.ab = ab + self.ae = ae + self.om = om + self.ppm = ppm + self.extras = extras + self.bgc = ( + cast(Tuple[int, int, int, int], tuple(map(int, bgc.split(",")))) + if bgc + else None + ) + + class CommonQueryPostParams: + def __init__( + self, + model: Annotated[ + str, Form(regex=r"(" + "|".join(sessions_names) + ")") + ] = Form( + default="u2net", + description="Model to use when processing image", + ), + a: bool = Form(default=False, description="Enable Alpha Matting"), + af: int = Form( + default=240, + ge=0, + le=255, + description="Alpha Matting (Foreground Threshold)", + ), + ab: int = Form( + default=10, + ge=0, + le=255, + description="Alpha Matting (Background Threshold)", + ), + ae: int = Form( + default=10, ge=0, description="Alpha Matting (Erode Structure Size)" + ), + om: bool = Form(default=False, description="Only Mask"), + ppm: bool = Form(default=False, description="Post Process Mask"), + bgc: Optional[str] = Query(default=None, description="Background Color"), + extras: Optional[str] = Query( + default=None, description="Extra parameters as JSON" + ), + ): + self.model = model + self.a = a + self.af = af + self.ab = ab + self.ae = ae + self.om = om + self.ppm = ppm + self.extras = extras + self.bgc = ( + cast(Tuple[int, int, int, int], tuple(map(int, bgc.split(",")))) + if bgc + else None + ) + + def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response: + kwargs = dict() + + try: + kwargs.update(json.loads(commons.extras)) + except Exception: + pass + + return Response( + remove( + content, + session=sessions.setdefault(commons.model, new_session(commons.model)), + alpha_matting=commons.a, + alpha_matting_foreground_threshold=commons.af, + alpha_matting_background_threshold=commons.ab, + alpha_matting_erode_size=commons.ae, + only_mask=commons.om, + post_process_mask=commons.ppm, + bgcolor=commons.bgc, + **kwargs + ), + media_type="image/png", + ) + + @app.on_event("startup") + def startup(): + if threads is not None: + from anyio import CapacityLimiter + from anyio.lowlevel import RunVar + + RunVar("_default_thread_limiter").set(CapacityLimiter(threads)) + + @app.get( + path="/", + tags=["Background Removal"], + summary="Remove from URL", + description="Removes the background from an image obtained by retrieving an URL.", + ) + async def get_index( + url: str = Query( + default=..., description="URL of the image that has to be processed." + ), + commons: CommonQueryParams = Depends(), + ): + async with aiohttp.ClientSession() as session: + async with session.get(url) as response: + file = await response.read() + return await asyncify(im_without_bg)(file, commons) + + @app.post( + path="/", + tags=["Background Removal"], + summary="Remove from Stream", + description="Removes the background from an image sent within the request itself.", + ) + async def post_index( + file: bytes = File( + default=..., + description="Image file (byte stream) that has to be processed.", + ), + commons: CommonQueryPostParams = Depends(), + ): + return await asyncify(im_without_bg)(file, commons) # type: ignore + + uvicorn.run(app, host="0.0.0.0", port=port, log_level=log_level) diff --git a/rembg/session_dis.py b/rembg/session_dis.py deleted file mode 100644 index e215806..0000000 --- a/rembg/session_dis.py +++ /dev/null @@ -1,28 +0,0 @@ -from typing import List - -import numpy as np -from PIL import Image -from PIL.Image import Image as PILImage - -from .session_base import BaseSession - - -class DisSession(BaseSession): - def predict(self, img: PILImage) -> List[PILImage]: - ort_outs = self.inner_session.run( - None, - self.normalize(img, (0.485, 0.456, 0.406), (1.0, 1.0, 1.0), (1024, 1024)), - ) - - pred = ort_outs[0][:, 0, :, :] - - ma = np.max(pred) - mi = np.min(pred) - - pred = (pred - mi) / (ma - mi) - pred = np.squeeze(pred) - - mask = Image.fromarray((pred * 255).astype("uint8"), mode="L") - mask = mask.resize(img.size, Image.LANCZOS) - - return [mask] diff --git a/rembg/session_factory.py b/rembg/session_factory.py index 9d021c4..b1e1cb3 100644 --- a/rembg/session_factory.py +++ b/rembg/session_factory.py @@ -1,114 +1,24 @@ -import hashlib import os -import sys -from contextlib import redirect_stdout -from pathlib import Path from typing import Type import onnxruntime as ort -import pooch -from .session_base import BaseSession -from .session_cloth import ClothSession -from .session_dis import DisSession -from .session_sam import SamSession -from .session_simple import SimpleSession +from .sessions import sessions_class +from .sessions.base import BaseSession +from .sessions.u2net import U2netSession -def download_model(url: str, md5: str, fname: str, path: Path): - pooch.retrieve( - url, - f"md5:{md5}", - fname=fname, - path=path, - progressbar=True, - ) +def new_session(model_name: str = "u2net", *args, **kwargs) -> BaseSession: + session_class: Type[BaseSession] = U2netSession - -def new_session(model_name: str = "u2net") -> BaseSession: - # Define the model path - u2net_home = os.getenv( - "U2NET_HOME", os.path.join(os.getenv("XDG_DATA_HOME", "~"), ".u2net") - ) - - fname = f"{model_name}.onnx" - path = Path(u2net_home).expanduser() - full_path = Path(u2net_home).expanduser() / fname - - session_class: Type[BaseSession] - md5 = "60024c5c889badc19c04ad937298a77b" - url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx" - session_class = SimpleSession - - if model_name == "u2netp": - md5 = "8e83ca70e441ab06c318d82300c84806" - url = ( - "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx" - ) - session_class = SimpleSession - elif model_name == "u2net_human_seg": - md5 = "c09ddc2e0104f800e3e1bb4652583d1f" - url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx" - session_class = SimpleSession - elif model_name == "u2net_cloth_seg": - md5 = "2434d1f3cb744e0e49386c906e5a08bb" - url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx" - session_class = ClothSession - elif model_name == "silueta": - md5 = "55e59e0d8062d2f5d013f4725ee84782" - url = ( - "https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx" - ) - session_class = SimpleSession - elif model_name == "isnet-general-use": - md5 = "fc16ebd8b0c10d971d3513d564d01e29" - url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx" - session_class = DisSession - elif model_name == "sam": - path = Path(u2net_home).expanduser() - - fname_encoder = f"{model_name}_encoder.onnx" - encoder_md5 = "13d97c5c79ab13ef86d67cbde5f1b250" - encoder_url = "https://github.com/Flippchen/rembg/releases/download/test/vit_b-encoder-quant.onnx" - - fname_decoder = f"{model_name}_decoder.onnx" - decoder_md5 = "fa3d1c36a3187d3de1c8deebf33dd127" - decoder_url = "https://github.com/Flippchen/rembg/releases/download/test/vit_b-decoder-quant.onnx" - - download_model(encoder_url, encoder_md5, fname_encoder, path) - download_model(decoder_url, decoder_md5, fname_decoder, path) - - sess_opts = ort.SessionOptions() - - if "OMP_NUM_THREADS" in os.environ: - sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"]) - - return SamSession( - model_name, - ort.InferenceSession( - str(path / fname_encoder), - providers=ort.get_available_providers(), - sess_options=sess_opts, - ), - ort.InferenceSession( - str(path / fname_decoder), - providers=ort.get_available_providers(), - sess_options=sess_opts, - ), - ) - - download_model(url, md5, fname, path) + for sc in sessions_class: + if sc.name() == model_name: + session_class = sc + break sess_opts = ort.SessionOptions() if "OMP_NUM_THREADS" in os.environ: sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"]) - return session_class( - model_name, - ort.InferenceSession( - str(full_path), - providers=ort.get_available_providers(), - sess_options=sess_opts, - ), - ) + return session_class(model_name, sess_opts, *args, **kwargs) diff --git a/rembg/session_simple.py b/rembg/session_simple.py deleted file mode 100644 index 7ec3181..0000000 --- a/rembg/session_simple.py +++ /dev/null @@ -1,30 +0,0 @@ -from typing import List - -import numpy as np -from PIL import Image -from PIL.Image import Image as PILImage - -from .session_base import BaseSession - - -class SimpleSession(BaseSession): - def predict(self, img: PILImage) -> List[PILImage]: - ort_outs = self.inner_session.run( - None, - self.normalize( - img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320) - ), - ) - - pred = ort_outs[0][:, 0, :, :] - - ma = np.max(pred) - mi = np.min(pred) - - pred = (pred - mi) / (ma - mi) - pred = np.squeeze(pred) - - mask = Image.fromarray((pred * 255).astype("uint8"), mode="L") - mask = mask.resize(img.size, Image.LANCZOS) - - return [mask] diff --git a/rembg/sessions/__init__.py b/rembg/sessions/__init__.py new file mode 100644 index 0000000..08ca20a --- /dev/null +++ b/rembg/sessions/__init__.py @@ -0,0 +1,22 @@ +from importlib import import_module +from inspect import isclass +from pathlib import Path +from pkgutil import iter_modules + +from .base import BaseSession + +sessions_class = [] +sessions_names = [] + +package_dir = Path(__file__).resolve().parent +for _b, module_name, _p in iter_modules([str(package_dir)]): + module = import_module(f"{__name__}.{module_name}") + for attribute_name in dir(module): + attribute = getattr(module, attribute_name) + if ( + isclass(attribute) + and issubclass(attribute, BaseSession) + and attribute != BaseSession + ): + sessions_class.append(attribute) + sessions_names.append(attribute.name()) diff --git a/rembg/session_base.py b/rembg/sessions/base.py similarity index 55% rename from rembg/session_base.py rename to rembg/sessions/base.py index aa98693..6c6c0d8 100644 --- a/rembg/session_base.py +++ b/rembg/sessions/base.py @@ -1,3 +1,4 @@ +import os from typing import Dict, List, Tuple import numpy as np @@ -7,9 +8,13 @@ from PIL.Image import Image as PILImage class BaseSession: - def __init__(self, model_name: str, inner_session: ort.InferenceSession): + def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs): self.model_name = model_name - self.inner_session = inner_session + self.inner_session = ort.InferenceSession( + str(self.__class__.download_models()), + providers=ort.get_available_providers(), + sess_options=sess_opts, + ) def normalize( self, @@ -17,6 +22,8 @@ class BaseSession: mean: Tuple[float, float, float], std: Tuple[float, float, float], size: Tuple[int, int], + *args, + **kwargs ) -> Dict[str, np.ndarray]: im = img.convert("RGB").resize(size, Image.LANCZOS) @@ -36,5 +43,21 @@ class BaseSession: .astype(np.float32) } - def predict(self, img: PILImage) -> List[PILImage]: + def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]: + raise NotImplementedError + + @classmethod + def u2net_home(cls, *args, **kwargs): + return os.path.expanduser( + os.getenv( + "U2NET_HOME", os.path.join(os.getenv("XDG_DATA_HOME", "~"), ".u2net") + ) + ) + + @classmethod + def download_models(cls, *args, **kwargs): + raise NotImplementedError + + @classmethod + def name(cls, *args, **kwargs): raise NotImplementedError diff --git a/rembg/sessions/dis.py b/rembg/sessions/dis.py new file mode 100644 index 0000000..8c38276 --- /dev/null +++ b/rembg/sessions/dis.py @@ -0,0 +1,47 @@ +import os +from typing import List + +import numpy as np +import pooch +from PIL import Image +from PIL.Image import Image as PILImage + +from .base import BaseSession + + +class DisSession(BaseSession): + def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]: + ort_outs = self.inner_session.run( + None, + self.normalize(img, (0.485, 0.456, 0.406), (1.0, 1.0, 1.0), (1024, 1024)), + ) + + pred = ort_outs[0][:, 0, :, :] + + ma = np.max(pred) + mi = np.min(pred) + + pred = (pred - mi) / (ma - mi) + pred = np.squeeze(pred) + + mask = Image.fromarray((pred * 255).astype("uint8"), mode="L") + mask = mask.resize(img.size, Image.LANCZOS) + + return [mask] + + @classmethod + def download_models(cls, *args, **kwargs): + fname = f"{cls.name()}.onnx" + pooch.retrieve( + "https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx", + f"md5:fc16ebd8b0c10d971d3513d564d01e29", + fname=fname, + path=cls.u2net_home(), + progressbar=True, + ) + + return os.path.join(cls.u2net_home(), fname) + + @classmethod + def name(cls, *args, **kwargs): + return "isnet-general-use" diff --git a/rembg/session_sam.py b/rembg/sessions/sam.py similarity index 58% rename from rembg/session_sam.py rename to rembg/sessions/sam.py index 5bf2067..131ec53 100644 --- a/rembg/session_sam.py +++ b/rembg/sessions/sam.py @@ -1,11 +1,13 @@ +import os from typing import List import numpy as np import onnxruntime as ort +import pooch from PIL import Image from PIL.Image import Image as PILImage -from .session_base import BaseSession +from .base import BaseSession def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int): @@ -47,14 +49,19 @@ def pad_to_square(img: np.ndarray, size=1024): class SamSession(BaseSession): - def __init__( - self, - model_name: str, - encoder: ort.InferenceSession, - decoder: ort.InferenceSession, - ): - super().__init__(model_name, encoder) - self.decoder = decoder + def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs): + self.model_name = model_name + paths = self.__class__.download_models() + self.encoder = ort.InferenceSession( + str(paths[0]), + providers=ort.get_available_providers(), + sess_options=sess_opts, + ) + self.decoder = ort.InferenceSession( + str(paths[1]), + providers=ort.get_available_providers(), + sess_options=sess_opts, + ) def normalize( self, @@ -62,17 +69,19 @@ class SamSession(BaseSession): mean=(123.675, 116.28, 103.53), std=(58.395, 57.12, 57.375), size=(1024, 1024), + *args, + **kwargs, ): pixel_mean = np.array([*mean]).reshape(1, 1, -1) pixel_std = np.array([*std]).reshape(1, 1, -1) x = (img - pixel_mean) / pixel_std return x - def predict_sam( + def predict( self, img: PILImage, - input_point: np.ndarray, - input_label: np.ndarray, + *args, + **kwargs, ) -> List[PILImage]: # Preprocess image image = resize_longes_side(img) @@ -80,17 +89,25 @@ class SamSession(BaseSession): image = self.normalize(image) image = pad_to_square(image) + input_labels = kwargs.get("input_labels") + input_points = kwargs.get("input_points") + + if input_labels is None: + raise ValueError("input_labels is required") + if input_points is None: + raise ValueError("input_points is required") + # Transpose image = image.transpose(2, 0, 1)[None, :, :, :] # Run encoder (Image embedding) - encoded = self.inner_session.run(None, {"x": image}) + encoded = self.encoder.run(None, {"x": image}) image_embedding = encoded[0] # Add a batch index, concatenate a padding point, and transform. - onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[ + onnx_coord = np.concatenate([input_points, np.array([[0.0, 0.0]])], axis=0)[ None, :, : ] - onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[ + onnx_label = np.concatenate([input_labels, np.array([-1])], axis=0)[ None, : ].astype(np.float32) onnx_coord = apply_coords(onnx_coord, img.size[::1], 1024).astype(np.float32) @@ -116,3 +133,33 @@ class SamSession(BaseSession): ] return masks + + @classmethod + def download_models(cls, *args, **kwargs): + fname_encoder = f"{cls.name()}_encoder.onnx" + fname_decoder = f"{cls.name()}_decoder.onnx" + + pooch.retrieve( + "https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-encoder-quant.onnx", + f"md5:13d97c5c79ab13ef86d67cbde5f1b250", + fname=fname_encoder, + path=cls.u2net_home(), + progressbar=True, + ) + + pooch.retrieve( + "https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-decoder-quant.onnx", + f"md5:fa3d1c36a3187d3de1c8deebf33dd127", + fname=fname_decoder, + path=cls.u2net_home(), + progressbar=True, + ) + + return ( + os.path.join(cls.u2net_home(), fname_encoder), + os.path.join(cls.u2net_home(), fname_decoder), + ) + + @classmethod + def name(cls, *args, **kwargs): + return "sam" diff --git a/rembg/sessions/silueta.py b/rembg/sessions/silueta.py new file mode 100644 index 0000000..c09e3f2 --- /dev/null +++ b/rembg/sessions/silueta.py @@ -0,0 +1,49 @@ +import os +from typing import List + +import numpy as np +import pooch +from PIL import Image +from PIL.Image import Image as PILImage + +from .base import BaseSession + + +class SiluetaSession(BaseSession): + def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]: + ort_outs = self.inner_session.run( + None, + self.normalize( + img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320) + ), + ) + + pred = ort_outs[0][:, 0, :, :] + + ma = np.max(pred) + mi = np.min(pred) + + pred = (pred - mi) / (ma - mi) + pred = np.squeeze(pred) + + mask = Image.fromarray((pred * 255).astype("uint8"), mode="L") + mask = mask.resize(img.size, Image.LANCZOS) + + return [mask] + + @classmethod + def download_models(cls, *args, **kwargs): + fname = f"{cls.name()}.onnx" + pooch.retrieve( + "https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx", + f"md5:55e59e0d8062d2f5d013f4725ee84782", + fname=fname, + path=cls.u2net_home(), + progressbar=True, + ) + + return os.path.join(cls.u2net_home(), fname) + + @classmethod + def name(cls, *args, **kwargs): + return "silueta" diff --git a/rembg/sessions/u2net.py b/rembg/sessions/u2net.py new file mode 100644 index 0000000..2702a96 --- /dev/null +++ b/rembg/sessions/u2net.py @@ -0,0 +1,49 @@ +import os +from typing import List + +import numpy as np +import pooch +from PIL import Image +from PIL.Image import Image as PILImage + +from .base import BaseSession + + +class U2netSession(BaseSession): + def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]: + ort_outs = self.inner_session.run( + None, + self.normalize( + img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320) + ), + ) + + pred = ort_outs[0][:, 0, :, :] + + ma = np.max(pred) + mi = np.min(pred) + + pred = (pred - mi) / (ma - mi) + pred = np.squeeze(pred) + + mask = Image.fromarray((pred * 255).astype("uint8"), mode="L") + mask = mask.resize(img.size, Image.LANCZOS) + + return [mask] + + @classmethod + def download_models(cls, *args, **kwargs): + fname = f"{cls.name()}.onnx" + pooch.retrieve( + "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx", + f"md5:60024c5c889badc19c04ad937298a77b", + fname=fname, + path=cls.u2net_home(), + progressbar=True, + ) + + return os.path.join(cls.u2net_home(), fname) + + @classmethod + def name(cls, *args, **kwargs): + return "u2net" diff --git a/rembg/session_cloth.py b/rembg/sessions/u2net_cloth_seg.py similarity index 67% rename from rembg/session_cloth.py rename to rembg/sessions/u2net_cloth_seg.py index 575aff1..7765109 100644 --- a/rembg/session_cloth.py +++ b/rembg/sessions/u2net_cloth_seg.py @@ -1,11 +1,13 @@ +import os from typing import List import numpy as np +import pooch from PIL import Image from PIL.Image import Image as PILImage from scipy.special import log_softmax -from .session_base import BaseSession +from .base import BaseSession pallete1 = [ 0, @@ -53,8 +55,8 @@ pallete3 = [ ] -class ClothSession(BaseSession): - def predict(self, img: PILImage) -> List[PILImage]: +class Unet2ClothSession(BaseSession): + def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]: ort_outs = self.inner_session.run( None, self.normalize( @@ -89,3 +91,20 @@ class ClothSession(BaseSession): masks.append(mask3) return masks + + @classmethod + def download_models(cls, *args, **kwargs): + fname = f"{cls.name()}.onnx" + pooch.retrieve( + "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx", + f"md5:2434d1f3cb744e0e49386c906e5a08bb", + fname=fname, + path=cls.u2net_home(), + progressbar=True, + ) + + return os.path.join(cls.u2net_home(), fname) + + @classmethod + def name(cls, *args, **kwargs): + return "u2net_cloth_seg" diff --git a/rembg/sessions/u2net_human_seg.py b/rembg/sessions/u2net_human_seg.py new file mode 100644 index 0000000..8156bef --- /dev/null +++ b/rembg/sessions/u2net_human_seg.py @@ -0,0 +1,49 @@ +import os +from typing import List + +import numpy as np +import pooch +from PIL import Image +from PIL.Image import Image as PILImage + +from .base import BaseSession + + +class U2netHumanSegSession(BaseSession): + def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]: + ort_outs = self.inner_session.run( + None, + self.normalize( + img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320) + ), + ) + + pred = ort_outs[0][:, 0, :, :] + + ma = np.max(pred) + mi = np.min(pred) + + pred = (pred - mi) / (ma - mi) + pred = np.squeeze(pred) + + mask = Image.fromarray((pred * 255).astype("uint8"), mode="L") + mask = mask.resize(img.size, Image.LANCZOS) + + return [mask] + + @classmethod + def download_models(cls, *args, **kwargs): + fname = f"{cls.name()}.onnx" + pooch.retrieve( + "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx", + f"md5:c09ddc2e0104f800e3e1bb4652583d1f", + fname=fname, + path=cls.u2net_home(), + progressbar=True, + ) + + return os.path.join(cls.u2net_home(), fname) + + @classmethod + def name(cls, *args, **kwargs): + return "u2net_human_seg" diff --git a/rembg/sessions/u2netp.py b/rembg/sessions/u2netp.py new file mode 100644 index 0000000..dc9edab --- /dev/null +++ b/rembg/sessions/u2netp.py @@ -0,0 +1,49 @@ +import os +from typing import List + +import numpy as np +import pooch +from PIL import Image +from PIL.Image import Image as PILImage + +from .base import BaseSession + + +class U2netpSession(BaseSession): + def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]: + ort_outs = self.inner_session.run( + None, + self.normalize( + img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320) + ), + ) + + pred = ort_outs[0][:, 0, :, :] + + ma = np.max(pred) + mi = np.min(pred) + + pred = (pred - mi) / (ma - mi) + pred = np.squeeze(pred) + + mask = Image.fromarray((pred * 255).astype("uint8"), mode="L") + mask = mask.resize(img.size, Image.LANCZOS) + + return [mask] + + @classmethod + def download_models(cls, *args, **kwargs): + fname = f"{cls.name()}.onnx" + pooch.retrieve( + "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx", + f"md5:8e83ca70e441ab06c318d82300c84806", + fname=fname, + path=cls.u2net_home(), + progressbar=True, + ) + + return os.path.join(cls.u2net_home(), fname) + + @classmethod + def name(cls, *args, **kwargs): + return "u2netp" diff --git a/requirements-gpu.txt b/requirements-gpu.txt index c58c556..f331549 100644 --- a/requirements-gpu.txt +++ b/requirements-gpu.txt @@ -1 +1 @@ -onnxruntime-gpu==1.13.1 +onnxruntime-gpu==1.14.1 diff --git a/requirements.txt b/requirements.txt index fb41d7f..da62c0f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,14 +1,14 @@ aiohttp==3.8.1 asyncer==0.0.2 click==8.1.3 -fastapi==0.87.0 +fastapi==0.92.0 filetype==1.2.0 -pooch==1.6.0 imagehash==4.3.1 numpy==1.23.5 onnxruntime==1.14.1 opencv-python-headless==4.6.0.66 pillow==9.3.0 +pooch==1.6.0 pymatting==1.1.8 python-multipart==0.0.5 scikit-image==0.19.3 diff --git a/setup.py b/setup.py index 91e3471..3412140 100644 --- a/setup.py +++ b/setup.py @@ -42,12 +42,12 @@ setup( "click>=8.1.3", "fastapi>=0.92.0", "filetype>=1.2.0", - "pooch>=1.6.0", "imagehash>=4.3.1", "numpy>=1.23.5", - "onnxruntime>=1.13.1", + "onnxruntime>=1.14.1", "opencv-python-headless>=4.6.0.66", "pillow>=9.3.0", + "pooch>=1.6.0", "pymatting>=1.1.8", "python-multipart>=0.0.5", "scikit-image>=0.19.3", @@ -62,7 +62,7 @@ setup( ], }, extras_require={ - "gpu": ["onnxruntime-gpu>=1.13.1"], + "gpu": ["onnxruntime-gpu>=1.14.1"], }, version=versioneer.get_version(), cmdclass=versioneer.get_cmdclass(), diff --git a/tests/results/car-1.sam.png b/tests/results/car-1.sam.png new file mode 100644 index 0000000..f36b669 Binary files /dev/null and b/tests/results/car-1.sam.png differ diff --git a/tests/results/cloth-1.sam.png b/tests/results/cloth-1.sam.png new file mode 100644 index 0000000..664a7dc Binary files /dev/null and b/tests/results/cloth-1.sam.png differ diff --git a/tests/test_remove.py b/tests/test_remove.py index 7c384b0..6c2fc44 100644 --- a/tests/test_remove.py +++ b/tests/test_remove.py @@ -4,28 +4,48 @@ from pathlib import Path from imagehash import phash as hash_img from PIL import Image -from rembg import remove -from rembg import new_session +from rembg import new_session, remove here = Path(__file__).parent.resolve() def test_remove(): - for model in ["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg", "silueta", "isnet-general-use"]: + kwargs = { + "sam": { + "car-1" : { + "input_points": [[250, 200]], + "input_labels": [1], + }, + + "cloth-1" : { + "input_points": [[370, 495]], + "input_labels": [1], + } + } + } + + for model in [ + "u2net", + "u2netp", + "u2net_human_seg", + "u2net_cloth_seg", + "silueta", + "isnet-general-use", + "sam" + ]: for picture in ["car-1", "cloth-1"]: - image_path = Path(here / "fixtures" / f"{picture}.jpg") - expected_path = Path(here / "results" / f"{picture}.{model}.png") - + image_path = Path(here / "fixtures" / f"{picture}.jpg") image = image_path.read_bytes() - expected = expected_path.read_bytes() - actual = remove(image, session=new_session(model)) + actual = remove(image, session=new_session(model), **kwargs.get(model, {}).get(picture, {})) + actual_hash = hash_img(Image.open(BytesIO(actual))) + expected_path = Path(here / "results" / f"{picture}.{model}.png") # Uncomment to update the expected results # f = open(expected_path, "ab") # f.write(actual) # f.close() - actual_hash = hash_img(Image.open(BytesIO(actual))) + expected = expected_path.read_bytes() expected_hash = hash_img(Image.open(BytesIO(expected))) print(f"image_path: {image_path}")