mirror of
https://git.mirrors.martin98.com/https://github.com/danielgatis/rembg
synced 2025-08-19 16:39:09 +08:00
fix linters
This commit is contained in:
parent
0eaaa0b393
commit
5c65374c9e
@ -9,6 +9,7 @@ from PIL.Image import Image as PILImage
|
|||||||
|
|
||||||
class BaseSession:
|
class BaseSession:
|
||||||
"""This is a base class for managing a session with a machine learning model."""
|
"""This is a base class for managing a session with a machine learning model."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
|
@ -13,6 +13,7 @@ class DisSession(BaseSession):
|
|||||||
"""
|
"""
|
||||||
This class represents a session for object detection.
|
This class represents a session for object detection.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
||||||
"""
|
"""
|
||||||
Use a pre-trained model to predict the object in the given image.
|
Use a pre-trained model to predict the object in the given image.
|
||||||
|
@ -58,6 +58,7 @@ class SamSession(BaseSession):
|
|||||||
*args: Variable length argument list.
|
*args: Variable length argument list.
|
||||||
**kwargs: Arbitrary keyword arguments.
|
**kwargs: Arbitrary keyword arguments.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
|
def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Initialize a new SamSession with the given model name and session options.
|
Initialize a new SamSession with the given model name and session options.
|
||||||
@ -181,7 +182,7 @@ class SamSession(BaseSession):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def download_models(cls, *args, **kwargs):
|
def download_models(cls, *args, **kwargs):
|
||||||
'''
|
"""
|
||||||
Class method to download ONNX model files.
|
Class method to download ONNX model files.
|
||||||
|
|
||||||
This method is responsible for downloading two ONNX model files from specified URLs and saving them locally. The downloaded files are saved with the naming convention 'name_encoder.onnx' and 'name_decoder.onnx', where 'name' is the value returned by the 'name' method.
|
This method is responsible for downloading two ONNX model files from specified URLs and saving them locally. The downloaded files are saved with the naming convention 'name_encoder.onnx' and 'name_decoder.onnx', where 'name' is the value returned by the 'name' method.
|
||||||
@ -193,7 +194,7 @@ class SamSession(BaseSession):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple: A tuple containing the file paths of the downloaded encoder and decoder models.
|
tuple: A tuple containing the file paths of the downloaded encoder and decoder models.
|
||||||
'''
|
"""
|
||||||
fname_encoder = f"{cls.name(*args, **kwargs)}_encoder.onnx"
|
fname_encoder = f"{cls.name(*args, **kwargs)}_encoder.onnx"
|
||||||
fname_decoder = f"{cls.name(*args, **kwargs)}_decoder.onnx"
|
fname_decoder = f"{cls.name(*args, **kwargs)}_decoder.onnx"
|
||||||
|
|
||||||
@ -224,7 +225,7 @@ class SamSession(BaseSession):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def name(cls, *args, **kwargs):
|
def name(cls, *args, **kwargs):
|
||||||
'''
|
"""
|
||||||
Class method to return a string value.
|
Class method to return a string value.
|
||||||
|
|
||||||
This method returns the string value 'sam'.
|
This method returns the string value 'sam'.
|
||||||
@ -236,5 +237,5 @@ class SamSession(BaseSession):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The string value 'sam'.
|
str: The string value 'sam'.
|
||||||
'''
|
"""
|
||||||
return "sam"
|
return "sam"
|
||||||
|
@ -11,6 +11,7 @@ from .base import BaseSession
|
|||||||
|
|
||||||
class SiluetaSession(BaseSession):
|
class SiluetaSession(BaseSession):
|
||||||
"""This is a class representing a SiluetaSession object."""
|
"""This is a class representing a SiluetaSession object."""
|
||||||
|
|
||||||
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
||||||
"""
|
"""
|
||||||
Predict the mask of the input image.
|
Predict the mask of the input image.
|
||||||
|
@ -13,6 +13,7 @@ class U2netSession(BaseSession):
|
|||||||
"""
|
"""
|
||||||
This class represents a U2net session, which is a subclass of BaseSession.
|
This class represents a U2net session, which is a subclass of BaseSession.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
||||||
"""
|
"""
|
||||||
Predicts the output masks for the input image using the inner session.
|
Predicts the output masks for the input image using the inner session.
|
||||||
|
@ -12,6 +12,7 @@ from .base import BaseSession
|
|||||||
|
|
||||||
class U2netCustomSession(BaseSession):
|
class U2netCustomSession(BaseSession):
|
||||||
"""This is a class representing a custom session for the U2net model."""
|
"""This is a class representing a custom session for the U2net model."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
|
@ -13,6 +13,7 @@ class U2netHumanSegSession(BaseSession):
|
|||||||
"""
|
"""
|
||||||
This class represents a session for performing human segmentation using the U2Net model.
|
This class represents a session for performing human segmentation using the U2Net model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
||||||
"""
|
"""
|
||||||
Predicts human segmentation masks for the input image.
|
Predicts human segmentation masks for the input image.
|
||||||
|
@ -11,6 +11,7 @@ from .base import BaseSession
|
|||||||
|
|
||||||
class U2netpSession(BaseSession):
|
class U2netpSession(BaseSession):
|
||||||
"""This class represents a session for using the U2netp model."""
|
"""This class represents a session for using the U2netp model."""
|
||||||
|
|
||||||
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
||||||
"""
|
"""
|
||||||
Predicts the mask for the given image using the U2netp model.
|
Predicts the mask for the given image using the U2netp model.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user