mirror of
https://git.mirrors.martin98.com/https://github.com/danielgatis/rembg
synced 2025-08-06 01:06:05 +08:00
Update session_factory.py
This commit is contained in:
parent
eb1796898f
commit
022cb69f2e
@ -8,7 +8,7 @@ from .sessions.base import BaseSession
|
||||
from .sessions.u2net import U2netSession
|
||||
|
||||
|
||||
def new_session(model_name: str = "u2net", *args, **kwargs) -> BaseSession:
|
||||
def new_session(model_name: str = "u2net", providers=None, *args, **kwargs) -> BaseSession:
|
||||
session_class: Type[BaseSession] = U2netSession
|
||||
|
||||
for sc in sessions_class:
|
||||
@ -21,4 +21,4 @@ def new_session(model_name: str = "u2net", *args, **kwargs) -> BaseSession:
|
||||
if "OMP_NUM_THREADS" in os.environ:
|
||||
sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
|
||||
|
||||
return session_class(model_name, sess_opts, *args, **kwargs)
|
||||
return session_class(model_name, sess_opts, providers, *args, **kwargs)
|
||||
|
Loading…
x
Reference in New Issue
Block a user