diff --git a/rembg/sessions/base.py b/rembg/sessions/base.py index c24e7e4..f942840 100644 --- a/rembg/sessions/base.py +++ b/rembg/sessions/base.py @@ -9,6 +9,7 @@ from PIL.Image import Image as PILImage class BaseSession: """This is a base class for managing a session with a machine learning model.""" + def __init__( self, model_name: str, diff --git a/rembg/sessions/dis_anime.py b/rembg/sessions/dis_anime.py index 0aa4b1a..559915d 100644 --- a/rembg/sessions/dis_anime.py +++ b/rembg/sessions/dis_anime.py @@ -13,6 +13,7 @@ class DisSession(BaseSession): """ This class represents a session for object detection. """ + def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]: """ Use a pre-trained model to predict the object in the given image. diff --git a/rembg/sessions/sam.py b/rembg/sessions/sam.py index 4bfcb98..5e524b0 100644 --- a/rembg/sessions/sam.py +++ b/rembg/sessions/sam.py @@ -58,6 +58,7 @@ class SamSession(BaseSession): *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. """ + def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs): """ Initialize a new SamSession with the given model name and session options. @@ -181,7 +182,7 @@ class SamSession(BaseSession): @classmethod def download_models(cls, *args, **kwargs): - ''' + """ Class method to download ONNX model files. This method is responsible for downloading two ONNX model files from specified URLs and saving them locally. The downloaded files are saved with the naming convention 'name_encoder.onnx' and 'name_decoder.onnx', where 'name' is the value returned by the 'name' method. @@ -193,7 +194,7 @@ class SamSession(BaseSession): Returns: tuple: A tuple containing the file paths of the downloaded encoder and decoder models. - ''' + """ fname_encoder = f"{cls.name(*args, **kwargs)}_encoder.onnx" fname_decoder = f"{cls.name(*args, **kwargs)}_decoder.onnx" @@ -224,7 +225,7 @@ class SamSession(BaseSession): @classmethod def name(cls, *args, **kwargs): - ''' + """ Class method to return a string value. This method returns the string value 'sam'. @@ -236,5 +237,5 @@ class SamSession(BaseSession): Returns: str: The string value 'sam'. - ''' + """ return "sam" diff --git a/rembg/sessions/silueta.py b/rembg/sessions/silueta.py index 938ccac..afeb484 100644 --- a/rembg/sessions/silueta.py +++ b/rembg/sessions/silueta.py @@ -11,6 +11,7 @@ from .base import BaseSession class SiluetaSession(BaseSession): """This is a class representing a SiluetaSession object.""" + def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]: """ Predict the mask of the input image. diff --git a/rembg/sessions/u2net.py b/rembg/sessions/u2net.py index f817de5..4e50bb7 100644 --- a/rembg/sessions/u2net.py +++ b/rembg/sessions/u2net.py @@ -13,6 +13,7 @@ class U2netSession(BaseSession): """ This class represents a U2net session, which is a subclass of BaseSession. """ + def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]: """ Predicts the output masks for the input image using the inner session. diff --git a/rembg/sessions/u2net_custom.py b/rembg/sessions/u2net_custom.py index 48dd39a..b984208 100644 --- a/rembg/sessions/u2net_custom.py +++ b/rembg/sessions/u2net_custom.py @@ -12,6 +12,7 @@ from .base import BaseSession class U2netCustomSession(BaseSession): """This is a class representing a custom session for the U2net model.""" + def __init__( self, model_name: str, diff --git a/rembg/sessions/u2net_human_seg.py b/rembg/sessions/u2net_human_seg.py index e371e56..ae59802 100644 --- a/rembg/sessions/u2net_human_seg.py +++ b/rembg/sessions/u2net_human_seg.py @@ -13,6 +13,7 @@ class U2netHumanSegSession(BaseSession): """ This class represents a session for performing human segmentation using the U2Net model. """ + def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]: """ Predicts human segmentation masks for the input image. diff --git a/rembg/sessions/u2netp.py b/rembg/sessions/u2netp.py index e9fc6dd..a6c4c53 100644 --- a/rembg/sessions/u2netp.py +++ b/rembg/sessions/u2netp.py @@ -11,6 +11,7 @@ from .base import BaseSession class U2netpSession(BaseSession): """This class represents a session for using the U2netp model.""" + def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]: """ Predicts the mask for the given image using the U2netp model.