├── LICENSE ├── README.md ├── app.py ├── assets ├── 1.png └── 2.png ├── input ├── 03615_00.jpg └── 08909_00.jpg ├── network.py ├── options.py ├── process.py └── requirements.txt /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Alok Pandey 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Huggingface cloth segmentation using U2NET 2 | 3 | ![Python 3.8](https://img.shields.io/badge/python-3.8-green.svg) 4 | [![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT) 5 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1LGgLiHiWcmpQalgazLgq4uQuVUm9ZM4M?usp=sharing) 6 | 7 | This repo contains inference code and gradio demo script using pre-trained U2NET model for Cloths Parsing from human portrait.
8 | Here clothes are parsed into 3 category: Upper body(red), Lower body(green) and Full body(yellow). The provided script also generates alpha images for each class. 9 | 10 | 11 | # Inference 12 | - clone the repo `git clone https://github.com/wildoctopus/huggingface-cloth-segmentation.git`. 13 | - Install dependencies `pip install -r requirements.txt` 14 | - Run `python process.py --image 'input/03615_00.jpg'` . **Script will automatically download the pretrained model**. 15 | - Outputs will be saved in `output` folder. 16 | - `output/alpha/..` contains alpha images corresponding to each class. 17 | - `output/cloth_seg` contains final segmentation. 18 | - 19 | 20 | # Gradio Demo 21 | - Run `python app.py` 22 | - Navigate to local or public url provided by app on successfull execution. 23 | ### OR 24 | - Inference in colab from here [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1LGgLiHiWcmpQalgazLgq4uQuVUm9ZM4M?usp=sharing) 25 | 26 | # Huggingface Demo 27 | - Check gradio demo on Huggingface space from here [huggingface-cloth-segmentation](https://huggingface.co/spaces/wildoctopus/cloth-segmentation). 28 | 29 | # Output samples 30 | ![Sample 000](assets/1.png) 31 | ![Sample 024](assets/2.png) 32 | 33 | 34 | This model works well with any background and almost all poses. 35 | 36 | # Acknowledgements 37 | - U2net model is from original [u2net repo](https://github.com/xuebinqin/U-2-Net). Thanks to Xuebin Qin for amazing repo. 38 | - Most of the code is taken and modified from [levindabhi/cloth-segmentation](https://github.com/levindabhi/cloth-segmentation) 39 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import gradio as gr 4 | from PIL import Image 5 | from process import load_seg_model, get_palette, generate_mask 6 | 7 | # Automatically select device 8 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 9 | 10 | # Ensure model directory exists 11 | model_dir = 'model' 12 | os.makedirs(model_dir, exist_ok=True) 13 | checkpoint_path = os.path.join(model_dir, 'cloth_segm.pth') 14 | 15 | # Download the model if not present 16 | if not os.path.exists(checkpoint_path): 17 | import gdown 18 | url = 'https://drive.google.com/uc?id=1w0nZzH9g6n5l3xQ8Z8Z8Z8Z8Z8Z8Z8Z' # Replace with actual URL 19 | gdown.download(url, checkpoint_path, quiet=False) 20 | 21 | # Load model 22 | net = load_seg_model(checkpoint_path, device=device) 23 | palette = get_palette(4) 24 | 25 | def run(img: Image.Image) -> Image.Image: 26 | try: 27 | cloth_seg = generate_mask(img, net=net, palette=palette, device=device) 28 | return cloth_seg 29 | except Exception as e: 30 | print(f"Error during processing: {e}") 31 | return None 32 | 33 | # Define Gradio interface 34 | title = "Demo for Cloth Segmentation" 35 | description = "An app for Cloth Segmentation using U2NET." 36 | 37 | iface = gr.Interface( 38 | fn=run, 39 | inputs=gr.Image(type="pil", label="Input Image"), 40 | outputs=gr.Image(type="pil", label="Cloth Segmentation"), 41 | title=title, 42 | description=description 43 | ) 44 | 45 | if __name__ == "__main__": 46 | iface.launch(share=True) 47 | -------------------------------------------------------------------------------- /assets/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wildoctopus/huggingface-cloth-segmentation/3f9d64d88a820b89bb140e0a377f2e3eaf4e7147/assets/1.png -------------------------------------------------------------------------------- /assets/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wildoctopus/huggingface-cloth-segmentation/3f9d64d88a820b89bb140e0a377f2e3eaf4e7147/assets/2.png -------------------------------------------------------------------------------- /input/03615_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wildoctopus/huggingface-cloth-segmentation/3f9d64d88a820b89bb140e0a377f2e3eaf4e7147/input/03615_00.jpg -------------------------------------------------------------------------------- /input/08909_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wildoctopus/huggingface-cloth-segmentation/3f9d64d88a820b89bb140e0a377f2e3eaf4e7147/input/08909_00.jpg -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class REBNCONV(nn.Module): 7 | def __init__(self, in_ch=3, out_ch=3, dirate=1): 8 | super(REBNCONV, self).__init__() 9 | 10 | self.conv_s1 = nn.Conv2d( 11 | in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate 12 | ) 13 | self.bn_s1 = nn.BatchNorm2d(out_ch) 14 | self.relu_s1 = nn.ReLU(inplace=True) 15 | 16 | def forward(self, x): 17 | 18 | hx = x 19 | xout = self.relu_s1(self.bn_s1(self.conv_s1(hx))) 20 | 21 | return xout 22 | 23 | 24 | ## upsample tensor 'src' to have the same spatial size with tensor 'tar' 25 | def _upsample_like(src, tar): 26 | 27 | src = F.upsample(src, size=tar.shape[2:], mode="bilinear") 28 | 29 | return src 30 | 31 | 32 | ### RSU-7 ### 33 | class RSU7(nn.Module): # UNet07DRES(nn.Module): 34 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3): 35 | super(RSU7, self).__init__() 36 | 37 | self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) 38 | 39 | self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) 40 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 41 | 42 | self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) 43 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 44 | 45 | self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) 46 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 47 | 48 | self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) 49 | self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 50 | 51 | self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1) 52 | self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 53 | 54 | self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1) 55 | 56 | self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2) 57 | 58 | self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 59 | self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 60 | self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 61 | self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 62 | self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 63 | self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) 64 | 65 | def forward(self, x): 66 | 67 | hx = x 68 | hxin = self.rebnconvin(hx) 69 | 70 | hx1 = self.rebnconv1(hxin) 71 | hx = self.pool1(hx1) 72 | 73 | hx2 = self.rebnconv2(hx) 74 | hx = self.pool2(hx2) 75 | 76 | hx3 = self.rebnconv3(hx) 77 | hx = self.pool3(hx3) 78 | 79 | hx4 = self.rebnconv4(hx) 80 | hx = self.pool4(hx4) 81 | 82 | hx5 = self.rebnconv5(hx) 83 | hx = self.pool5(hx5) 84 | 85 | hx6 = self.rebnconv6(hx) 86 | 87 | hx7 = self.rebnconv7(hx6) 88 | 89 | hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1)) 90 | hx6dup = _upsample_like(hx6d, hx5) 91 | 92 | hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1)) 93 | hx5dup = _upsample_like(hx5d, hx4) 94 | 95 | hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1)) 96 | hx4dup = _upsample_like(hx4d, hx3) 97 | 98 | hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) 99 | hx3dup = _upsample_like(hx3d, hx2) 100 | 101 | hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) 102 | hx2dup = _upsample_like(hx2d, hx1) 103 | 104 | hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) 105 | 106 | """ 107 | del hx1, hx2, hx3, hx4, hx5, hx6, hx7 108 | del hx6d, hx5d, hx3d, hx2d 109 | del hx2dup, hx3dup, hx4dup, hx5dup, hx6dup 110 | """ 111 | 112 | return hx1d + hxin 113 | 114 | 115 | ### RSU-6 ### 116 | class RSU6(nn.Module): # UNet06DRES(nn.Module): 117 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3): 118 | super(RSU6, self).__init__() 119 | 120 | self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) 121 | 122 | self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) 123 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 124 | 125 | self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) 126 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 127 | 128 | self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) 129 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 130 | 131 | self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) 132 | self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 133 | 134 | self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1) 135 | 136 | self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2) 137 | 138 | self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 139 | self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 140 | self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 141 | self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 142 | self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) 143 | 144 | def forward(self, x): 145 | 146 | hx = x 147 | 148 | hxin = self.rebnconvin(hx) 149 | 150 | hx1 = self.rebnconv1(hxin) 151 | hx = self.pool1(hx1) 152 | 153 | hx2 = self.rebnconv2(hx) 154 | hx = self.pool2(hx2) 155 | 156 | hx3 = self.rebnconv3(hx) 157 | hx = self.pool3(hx3) 158 | 159 | hx4 = self.rebnconv4(hx) 160 | hx = self.pool4(hx4) 161 | 162 | hx5 = self.rebnconv5(hx) 163 | 164 | hx6 = self.rebnconv6(hx5) 165 | 166 | hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1)) 167 | hx5dup = _upsample_like(hx5d, hx4) 168 | 169 | hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1)) 170 | hx4dup = _upsample_like(hx4d, hx3) 171 | 172 | hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) 173 | hx3dup = _upsample_like(hx3d, hx2) 174 | 175 | hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) 176 | hx2dup = _upsample_like(hx2d, hx1) 177 | 178 | hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) 179 | 180 | """ 181 | del hx1, hx2, hx3, hx4, hx5, hx6 182 | del hx5d, hx4d, hx3d, hx2d 183 | del hx2dup, hx3dup, hx4dup, hx5dup 184 | """ 185 | 186 | return hx1d + hxin 187 | 188 | 189 | ### RSU-5 ### 190 | class RSU5(nn.Module): # UNet05DRES(nn.Module): 191 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3): 192 | super(RSU5, self).__init__() 193 | 194 | self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) 195 | 196 | self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) 197 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 198 | 199 | self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) 200 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 201 | 202 | self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) 203 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 204 | 205 | self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) 206 | 207 | self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2) 208 | 209 | self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 210 | self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 211 | self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 212 | self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) 213 | 214 | def forward(self, x): 215 | 216 | hx = x 217 | 218 | hxin = self.rebnconvin(hx) 219 | 220 | hx1 = self.rebnconv1(hxin) 221 | hx = self.pool1(hx1) 222 | 223 | hx2 = self.rebnconv2(hx) 224 | hx = self.pool2(hx2) 225 | 226 | hx3 = self.rebnconv3(hx) 227 | hx = self.pool3(hx3) 228 | 229 | hx4 = self.rebnconv4(hx) 230 | 231 | hx5 = self.rebnconv5(hx4) 232 | 233 | hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1)) 234 | hx4dup = _upsample_like(hx4d, hx3) 235 | 236 | hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) 237 | hx3dup = _upsample_like(hx3d, hx2) 238 | 239 | hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) 240 | hx2dup = _upsample_like(hx2d, hx1) 241 | 242 | hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) 243 | 244 | """ 245 | del hx1, hx2, hx3, hx4, hx5 246 | del hx4d, hx3d, hx2d 247 | del hx2dup, hx3dup, hx4dup 248 | """ 249 | 250 | return hx1d + hxin 251 | 252 | 253 | ### RSU-4 ### 254 | class RSU4(nn.Module): # UNet04DRES(nn.Module): 255 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3): 256 | super(RSU4, self).__init__() 257 | 258 | self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) 259 | 260 | self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) 261 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 262 | 263 | self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) 264 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 265 | 266 | self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) 267 | 268 | self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2) 269 | 270 | self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 271 | self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 272 | self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) 273 | 274 | def forward(self, x): 275 | 276 | hx = x 277 | 278 | hxin = self.rebnconvin(hx) 279 | 280 | hx1 = self.rebnconv1(hxin) 281 | hx = self.pool1(hx1) 282 | 283 | hx2 = self.rebnconv2(hx) 284 | hx = self.pool2(hx2) 285 | 286 | hx3 = self.rebnconv3(hx) 287 | 288 | hx4 = self.rebnconv4(hx3) 289 | 290 | hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1)) 291 | hx3dup = _upsample_like(hx3d, hx2) 292 | 293 | hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) 294 | hx2dup = _upsample_like(hx2d, hx1) 295 | 296 | hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) 297 | 298 | """ 299 | del hx1, hx2, hx3, hx4 300 | del hx3d, hx2d 301 | del hx2dup, hx3dup 302 | """ 303 | 304 | return hx1d + hxin 305 | 306 | 307 | ### RSU-4F ### 308 | class RSU4F(nn.Module): # UNet04FRES(nn.Module): 309 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3): 310 | super(RSU4F, self).__init__() 311 | 312 | self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) 313 | 314 | self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) 315 | self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2) 316 | self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4) 317 | 318 | self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8) 319 | 320 | self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4) 321 | self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2) 322 | self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) 323 | 324 | def forward(self, x): 325 | 326 | hx = x 327 | 328 | hxin = self.rebnconvin(hx) 329 | 330 | hx1 = self.rebnconv1(hxin) 331 | hx2 = self.rebnconv2(hx1) 332 | hx3 = self.rebnconv3(hx2) 333 | 334 | hx4 = self.rebnconv4(hx3) 335 | 336 | hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1)) 337 | hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1)) 338 | hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1)) 339 | 340 | """ 341 | del hx1, hx2, hx3, hx4 342 | del hx3d, hx2d 343 | """ 344 | 345 | return hx1d + hxin 346 | 347 | 348 | ##### U^2-Net #### 349 | class U2NET(nn.Module): 350 | def __init__(self, in_ch=3, out_ch=1): 351 | super(U2NET, self).__init__() 352 | 353 | self.stage1 = RSU7(in_ch, 32, 64) 354 | self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 355 | 356 | self.stage2 = RSU6(64, 32, 128) 357 | self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 358 | 359 | self.stage3 = RSU5(128, 64, 256) 360 | self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 361 | 362 | self.stage4 = RSU4(256, 128, 512) 363 | self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 364 | 365 | self.stage5 = RSU4F(512, 256, 512) 366 | self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 367 | 368 | self.stage6 = RSU4F(512, 256, 512) 369 | 370 | # decoder 371 | self.stage5d = RSU4F(1024, 256, 512) 372 | self.stage4d = RSU4(1024, 128, 256) 373 | self.stage3d = RSU5(512, 64, 128) 374 | self.stage2d = RSU6(256, 32, 64) 375 | self.stage1d = RSU7(128, 16, 64) 376 | 377 | self.side1 = nn.Conv2d(64, out_ch, 3, padding=1) 378 | self.side2 = nn.Conv2d(64, out_ch, 3, padding=1) 379 | self.side3 = nn.Conv2d(128, out_ch, 3, padding=1) 380 | self.side4 = nn.Conv2d(256, out_ch, 3, padding=1) 381 | self.side5 = nn.Conv2d(512, out_ch, 3, padding=1) 382 | self.side6 = nn.Conv2d(512, out_ch, 3, padding=1) 383 | 384 | self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1) 385 | 386 | def forward(self, x): 387 | 388 | hx = x 389 | 390 | # stage 1 391 | hx1 = self.stage1(hx) 392 | hx = self.pool12(hx1) 393 | 394 | # stage 2 395 | hx2 = self.stage2(hx) 396 | hx = self.pool23(hx2) 397 | 398 | # stage 3 399 | hx3 = self.stage3(hx) 400 | hx = self.pool34(hx3) 401 | 402 | # stage 4 403 | hx4 = self.stage4(hx) 404 | hx = self.pool45(hx4) 405 | 406 | # stage 5 407 | hx5 = self.stage5(hx) 408 | hx = self.pool56(hx5) 409 | 410 | # stage 6 411 | hx6 = self.stage6(hx) 412 | hx6up = _upsample_like(hx6, hx5) 413 | 414 | # -------------------- decoder -------------------- 415 | hx5d = self.stage5d(torch.cat((hx6up, hx5), 1)) 416 | hx5dup = _upsample_like(hx5d, hx4) 417 | 418 | hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1)) 419 | hx4dup = _upsample_like(hx4d, hx3) 420 | 421 | hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1)) 422 | hx3dup = _upsample_like(hx3d, hx2) 423 | 424 | hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1)) 425 | hx2dup = _upsample_like(hx2d, hx1) 426 | 427 | hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1)) 428 | 429 | # side output 430 | d1 = self.side1(hx1d) 431 | 432 | d2 = self.side2(hx2d) 433 | d2 = _upsample_like(d2, d1) 434 | 435 | d3 = self.side3(hx3d) 436 | d3 = _upsample_like(d3, d1) 437 | 438 | d4 = self.side4(hx4d) 439 | d4 = _upsample_like(d4, d1) 440 | 441 | d5 = self.side5(hx5d) 442 | d5 = _upsample_like(d5, d1) 443 | 444 | d6 = self.side6(hx6) 445 | d6 = _upsample_like(d6, d1) 446 | 447 | d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1)) 448 | 449 | """ 450 | del hx1, hx2, hx3, hx4, hx5, hx6 451 | del hx5d, hx4d, hx3d, hx2d, hx1d 452 | del hx6up, hx5dup, hx4dup, hx3dup, hx2dup 453 | """ 454 | 455 | return d0, d1, d2, d3, d4, d5, d6 456 | 457 | 458 | ### U^2-Net small ### 459 | class U2NETP(nn.Module): 460 | def __init__(self, in_ch=3, out_ch=1): 461 | super(U2NETP, self).__init__() 462 | 463 | self.stage1 = RSU7(in_ch, 16, 64) 464 | self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 465 | 466 | self.stage2 = RSU6(64, 16, 64) 467 | self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 468 | 469 | self.stage3 = RSU5(64, 16, 64) 470 | self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 471 | 472 | self.stage4 = RSU4(64, 16, 64) 473 | self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 474 | 475 | self.stage5 = RSU4F(64, 16, 64) 476 | self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 477 | 478 | self.stage6 = RSU4F(64, 16, 64) 479 | 480 | # decoder 481 | self.stage5d = RSU4F(128, 16, 64) 482 | self.stage4d = RSU4(128, 16, 64) 483 | self.stage3d = RSU5(128, 16, 64) 484 | self.stage2d = RSU6(128, 16, 64) 485 | self.stage1d = RSU7(128, 16, 64) 486 | 487 | self.side1 = nn.Conv2d(64, out_ch, 3, padding=1) 488 | self.side2 = nn.Conv2d(64, out_ch, 3, padding=1) 489 | self.side3 = nn.Conv2d(64, out_ch, 3, padding=1) 490 | self.side4 = nn.Conv2d(64, out_ch, 3, padding=1) 491 | self.side5 = nn.Conv2d(64, out_ch, 3, padding=1) 492 | self.side6 = nn.Conv2d(64, out_ch, 3, padding=1) 493 | 494 | self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1) 495 | 496 | def forward(self, x): 497 | 498 | hx = x 499 | 500 | # stage 1 501 | hx1 = self.stage1(hx) 502 | hx = self.pool12(hx1) 503 | 504 | # stage 2 505 | hx2 = self.stage2(hx) 506 | hx = self.pool23(hx2) 507 | 508 | # stage 3 509 | hx3 = self.stage3(hx) 510 | hx = self.pool34(hx3) 511 | 512 | # stage 4 513 | hx4 = self.stage4(hx) 514 | hx = self.pool45(hx4) 515 | 516 | # stage 5 517 | hx5 = self.stage5(hx) 518 | hx = self.pool56(hx5) 519 | 520 | # stage 6 521 | hx6 = self.stage6(hx) 522 | hx6up = _upsample_like(hx6, hx5) 523 | 524 | # decoder 525 | hx5d = self.stage5d(torch.cat((hx6up, hx5), 1)) 526 | hx5dup = _upsample_like(hx5d, hx4) 527 | 528 | hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1)) 529 | hx4dup = _upsample_like(hx4d, hx3) 530 | 531 | hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1)) 532 | hx3dup = _upsample_like(hx3d, hx2) 533 | 534 | hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1)) 535 | hx2dup = _upsample_like(hx2d, hx1) 536 | 537 | hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1)) 538 | 539 | # side output 540 | d1 = self.side1(hx1d) 541 | 542 | d2 = self.side2(hx2d) 543 | d2 = _upsample_like(d2, d1) 544 | 545 | d3 = self.side3(hx3d) 546 | d3 = _upsample_like(d3, d1) 547 | 548 | d4 = self.side4(hx4d) 549 | d4 = _upsample_like(d4, d1) 550 | 551 | d5 = self.side5(hx5d) 552 | d5 = _upsample_like(d5, d1) 553 | 554 | d6 = self.side6(hx6) 555 | d6 = _upsample_like(d6, d1) 556 | 557 | d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1)) 558 | 559 | 560 | return d0, d1, d2, d3, d4, d5, d6 -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import os 3 | 4 | 5 | class parser(object): 6 | def __init__(self): 7 | 8 | self.output = "./output" # output image folder path 9 | self.logs_dir = './logs' 10 | self.device = 'cuda:0' 11 | 12 | opt = parser() -------------------------------------------------------------------------------- /process.py: -------------------------------------------------------------------------------- 1 | from network import U2NET 2 | 3 | import os 4 | from PIL import Image 5 | import cv2 6 | import gdown 7 | import argparse 8 | import numpy as np 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | import torchvision.transforms as transforms 13 | 14 | from collections import OrderedDict 15 | from options import opt 16 | 17 | 18 | def load_checkpoint(model, checkpoint_path): 19 | if not os.path.exists(checkpoint_path): 20 | print("----No checkpoints at given path----") 21 | return 22 | model_state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu")) 23 | new_state_dict = OrderedDict() 24 | for k, v in model_state_dict.items(): 25 | name = k[7:] # remove `module.` 26 | new_state_dict[name] = v 27 | 28 | model.load_state_dict(new_state_dict) 29 | print("----checkpoints loaded from path: {}----".format(checkpoint_path)) 30 | return model 31 | 32 | 33 | def get_palette(num_cls): 34 | """ Returns the color map for visualizing the segmentation mask. 35 | Args: 36 | num_cls: Number of classes 37 | Returns: 38 | The color map 39 | """ 40 | n = num_cls 41 | palette = [0] * (n * 3) 42 | for j in range(0, n): 43 | lab = j 44 | palette[j * 3 + 0] = 0 45 | palette[j * 3 + 1] = 0 46 | palette[j * 3 + 2] = 0 47 | i = 0 48 | while lab: 49 | palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i)) 50 | palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i)) 51 | palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i)) 52 | i += 1 53 | lab >>= 3 54 | return palette 55 | 56 | 57 | class Normalize_image(object): 58 | """Normalize given tensor into given mean and standard dev 59 | 60 | Args: 61 | mean (float): Desired mean to substract from tensors 62 | std (float): Desired std to divide from tensors 63 | """ 64 | 65 | def __init__(self, mean, std): 66 | assert isinstance(mean, (float)) 67 | if isinstance(mean, float): 68 | self.mean = mean 69 | 70 | if isinstance(std, float): 71 | self.std = std 72 | 73 | self.normalize_1 = transforms.Normalize(self.mean, self.std) 74 | self.normalize_3 = transforms.Normalize([self.mean] * 3, [self.std] * 3) 75 | self.normalize_18 = transforms.Normalize([self.mean] * 18, [self.std] * 18) 76 | 77 | def __call__(self, image_tensor): 78 | if image_tensor.shape[0] == 1: 79 | return self.normalize_1(image_tensor) 80 | 81 | elif image_tensor.shape[0] == 3: 82 | return self.normalize_3(image_tensor) 83 | 84 | elif image_tensor.shape[0] == 18: 85 | return self.normalize_18(image_tensor) 86 | 87 | else: 88 | assert "Please set proper channels! Normlization implemented only for 1, 3 and 18" 89 | 90 | 91 | 92 | 93 | def apply_transform(img): 94 | transforms_list = [] 95 | transforms_list += [transforms.ToTensor()] 96 | transforms_list += [Normalize_image(0.5, 0.5)] 97 | transform_rgb = transforms.Compose(transforms_list) 98 | return transform_rgb(img) 99 | 100 | 101 | 102 | def generate_mask(input_image, net, palette, device = 'cpu'): 103 | 104 | #img = Image.open(input_image).convert('RGB') 105 | img = input_image 106 | img_size = img.size 107 | img = img.resize((768, 768), Image.BICUBIC) 108 | image_tensor = apply_transform(img) 109 | image_tensor = torch.unsqueeze(image_tensor, 0) 110 | 111 | alpha_out_dir = os.path.join(opt.output,'alpha') 112 | cloth_seg_out_dir = os.path.join(opt.output,'cloth_seg') 113 | 114 | os.makedirs(alpha_out_dir, exist_ok=True) 115 | os.makedirs(cloth_seg_out_dir, exist_ok=True) 116 | 117 | with torch.no_grad(): 118 | output_tensor = net(image_tensor.to(device)) 119 | output_tensor = F.log_softmax(output_tensor[0], dim=1) 120 | output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1] 121 | output_tensor = torch.squeeze(output_tensor, dim=0) 122 | output_arr = output_tensor.cpu().numpy() 123 | 124 | classes_to_save = [] 125 | 126 | # Check which classes are present in the image 127 | for cls in range(1, 4): # Exclude background class (0) 128 | if np.any(output_arr == cls): 129 | classes_to_save.append(cls) 130 | 131 | # Save alpha masks 132 | for cls in classes_to_save: 133 | alpha_mask = (output_arr == cls).astype(np.uint8) * 255 134 | alpha_mask = alpha_mask[0] # Selecting the first channel to make it 2D 135 | alpha_mask_img = Image.fromarray(alpha_mask, mode='L') 136 | alpha_mask_img = alpha_mask_img.resize(img_size, Image.BICUBIC) 137 | alpha_mask_img.save(os.path.join(alpha_out_dir, f'{cls}.png')) 138 | 139 | # Save final cloth segmentations 140 | cloth_seg = Image.fromarray(output_arr[0].astype(np.uint8), mode='P') 141 | cloth_seg.putpalette(palette) 142 | cloth_seg = cloth_seg.resize(img_size, Image.BICUBIC) 143 | cloth_seg.save(os.path.join(cloth_seg_out_dir, 'final_seg.png')) 144 | return cloth_seg 145 | 146 | 147 | 148 | def check_or_download_model(file_path): 149 | if not os.path.exists(file_path): 150 | os.makedirs(os.path.dirname(file_path), exist_ok=True) 151 | url = "https://drive.google.com/uc?id=11xTBALOeUkyuaK3l60CpkYHLTmv7k3dY" 152 | gdown.download(url, file_path, quiet=False) 153 | print("Model downloaded successfully.") 154 | else: 155 | print("Model already exists.") 156 | 157 | 158 | def load_seg_model(checkpoint_path, device='cpu'): 159 | net = U2NET(in_ch=3, out_ch=4) 160 | check_or_download_model(checkpoint_path) 161 | net = load_checkpoint(net, checkpoint_path) 162 | net = net.to(device) 163 | net = net.eval() 164 | 165 | return net 166 | 167 | 168 | def main(args): 169 | 170 | device = 'cuda:0' if args.cuda else 'cpu' 171 | 172 | # Create an instance of your model 173 | model = load_seg_model(args.checkpoint_path, device=device) 174 | 175 | palette = get_palette(4) 176 | 177 | img = Image.open(args.image).convert('RGB') 178 | 179 | cloth_seg = generate_mask(img, net=model, palette=palette, device=device) 180 | 181 | 182 | 183 | if __name__ == '__main__': 184 | parser = argparse.ArgumentParser(description='Help to set arguments for Cloth Segmentation.') 185 | parser.add_argument('--image', type=str, help='Path to the input image') 186 | parser.add_argument('--cuda', action='store_true', help='Enable CUDA (default: False)') 187 | parser.add_argument('--checkpoint_path', type=str, default='model/cloth_segm.pth', help='Path to the checkpoint file') 188 | args = parser.parse_args() 189 | 190 | main(args) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | gradio 4 | gdown 5 | Pillow 6 | opencv-python 7 | numpy --------------------------------------------------------------------------------