mirror of
https://git.mirrors.martin98.com/https://github.com/danielgatis/rembg
synced 2025-08-17 09:05:53 +08:00
throw on bogus model name
Previously if a bogus model name was provided, it would silently ignore the error and load u2net. It still default to u2net when no model name is given, but throw if a bogus model name is given.
This commit is contained in:
parent
53e711f075
commit
58e10239c5
@ -22,20 +22,27 @@ def new_session(model_name: str = "u2net", *args, **kwargs) -> BaseSession:
|
|||||||
*args: Additional positional arguments.
|
*args: Additional positional arguments.
|
||||||
**kwargs: Additional keyword arguments.
|
**kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no session class with the given `model_name` is found.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
BaseSession: The created session object.
|
BaseSession: The created session object.
|
||||||
"""
|
"""
|
||||||
session_class: Type[BaseSession] = U2netSession
|
session_class: Optional[Type[BaseSession]] = None
|
||||||
|
|
||||||
for sc in sessions_class:
|
for sc in sessions_class:
|
||||||
if sc.name() == model_name:
|
if sc.name() == model_name:
|
||||||
session_class = sc
|
session_class = sc
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if session_class is None:
|
||||||
|
raise ValueError(f"No session class found for model '{model_name}'")
|
||||||
|
|
||||||
sess_opts = ort.SessionOptions()
|
sess_opts = ort.SessionOptions()
|
||||||
|
|
||||||
if "OMP_NUM_THREADS" in os.environ:
|
if "OMP_NUM_THREADS" in os.environ:
|
||||||
sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
|
threads = int(os.environ["OMP_NUM_THREADS"])
|
||||||
sess_opts.intra_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
|
sess_opts.inter_op_num_threads = threads
|
||||||
|
sess_opts.intra_op_num_threads = threads
|
||||||
|
|
||||||
return session_class(model_name, sess_opts, *args, **kwargs)
|
return session_class(model_name, sess_opts, *args, **kwargs)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user