From 58e10239c508a22ce158c039bb2426b3e55fae59 Mon Sep 17 00:00:00 2001 From: divinity76 Date: Tue, 6 May 2025 19:54:33 +0200 Subject: [PATCH] throw on bogus model name Previously if a bogus model name was provided, it would silently ignore the error and load u2net. It still default to u2net when no model name is given, but throw if a bogus model name is given. --- rembg/session_factory.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/rembg/session_factory.py b/rembg/session_factory.py index 1456796..47d2ec7 100644 --- a/rembg/session_factory.py +++ b/rembg/session_factory.py @@ -22,20 +22,27 @@ def new_session(model_name: str = "u2net", *args, **kwargs) -> BaseSession: *args: Additional positional arguments. **kwargs: Additional keyword arguments. + Raises: + ValueError: If no session class with the given `model_name` is found. + Returns: BaseSession: The created session object. """ - session_class: Type[BaseSession] = U2netSession + session_class: Optional[Type[BaseSession]] = None for sc in sessions_class: if sc.name() == model_name: session_class = sc break + if session_class is None: + raise ValueError(f"No session class found for model '{model_name}'") + sess_opts = ort.SessionOptions() if "OMP_NUM_THREADS" in os.environ: - sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"]) - sess_opts.intra_op_num_threads = int(os.environ["OMP_NUM_THREADS"]) + threads = int(os.environ["OMP_NUM_THREADS"]) + sess_opts.inter_op_num_threads = threads + sess_opts.intra_op_num_threads = threads return session_class(model_name, sess_opts, *args, **kwargs)