fix docker

This commit is contained in:
Daniel Gatis 2023-06-29 21:09:22 -03:00
parent ccaa9005af
commit b527af6af5
2 changed files with 20 additions and 5 deletions

View File

@ -28,11 +28,6 @@ class BaseSession:
else:
self.providers.extend(_providers)
model_path = kwargs.get("model_path")
if model_path is None:
raise ValueError("model_path is required")
self.inner_session = ort.InferenceSession(
str(self.__class__.download_models(*args, **kwargs)),
providers=self.providers,

View File

@ -10,6 +10,26 @@ from .base import BaseSession
class U2netCustomSession(BaseSession):
def __init__(
self,
model_name: str,
sess_opts: ort.SessionOptions,
providers=None,
*args,
**kwargs
):
model_path = kwargs.get("model_path")
if model_path is None:
raise ValueError("model_path is required")
super().__init__(
model_name,
sess_opts,
providers,
*args,
**kwargs
)
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
ort_outs = self.inner_session.run(
None,