mirror of
https://git.mirrors.martin98.com/https://github.com/danielgatis/rembg
synced 2025-08-06 04:48:20 +08:00
reordered imports
This commit is contained in:
parent
ff38b9a377
commit
72d1c6c64c
@ -11,8 +11,8 @@ import pooch
|
|||||||
from .session_base import BaseSession
|
from .session_base import BaseSession
|
||||||
from .session_cloth import ClothSession
|
from .session_cloth import ClothSession
|
||||||
from .session_dis import DisSession
|
from .session_dis import DisSession
|
||||||
from .session_simple import SimpleSession
|
|
||||||
from .session_sam import SamSession
|
from .session_sam import SamSession
|
||||||
|
from .session_simple import SimpleSession
|
||||||
|
|
||||||
|
|
||||||
def download_model(url: str, md5: str, fname: str, path: Path):
|
def download_model(url: str, md5: str, fname: str, path: Path):
|
||||||
|
@ -1,11 +1,9 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import numpy
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from PIL.Image import Image as PILImage
|
from PIL.Image import Image as PILImage
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
from matplotlib import pyplot as plt
|
|
||||||
|
|
||||||
from .session_base import BaseSession
|
from .session_base import BaseSession
|
||||||
|
|
||||||
@ -39,7 +37,7 @@ def resize_longes_side(img: PILImage, size=1024):
|
|||||||
return img.resize((new_w, new_h))
|
return img.resize((new_w, new_h))
|
||||||
|
|
||||||
|
|
||||||
def pad_to_square(img: numpy.ndarray, size=1024):
|
def pad_to_square(img: np.ndarray, size=1024):
|
||||||
h, w = img.shape[:2]
|
h, w = img.shape[:2]
|
||||||
padh = size - h
|
padh = size - h
|
||||||
padw = size - w
|
padw = size - w
|
||||||
@ -60,7 +58,7 @@ class SamSession(BaseSession):
|
|||||||
|
|
||||||
def normalize(
|
def normalize(
|
||||||
self,
|
self,
|
||||||
img: numpy.ndarray,
|
img: np.ndarray,
|
||||||
mean=(123.675, 116.28, 103.53),
|
mean=(123.675, 116.28, 103.53),
|
||||||
std=(58.395, 57.12, 57.375),
|
std=(58.395, 57.12, 57.375),
|
||||||
size=(1024, 1024),
|
size=(1024, 1024),
|
||||||
@ -78,7 +76,7 @@ class SamSession(BaseSession):
|
|||||||
) -> List[PILImage]:
|
) -> List[PILImage]:
|
||||||
# Preprocess image
|
# Preprocess image
|
||||||
image = resize_longes_side(img)
|
image = resize_longes_side(img)
|
||||||
image = numpy.array(image)
|
image = np.array(image)
|
||||||
image = self.normalize(image)
|
image = self.normalize(image)
|
||||||
image = pad_to_square(image)
|
image = pad_to_square(image)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user