diff --git a/rembg/session_factory.py b/rembg/session_factory.py index 1456796..47d2ec7 100644 --- a/rembg/session_factory.py +++ b/rembg/session_factory.py @@ -22,20 +22,27 @@ def new_session(model_name: str = "u2net", *args, **kwargs) -> BaseSession: *args: Additional positional arguments. **kwargs: Additional keyword arguments. + Raises: + ValueError: If no session class with the given `model_name` is found. + Returns: BaseSession: The created session object. """ - session_class: Type[BaseSession] = U2netSession + session_class: Optional[Type[BaseSession]] = None for sc in sessions_class: if sc.name() == model_name: session_class = sc break + if session_class is None: + raise ValueError(f"No session class found for model '{model_name}'") + sess_opts = ort.SessionOptions() if "OMP_NUM_THREADS" in os.environ: - sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"]) - sess_opts.intra_op_num_threads = int(os.environ["OMP_NUM_THREADS"]) + threads = int(os.environ["OMP_NUM_THREADS"]) + sess_opts.inter_op_num_threads = threads + sess_opts.intra_op_num_threads = threads return session_class(model_name, sess_opts, *args, **kwargs)