fix linters

This commit is contained in:
Daniel Gatis 2023-10-09 02:45:59 -03:00
parent 0eaaa0b393
commit 5c65374c9e
8 changed files with 12 additions and 4 deletions

View File

@ -9,6 +9,7 @@ from PIL.Image import Image as PILImage
class BaseSession:
"""This is a base class for managing a session with a machine learning model."""
def __init__(
self,
model_name: str,

View File

@ -13,6 +13,7 @@ class DisSession(BaseSession):
"""
This class represents a session for object detection.
"""
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
"""
Use a pre-trained model to predict the object in the given image.

View File

@ -58,6 +58,7 @@ class SamSession(BaseSession):
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
"""
Initialize a new SamSession with the given model name and session options.
@ -181,7 +182,7 @@ class SamSession(BaseSession):
@classmethod
def download_models(cls, *args, **kwargs):
'''
"""
Class method to download ONNX model files.
This method is responsible for downloading two ONNX model files from specified URLs and saving them locally. The downloaded files are saved with the naming convention 'name_encoder.onnx' and 'name_decoder.onnx', where 'name' is the value returned by the 'name' method.
@ -193,7 +194,7 @@ class SamSession(BaseSession):
Returns:
tuple: A tuple containing the file paths of the downloaded encoder and decoder models.
'''
"""
fname_encoder = f"{cls.name(*args, **kwargs)}_encoder.onnx"
fname_decoder = f"{cls.name(*args, **kwargs)}_decoder.onnx"
@ -224,7 +225,7 @@ class SamSession(BaseSession):
@classmethod
def name(cls, *args, **kwargs):
'''
"""
Class method to return a string value.
This method returns the string value 'sam'.
@ -236,5 +237,5 @@ class SamSession(BaseSession):
Returns:
str: The string value 'sam'.
'''
"""
return "sam"

View File

@ -11,6 +11,7 @@ from .base import BaseSession
class SiluetaSession(BaseSession):
"""This is a class representing a SiluetaSession object."""
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
"""
Predict the mask of the input image.

View File

@ -13,6 +13,7 @@ class U2netSession(BaseSession):
"""
This class represents a U2net session, which is a subclass of BaseSession.
"""
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
"""
Predicts the output masks for the input image using the inner session.

View File

@ -12,6 +12,7 @@ from .base import BaseSession
class U2netCustomSession(BaseSession):
"""This is a class representing a custom session for the U2net model."""
def __init__(
self,
model_name: str,

View File

@ -13,6 +13,7 @@ class U2netHumanSegSession(BaseSession):
"""
This class represents a session for performing human segmentation using the U2Net model.
"""
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
"""
Predicts human segmentation masks for the input image.

View File

@ -11,6 +11,7 @@ from .base import BaseSession
class U2netpSession(BaseSession):
"""This class represents a session for using the U2netp model."""
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
"""
Predicts the mask for the given image using the U2netp model.