add gradio

This commit is contained in:
Daniel Gatis 2023-05-29 21:02:16 -03:00
parent 937230694d
commit f6159b45e0
4 changed files with 56 additions and 7 deletions

View File

@ -182,18 +182,18 @@ rembg p -w path/to/input path/to/output
Used to start http server. Used to start http server.
To see the complete endpoints documentation, go to: `http://localhost:5000/docs`. To see the complete endpoints documentation, go to: `http://localhost:5000/api`.
Remove the background from an image url Remove the background from an image url
``` ```
curl -s "http://localhost:5000/?url=http://input.png" -o output.png curl -s "http://localhost:5000/api/remove?url=http://input.png" -o output.png
``` ```
Remove the background from an uploaded image Remove the background from an uploaded image
``` ```
curl -s -F file=@/path/to/input.jpg "http://localhost:5000" -o output.png curl -s -F file=@/path/to/input.jpg "http://localhost:5000/api/remove" -o output.png
``` ```
### rembg `b` ### rembg `b`

View File

@ -1,8 +1,11 @@
import json import json
import os
import webbrowser
from typing import Optional, Tuple, cast from typing import Optional, Tuple, cast
import aiohttp import aiohttp
import click import click
import gradio as gr
import uvicorn import uvicorn
from asyncer import asyncify from asyncer import asyncify
from fastapi import Depends, FastAPI, File, Form, Query from fastapi import Depends, FastAPI, File, Form, Query
@ -70,6 +73,7 @@ def s_command(port: int, log_level: str, threads: int) -> None:
"url": "https://github.com/danielgatis/rembg/blob/main/LICENSE.txt", "url": "https://github.com/danielgatis/rembg/blob/main/LICENSE.txt",
}, },
openapi_tags=tags_metadata, openapi_tags=tags_metadata,
docs_url="/api",
) )
app.add_middleware( app.add_middleware(
@ -190,13 +194,18 @@ def s_command(port: int, log_level: str, threads: int) -> None:
only_mask=commons.om, only_mask=commons.om,
post_process_mask=commons.ppm, post_process_mask=commons.ppm,
bgcolor=commons.bgc, bgcolor=commons.bgc,
**kwargs **kwargs,
), ),
media_type="image/png", media_type="image/png",
) )
@app.on_event("startup") @app.on_event("startup")
def startup(): def startup():
try:
webbrowser.open(f"http://localhost:{port}")
except:
pass
if threads is not None: if threads is not None:
from anyio import CapacityLimiter from anyio import CapacityLimiter
from anyio.lowlevel import RunVar from anyio.lowlevel import RunVar
@ -204,7 +213,7 @@ def s_command(port: int, log_level: str, threads: int) -> None:
RunVar("_default_thread_limiter").set(CapacityLimiter(threads)) RunVar("_default_thread_limiter").set(CapacityLimiter(threads))
@app.get( @app.get(
path="/", path="/api/remove",
tags=["Background Removal"], tags=["Background Removal"],
summary="Remove from URL", summary="Remove from URL",
description="Removes the background from an image obtained by retrieving an URL.", description="Removes the background from an image obtained by retrieving an URL.",
@ -221,7 +230,7 @@ def s_command(port: int, log_level: str, threads: int) -> None:
return await asyncify(im_without_bg)(file, commons) return await asyncify(im_without_bg)(file, commons)
@app.post( @app.post(
path="/", path="/api/remove",
tags=["Background Removal"], tags=["Background Removal"],
summary="Remove from Stream", summary="Remove from Stream",
description="Removes the background from an image sent within the request itself.", description="Removes the background from an image sent within the request itself.",
@ -235,4 +244,42 @@ def s_command(port: int, log_level: str, threads: int) -> None:
): ):
return await asyncify(im_without_bg)(file, commons) # type: ignore return await asyncify(im_without_bg)(file, commons) # type: ignore
uvicorn.run(app, host="0.0.0.0", port=port, log_level=log_level) def gr_app(app):
def inference(input_path, model):
output_path = "output.png"
with open(input_path, "rb") as i:
with open(output_path, "wb") as o:
input = i.read()
output = remove(input, session=new_session(model))
o.write(output)
return os.path.join(output_path)
interface = gr.Interface(
inference,
[
gr.components.Image(type="filepath", label="Input"),
gr.components.Dropdown(
[
"u2net",
"u2netp",
"u2net_human_seg",
"u2net_cloth_seg",
"silueta",
"isnet-general-use",
"isnet-anime",
],
value="u2net",
label="Models",
),
],
gr.components.Image(type="filepath", label="Output"),
)
interface.queue(concurrency_count=3)
app = gr.mount_gradio_app(app, interface, path="/")
return app
print(f"To access the API documentation, go to http://localhost:{port}/api")
print(f"To access the UI, go to http://localhost:{port}")
uvicorn.run(gr_app(app), host="0.0.0.0", port=port, log_level=log_level)

View File

@ -3,6 +3,7 @@ asyncer==0.0.2
click==8.1.3 click==8.1.3
fastapi==0.92.0 fastapi==0.92.0
filetype==1.2.0 filetype==1.2.0
gradio==3.32.0
imagehash==4.3.1 imagehash==4.3.1
numpy==1.23.5 numpy==1.23.5
onnxruntime==1.14.1 onnxruntime==1.14.1

View File

@ -42,6 +42,7 @@ setup(
"click>=8.1.3", "click>=8.1.3",
"fastapi>=0.92.0", "fastapi>=0.92.0",
"filetype>=1.2.0", "filetype>=1.2.0",
"gradio>=3.32.0",
"imagehash>=4.3.1", "imagehash>=4.3.1",
"numpy>=1.23.5", "numpy>=1.23.5",
"onnxruntime>=1.14.1", "onnxruntime>=1.14.1",