mirror of
https://git.mirrors.martin98.com/https://github.com/danielgatis/rembg
synced 2025-08-17 11:05:56 +08:00
Add cloth category selection feature to u2net_cloth_seg (#485)
This commit is contained in:
parent
a9c2a39213
commit
f0019d723b
@ -75,21 +75,37 @@ class Unet2ClothSession(BaseSession):
|
|||||||
|
|
||||||
masks = []
|
masks = []
|
||||||
|
|
||||||
|
cloth_category = kwargs.get("cc") or kwargs.get("cloth_category")
|
||||||
|
|
||||||
|
def upper_cloth():
|
||||||
mask1 = mask.copy()
|
mask1 = mask.copy()
|
||||||
mask1.putpalette(palette1)
|
mask1.putpalette(palette1)
|
||||||
mask1 = mask1.convert("RGB").convert("L")
|
mask1 = mask1.convert("RGB").convert("L")
|
||||||
masks.append(mask1)
|
masks.append(mask1)
|
||||||
|
|
||||||
|
def lower_cloth():
|
||||||
mask2 = mask.copy()
|
mask2 = mask.copy()
|
||||||
mask2.putpalette(palette2)
|
mask2.putpalette(palette2)
|
||||||
mask2 = mask2.convert("RGB").convert("L")
|
mask2 = mask2.convert("RGB").convert("L")
|
||||||
masks.append(mask2)
|
masks.append(mask2)
|
||||||
|
|
||||||
|
def full_cloth():
|
||||||
mask3 = mask.copy()
|
mask3 = mask.copy()
|
||||||
mask3.putpalette(palette3)
|
mask3.putpalette(palette3)
|
||||||
mask3 = mask3.convert("RGB").convert("L")
|
mask3 = mask3.convert("RGB").convert("L")
|
||||||
masks.append(mask3)
|
masks.append(mask3)
|
||||||
|
|
||||||
|
if cloth_category == "upper":
|
||||||
|
upper_cloth()
|
||||||
|
elif cloth_category == "lower":
|
||||||
|
lower_cloth()
|
||||||
|
elif cloth_category == "full":
|
||||||
|
full_cloth()
|
||||||
|
else:
|
||||||
|
upper_cloth()
|
||||||
|
lower_cloth()
|
||||||
|
full_cloth()
|
||||||
|
|
||||||
return masks
|
return masks
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
Loading…
x
Reference in New Issue
Block a user