mirror of
https://git.mirrors.martin98.com/https://github.com/danielgatis/rembg
synced 2025-08-17 05:45:55 +08:00
fix docker
This commit is contained in:
parent
ccaa9005af
commit
b527af6af5
@ -28,11 +28,6 @@ class BaseSession:
|
|||||||
else:
|
else:
|
||||||
self.providers.extend(_providers)
|
self.providers.extend(_providers)
|
||||||
|
|
||||||
model_path = kwargs.get("model_path")
|
|
||||||
|
|
||||||
if model_path is None:
|
|
||||||
raise ValueError("model_path is required")
|
|
||||||
|
|
||||||
self.inner_session = ort.InferenceSession(
|
self.inner_session = ort.InferenceSession(
|
||||||
str(self.__class__.download_models(*args, **kwargs)),
|
str(self.__class__.download_models(*args, **kwargs)),
|
||||||
providers=self.providers,
|
providers=self.providers,
|
||||||
|
@ -10,6 +10,26 @@ from .base import BaseSession
|
|||||||
|
|
||||||
|
|
||||||
class U2netCustomSession(BaseSession):
|
class U2netCustomSession(BaseSession):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
sess_opts: ort.SessionOptions,
|
||||||
|
providers=None,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
model_path = kwargs.get("model_path")
|
||||||
|
if model_path is None:
|
||||||
|
raise ValueError("model_path is required")
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
model_name,
|
||||||
|
sess_opts,
|
||||||
|
providers,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
||||||
ort_outs = self.inner_session.run(
|
ort_outs = self.inner_session.run(
|
||||||
None,
|
None,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user