add providers member to BaseSession

This commit is contained in:
MCYBA 2023-05-24 02:03:05 +03:00 committed by GitHub
parent c360cf6d90
commit eb1796898f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -8,11 +8,25 @@ from PIL.Image import Image as PILImage
class BaseSession:
def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
def __init__(self, model_name: str, sess_opts: ort.SessionOptions, providers=None, *args, **kwargs):
self.model_name = model_name
self.providers = []
_providers = ort.get_available_providers()
if providers:
for provider in providers:
if provider in _providers:
self.providers.append(provider)
else:
self.providers.extend(_providers)
self.providers=
self.inner_session = ort.InferenceSession(
str(self.__class__.download_models()),
providers=ort.get_available_providers(),
providers=self.providers,
sess_options=sess_opts,
)