mirror of
https://git.mirrors.martin98.com/https://github.com/danielgatis/rembg
synced 2025-08-19 19:09:10 +08:00
parent
07a46d478a
commit
ca4be25f90
10
README.md
10
README.md
@ -80,7 +80,15 @@ rembg -o path/to/output.png path/to/input.png
|
|||||||
|
|
||||||
Remove the background from all images in a folder
|
Remove the background from all images in a folder
|
||||||
```bash
|
```bash
|
||||||
rembg -p path/to/inputs
|
rembg -p path/to/inputs path/to/output
|
||||||
|
```
|
||||||
|
|
||||||
|
### Add a custom model
|
||||||
|
|
||||||
|
Copy the `custom-model.pth` file to `~/.u2net` and run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -s http://input.png | rembg -m custom-model > output.png
|
||||||
```
|
```
|
||||||
|
|
||||||
### Usage as a server
|
### Usage as a server
|
||||||
|
2
setup.py
2
setup.py
@ -11,7 +11,7 @@ with open("requirements.txt") as f:
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="rembg",
|
name="rembg",
|
||||||
version="1.0.24",
|
version="1.0.25",
|
||||||
description="Remove image background",
|
description="Remove image background",
|
||||||
long_description=long_description,
|
long_description=long_description,
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
|
@ -10,6 +10,13 @@ from ..bg import remove
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
model_path = os.environ.get(
|
||||||
|
"U2NETP_PATH",
|
||||||
|
os.path.expanduser(os.path.join("~", ".u2net")),
|
||||||
|
)
|
||||||
|
model_choices = [os.path.splitext(os.path.basename(x))[0] for x in set(glob.glob(model_path + "/*"))]
|
||||||
|
|
||||||
|
|
||||||
ap = argparse.ArgumentParser()
|
ap = argparse.ArgumentParser()
|
||||||
|
|
||||||
ap.add_argument(
|
ap.add_argument(
|
||||||
@ -17,7 +24,7 @@ def main():
|
|||||||
"--model",
|
"--model",
|
||||||
default="u2net",
|
default="u2net",
|
||||||
type=str,
|
type=str,
|
||||||
choices=("u2net", "u2net_human_seg", "u2netp"),
|
choices=model_choices,
|
||||||
help="The model name.",
|
help="The model name.",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -66,8 +73,8 @@ def main():
|
|||||||
ap.add_argument(
|
ap.add_argument(
|
||||||
"-p",
|
"-p",
|
||||||
"--path",
|
"--path",
|
||||||
nargs="+",
|
nargs=2,
|
||||||
help="Path of a file or a folder of files.",
|
help="An input folder and an output folder.",
|
||||||
)
|
)
|
||||||
|
|
||||||
ap.add_argument(
|
ap.add_argument(
|
||||||
@ -94,15 +101,20 @@ def main():
|
|||||||
|
|
||||||
if args.path:
|
if args.path:
|
||||||
full_paths = [os.path.abspath(path) for path in args.path]
|
full_paths = [os.path.abspath(path) for path in args.path]
|
||||||
|
|
||||||
|
input_paths = [full_paths[0]]
|
||||||
|
output_path = full_paths[1]
|
||||||
|
|
||||||
|
if not os.path.exists(output_path):
|
||||||
|
os.makedirs(output_path)
|
||||||
|
|
||||||
files = set()
|
files = set()
|
||||||
|
|
||||||
for path in full_paths:
|
for path in input_paths:
|
||||||
if os.path.isfile(path):
|
if os.path.isfile(path):
|
||||||
files.add(path)
|
files.add(path)
|
||||||
else:
|
else:
|
||||||
full_paths += set(glob.glob(path + "/*")) - set(
|
input_paths += set(glob.glob(path + "/*"))
|
||||||
glob.glob(path + "/*.out.png")
|
|
||||||
)
|
|
||||||
|
|
||||||
for fi in tqdm(files):
|
for fi in tqdm(files):
|
||||||
fi_type = filetype.guess(fi)
|
fi_type = filetype.guess(fi)
|
||||||
@ -113,7 +125,7 @@ def main():
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
with open(fi, "rb") as input:
|
with open(fi, "rb") as input:
|
||||||
with open(os.path.splitext(fi)[0] + ".out.png", "wb") as output:
|
with open(os.path.join(output_path, os.path.splitext(os.path.basename(fi))[0] + ".png"), "wb") as output:
|
||||||
w(
|
w(
|
||||||
output,
|
output,
|
||||||
remove(
|
remove(
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
import os
|
||||||
|
import glob
|
||||||
import argparse
|
import argparse
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from urllib.parse import unquote_plus
|
from urllib.parse import unquote_plus
|
||||||
@ -38,8 +40,14 @@ def index():
|
|||||||
az = request.values.get("az", type=int, default=1000)
|
az = request.values.get("az", type=int, default=1000)
|
||||||
|
|
||||||
model = request.args.get("model", type=str, default="u2net")
|
model = request.args.get("model", type=str, default="u2net")
|
||||||
if model not in ("u2net", "u2net_human_seg", "u2netp"):
|
model_path = os.environ.get(
|
||||||
return {"error": "invalid query param 'model'"}, 400
|
"U2NETP_PATH",
|
||||||
|
os.path.expanduser(os.path.join("~", ".u2net")),
|
||||||
|
)
|
||||||
|
model_choices = [os.path.splitext(os.path.basename(x))[0] for x in set(glob.glob(model_path + "/*"))]
|
||||||
|
|
||||||
|
if model not in model_choices:
|
||||||
|
return {"error": f"invalid query param 'model'. Available options are {model_choices}"}, 400
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return send_file(
|
return send_file(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user