throw on bogus model name

Previously if a bogus model name was provided,
it would silently ignore the error and load u2net.

It still default to u2net when no model name is given, but throw if a bogus model name is given.
This commit is contained in:
divinity76 2025-05-06 19:54:33 +02:00 committed by GitHub
parent 53e711f075
commit 58e10239c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -22,20 +22,27 @@ def new_session(model_name: str = "u2net", *args, **kwargs) -> BaseSession:
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Raises:
ValueError: If no session class with the given `model_name` is found.
Returns:
BaseSession: The created session object.
"""
session_class: Type[BaseSession] = U2netSession
session_class: Optional[Type[BaseSession]] = None
for sc in sessions_class:
if sc.name() == model_name:
session_class = sc
break
if session_class is None:
raise ValueError(f"No session class found for model '{model_name}'")
sess_opts = ort.SessionOptions()
if "OMP_NUM_THREADS" in os.environ:
sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
sess_opts.intra_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
threads = int(os.environ["OMP_NUM_THREADS"])
sess_opts.inter_op_num_threads = threads
sess_opts.intra_op_num_threads = threads
return session_class(model_name, sess_opts, *args, **kwargs)