Add cloth category selection feature to u2net_cloth_seg (#485)

This commit is contained in:
szriru 2023-07-13 15:09:13 +09:00 committed by GitHub
parent a9c2a39213
commit f0019d723b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -75,20 +75,36 @@ class Unet2ClothSession(BaseSession):
masks = []
mask1 = mask.copy()
mask1.putpalette(palette1)
mask1 = mask1.convert("RGB").convert("L")
masks.append(mask1)
cloth_category = kwargs.get("cc") or kwargs.get("cloth_category")
mask2 = mask.copy()
mask2.putpalette(palette2)
mask2 = mask2.convert("RGB").convert("L")
masks.append(mask2)
def upper_cloth():
mask1 = mask.copy()
mask1.putpalette(palette1)
mask1 = mask1.convert("RGB").convert("L")
masks.append(mask1)
mask3 = mask.copy()
mask3.putpalette(palette3)
mask3 = mask3.convert("RGB").convert("L")
masks.append(mask3)
def lower_cloth():
mask2 = mask.copy()
mask2.putpalette(palette2)
mask2 = mask2.convert("RGB").convert("L")
masks.append(mask2)
def full_cloth():
mask3 = mask.copy()
mask3.putpalette(palette3)
mask3 = mask3.convert("RGB").convert("L")
masks.append(mask3)
if cloth_category == "upper":
upper_cloth()
elif cloth_category == "lower":
lower_cloth()
elif cloth_category == "full":
full_cloth()
else:
upper_cloth()
lower_cloth()
full_cloth()
return masks