From f6159b45e00e975ac5855d0e47363482f2e4c0b5 Mon Sep 17 00:00:00 2001 From: Daniel Gatis Date: Mon, 29 May 2023 21:02:16 -0300 Subject: [PATCH] add gradio --- README.md | 6 ++-- rembg/commands/s_command.py | 55 ++++++++++++++++++++++++++++++++++--- requirements.txt | 1 + setup.py | 1 + 4 files changed, 56 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 111e892..a4f4216 100644 --- a/README.md +++ b/README.md @@ -182,18 +182,18 @@ rembg p -w path/to/input path/to/output Used to start http server. -To see the complete endpoints documentation, go to: `http://localhost:5000/docs`. +To see the complete endpoints documentation, go to: `http://localhost:5000/api`. Remove the background from an image url ``` -curl -s "http://localhost:5000/?url=http://input.png" -o output.png +curl -s "http://localhost:5000/api/remove?url=http://input.png" -o output.png ``` Remove the background from an uploaded image ``` -curl -s -F file=@/path/to/input.jpg "http://localhost:5000" -o output.png +curl -s -F file=@/path/to/input.jpg "http://localhost:5000/api/remove" -o output.png ``` ### rembg `b` diff --git a/rembg/commands/s_command.py b/rembg/commands/s_command.py index c04a0c4..dece3eb 100644 --- a/rembg/commands/s_command.py +++ b/rembg/commands/s_command.py @@ -1,8 +1,11 @@ import json +import os +import webbrowser from typing import Optional, Tuple, cast import aiohttp import click +import gradio as gr import uvicorn from asyncer import asyncify from fastapi import Depends, FastAPI, File, Form, Query @@ -70,6 +73,7 @@ def s_command(port: int, log_level: str, threads: int) -> None: "url": "https://github.com/danielgatis/rembg/blob/main/LICENSE.txt", }, openapi_tags=tags_metadata, + docs_url="/api", ) app.add_middleware( @@ -190,13 +194,18 @@ def s_command(port: int, log_level: str, threads: int) -> None: only_mask=commons.om, post_process_mask=commons.ppm, bgcolor=commons.bgc, - **kwargs + **kwargs, ), media_type="image/png", ) @app.on_event("startup") def startup(): + try: + webbrowser.open(f"http://localhost:{port}") + except: + pass + if threads is not None: from anyio import CapacityLimiter from anyio.lowlevel import RunVar @@ -204,7 +213,7 @@ def s_command(port: int, log_level: str, threads: int) -> None: RunVar("_default_thread_limiter").set(CapacityLimiter(threads)) @app.get( - path="/", + path="/api/remove", tags=["Background Removal"], summary="Remove from URL", description="Removes the background from an image obtained by retrieving an URL.", @@ -221,7 +230,7 @@ def s_command(port: int, log_level: str, threads: int) -> None: return await asyncify(im_without_bg)(file, commons) @app.post( - path="/", + path="/api/remove", tags=["Background Removal"], summary="Remove from Stream", description="Removes the background from an image sent within the request itself.", @@ -235,4 +244,42 @@ def s_command(port: int, log_level: str, threads: int) -> None: ): return await asyncify(im_without_bg)(file, commons) # type: ignore - uvicorn.run(app, host="0.0.0.0", port=port, log_level=log_level) + def gr_app(app): + def inference(input_path, model): + output_path = "output.png" + with open(input_path, "rb") as i: + with open(output_path, "wb") as o: + input = i.read() + output = remove(input, session=new_session(model)) + o.write(output) + return os.path.join(output_path) + + interface = gr.Interface( + inference, + [ + gr.components.Image(type="filepath", label="Input"), + gr.components.Dropdown( + [ + "u2net", + "u2netp", + "u2net_human_seg", + "u2net_cloth_seg", + "silueta", + "isnet-general-use", + "isnet-anime", + ], + value="u2net", + label="Models", + ), + ], + gr.components.Image(type="filepath", label="Output"), + ) + + interface.queue(concurrency_count=3) + app = gr.mount_gradio_app(app, interface, path="/") + return app + + print(f"To access the API documentation, go to http://localhost:{port}/api") + print(f"To access the UI, go to http://localhost:{port}") + + uvicorn.run(gr_app(app), host="0.0.0.0", port=port, log_level=log_level) diff --git a/requirements.txt b/requirements.txt index da62c0f..28458ed 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ asyncer==0.0.2 click==8.1.3 fastapi==0.92.0 filetype==1.2.0 +gradio==3.32.0 imagehash==4.3.1 numpy==1.23.5 onnxruntime==1.14.1 diff --git a/setup.py b/setup.py index 9ae15b4..6f0345a 100644 --- a/setup.py +++ b/setup.py @@ -42,6 +42,7 @@ setup( "click>=8.1.3", "fastapi>=0.92.0", "filetype>=1.2.0", + "gradio>=3.32.0", "imagehash>=4.3.1", "numpy>=1.23.5", "onnxruntime>=1.14.1",