mirror of
https://git.mirrors.martin98.com/https://github.com/danielgatis/rembg
synced 2025-08-16 18:45:54 +08:00
Add cloth category selection feature to u2net_cloth_seg (#485)
This commit is contained in:
parent
a9c2a39213
commit
f0019d723b
@ -75,20 +75,36 @@ class Unet2ClothSession(BaseSession):
|
||||
|
||||
masks = []
|
||||
|
||||
mask1 = mask.copy()
|
||||
mask1.putpalette(palette1)
|
||||
mask1 = mask1.convert("RGB").convert("L")
|
||||
masks.append(mask1)
|
||||
cloth_category = kwargs.get("cc") or kwargs.get("cloth_category")
|
||||
|
||||
mask2 = mask.copy()
|
||||
mask2.putpalette(palette2)
|
||||
mask2 = mask2.convert("RGB").convert("L")
|
||||
masks.append(mask2)
|
||||
def upper_cloth():
|
||||
mask1 = mask.copy()
|
||||
mask1.putpalette(palette1)
|
||||
mask1 = mask1.convert("RGB").convert("L")
|
||||
masks.append(mask1)
|
||||
|
||||
mask3 = mask.copy()
|
||||
mask3.putpalette(palette3)
|
||||
mask3 = mask3.convert("RGB").convert("L")
|
||||
masks.append(mask3)
|
||||
def lower_cloth():
|
||||
mask2 = mask.copy()
|
||||
mask2.putpalette(palette2)
|
||||
mask2 = mask2.convert("RGB").convert("L")
|
||||
masks.append(mask2)
|
||||
|
||||
def full_cloth():
|
||||
mask3 = mask.copy()
|
||||
mask3.putpalette(palette3)
|
||||
mask3 = mask3.convert("RGB").convert("L")
|
||||
masks.append(mask3)
|
||||
|
||||
if cloth_category == "upper":
|
||||
upper_cloth()
|
||||
elif cloth_category == "lower":
|
||||
lower_cloth()
|
||||
elif cloth_category == "full":
|
||||
full_cloth()
|
||||
else:
|
||||
upper_cloth()
|
||||
lower_cloth()
|
||||
full_cloth()
|
||||
|
||||
return masks
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user