diff --git a/rembg/sessions/u2net_cloth_seg.py b/rembg/sessions/u2net_cloth_seg.py index 97179a8..c5b9031 100644 --- a/rembg/sessions/u2net_cloth_seg.py +++ b/rembg/sessions/u2net_cloth_seg.py @@ -75,20 +75,36 @@ class Unet2ClothSession(BaseSession): masks = [] - mask1 = mask.copy() - mask1.putpalette(palette1) - mask1 = mask1.convert("RGB").convert("L") - masks.append(mask1) + cloth_category = kwargs.get("cc") or kwargs.get("cloth_category") - mask2 = mask.copy() - mask2.putpalette(palette2) - mask2 = mask2.convert("RGB").convert("L") - masks.append(mask2) + def upper_cloth(): + mask1 = mask.copy() + mask1.putpalette(palette1) + mask1 = mask1.convert("RGB").convert("L") + masks.append(mask1) + + def lower_cloth(): + mask2 = mask.copy() + mask2.putpalette(palette2) + mask2 = mask2.convert("RGB").convert("L") + masks.append(mask2) + + def full_cloth(): + mask3 = mask.copy() + mask3.putpalette(palette3) + mask3 = mask3.convert("RGB").convert("L") + masks.append(mask3) - mask3 = mask.copy() - mask3.putpalette(palette3) - mask3 = mask3.convert("RGB").convert("L") - masks.append(mask3) + if cloth_category == "upper": + upper_cloth() + elif cloth_category == "lower": + lower_cloth() + elif cloth_category == "full": + full_cloth() + else: + upper_cloth() + lower_cloth() + full_cloth() return masks