mirror of
https://git.mirrors.martin98.com/https://github.com/danielgatis/rembg
synced 2025-08-19 06:55:59 +08:00
fix linters
This commit is contained in:
parent
0eaaa0b393
commit
5c65374c9e
@ -9,6 +9,7 @@ from PIL.Image import Image as PILImage
|
||||
|
||||
class BaseSession:
|
||||
"""This is a base class for managing a session with a machine learning model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
|
@ -13,6 +13,7 @@ class DisSession(BaseSession):
|
||||
"""
|
||||
This class represents a session for object detection.
|
||||
"""
|
||||
|
||||
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
||||
"""
|
||||
Use a pre-trained model to predict the object in the given image.
|
||||
|
@ -58,6 +58,7 @@ class SamSession(BaseSession):
|
||||
*args: Variable length argument list.
|
||||
**kwargs: Arbitrary keyword arguments.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
|
||||
"""
|
||||
Initialize a new SamSession with the given model name and session options.
|
||||
@ -181,7 +182,7 @@ class SamSession(BaseSession):
|
||||
|
||||
@classmethod
|
||||
def download_models(cls, *args, **kwargs):
|
||||
'''
|
||||
"""
|
||||
Class method to download ONNX model files.
|
||||
|
||||
This method is responsible for downloading two ONNX model files from specified URLs and saving them locally. The downloaded files are saved with the naming convention 'name_encoder.onnx' and 'name_decoder.onnx', where 'name' is the value returned by the 'name' method.
|
||||
@ -193,7 +194,7 @@ class SamSession(BaseSession):
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing the file paths of the downloaded encoder and decoder models.
|
||||
'''
|
||||
"""
|
||||
fname_encoder = f"{cls.name(*args, **kwargs)}_encoder.onnx"
|
||||
fname_decoder = f"{cls.name(*args, **kwargs)}_decoder.onnx"
|
||||
|
||||
@ -224,7 +225,7 @@ class SamSession(BaseSession):
|
||||
|
||||
@classmethod
|
||||
def name(cls, *args, **kwargs):
|
||||
'''
|
||||
"""
|
||||
Class method to return a string value.
|
||||
|
||||
This method returns the string value 'sam'.
|
||||
@ -236,5 +237,5 @@ class SamSession(BaseSession):
|
||||
|
||||
Returns:
|
||||
str: The string value 'sam'.
|
||||
'''
|
||||
"""
|
||||
return "sam"
|
||||
|
@ -11,6 +11,7 @@ from .base import BaseSession
|
||||
|
||||
class SiluetaSession(BaseSession):
|
||||
"""This is a class representing a SiluetaSession object."""
|
||||
|
||||
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
||||
"""
|
||||
Predict the mask of the input image.
|
||||
|
@ -13,6 +13,7 @@ class U2netSession(BaseSession):
|
||||
"""
|
||||
This class represents a U2net session, which is a subclass of BaseSession.
|
||||
"""
|
||||
|
||||
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
||||
"""
|
||||
Predicts the output masks for the input image using the inner session.
|
||||
|
@ -12,6 +12,7 @@ from .base import BaseSession
|
||||
|
||||
class U2netCustomSession(BaseSession):
|
||||
"""This is a class representing a custom session for the U2net model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
|
@ -13,6 +13,7 @@ class U2netHumanSegSession(BaseSession):
|
||||
"""
|
||||
This class represents a session for performing human segmentation using the U2Net model.
|
||||
"""
|
||||
|
||||
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
||||
"""
|
||||
Predicts human segmentation masks for the input image.
|
||||
|
@ -11,6 +11,7 @@ from .base import BaseSession
|
||||
|
||||
class U2netpSession(BaseSession):
|
||||
"""This class represents a session for using the U2netp model."""
|
||||
|
||||
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
||||
"""
|
||||
Predicts the mask for the given image using the U2netp model.
|
||||
|
Loading…
x
Reference in New Issue
Block a user