├── LICENSE ├── README.md ├── datasets ├── README.md ├── __init__.py ├── cityscapes.py └── voc.py ├── eval └── eval_voc.py ├── models ├── __init__.py ├── config.py ├── duc_hdc.py ├── fcn16s.py ├── fcn32s.py ├── fcn8s.py ├── gcn.py ├── psp_net.py ├── seg_net.py └── u_net.py ├── train ├── cityscapes-fcn (caffe vgg) │ ├── README.md │ ├── static │ │ ├── fcn8s-epoch328.jpg │ │ ├── fcn8s-mean_iu.jpg │ │ ├── fcn8s-train_loss.jpg │ │ └── fcn8s-val_loss.jpg │ └── train.py ├── cityscapes-fcn │ └── train.py ├── cityscapes-psp_net │ ├── static │ │ ├── 0_gt.png │ │ ├── 0_prediction.png │ │ ├── 1_gt.png │ │ ├── 1_prediction.png │ │ ├── 2_gt.png │ │ ├── 2_prediction.png │ │ ├── 3_gt.png │ │ ├── 3_prediction.png │ │ ├── 4_gt.png │ │ ├── 4_prediction.png │ │ ├── 5_gt.png │ │ ├── 5_prediction.png │ │ ├── 6_gt.png │ │ ├── 6_prediction.png │ │ ├── 7_gt.png │ │ ├── 7_prediction.png │ │ ├── 8_gt.png │ │ ├── 8_prediction.png │ │ ├── 9_gt.png │ │ └── 9_prediction.png │ └── train.py ├── voc-fcn (caffe vgg) │ ├── README.md │ ├── static │ │ ├── fcn8s-epoch9.jpg │ │ ├── fcn8s-mean_iu.jpg │ │ ├── fcn8s-train_loss.jpg │ │ └── fcn8s-val_loss.jpg │ └── train.py ├── voc-fcn │ └── train.py └── voc-psp_net │ └── train.py └── utils ├── __init__.py ├── joint_transforms.py ├── misc.py └── transforms.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 ZijunDeng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch for Semantic Segmentation 2 | This repository contains some models for semantic segmentation and the pipeline of training and testing models, 3 | implemented in PyTorch 4 | 5 | ## Models 6 | 1. Vanilla FCN: FCN32, FCN16, FCN8, in the versions of VGG, ResNet and DenseNet respectively 7 | ([Fully convolutional networks for semantic segmentation](http://www.cv-foundation.org/openaccess/content_cvpr_2015/papers/Long_Fully_Convolutional_Networks_2015_CVPR_paper.pdf)) 8 | 2. U-Net ([U-net: Convolutional networks for biomedical image segmentation](https://arxiv.org/pdf/1505.04597)) 9 | 3. SegNet ([Segnet: A deep convolutional encoder-decoder architecture for image segmentation](https://arxiv.org/pdf/1511.00561)) 10 | 4. PSPNet ([Pyramid scene parsing network](https://arxiv.org/pdf/1612.01105)) 11 | 5. GCN ([Large Kernel Matters](https://arxiv.org/pdf/1703.02719)) 12 | 6. DUC, HDC ([understanding convolution for semantic segmentation](https://arxiv.org/pdf/1702.08502.pdf)) 13 | 14 | ## Requirement 15 | 1. PyTorch 0.2.0 16 | 2. TensorBoard for PyTorch. [Here](https://github.com/lanpa/tensorboard-pytorch) to install 17 | 3. Some other libraries (find what you miss when running the code :-P) 18 | 19 | ## Preparation 20 | 1. Go to *models* directory and set the path of pretrained models in *config.py* 21 | 2. Go to *datasets* directory and do following the README 22 | 23 | ## TODO 24 | 1. DeepLab v3 25 | 2. RefineNet 26 | 3. More dataset (e.g. ADE) 27 | -------------------------------------------------------------------------------- /datasets/README.md: -------------------------------------------------------------------------------- 1 | # Dataset 2 | 3 | ## PASCAL VOC 2012 4 | 1. Visit [this](https://github.com/shelhamer/fcn.berkeleyvision.org/tree/master/data/pascal), download SBD and 5 | PASCAL VOC 2012 6 | 2. Extract them, you will get *benchmark_RELEASE* and *VOCdevkit* folders. 7 | 3. Add file *seg11valid.txt* ([download]( 8 | https://github.com/shelhamer/fcn.berkeleyvision.org/blob/master/data/pascal/seg11valid.txt)) 9 | into *VOCdevkit/VOC2012/ImageSets/Segmentation* 10 | 4. Put the *benchmark_RELEASE* and *VOCdevkit* folders in a folder called *VOC* 11 | 5. Set the path (*root*) of *VOC* folder in the last step in *voc.py* 12 | 13 | ## Cityscapes 14 | 1. Download *leftImg8bit_trainvaltest*, *gtFine_trainvaltest*, *leftImg8bit_trainextra*, and *gtCoarse* from the cityscapes website 15 | 2. Extract and put them in a folder called *cityscapes* 16 | 3. Set the path (*root*) of *cityscapes* folder in the last step in *cityscapes.py* 17 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from . import cityscapes 2 | from . import voc 3 | -------------------------------------------------------------------------------- /datasets/cityscapes.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | from PIL import Image 6 | from torch.utils import data 7 | 8 | num_classes = 19 9 | ignore_label = 255 10 | root = '/media/b3-542/LIBRARY/Datasets/cityscapes' 11 | 12 | palette = [128, 64, 128, 244, 35, 232, 70, 70, 70, 102, 102, 156, 190, 153, 153, 153, 153, 153, 250, 170, 30, 13 | 220, 220, 0, 107, 142, 35, 152, 251, 152, 70, 130, 180, 220, 20, 60, 255, 0, 0, 0, 0, 142, 0, 0, 70, 14 | 0, 60, 100, 0, 80, 100, 0, 0, 230, 119, 11, 32] 15 | zero_pad = 256 * 3 - len(palette) 16 | for i in range(zero_pad): 17 | palette.append(0) 18 | 19 | 20 | def colorize_mask(mask): 21 | # mask: numpy array of the mask 22 | new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P') 23 | new_mask.putpalette(palette) 24 | 25 | return new_mask 26 | 27 | 28 | def make_dataset(quality, mode): 29 | assert (quality == 'fine' and mode in ['train', 'val']) or \ 30 | (quality == 'coarse' and mode in ['train', 'train_extra', 'val']) 31 | 32 | if quality == 'coarse': 33 | img_dir_name = 'leftImg8bit_trainextra' if mode == 'train_extra' else 'leftImg8bit_trainvaltest' 34 | mask_path = os.path.join(root, 'gtCoarse', 'gtCoarse', mode) 35 | mask_postfix = '_gtCoarse_labelIds.png' 36 | else: 37 | img_dir_name = 'leftImg8bit_trainvaltest' 38 | mask_path = os.path.join(root, 'gtFine_trainvaltest', 'gtFine', mode) 39 | mask_postfix = '_gtFine_labelIds.png' 40 | img_path = os.path.join(root, img_dir_name, 'leftImg8bit', mode) 41 | assert os.listdir(img_path) == os.listdir(mask_path) 42 | items = [] 43 | categories = os.listdir(img_path) 44 | for c in categories: 45 | c_items = [name.split('_leftImg8bit.png')[0] for name in os.listdir(os.path.join(img_path, c))] 46 | for it in c_items: 47 | item = (os.path.join(img_path, c, it + '_leftImg8bit.png'), os.path.join(mask_path, c, it + mask_postfix)) 48 | items.append(item) 49 | return items 50 | 51 | 52 | class CityScapes(data.Dataset): 53 | def __init__(self, quality, mode, joint_transform=None, sliding_crop=None, transform=None, target_transform=None): 54 | self.imgs = make_dataset(quality, mode) 55 | if len(self.imgs) == 0: 56 | raise RuntimeError('Found 0 images, please check the data set') 57 | self.quality = quality 58 | self.mode = mode 59 | self.joint_transform = joint_transform 60 | self.sliding_crop = sliding_crop 61 | self.transform = transform 62 | self.target_transform = target_transform 63 | self.id_to_trainid = {-1: ignore_label, 0: ignore_label, 1: ignore_label, 2: ignore_label, 64 | 3: ignore_label, 4: ignore_label, 5: ignore_label, 6: ignore_label, 65 | 7: 0, 8: 1, 9: ignore_label, 10: ignore_label, 11: 2, 12: 3, 13: 4, 66 | 14: ignore_label, 15: ignore_label, 16: ignore_label, 17: 5, 67 | 18: ignore_label, 19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 25: 12, 26: 13, 27: 14, 68 | 28: 15, 29: ignore_label, 30: ignore_label, 31: 16, 32: 17, 33: 18} 69 | 70 | def __getitem__(self, index): 71 | img_path, mask_path = self.imgs[index] 72 | img, mask = Image.open(img_path).convert('RGB'), Image.open(mask_path) 73 | 74 | mask = np.array(mask) 75 | mask_copy = mask.copy() 76 | for k, v in self.id_to_trainid.items(): 77 | mask_copy[mask == k] = v 78 | mask = Image.fromarray(mask_copy.astype(np.uint8)) 79 | 80 | if self.joint_transform is not None: 81 | img, mask = self.joint_transform(img, mask) 82 | if self.sliding_crop is not None: 83 | img_slices, mask_slices, slices_info = self.sliding_crop(img, mask) 84 | if self.transform is not None: 85 | img_slices = [self.transform(e) for e in img_slices] 86 | if self.target_transform is not None: 87 | mask_slices = [self.target_transform(e) for e in mask_slices] 88 | img, mask = torch.stack(img_slices, 0), torch.stack(mask_slices, 0) 89 | return img, mask, torch.LongTensor(slices_info) 90 | else: 91 | if self.transform is not None: 92 | img = self.transform(img) 93 | if self.target_transform is not None: 94 | mask = self.target_transform(mask) 95 | return img, mask 96 | 97 | def __len__(self): 98 | return len(self.imgs) 99 | -------------------------------------------------------------------------------- /datasets/voc.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import scipy.io as sio 5 | import torch 6 | from PIL import Image 7 | from torch.utils import data 8 | 9 | num_classes = 21 10 | ignore_label = 255 11 | root = '/media/b3-542/LIBRARY/Datasets/VOC' 12 | 13 | ''' 14 | color map 15 | 0=background, 1=aeroplane, 2=bicycle, 3=bird, 4=boat, 5=bottle # 6=bus, 7=car, 8=cat, 9=chair, 10=cow, 11=diningtable, 16 | 12=dog, 13=horse, 14=motorbike, 15=person # 16=potted plant, 17=sheep, 18=sofa, 19=train, 20=tv/monitor 17 | ''' 18 | palette = [0, 0, 0, 128, 0, 0, 0, 128, 0, 128, 128, 0, 0, 0, 128, 128, 0, 128, 0, 128, 128, 19 | 128, 128, 128, 64, 0, 0, 192, 0, 0, 64, 128, 0, 192, 128, 0, 64, 0, 128, 192, 0, 128, 20 | 64, 128, 128, 192, 128, 128, 0, 64, 0, 128, 64, 0, 0, 192, 0, 128, 192, 0, 0, 64, 128] 21 | 22 | zero_pad = 256 * 3 - len(palette) 23 | for i in range(zero_pad): 24 | palette.append(0) 25 | 26 | 27 | def colorize_mask(mask): 28 | # mask: numpy array of the mask 29 | new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P') 30 | new_mask.putpalette(palette) 31 | 32 | return new_mask 33 | 34 | 35 | def make_dataset(mode): 36 | assert mode in ['train', 'val', 'test'] 37 | items = [] 38 | if mode == 'train': 39 | img_path = os.path.join(root, 'benchmark_RELEASE', 'dataset', 'img') 40 | mask_path = os.path.join(root, 'benchmark_RELEASE', 'dataset', 'cls') 41 | data_list = [l.strip('\n') for l in open(os.path.join( 42 | root, 'benchmark_RELEASE', 'dataset', 'train.txt')).readlines()] 43 | for it in data_list: 44 | item = (os.path.join(img_path, it + '.jpg'), os.path.join(mask_path, it + '.mat')) 45 | items.append(item) 46 | elif mode == 'val': 47 | img_path = os.path.join(root, 'VOCdevkit', 'VOC2012', 'JPEGImages') 48 | mask_path = os.path.join(root, 'VOCdevkit', 'VOC2012', 'SegmentationClass') 49 | data_list = [l.strip('\n') for l in open(os.path.join( 50 | root, 'VOCdevkit', 'VOC2012', 'ImageSets', 'Segmentation', 'seg11valid.txt')).readlines()] 51 | for it in data_list: 52 | item = (os.path.join(img_path, it + '.jpg'), os.path.join(mask_path, it + '.png')) 53 | items.append(item) 54 | else: 55 | img_path = os.path.join(root, 'VOCdevkit (test)', 'VOC2012', 'JPEGImages') 56 | data_list = [l.strip('\n') for l in open(os.path.join( 57 | root, 'VOCdevkit (test)', 'VOC2012', 'ImageSets', 'Segmentation', 'test.txt')).readlines()] 58 | for it in data_list: 59 | items.append((img_path, it)) 60 | return items 61 | 62 | 63 | class VOC(data.Dataset): 64 | def __init__(self, mode, joint_transform=None, sliding_crop=None, transform=None, target_transform=None): 65 | self.imgs = make_dataset(mode) 66 | if len(self.imgs) == 0: 67 | raise RuntimeError('Found 0 images, please check the data set') 68 | self.mode = mode 69 | self.joint_transform = joint_transform 70 | self.sliding_crop = sliding_crop 71 | self.transform = transform 72 | self.target_transform = target_transform 73 | 74 | def __getitem__(self, index): 75 | if self.mode == 'test': 76 | img_path, img_name = self.imgs[index] 77 | img = Image.open(os.path.join(img_path, img_name + '.jpg')).convert('RGB') 78 | if self.transform is not None: 79 | img = self.transform(img) 80 | return img_name, img 81 | 82 | img_path, mask_path = self.imgs[index] 83 | img = Image.open(img_path).convert('RGB') 84 | if self.mode == 'train': 85 | mask = sio.loadmat(mask_path)['GTcls']['Segmentation'][0][0] 86 | mask = Image.fromarray(mask.astype(np.uint8)) 87 | else: 88 | mask = Image.open(mask_path) 89 | 90 | if self.joint_transform is not None: 91 | img, mask = self.joint_transform(img, mask) 92 | 93 | if self.sliding_crop is not None: 94 | img_slices, mask_slices, slices_info = self.sliding_crop(img, mask) 95 | if self.transform is not None: 96 | img_slices = [self.transform(e) for e in img_slices] 97 | if self.target_transform is not None: 98 | mask_slices = [self.target_transform(e) for e in mask_slices] 99 | img, mask = torch.stack(img_slices, 0), torch.stack(mask_slices, 0) 100 | return img, mask, torch.LongTensor(slices_info) 101 | else: 102 | if self.transform is not None: 103 | img = self.transform(img) 104 | if self.target_transform is not None: 105 | mask = self.target_transform(mask) 106 | return img, mask 107 | 108 | def __len__(self): 109 | return len(self.imgs) 110 | -------------------------------------------------------------------------------- /eval/eval_voc.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torchvision.transforms as standard_transforms 4 | from torch.autograd import Variable 5 | from torch.backends import cudnn 6 | from torch.utils.data import DataLoader 7 | 8 | from datasets import voc 9 | from models import * 10 | from utils import check_mkdir 11 | 12 | cudnn.benchmark = True 13 | 14 | ckpt_path = './ckpt' 15 | 16 | args = { 17 | 'exp_name': 'voc-psp_net', 18 | 'snapshot': 'epoch_33_loss_0.31766_acc_0.92188_acc-cls_0.81110_mean-iu_0.70271_fwavacc_0.86757_lr_0.0023769346.pth' 19 | } 20 | 21 | 22 | def main(): 23 | net = PSPNet(num_classes=voc.num_classes).cuda() 24 | print('load model ' + args['snapshot']) 25 | net.load_state_dict(torch.load(os.path.join(ckpt_path, args['exp_name'], args['snapshot']))) 26 | net.eval() 27 | 28 | mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 29 | 30 | val_input_transform = standard_transforms.Compose([ 31 | standard_transforms.ToTensor(), 32 | standard_transforms.Normalize(*mean_std) 33 | ]) 34 | 35 | test_set = voc.VOC('test', transform=val_input_transform) 36 | test_loader = DataLoader(test_set, batch_size=1, num_workers=8, shuffle=False) 37 | 38 | check_mkdir(os.path.join(ckpt_path, args['exp_name'], 'test')) 39 | 40 | for vi, data in enumerate(test_loader): 41 | img_name, img = data 42 | img_name = img_name[0] 43 | 44 | img = Variable(img, volatile=True).cuda() 45 | output = net(img) 46 | 47 | prediction = output.data.max(1)[1].squeeze_(1).squeeze_(0).cpu().numpy() 48 | prediction = voc.colorize_mask(prediction) 49 | prediction.save(os.path.join(ckpt_path, args['exp_name'], 'test', img_name + '.png')) 50 | 51 | print('%d / %d' % (vi + 1, len(test_loader))) 52 | 53 | 54 | if __name__ == '__main__': 55 | main() 56 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .duc_hdc import * 2 | from .fcn16s import * 3 | from .fcn32s import * 4 | from .fcn8s import * 5 | from .gcn import * 6 | from .psp_net import * 7 | from .seg_net import * 8 | from .u_net import * 9 | -------------------------------------------------------------------------------- /models/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # here (https://github.com/pytorch/vision/tree/master/torchvision/models) to find the download link of pretrained models 4 | 5 | root = '/media/b3-542/LIBRARY/ZijunDeng/PyTorch Pretrained' 6 | res101_path = os.path.join(root, 'ResNet', 'resnet101-5d3b4d8f.pth') 7 | res152_path = os.path.join(root, 'ResNet', 'resnet152-b121ed2d.pth') 8 | inception_v3_path = os.path.join(root, 'Inception', 'inception_v3_google-1a9a5a14.pth') 9 | vgg19_bn_path = os.path.join(root, 'VggNet', 'vgg19_bn-c79401a0.pth') 10 | vgg16_path = os.path.join(root, 'VggNet', 'vgg16-397923af.pth') 11 | dense201_path = os.path.join(root, 'DenseNet', 'densenet201-4c113574.pth') 12 | 13 | ''' 14 | vgg16 trained using caffe 15 | visit this (https://github.com/jcjohnson/pytorch-vgg) to download the converted vgg16 16 | ''' 17 | vgg16_caffe_path = os.path.join(root, 'VggNet', 'vgg16-caffe.pth') 18 | -------------------------------------------------------------------------------- /models/duc_hdc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torchvision import models 4 | 5 | from .config import res152_path 6 | 7 | 8 | class _DenseUpsamplingConvModule(nn.Module): 9 | def __init__(self, down_factor, in_dim, num_classes): 10 | super(_DenseUpsamplingConvModule, self).__init__() 11 | upsample_dim = (down_factor ** 2) * num_classes 12 | self.conv = nn.Conv2d(in_dim, upsample_dim, kernel_size=3, padding=1) 13 | self.bn = nn.BatchNorm2d(upsample_dim) 14 | self.relu = nn.ReLU(inplace=True) 15 | self.pixel_shuffle = nn.PixelShuffle(down_factor) 16 | 17 | def forward(self, x): 18 | x = self.conv(x) 19 | x = self.bn(x) 20 | x = self.relu(x) 21 | x = self.pixel_shuffle(x) 22 | return x 23 | 24 | 25 | class ResNetDUC(nn.Module): 26 | # the size of image should be multiple of 8 27 | def __init__(self, num_classes, pretrained=True): 28 | super(ResNetDUC, self).__init__() 29 | resnet = models.resnet152() 30 | if pretrained: 31 | resnet.load_state_dict(torch.load(res152_path)) 32 | self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) 33 | self.layer1 = resnet.layer1 34 | self.layer2 = resnet.layer2 35 | self.layer3 = resnet.layer3 36 | self.layer4 = resnet.layer4 37 | 38 | for n, m in self.layer3.named_modules(): 39 | if 'conv2' in n: 40 | m.dilation = (2, 2) 41 | m.padding = (2, 2) 42 | m.stride = (1, 1) 43 | elif 'downsample.0' in n: 44 | m.stride = (1, 1) 45 | for n, m in self.layer4.named_modules(): 46 | if 'conv2' in n: 47 | m.dilation = (4, 4) 48 | m.padding = (4, 4) 49 | m.stride = (1, 1) 50 | elif 'downsample.0' in n: 51 | m.stride = (1, 1) 52 | 53 | self.duc = _DenseUpsamplingConvModule(8, 2048, num_classes) 54 | 55 | def forward(self, x): 56 | x = self.layer0(x) 57 | x = self.layer1(x) 58 | x = self.layer2(x) 59 | x = self.layer3(x) 60 | x = self.layer4(x) 61 | x = self.duc(x) 62 | return x 63 | 64 | 65 | class ResNetDUCHDC(nn.Module): 66 | # the size of image should be multiple of 8 67 | def __init__(self, num_classes, pretrained=True): 68 | super(ResNetDUCHDC, self).__init__() 69 | resnet = models.resnet152() 70 | if pretrained: 71 | resnet.load_state_dict(torch.load(res152_path)) 72 | self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) 73 | self.layer1 = resnet.layer1 74 | self.layer2 = resnet.layer2 75 | self.layer3 = resnet.layer3 76 | self.layer4 = resnet.layer4 77 | 78 | for n, m in self.layer3.named_modules(): 79 | if 'conv2' in n or 'downsample.0' in n: 80 | m.stride = (1, 1) 81 | for n, m in self.layer4.named_modules(): 82 | if 'conv2' in n or 'downsample.0' in n: 83 | m.stride = (1, 1) 84 | layer3_group_config = [1, 2, 5, 9] 85 | for idx in range(len(self.layer3)): 86 | self.layer3[idx].conv2.dilation = (layer3_group_config[idx % 4], layer3_group_config[idx % 4]) 87 | self.layer3[idx].conv2.padding = (layer3_group_config[idx % 4], layer3_group_config[idx % 4]) 88 | layer4_group_config = [5, 9, 17] 89 | for idx in range(len(self.layer4)): 90 | self.layer4[idx].conv2.dilation = (layer4_group_config[idx], layer4_group_config[idx]) 91 | self.layer4[idx].conv2.padding = (layer4_group_config[idx], layer4_group_config[idx]) 92 | 93 | self.duc = _DenseUpsamplingConvModule(8, 2048, num_classes) 94 | 95 | def forward(self, x): 96 | x = self.layer0(x) 97 | x = self.layer1(x) 98 | x = self.layer2(x) 99 | x = self.layer3(x) 100 | x = self.layer4(x) 101 | x = self.duc(x) 102 | return x 103 | -------------------------------------------------------------------------------- /models/fcn16s.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torchvision import models 4 | 5 | from ..utils import get_upsampling_weight 6 | from .config import vgg16_caffe_path 7 | 8 | 9 | class FCN16VGG(nn.Module): 10 | def __init__(self, num_classes, pretrained=True): 11 | super(FCN16VGG, self).__init__() 12 | vgg = models.vgg16() 13 | if pretrained: 14 | vgg.load_state_dict(torch.load(vgg16_caffe_path)) 15 | features, classifier = list(vgg.features.children()), list(vgg.classifier.children()) 16 | 17 | features[0].padding = (100, 100) 18 | 19 | for f in features: 20 | if 'MaxPool' in f.__class__.__name__: 21 | f.ceil_mode = True 22 | elif 'ReLU' in f.__class__.__name__: 23 | f.inplace = True 24 | 25 | self.features4 = nn.Sequential(*features[: 24]) 26 | self.features5 = nn.Sequential(*features[24:]) 27 | 28 | self.score_pool4 = nn.Conv2d(512, num_classes, kernel_size=1) 29 | self.score_pool4.weight.data.zero_() 30 | self.score_pool4.bias.data.zero_() 31 | 32 | fc6 = nn.Conv2d(512, 4096, kernel_size=7) 33 | fc6.weight.data.copy_(classifier[0].weight.data.view(4096, 512, 7, 7)) 34 | fc6.bias.data.copy_(classifier[0].bias.data) 35 | fc7 = nn.Conv2d(4096, 4096, kernel_size=1) 36 | fc7.weight.data.copy_(classifier[3].weight.data.view(4096, 4096, 1, 1)) 37 | fc7.bias.data.copy_(classifier[3].bias.data) 38 | score_fr = nn.Conv2d(4096, num_classes, kernel_size=1) 39 | score_fr.weight.data.zero_() 40 | score_fr.bias.data.zero_() 41 | self.score_fr = nn.Sequential( 42 | fc6, nn.ReLU(inplace=True), nn.Dropout(), fc7, nn.ReLU(inplace=True), nn.Dropout(), score_fr 43 | ) 44 | 45 | self.upscore2 = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=4, stride=2, bias=False) 46 | self.upscore16 = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=32, stride=16, bias=False) 47 | self.upscore2.weight.data.copy_(get_upsampling_weight(num_classes, num_classes, 4)) 48 | self.upscore16.weight.data.copy_(get_upsampling_weight(num_classes, num_classes, 32)) 49 | 50 | def forward(self, x): 51 | x_size = x.size() 52 | pool4 = self.features4(x) 53 | pool5 = self.features5(pool4) 54 | 55 | score_fr = self.score_fr(pool5) 56 | upscore2 = self.upscore2(score_fr) 57 | 58 | score_pool4 = self.score_pool4(0.01 * pool4) 59 | upscore16 = self.upscore16(score_pool4[:, :, 5: (5 + upscore2.size()[2]), 5: (5 + upscore2.size()[3])] 60 | + upscore2) 61 | return upscore16[:, :, 27: (27 + x_size[2]), 27: (27 + x_size[3])].contiguous() 62 | -------------------------------------------------------------------------------- /models/fcn32s.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torchvision import models 4 | 5 | from ..utils import get_upsampling_weight 6 | from .config import vgg16_caffe_path 7 | 8 | 9 | class FCN32VGG(nn.Module): 10 | def __init__(self, num_classes, pretrained=True): 11 | super(FCN32VGG, self).__init__() 12 | vgg = models.vgg16() 13 | if pretrained: 14 | vgg.load_state_dict(torch.load(vgg16_caffe_path)) 15 | features, classifier = list(vgg.features.children()), list(vgg.classifier.children()) 16 | 17 | features[0].padding = (100, 100) 18 | 19 | for f in features: 20 | if 'MaxPool' in f.__class__.__name__: 21 | f.ceil_mode = True 22 | elif 'ReLU' in f.__class__.__name__: 23 | f.inplace = True 24 | 25 | self.features5 = nn.Sequential(*features) 26 | 27 | fc6 = nn.Conv2d(512, 4096, kernel_size=7) 28 | fc6.weight.data.copy_(classifier[0].weight.data.view(4096, 512, 7, 7)) 29 | fc6.bias.data.copy_(classifier[0].bias.data) 30 | fc7 = nn.Conv2d(4096, 4096, kernel_size=1) 31 | fc7.weight.data.copy_(classifier[3].weight.data.view(4096, 4096, 1, 1)) 32 | fc7.bias.data.copy_(classifier[3].bias.data) 33 | score_fr = nn.Conv2d(4096, num_classes, kernel_size=1) 34 | score_fr.weight.data.zero_() 35 | score_fr.bias.data.zero_() 36 | self.score_fr = nn.Sequential( 37 | fc6, nn.ReLU(inplace=True), nn.Dropout(), fc7, nn.ReLU(inplace=True), nn.Dropout(), score_fr 38 | ) 39 | 40 | self.upscore = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=64, stride=32, bias=False) 41 | self.upscore.weight.data.copy_(get_upsampling_weight(num_classes, num_classes, 64)) 42 | 43 | def forward(self, x): 44 | x_size = x.size() 45 | pool5 = self.features5(x) 46 | score_fr = self.score_fr(pool5) 47 | upscore = self.upscore(score_fr) 48 | return upscore[:, :, 19: (19 + x_size[2]), 19: (19 + x_size[3])].contiguous() 49 | -------------------------------------------------------------------------------- /models/fcn8s.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torchvision import models 4 | 5 | from ..utils import get_upsampling_weight 6 | from .config import vgg16_path, vgg16_caffe_path 7 | 8 | 9 | # This is implemented in full accordance with the original one (https://github.com/shelhamer/fcn.berkeleyvision.org) 10 | class FCN8s(nn.Module): 11 | def __init__(self, num_classes, pretrained=True, caffe=False): 12 | super(FCN8s, self).__init__() 13 | vgg = models.vgg16() 14 | if pretrained: 15 | if caffe: 16 | # load the pretrained vgg16 used by the paper's author 17 | vgg.load_state_dict(torch.load(vgg16_caffe_path)) 18 | else: 19 | vgg.load_state_dict(torch.load(vgg16_path)) 20 | features, classifier = list(vgg.features.children()), list(vgg.classifier.children()) 21 | 22 | ''' 23 | 100 padding for 2 reasons: 24 | 1) support very small input size 25 | 2) allow cropping in order to match size of different layers' feature maps 26 | Note that the cropped part corresponds to a part of the 100 padding 27 | Spatial information of different layers' feature maps cannot be align exactly because of cropping, which is bad 28 | ''' 29 | features[0].padding = (100, 100) 30 | 31 | for f in features: 32 | if 'MaxPool' in f.__class__.__name__: 33 | f.ceil_mode = True 34 | elif 'ReLU' in f.__class__.__name__: 35 | f.inplace = True 36 | 37 | self.features3 = nn.Sequential(*features[: 17]) 38 | self.features4 = nn.Sequential(*features[17: 24]) 39 | self.features5 = nn.Sequential(*features[24:]) 40 | 41 | self.score_pool3 = nn.Conv2d(256, num_classes, kernel_size=1) 42 | self.score_pool4 = nn.Conv2d(512, num_classes, kernel_size=1) 43 | self.score_pool3.weight.data.zero_() 44 | self.score_pool3.bias.data.zero_() 45 | self.score_pool4.weight.data.zero_() 46 | self.score_pool4.bias.data.zero_() 47 | 48 | fc6 = nn.Conv2d(512, 4096, kernel_size=7) 49 | fc6.weight.data.copy_(classifier[0].weight.data.view(4096, 512, 7, 7)) 50 | fc6.bias.data.copy_(classifier[0].bias.data) 51 | fc7 = nn.Conv2d(4096, 4096, kernel_size=1) 52 | fc7.weight.data.copy_(classifier[3].weight.data.view(4096, 4096, 1, 1)) 53 | fc7.bias.data.copy_(classifier[3].bias.data) 54 | score_fr = nn.Conv2d(4096, num_classes, kernel_size=1) 55 | score_fr.weight.data.zero_() 56 | score_fr.bias.data.zero_() 57 | self.score_fr = nn.Sequential( 58 | fc6, nn.ReLU(inplace=True), nn.Dropout(), fc7, nn.ReLU(inplace=True), nn.Dropout(), score_fr 59 | ) 60 | 61 | self.upscore2 = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=4, stride=2, bias=False) 62 | self.upscore_pool4 = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=4, stride=2, bias=False) 63 | self.upscore8 = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=16, stride=8, bias=False) 64 | self.upscore2.weight.data.copy_(get_upsampling_weight(num_classes, num_classes, 4)) 65 | self.upscore_pool4.weight.data.copy_(get_upsampling_weight(num_classes, num_classes, 4)) 66 | self.upscore8.weight.data.copy_(get_upsampling_weight(num_classes, num_classes, 16)) 67 | 68 | def forward(self, x): 69 | x_size = x.size() 70 | pool3 = self.features3(x) 71 | pool4 = self.features4(pool3) 72 | pool5 = self.features5(pool4) 73 | 74 | score_fr = self.score_fr(pool5) 75 | upscore2 = self.upscore2(score_fr) 76 | 77 | score_pool4 = self.score_pool4(0.01 * pool4) 78 | upscore_pool4 = self.upscore_pool4(score_pool4[:, :, 5: (5 + upscore2.size()[2]), 5: (5 + upscore2.size()[3])] 79 | + upscore2) 80 | 81 | score_pool3 = self.score_pool3(0.0001 * pool3) 82 | upscore8 = self.upscore8(score_pool3[:, :, 9: (9 + upscore_pool4.size()[2]), 9: (9 + upscore_pool4.size()[3])] 83 | + upscore_pool4) 84 | return upscore8[:, :, 31: (31 + x_size[2]), 31: (31 + x_size[3])].contiguous() 85 | -------------------------------------------------------------------------------- /models/gcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torchvision import models 5 | 6 | from ..utils import initialize_weights 7 | from .config import res152_path 8 | 9 | 10 | # many are borrowed from https://github.com/ycszen/pytorch-ss/blob/master/gcn.py 11 | class _GlobalConvModule(nn.Module): 12 | def __init__(self, in_dim, out_dim, kernel_size): 13 | super(_GlobalConvModule, self).__init__() 14 | pad0 = (kernel_size[0] - 1) / 2 15 | pad1 = (kernel_size[1] - 1) / 2 16 | # kernel size had better be odd number so as to avoid alignment error 17 | super(_GlobalConvModule, self).__init__() 18 | self.conv_l1 = nn.Conv2d(in_dim, out_dim, kernel_size=(kernel_size[0], 1), 19 | padding=(pad0, 0)) 20 | self.conv_l2 = nn.Conv2d(out_dim, out_dim, kernel_size=(1, kernel_size[1]), 21 | padding=(0, pad1)) 22 | self.conv_r1 = nn.Conv2d(in_dim, out_dim, kernel_size=(1, kernel_size[1]), 23 | padding=(0, pad1)) 24 | self.conv_r2 = nn.Conv2d(out_dim, out_dim, kernel_size=(kernel_size[0], 1), 25 | padding=(pad0, 0)) 26 | 27 | def forward(self, x): 28 | x_l = self.conv_l1(x) 29 | x_l = self.conv_l2(x_l) 30 | x_r = self.conv_r1(x) 31 | x_r = self.conv_r2(x_r) 32 | x = x_l + x_r 33 | return x 34 | 35 | 36 | class _BoundaryRefineModule(nn.Module): 37 | def __init__(self, dim): 38 | super(_BoundaryRefineModule, self).__init__() 39 | self.relu = nn.ReLU(inplace=True) 40 | self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, padding=1) 41 | self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1) 42 | 43 | def forward(self, x): 44 | residual = self.conv1(x) 45 | residual = self.relu(residual) 46 | residual = self.conv2(residual) 47 | out = x + residual 48 | return out 49 | 50 | 51 | class GCN(nn.Module): 52 | def __init__(self, num_classes, input_size, pretrained=True): 53 | super(GCN, self).__init__() 54 | self.input_size = input_size 55 | resnet = models.resnet152() 56 | if pretrained: 57 | resnet.load_state_dict(torch.load(res152_path)) 58 | self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu) 59 | self.layer1 = nn.Sequential(resnet.maxpool, resnet.layer1) 60 | self.layer2 = resnet.layer2 61 | self.layer3 = resnet.layer3 62 | self.layer4 = resnet.layer4 63 | 64 | self.gcm1 = _GlobalConvModule(2048, num_classes, (7, 7)) 65 | self.gcm2 = _GlobalConvModule(1024, num_classes, (7, 7)) 66 | self.gcm3 = _GlobalConvModule(512, num_classes, (7, 7)) 67 | self.gcm4 = _GlobalConvModule(256, num_classes, (7, 7)) 68 | 69 | self.brm1 = _BoundaryRefineModule(num_classes) 70 | self.brm2 = _BoundaryRefineModule(num_classes) 71 | self.brm3 = _BoundaryRefineModule(num_classes) 72 | self.brm4 = _BoundaryRefineModule(num_classes) 73 | self.brm5 = _BoundaryRefineModule(num_classes) 74 | self.brm6 = _BoundaryRefineModule(num_classes) 75 | self.brm7 = _BoundaryRefineModule(num_classes) 76 | self.brm8 = _BoundaryRefineModule(num_classes) 77 | self.brm9 = _BoundaryRefineModule(num_classes) 78 | 79 | initialize_weights(self.gcm1, self.gcm2, self.gcm3, self.gcm4, self.brm1, self.brm2, self.brm3, 80 | self.brm4, self.brm5, self.brm6, self.brm7, self.brm8, self.brm9) 81 | 82 | def forward(self, x): 83 | # if x: 512 84 | fm0 = self.layer0(x) # 256 85 | fm1 = self.layer1(fm0) # 128 86 | fm2 = self.layer2(fm1) # 64 87 | fm3 = self.layer3(fm2) # 32 88 | fm4 = self.layer4(fm3) # 16 89 | 90 | gcfm1 = self.brm1(self.gcm1(fm4)) # 16 91 | gcfm2 = self.brm2(self.gcm2(fm3)) # 32 92 | gcfm3 = self.brm3(self.gcm3(fm2)) # 64 93 | gcfm4 = self.brm4(self.gcm4(fm1)) # 128 94 | 95 | fs1 = self.brm5(F.upsample_bilinear(gcfm1, fm3.size()[2:]) + gcfm2) # 32 96 | fs2 = self.brm6(F.upsample_bilinear(fs1, fm2.size()[2:]) + gcfm3) # 64 97 | fs3 = self.brm7(F.upsample_bilinear(fs2, fm1.size()[2:]) + gcfm4) # 128 98 | fs4 = self.brm8(F.upsample_bilinear(fs3, fm0.size()[2:])) # 256 99 | out = self.brm9(F.upsample_bilinear(fs4, self.input_size)) # 512 100 | 101 | return out 102 | -------------------------------------------------------------------------------- /models/psp_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torchvision import models 5 | 6 | from ..utils import initialize_weights 7 | from ..utils.misc import Conv2dDeformable 8 | from .config import res101_path 9 | 10 | 11 | class _PyramidPoolingModule(nn.Module): 12 | def __init__(self, in_dim, reduction_dim, setting): 13 | super(_PyramidPoolingModule, self).__init__() 14 | self.features = [] 15 | for s in setting: 16 | self.features.append(nn.Sequential( 17 | nn.AdaptiveAvgPool2d(s), 18 | nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False), 19 | nn.BatchNorm2d(reduction_dim, momentum=.95), 20 | nn.ReLU(inplace=True) 21 | )) 22 | self.features = nn.ModuleList(self.features) 23 | 24 | def forward(self, x): 25 | x_size = x.size() 26 | out = [x] 27 | for f in self.features: 28 | out.append(F.upsample(f(x), x_size[2:], mode='bilinear')) 29 | out = torch.cat(out, 1) 30 | return out 31 | 32 | 33 | class PSPNet(nn.Module): 34 | def __init__(self, num_classes, pretrained=True, use_aux=True): 35 | super(PSPNet, self).__init__() 36 | self.use_aux = use_aux 37 | resnet = models.resnet101() 38 | if pretrained: 39 | resnet.load_state_dict(torch.load(res101_path)) 40 | self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) 41 | self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4 42 | 43 | for n, m in self.layer3.named_modules(): 44 | if 'conv2' in n: 45 | m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) 46 | elif 'downsample.0' in n: 47 | m.stride = (1, 1) 48 | for n, m in self.layer4.named_modules(): 49 | if 'conv2' in n: 50 | m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) 51 | elif 'downsample.0' in n: 52 | m.stride = (1, 1) 53 | 54 | self.ppm = _PyramidPoolingModule(2048, 512, (1, 2, 3, 6)) 55 | self.final = nn.Sequential( 56 | nn.Conv2d(4096, 512, kernel_size=3, padding=1, bias=False), 57 | nn.BatchNorm2d(512, momentum=.95), 58 | nn.ReLU(inplace=True), 59 | nn.Dropout(0.1), 60 | nn.Conv2d(512, num_classes, kernel_size=1) 61 | ) 62 | 63 | if use_aux: 64 | self.aux_logits = nn.Conv2d(1024, num_classes, kernel_size=1) 65 | initialize_weights(self.aux_logits) 66 | 67 | initialize_weights(self.ppm, self.final) 68 | 69 | def forward(self, x): 70 | x_size = x.size() 71 | x = self.layer0(x) 72 | x = self.layer1(x) 73 | x = self.layer2(x) 74 | x = self.layer3(x) 75 | if self.training and self.use_aux: 76 | aux = self.aux_logits(x) 77 | x = self.layer4(x) 78 | x = self.ppm(x) 79 | x = self.final(x) 80 | if self.training and self.use_aux: 81 | return F.upsample(x, x_size[2:], mode='bilinear'), F.upsample(aux, x_size[2:], mode='bilinear') 82 | return F.upsample(x, x_size[2:], mode='bilinear') 83 | 84 | 85 | # just a try, not recommend to use 86 | class PSPNetDeform(nn.Module): 87 | def __init__(self, num_classes, input_size, pretrained=True, use_aux=True): 88 | super(PSPNetDeform, self).__init__() 89 | self.input_size = input_size 90 | self.use_aux = use_aux 91 | resnet = models.resnet101() 92 | if pretrained: 93 | resnet.load_state_dict(torch.load(res101_path)) 94 | self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) 95 | self.layer1 = resnet.layer1 96 | self.layer2 = resnet.layer2 97 | self.layer3 = resnet.layer3 98 | self.layer4 = resnet.layer4 99 | 100 | for n, m in self.layer3.named_modules(): 101 | if 'conv2' in n: 102 | m.padding = (1, 1) 103 | m.stride = (1, 1) 104 | elif 'downsample.0' in n: 105 | m.stride = (1, 1) 106 | for n, m in self.layer4.named_modules(): 107 | if 'conv2' in n: 108 | m.padding = (1, 1) 109 | m.stride = (1, 1) 110 | elif 'downsample.0' in n: 111 | m.stride = (1, 1) 112 | for idx in range(len(self.layer3)): 113 | self.layer3[idx].conv2 = Conv2dDeformable(self.layer3[idx].conv2) 114 | for idx in range(len(self.layer4)): 115 | self.layer4[idx].conv2 = Conv2dDeformable(self.layer4[idx].conv2) 116 | 117 | self.ppm = _PyramidPoolingModule(2048, 512, (1, 2, 3, 6)) 118 | self.final = nn.Sequential( 119 | nn.Conv2d(4096, 512, kernel_size=3, padding=1, bias=False), 120 | nn.BatchNorm2d(512, momentum=.95), 121 | nn.ReLU(inplace=True), 122 | nn.Dropout(0.1), 123 | nn.Conv2d(512, num_classes, kernel_size=1) 124 | ) 125 | 126 | if use_aux: 127 | self.aux_logits = nn.Conv2d(1024, num_classes, kernel_size=1) 128 | initialize_weights(self.aux_logits) 129 | 130 | initialize_weights(self.ppm, self.final) 131 | 132 | def forward(self, x): 133 | x = self.layer0(x) 134 | x = self.layer1(x) 135 | x = self.layer2(x) 136 | x = self.layer3(x) 137 | if self.training and self.use_aux: 138 | aux = self.aux_logits(x) 139 | x = self.layer4(x) 140 | x = self.ppm(x) 141 | x = self.final(x) 142 | if self.training and self.use_aux: 143 | return F.upsample(x, self.input_size, mode='bilinear'), F.upsample(aux, self.input_size, mode='bilinear') 144 | return F.upsample(x, self.input_size, mode='bilinear') 145 | -------------------------------------------------------------------------------- /models/seg_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torchvision import models 4 | 5 | from ..utils import initialize_weights 6 | from .config import vgg19_bn_path 7 | 8 | 9 | class _DecoderBlock(nn.Module): 10 | def __init__(self, in_channels, out_channels, num_conv_layers): 11 | super(_DecoderBlock, self).__init__() 12 | middle_channels = in_channels / 2 13 | layers = [ 14 | nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2), 15 | nn.Conv2d(in_channels, middle_channels, kernel_size=3, padding=1), 16 | nn.BatchNorm2d(middle_channels), 17 | nn.ReLU(inplace=True) 18 | ] 19 | layers += [ 20 | nn.Conv2d(middle_channels, middle_channels, kernel_size=3, padding=1), 21 | nn.BatchNorm2d(middle_channels), 22 | nn.ReLU(inplace=True), 23 | ] * (num_conv_layers - 2) 24 | layers += [ 25 | nn.Conv2d(middle_channels, out_channels, kernel_size=3, padding=1), 26 | nn.BatchNorm2d(out_channels), 27 | nn.ReLU(inplace=True), 28 | ] 29 | self.decode = nn.Sequential(*layers) 30 | 31 | def forward(self, x): 32 | return self.decode(x) 33 | 34 | 35 | class SegNet(nn.Module): 36 | def __init__(self, num_classes, pretrained=True): 37 | super(SegNet, self).__init__() 38 | vgg = models.vgg19_bn() 39 | if pretrained: 40 | vgg.load_state_dict(torch.load(vgg19_bn_path)) 41 | features = list(vgg.features.children()) 42 | self.enc1 = nn.Sequential(*features[0:7]) 43 | self.enc2 = nn.Sequential(*features[7:14]) 44 | self.enc3 = nn.Sequential(*features[14:27]) 45 | self.enc4 = nn.Sequential(*features[27:40]) 46 | self.enc5 = nn.Sequential(*features[40:]) 47 | 48 | self.dec5 = nn.Sequential( 49 | *([nn.ConvTranspose2d(512, 512, kernel_size=2, stride=2)] + 50 | [nn.Conv2d(512, 512, kernel_size=3, padding=1), 51 | nn.BatchNorm2d(512), 52 | nn.ReLU(inplace=True)] * 4) 53 | ) 54 | self.dec4 = _DecoderBlock(1024, 256, 4) 55 | self.dec3 = _DecoderBlock(512, 128, 4) 56 | self.dec2 = _DecoderBlock(256, 64, 2) 57 | self.dec1 = _DecoderBlock(128, num_classes, 2) 58 | initialize_weights(self.dec5, self.dec4, self.dec3, self.dec2, self.dec1) 59 | 60 | def forward(self, x): 61 | enc1 = self.enc1(x) 62 | enc2 = self.enc2(enc1) 63 | enc3 = self.enc3(enc2) 64 | enc4 = self.enc4(enc3) 65 | enc5 = self.enc5(enc4) 66 | 67 | dec5 = self.dec5(enc5) 68 | dec4 = self.dec4(torch.cat([enc4, dec5], 1)) 69 | dec3 = self.dec3(torch.cat([enc3, dec4], 1)) 70 | dec2 = self.dec2(torch.cat([enc2, dec3], 1)) 71 | dec1 = self.dec1(torch.cat([enc1, dec2], 1)) 72 | return dec1 73 | -------------------------------------------------------------------------------- /models/u_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | from ..utils import initialize_weights 6 | 7 | 8 | class _EncoderBlock(nn.Module): 9 | def __init__(self, in_channels, out_channels, dropout=False): 10 | super(_EncoderBlock, self).__init__() 11 | layers = [ 12 | nn.Conv2d(in_channels, out_channels, kernel_size=3), 13 | nn.BatchNorm2d(out_channels), 14 | nn.ReLU(inplace=True), 15 | nn.Conv2d(out_channels, out_channels, kernel_size=3), 16 | nn.BatchNorm2d(out_channels), 17 | nn.ReLU(inplace=True), 18 | ] 19 | if dropout: 20 | layers.append(nn.Dropout()) 21 | layers.append(nn.MaxPool2d(kernel_size=2, stride=2)) 22 | self.encode = nn.Sequential(*layers) 23 | 24 | def forward(self, x): 25 | return self.encode(x) 26 | 27 | 28 | class _DecoderBlock(nn.Module): 29 | def __init__(self, in_channels, middle_channels, out_channels): 30 | super(_DecoderBlock, self).__init__() 31 | self.decode = nn.Sequential( 32 | nn.Conv2d(in_channels, middle_channels, kernel_size=3), 33 | nn.BatchNorm2d(middle_channels), 34 | nn.ReLU(inplace=True), 35 | nn.Conv2d(middle_channels, middle_channels, kernel_size=3), 36 | nn.BatchNorm2d(middle_channels), 37 | nn.ReLU(inplace=True), 38 | nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=2, stride=2), 39 | ) 40 | 41 | def forward(self, x): 42 | return self.decode(x) 43 | 44 | 45 | class UNet(nn.Module): 46 | def __init__(self, num_classes): 47 | super(UNet, self).__init__() 48 | self.enc1 = _EncoderBlock(3, 64) 49 | self.enc2 = _EncoderBlock(64, 128) 50 | self.enc3 = _EncoderBlock(128, 256) 51 | self.enc4 = _EncoderBlock(256, 512, dropout=True) 52 | self.center = _DecoderBlock(512, 1024, 512) 53 | self.dec4 = _DecoderBlock(1024, 512, 256) 54 | self.dec3 = _DecoderBlock(512, 256, 128) 55 | self.dec2 = _DecoderBlock(256, 128, 64) 56 | self.dec1 = nn.Sequential( 57 | nn.Conv2d(128, 64, kernel_size=3), 58 | nn.BatchNorm2d(64), 59 | nn.ReLU(inplace=True), 60 | nn.Conv2d(64, 64, kernel_size=3), 61 | nn.BatchNorm2d(64), 62 | nn.ReLU(inplace=True), 63 | ) 64 | self.final = nn.Conv2d(64, num_classes, kernel_size=1) 65 | initialize_weights(self) 66 | 67 | def forward(self, x): 68 | enc1 = self.enc1(x) 69 | enc2 = self.enc2(enc1) 70 | enc3 = self.enc3(enc2) 71 | enc4 = self.enc4(enc3) 72 | center = self.center(enc4) 73 | dec4 = self.dec4(torch.cat([center, F.upsample(enc4, center.size()[2:], mode='bilinear')], 1)) 74 | dec3 = self.dec3(torch.cat([dec4, F.upsample(enc3, dec4.size()[2:], mode='bilinear')], 1)) 75 | dec2 = self.dec2(torch.cat([dec3, F.upsample(enc2, dec3.size()[2:], mode='bilinear')], 1)) 76 | dec1 = self.dec1(torch.cat([dec2, F.upsample(enc1, dec2.size()[2:], mode='bilinear')], 1)) 77 | final = self.final(dec1) 78 | return F.upsample(final, x.size()[2:], mode='bilinear') 79 | -------------------------------------------------------------------------------- /train/cityscapes-fcn (caffe vgg)/README.md: -------------------------------------------------------------------------------- 1 | # Results 2 | 3 | ## Metrics 4 | train only on cityscapes fine, training batch size: 12, iter num per epoch: 248, lr: 1e-10, sum the pixel loss 5 | ![](static/fcn8s-train_loss.jpg) 6 | 7 | validate the loss and mean_iu after training of one epoch 8 | ![](static/fcn8s-val_loss.jpg) 9 | 10 | ![](static/fcn8s-mean_iu.jpg) 11 | 12 | ## Visualization 13 | ![](static/fcn8s-epoch328.jpg) 14 | -------------------------------------------------------------------------------- /train/cityscapes-fcn (caffe vgg)/static/fcn8s-epoch328.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijundeng/pytorch-semantic-segmentation/4a1721f9a3284788336430efb140288096c6dd09/train/cityscapes-fcn (caffe vgg)/static/fcn8s-epoch328.jpg -------------------------------------------------------------------------------- /train/cityscapes-fcn (caffe vgg)/static/fcn8s-mean_iu.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijundeng/pytorch-semantic-segmentation/4a1721f9a3284788336430efb140288096c6dd09/train/cityscapes-fcn (caffe vgg)/static/fcn8s-mean_iu.jpg -------------------------------------------------------------------------------- /train/cityscapes-fcn (caffe vgg)/static/fcn8s-train_loss.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijundeng/pytorch-semantic-segmentation/4a1721f9a3284788336430efb140288096c6dd09/train/cityscapes-fcn (caffe vgg)/static/fcn8s-train_loss.jpg -------------------------------------------------------------------------------- /train/cityscapes-fcn (caffe vgg)/static/fcn8s-val_loss.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijundeng/pytorch-semantic-segmentation/4a1721f9a3284788336430efb140288096c6dd09/train/cityscapes-fcn (caffe vgg)/static/fcn8s-val_loss.jpg -------------------------------------------------------------------------------- /train/cityscapes-fcn (caffe vgg)/train.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import random 4 | 5 | import numpy as np 6 | import torchvision.transforms as standard_transforms 7 | import torchvision.utils as vutils 8 | from tensorboard import SummaryWriter 9 | from torch import optim 10 | from torch.autograd import Variable 11 | from torch.backends import cudnn 12 | from torch.optim.lr_scheduler import ReduceLROnPlateau 13 | from torch.utils.data import DataLoader 14 | 15 | import utils.joint_transforms as joint_transforms 16 | import utils.transforms as extended_transforms 17 | from datasets import cityscapes 18 | from models import * 19 | from utils import check_mkdir, evaluate, AverageMeter, CrossEntropyLoss2d 20 | 21 | cudnn.benchmark = True 22 | 23 | ckpt_path = '../../ckpt' 24 | exp_name = 'cityscapes-fcn8s (caffe vgg)' 25 | writer = SummaryWriter(os.path.join(ckpt_path, 'exp', exp_name)) 26 | 27 | args = { 28 | 'train_batch_size': 12, 29 | 'epoch_num': 500, 30 | 'lr': 1e-10, 31 | 'weight_decay': 5e-4, 32 | 'input_size': (256, 512), 33 | 'momentum': 0.99, 34 | 'lr_patience': 100, # large patience denotes fixed lr 35 | 'snapshot': '', # empty string denotes no snapshot 36 | 'print_freq': 20, 37 | 'val_batch_size': 16, 38 | 'val_save_to_img_file': False, 39 | 'val_img_sample_rate': 0.05 # randomly sample some validation results to display 40 | } 41 | 42 | 43 | def main(): 44 | net = FCN8s(num_classes=cityscapes.num_classes, caffe=True).cuda() 45 | 46 | if len(args['snapshot']) == 0: 47 | curr_epoch = 1 48 | args['best_record'] = {'epoch': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0} 49 | else: 50 | print('training resumes from ' + args['snapshot']) 51 | net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot']))) 52 | split_snapshot = args['snapshot'].split('_') 53 | curr_epoch = int(split_snapshot[1]) + 1 54 | args['best_record'] = {'epoch': int(split_snapshot[1]), 'val_loss': float(split_snapshot[3]), 55 | 'acc': float(split_snapshot[5]), 'acc_cls': float(split_snapshot[7]), 56 | 'mean_iu': float(split_snapshot[9]), 'fwavacc': float(split_snapshot[11])} 57 | 58 | net.train() 59 | 60 | mean_std = ([103.939, 116.779, 123.68], [1.0, 1.0, 1.0]) 61 | 62 | short_size = int(min(args['input_size']) / 0.875) 63 | train_joint_transform = joint_transforms.Compose([ 64 | joint_transforms.Scale(short_size), 65 | joint_transforms.RandomCrop(args['input_size']), 66 | joint_transforms.RandomHorizontallyFlip() 67 | ]) 68 | val_joint_transform = joint_transforms.Compose([ 69 | joint_transforms.Scale(short_size), 70 | joint_transforms.CenterCrop(args['input_size']) 71 | ]) 72 | input_transform = standard_transforms.Compose([ 73 | extended_transforms.FlipChannels(), 74 | standard_transforms.ToTensor(), 75 | standard_transforms.Lambda(lambda x: x.mul_(255)), 76 | standard_transforms.Normalize(*mean_std) 77 | ]) 78 | target_transform = extended_transforms.MaskToTensor() 79 | restore_transform = standard_transforms.Compose([ 80 | extended_transforms.DeNormalize(*mean_std), 81 | standard_transforms.Lambda(lambda x: x.div_(255)), 82 | standard_transforms.ToPILImage(), 83 | extended_transforms.FlipChannels() 84 | ]) 85 | visualize = standard_transforms.ToTensor() 86 | 87 | train_set = cityscapes.CityScapes('fine', 'train', joint_transform=train_joint_transform, 88 | transform=input_transform, target_transform=target_transform) 89 | train_loader = DataLoader(train_set, batch_size=args['train_batch_size'], num_workers=8, shuffle=True) 90 | val_set = cityscapes.CityScapes('fine', 'val', joint_transform=val_joint_transform, transform=input_transform, 91 | target_transform=target_transform) 92 | val_loader = DataLoader(val_set, batch_size=args['val_batch_size'], num_workers=8, shuffle=False) 93 | 94 | criterion = CrossEntropyLoss2d(size_average=False, ignore_index=cityscapes.ignore_label).cuda() 95 | 96 | optimizer = optim.Adam([ 97 | {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'], 98 | 'lr': 2 * args['lr']}, 99 | {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'], 100 | 'lr': args['lr'], 'weight_decay': args['weight_decay']} 101 | ], betas=(args['momentum'], 0.999)) 102 | 103 | if len(args['snapshot']) > 0: 104 | optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, 'opt_' + args['snapshot']))) 105 | optimizer.param_groups[0]['lr'] = 2 * args['lr'] 106 | optimizer.param_groups[1]['lr'] = args['lr'] 107 | 108 | check_mkdir(ckpt_path) 109 | check_mkdir(os.path.join(ckpt_path, exp_name)) 110 | open(os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt'), 'w').write(str(args) + '\n\n') 111 | 112 | scheduler = ReduceLROnPlateau(optimizer, 'min', patience=args['lr_patience'], min_lr=1e-10, verbose=True) 113 | for epoch in range(curr_epoch, args['epoch_num'] + 1): 114 | train(train_loader, net, criterion, optimizer, epoch, args) 115 | val_loss = validate(val_loader, net, criterion, optimizer, epoch, args, restore_transform, visualize) 116 | scheduler.step(val_loss) 117 | 118 | 119 | def train(train_loader, net, criterion, optimizer, epoch, train_args): 120 | train_loss = AverageMeter() 121 | curr_iter = (epoch - 1) * len(train_loader) 122 | for i, data in enumerate(train_loader): 123 | inputs, labels = data 124 | assert inputs.size()[2:] == labels.size()[1:] 125 | N = inputs.size(0) 126 | inputs = Variable(inputs).cuda() 127 | labels = Variable(labels).cuda() 128 | 129 | optimizer.zero_grad() 130 | outputs = net(inputs) 131 | assert outputs.size()[2:] == labels.size()[1:] 132 | assert outputs.size()[1] == cityscapes.num_classes 133 | 134 | loss = criterion(outputs, labels) / N 135 | loss.backward() 136 | optimizer.step() 137 | 138 | train_loss.update(loss.data[0], N) 139 | 140 | curr_iter += 1 141 | writer.add_scalar('train_loss', train_loss.avg, curr_iter) 142 | 143 | if (i + 1) % train_args['print_freq'] == 0: 144 | print('[epoch %d], [iter %d / %d], [train loss %.5f]' % ( 145 | epoch, i + 1, len(train_loader), train_loss.avg)) 146 | 147 | 148 | def validate(val_loader, net, criterion, optimizer, epoch, train_args, restore, visualize): 149 | net.eval() 150 | 151 | val_loss = AverageMeter() 152 | inputs_all, gts_all, predictions_all = [], [], [] 153 | 154 | for vi, data in enumerate(val_loader): 155 | inputs, gts = data 156 | N = inputs.size(0) 157 | inputs = Variable(inputs, volatile=True).cuda() 158 | gts = Variable(gts, volatile=True).cuda() 159 | 160 | outputs = net(inputs) 161 | predictions = outputs.data.max(1)[1].squeeze_(1).cpu().numpy() 162 | 163 | val_loss.update(criterion(outputs, gts).data[0] / N, N) 164 | 165 | for i in inputs: 166 | if random.random() > train_args['val_img_sample_rate']: 167 | inputs_all.append(None) 168 | else: 169 | inputs_all.append(i.data.cpu()) 170 | gts_all.append(gts.data.cpu().numpy()) 171 | predictions_all.append(predictions) 172 | 173 | gts_all = np.concatenate(gts_all) 174 | predictions_all = np.concatenate(predictions_all) 175 | 176 | acc, acc_cls, mean_iu, fwavacc = evaluate(predictions_all, gts_all, cityscapes.num_classes) 177 | 178 | if mean_iu > train_args['best_record']['mean_iu']: 179 | train_args['best_record']['val_loss'] = val_loss.avg 180 | train_args['best_record']['epoch'] = epoch 181 | train_args['best_record']['acc'] = acc 182 | train_args['best_record']['acc_cls'] = acc_cls 183 | train_args['best_record']['mean_iu'] = mean_iu 184 | train_args['best_record']['fwavacc'] = fwavacc 185 | snapshot_name = 'epoch_%d_loss_%.5f_acc_%.5f_acc-cls_%.5f_mean-iu_%.5f_fwavacc_%.5f_lr_%.10f' % ( 186 | epoch, val_loss.avg, acc, acc_cls, mean_iu, fwavacc, optimizer.param_groups[1]['lr'] 187 | ) 188 | torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, snapshot_name + '.pth')) 189 | torch.save(optimizer.state_dict(), os.path.join(ckpt_path, exp_name, 'opt_' + snapshot_name + '.pth')) 190 | 191 | if train_args['val_save_to_img_file']: 192 | to_save_dir = os.path.join(ckpt_path, exp_name, str(epoch)) 193 | check_mkdir(to_save_dir) 194 | 195 | val_visual = [] 196 | for idx, data in enumerate(zip(inputs_all, gts_all, predictions_all)): 197 | if data[0] is None: 198 | continue 199 | input_pil = restore(data[0]) 200 | gt_pil = cityscapes.colorize_mask(data[1]) 201 | predictions_pil = cityscapes.colorize_mask(data[2]) 202 | if train_args['val_save_to_img_file']: 203 | input_pil.save(os.path.join(to_save_dir, '%d_input.png' % idx)) 204 | predictions_pil.save(os.path.join(to_save_dir, '%d_prediction.png' % idx)) 205 | gt_pil.save(os.path.join(to_save_dir, '%d_gt.png' % idx)) 206 | val_visual.extend([visualize(input_pil.convert('RGB')), visualize(gt_pil.convert('RGB')), 207 | visualize(predictions_pil.convert('RGB'))]) 208 | val_visual = torch.stack(val_visual, 0) 209 | val_visual = vutils.make_grid(val_visual, nrow=3, padding=5) 210 | writer.add_image(snapshot_name, val_visual) 211 | 212 | print('-----------------------------------------------------------------------------------------------------------') 213 | print('[epoch %d], [val loss %.5f], [acc %.5f], [acc_cls %.5f], [mean_iu %.5f], [fwavacc %.5f]' % ( 214 | epoch, val_loss.avg, acc, acc_cls, mean_iu, fwavacc)) 215 | 216 | print('best record: [val loss %.5f], [acc %.5f], [acc_cls %.5f], [mean_iu %.5f], [fwavacc %.5f], [epoch %d]' % ( 217 | train_args['best_record']['val_loss'], train_args['best_record']['acc'], train_args['best_record']['acc_cls'], 218 | train_args['best_record']['mean_iu'], train_args['best_record']['fwavacc'], train_args['best_record']['epoch'])) 219 | 220 | print('-----------------------------------------------------------------------------------------------------------') 221 | 222 | writer.add_scalar('val_loss', val_loss.avg, epoch) 223 | writer.add_scalar('acc', acc, epoch) 224 | writer.add_scalar('acc_cls', acc_cls, epoch) 225 | writer.add_scalar('mean_iu', mean_iu, epoch) 226 | writer.add_scalar('fwavacc', fwavacc, epoch) 227 | writer.add_scalar('lr', optimizer.param_groups[1]['lr'], epoch) 228 | 229 | net.train() 230 | return val_loss.avg 231 | 232 | 233 | if __name__ == '__main__': 234 | main() 235 | -------------------------------------------------------------------------------- /train/cityscapes-fcn/train.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import random 4 | 5 | import numpy as np 6 | import torchvision.transforms as standard_transforms 7 | import torchvision.utils as vutils 8 | from tensorboard import SummaryWriter 9 | from torch import optim 10 | from torch.autograd import Variable 11 | from torch.backends import cudnn 12 | from torch.optim.lr_scheduler import ReduceLROnPlateau 13 | from torch.utils.data import DataLoader 14 | 15 | import utils.joint_transforms as joint_transforms 16 | import utils.transforms as extended_transforms 17 | from datasets import cityscapes 18 | from models import * 19 | from utils import check_mkdir, evaluate, AverageMeter, CrossEntropyLoss2d 20 | 21 | cudnn.benchmark = True 22 | 23 | ckpt_path = '../../ckpt' 24 | exp_name = 'cityscapes-fcn8s' 25 | writer = SummaryWriter(os.path.join(ckpt_path, 'exp', exp_name)) 26 | 27 | args = { 28 | 'train_batch_size': 16, 29 | 'epoch_num': 500, 30 | 'lr': 1e-10, 31 | 'weight_decay': 5e-4, 32 | 'input_size': (256, 512), 33 | 'momentum': 0.95, 34 | 'lr_patience': 100, # large patience denotes fixed lr 35 | 'snapshot': '', # empty string denotes no snapshot 36 | 'print_freq': 20, 37 | 'val_batch_size': 16, 38 | 'val_save_to_img_file': False, 39 | 'val_img_sample_rate': 0.05 # randomly sample some validation results to display 40 | } 41 | 42 | 43 | def main(): 44 | net = FCN8s(num_classes=cityscapes.num_classes).cuda() 45 | 46 | if len(args['snapshot']) == 0: 47 | curr_epoch = 1 48 | args['best_record'] = {'epoch': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0} 49 | else: 50 | print('training resumes from ' + args['snapshot']) 51 | net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot']))) 52 | split_snapshot = args['snapshot'].split('_') 53 | curr_epoch = int(split_snapshot[1]) + 1 54 | args['best_record'] = {'epoch': int(split_snapshot[1]), 'val_loss': float(split_snapshot[3]), 55 | 'acc': float(split_snapshot[5]), 'acc_cls': float(split_snapshot[7]), 56 | 'mean_iu': float(split_snapshot[9]), 'fwavacc': float(split_snapshot[11])} 57 | net.train() 58 | 59 | mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 60 | short_size = int(min(args['input_size']) / 0.875) 61 | train_joint_transform = joint_transforms.Compose([ 62 | joint_transforms.Scale(short_size), 63 | joint_transforms.RandomCrop(args['input_size']), 64 | joint_transforms.RandomHorizontallyFlip() 65 | ]) 66 | val_joint_transform = joint_transforms.Compose([ 67 | joint_transforms.Scale(short_size), 68 | joint_transforms.CenterCrop(args['input_size']) 69 | ]) 70 | input_transform = standard_transforms.Compose([ 71 | standard_transforms.ToTensor(), 72 | standard_transforms.Normalize(*mean_std) 73 | ]) 74 | target_transform = extended_transforms.MaskToTensor() 75 | restore_transform = standard_transforms.Compose([ 76 | extended_transforms.DeNormalize(*mean_std), 77 | standard_transforms.ToPILImage() 78 | ]) 79 | visualize = standard_transforms.ToTensor() 80 | 81 | train_set = cityscapes.CityScapes('fine', 'train', joint_transform=train_joint_transform, 82 | transform=input_transform, target_transform=target_transform) 83 | train_loader = DataLoader(train_set, batch_size=args['train_batch_size'], num_workers=8, shuffle=True) 84 | val_set = cityscapes.CityScapes('fine', 'val', joint_transform=val_joint_transform, transform=input_transform, 85 | target_transform=target_transform) 86 | val_loader = DataLoader(val_set, batch_size=args['val_batch_size'], num_workers=8, shuffle=False) 87 | 88 | criterion = CrossEntropyLoss2d(size_average=False, ignore_index=cityscapes.ignore_label).cuda() 89 | 90 | optimizer = optim.SGD([ 91 | {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'], 92 | 'lr': 2 * args['lr']}, 93 | {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'], 94 | 'lr': args['lr'], 'weight_decay': args['weight_decay']} 95 | ], momentum=args['momentum']) 96 | 97 | if len(args['snapshot']) > 0: 98 | optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, 'opt_' + args['snapshot']))) 99 | optimizer.param_groups[0]['lr'] = 2 * args['lr'] 100 | optimizer.param_groups[1]['lr'] = args['lr'] 101 | 102 | check_mkdir(ckpt_path) 103 | check_mkdir(os.path.join(ckpt_path, exp_name)) 104 | open(os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt'), 'w').write(str(args) + '\n\n') 105 | 106 | scheduler = ReduceLROnPlateau(optimizer, 'min', patience=args['lr_patience'], min_lr=1e-10) 107 | for epoch in range(curr_epoch, args['epoch_num'] + 1): 108 | train(train_loader, net, criterion, optimizer, epoch, args) 109 | val_loss = validate(val_loader, net, criterion, optimizer, epoch, args, restore_transform, visualize) 110 | scheduler.step(val_loss) 111 | 112 | 113 | def train(train_loader, net, criterion, optimizer, epoch, train_args): 114 | train_loss = AverageMeter() 115 | curr_iter = (epoch - 1) * len(train_loader) 116 | for i, data in enumerate(train_loader): 117 | inputs, labels = data 118 | assert inputs.size()[2:] == labels.size()[1:] 119 | N = inputs.size(0) 120 | inputs = Variable(inputs).cuda() 121 | labels = Variable(labels).cuda() 122 | 123 | optimizer.zero_grad() 124 | outputs = net(inputs) 125 | assert outputs.size()[2:] == labels.size()[1:] 126 | assert outputs.size()[1] == cityscapes.num_classes 127 | 128 | loss = criterion(outputs, labels) / N 129 | loss.backward() 130 | optimizer.step() 131 | 132 | train_loss.update(loss.data[0], N) 133 | 134 | curr_iter += 1 135 | writer.add_scalar('train_loss', train_loss.avg, curr_iter) 136 | 137 | if (i + 1) % train_args['print_freq'] == 0: 138 | print('[epoch %d], [iter %d / %d], [train loss %.5f]' % ( 139 | epoch, i + 1, len(train_loader), train_loss.avg)) 140 | 141 | 142 | def validate(val_loader, net, criterion, optimizer, epoch, train_args, restore, visualize): 143 | net.eval() 144 | 145 | val_loss = AverageMeter() 146 | inputs_all, gts_all, predictions_all = [], [], [] 147 | 148 | for vi, data in enumerate(val_loader): 149 | inputs, gts = data 150 | N = inputs.size(0) 151 | inputs = Variable(inputs, volatile=True).cuda() 152 | gts = Variable(gts, volatile=True).cuda() 153 | 154 | outputs = net(inputs) 155 | predictions = outputs.data.max(1)[1].squeeze_(1).cpu().numpy() 156 | 157 | val_loss.update(criterion(outputs, gts).data[0] / N, N) 158 | 159 | for i in inputs: 160 | if random.random() > train_args['val_img_sample_rate']: 161 | inputs_all.append(None) 162 | else: 163 | inputs_all.append(i.data.cpu()) 164 | gts_all.append(gts.data.cpu().numpy()) 165 | predictions_all.append(predictions) 166 | 167 | gts_all = np.concatenate(gts_all) 168 | predictions_all = np.concatenate(predictions_all) 169 | 170 | acc, acc_cls, mean_iu, fwavacc = evaluate(predictions_all, gts_all, cityscapes.num_classes) 171 | 172 | if mean_iu > train_args['best_record']['mean_iu']: 173 | train_args['best_record']['val_loss'] = val_loss.avg 174 | train_args['best_record']['epoch'] = epoch 175 | train_args['best_record']['acc'] = acc 176 | train_args['best_record']['acc_cls'] = acc_cls 177 | train_args['best_record']['mean_iu'] = mean_iu 178 | train_args['best_record']['fwavacc'] = fwavacc 179 | snapshot_name = 'epoch_%d_loss_%.5f_acc_%.5f_acc-cls_%.5f_mean-iu_%.5f_fwavacc_%.5f_lr_%.10f' % ( 180 | epoch, val_loss.avg, acc, acc_cls, mean_iu, fwavacc, optimizer.param_groups[1]['lr'] 181 | ) 182 | torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, snapshot_name + '.pth')) 183 | torch.save(optimizer.state_dict(), os.path.join(ckpt_path, exp_name, 'opt_' + snapshot_name + '.pth')) 184 | 185 | if train_args['val_save_to_img_file']: 186 | to_save_dir = os.path.join(ckpt_path, exp_name, str(epoch)) 187 | check_mkdir(to_save_dir) 188 | 189 | val_visual = [] 190 | for idx, data in enumerate(zip(inputs_all, gts_all, predictions_all)): 191 | if data[0] is None: 192 | continue 193 | input_pil = restore(data[0]) 194 | gt_pil = cityscapes.colorize_mask(data[1]) 195 | predictions_pil = cityscapes.colorize_mask(data[2]) 196 | if train_args['val_save_to_img_file']: 197 | input_pil.save(os.path.join(to_save_dir, '%d_input.png' % idx)) 198 | predictions_pil.save(os.path.join(to_save_dir, '%d_prediction.png' % idx)) 199 | gt_pil.save(os.path.join(to_save_dir, '%d_gt.png' % idx)) 200 | val_visual.extend([visualize(input_pil.convert('RGB')), visualize(gt_pil.convert('RGB')), 201 | visualize(predictions_pil.convert('RGB'))]) 202 | val_visual = torch.stack(val_visual, 0) 203 | val_visual = vutils.make_grid(val_visual, nrow=3, padding=5) 204 | writer.add_image(snapshot_name, val_visual) 205 | 206 | print('-----------------------------------------------------------------------------------------------------------') 207 | print('[epoch %d], [val loss %.5f], [acc %.5f], [acc_cls %.5f], [mean_iu %.5f], [fwavacc %.5f]' % ( 208 | epoch, val_loss.avg, acc, acc_cls, mean_iu, fwavacc)) 209 | 210 | print('best record: [val loss %.5f], [acc %.5f], [acc_cls %.5f], [mean_iu %.5f], [fwavacc %.5f], [epoch %d]' % ( 211 | train_args['best_record']['val_loss'], train_args['best_record']['acc'], train_args['best_record']['acc_cls'], 212 | train_args['best_record']['mean_iu'], train_args['best_record']['fwavacc'], train_args['best_record']['epoch'])) 213 | 214 | print('-----------------------------------------------------------------------------------------------------------') 215 | 216 | writer.add_scalar('val_loss', val_loss.avg, epoch) 217 | writer.add_scalar('acc', acc, epoch) 218 | writer.add_scalar('acc_cls', acc_cls, epoch) 219 | writer.add_scalar('mean_iu', mean_iu, epoch) 220 | writer.add_scalar('fwavacc', fwavacc, epoch) 221 | writer.add_scalar('lr', optimizer.param_groups[1]['lr'], epoch) 222 | 223 | net.train() 224 | return val_loss.avg 225 | 226 | 227 | if __name__ == '__main__': 228 | main() 229 | -------------------------------------------------------------------------------- /train/cityscapes-psp_net/static/0_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijundeng/pytorch-semantic-segmentation/4a1721f9a3284788336430efb140288096c6dd09/train/cityscapes-psp_net/static/0_gt.png -------------------------------------------------------------------------------- /train/cityscapes-psp_net/static/0_prediction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijundeng/pytorch-semantic-segmentation/4a1721f9a3284788336430efb140288096c6dd09/train/cityscapes-psp_net/static/0_prediction.png -------------------------------------------------------------------------------- /train/cityscapes-psp_net/static/1_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijundeng/pytorch-semantic-segmentation/4a1721f9a3284788336430efb140288096c6dd09/train/cityscapes-psp_net/static/1_gt.png -------------------------------------------------------------------------------- /train/cityscapes-psp_net/static/1_prediction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijundeng/pytorch-semantic-segmentation/4a1721f9a3284788336430efb140288096c6dd09/train/cityscapes-psp_net/static/1_prediction.png -------------------------------------------------------------------------------- /train/cityscapes-psp_net/static/2_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijundeng/pytorch-semantic-segmentation/4a1721f9a3284788336430efb140288096c6dd09/train/cityscapes-psp_net/static/2_gt.png -------------------------------------------------------------------------------- /train/cityscapes-psp_net/static/2_prediction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijundeng/pytorch-semantic-segmentation/4a1721f9a3284788336430efb140288096c6dd09/train/cityscapes-psp_net/static/2_prediction.png -------------------------------------------------------------------------------- /train/cityscapes-psp_net/static/3_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijundeng/pytorch-semantic-segmentation/4a1721f9a3284788336430efb140288096c6dd09/train/cityscapes-psp_net/static/3_gt.png -------------------------------------------------------------------------------- /train/cityscapes-psp_net/static/3_prediction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijundeng/pytorch-semantic-segmentation/4a1721f9a3284788336430efb140288096c6dd09/train/cityscapes-psp_net/static/3_prediction.png -------------------------------------------------------------------------------- /train/cityscapes-psp_net/static/4_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijundeng/pytorch-semantic-segmentation/4a1721f9a3284788336430efb140288096c6dd09/train/cityscapes-psp_net/static/4_gt.png -------------------------------------------------------------------------------- /train/cityscapes-psp_net/static/4_prediction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijundeng/pytorch-semantic-segmentation/4a1721f9a3284788336430efb140288096c6dd09/train/cityscapes-psp_net/static/4_prediction.png -------------------------------------------------------------------------------- /train/cityscapes-psp_net/static/5_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijundeng/pytorch-semantic-segmentation/4a1721f9a3284788336430efb140288096c6dd09/train/cityscapes-psp_net/static/5_gt.png -------------------------------------------------------------------------------- /train/cityscapes-psp_net/static/5_prediction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijundeng/pytorch-semantic-segmentation/4a1721f9a3284788336430efb140288096c6dd09/train/cityscapes-psp_net/static/5_prediction.png -------------------------------------------------------------------------------- /train/cityscapes-psp_net/static/6_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijundeng/pytorch-semantic-segmentation/4a1721f9a3284788336430efb140288096c6dd09/train/cityscapes-psp_net/static/6_gt.png -------------------------------------------------------------------------------- /train/cityscapes-psp_net/static/6_prediction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijundeng/pytorch-semantic-segmentation/4a1721f9a3284788336430efb140288096c6dd09/train/cityscapes-psp_net/static/6_prediction.png -------------------------------------------------------------------------------- /train/cityscapes-psp_net/static/7_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijundeng/pytorch-semantic-segmentation/4a1721f9a3284788336430efb140288096c6dd09/train/cityscapes-psp_net/static/7_gt.png -------------------------------------------------------------------------------- /train/cityscapes-psp_net/static/7_prediction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijundeng/pytorch-semantic-segmentation/4a1721f9a3284788336430efb140288096c6dd09/train/cityscapes-psp_net/static/7_prediction.png -------------------------------------------------------------------------------- /train/cityscapes-psp_net/static/8_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijundeng/pytorch-semantic-segmentation/4a1721f9a3284788336430efb140288096c6dd09/train/cityscapes-psp_net/static/8_gt.png -------------------------------------------------------------------------------- /train/cityscapes-psp_net/static/8_prediction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijundeng/pytorch-semantic-segmentation/4a1721f9a3284788336430efb140288096c6dd09/train/cityscapes-psp_net/static/8_prediction.png -------------------------------------------------------------------------------- /train/cityscapes-psp_net/static/9_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijundeng/pytorch-semantic-segmentation/4a1721f9a3284788336430efb140288096c6dd09/train/cityscapes-psp_net/static/9_gt.png -------------------------------------------------------------------------------- /train/cityscapes-psp_net/static/9_prediction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijundeng/pytorch-semantic-segmentation/4a1721f9a3284788336430efb140288096c6dd09/train/cityscapes-psp_net/static/9_prediction.png -------------------------------------------------------------------------------- /train/cityscapes-psp_net/train.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | from math import sqrt 4 | 5 | import numpy as np 6 | import torchvision.transforms as standard_transforms 7 | from tensorboard import SummaryWriter 8 | from torch import optim 9 | from torch.autograd import Variable 10 | from torch.utils.data import DataLoader 11 | 12 | import utils.joint_transforms as joint_transforms 13 | import utils.transforms as extended_transforms 14 | from datasets import cityscapes 15 | from models import * 16 | from utils import check_mkdir, evaluate, AverageMeter, CrossEntropyLoss2d 17 | 18 | ckpt_path = '../../ckpt' 19 | exp_name = 'cityscapes (fine)-psp_net' 20 | writer = SummaryWriter(os.path.join(ckpt_path, 'exp', exp_name)) 21 | 22 | args = { 23 | 'train_batch_size': 2, 24 | 'lr': 1e-2 / sqrt(16 / 2), 25 | 'lr_decay': 0.9, 26 | 'max_iter': 9e4, 27 | 'longer_size': 2048, 28 | 'crop_size': 713, 29 | 'stride_rate': 2 / 3., 30 | 'weight_decay': 1e-4, 31 | 'momentum': 0.9, 32 | 'snapshot': '', 33 | 'print_freq': 10, 34 | 'val_save_to_img_file': False, 35 | 'val_img_sample_rate': 0.01, # randomly sample some validation results to display, 36 | 'val_img_display_size': 384, 37 | 'val_freq': 400 38 | } 39 | 40 | 41 | def main(): 42 | net = PSPNet(num_classes=cityscapes.num_classes) 43 | 44 | if len(args['snapshot']) == 0: 45 | # net.load_state_dict(torch.load(os.path.join(ckpt_path, 'cityscapes (coarse)-psp_net', 'xx.pth'))) 46 | curr_epoch = 1 47 | args['best_record'] = {'epoch': 0, 'iter': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 48 | 'fwavacc': 0} 49 | else: 50 | print('training resumes from ' + args['snapshot']) 51 | net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot']))) 52 | split_snapshot = args['snapshot'].split('_') 53 | curr_epoch = int(split_snapshot[1]) + 1 54 | args['best_record'] = {'epoch': int(split_snapshot[1]), 'iter': int(split_snapshot[3]), 55 | 'val_loss': float(split_snapshot[5]), 'acc': float(split_snapshot[7]), 56 | 'acc_cls': float(split_snapshot[9]),'mean_iu': float(split_snapshot[11]), 57 | 'fwavacc': float(split_snapshot[13])} 58 | net.cuda().train() 59 | 60 | mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 61 | 62 | train_joint_transform = joint_transforms.Compose([ 63 | joint_transforms.Scale(args['longer_size']), 64 | joint_transforms.RandomRotate(10), 65 | joint_transforms.RandomHorizontallyFlip() 66 | ]) 67 | sliding_crop = joint_transforms.SlidingCrop(args['crop_size'], args['stride_rate'], cityscapes.ignore_label) 68 | train_input_transform = standard_transforms.Compose([ 69 | standard_transforms.ToTensor(), 70 | standard_transforms.Normalize(*mean_std) 71 | ]) 72 | val_input_transform = standard_transforms.Compose([ 73 | standard_transforms.ToTensor(), 74 | standard_transforms.Normalize(*mean_std) 75 | ]) 76 | target_transform = extended_transforms.MaskToTensor() 77 | visualize = standard_transforms.Compose([ 78 | standard_transforms.Scale(args['val_img_display_size']), 79 | standard_transforms.ToTensor() 80 | ]) 81 | 82 | train_set = cityscapes.CityScapes('fine', 'train', joint_transform=train_joint_transform, sliding_crop=sliding_crop, 83 | transform=train_input_transform, target_transform=target_transform) 84 | train_loader = DataLoader(train_set, batch_size=args['train_batch_size'], num_workers=8, shuffle=True) 85 | val_set = cityscapes.CityScapes('fine', 'val', transform=val_input_transform, sliding_crop=sliding_crop, 86 | target_transform=target_transform) 87 | val_loader = DataLoader(val_set, batch_size=1, num_workers=8, shuffle=False) 88 | 89 | criterion = CrossEntropyLoss2d(size_average=True, ignore_index=cityscapes.ignore_label).cuda() 90 | 91 | optimizer = optim.SGD([ 92 | {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'], 93 | 'lr': 2 * args['lr']}, 94 | {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'], 95 | 'lr': args['lr'], 'weight_decay': args['weight_decay']} 96 | ], momentum=args['momentum'], nesterov=True) 97 | 98 | if len(args['snapshot']) > 0: 99 | optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, 'opt_' + args['snapshot']))) 100 | optimizer.param_groups[0]['lr'] = 2 * args['lr'] 101 | optimizer.param_groups[1]['lr'] = args['lr'] 102 | 103 | check_mkdir(ckpt_path) 104 | check_mkdir(os.path.join(ckpt_path, exp_name)) 105 | open(os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt'), 'w').write(str(args) + '\n\n') 106 | 107 | train(train_loader, net, criterion, optimizer, curr_epoch, args, val_loader, visualize) 108 | 109 | 110 | def train(train_loader, net, criterion, optimizer, curr_epoch, train_args, val_loader, visualize): 111 | while True: 112 | train_main_loss = AverageMeter() 113 | train_aux_loss = AverageMeter() 114 | curr_iter = (curr_epoch - 1) * len(train_loader) 115 | for i, data in enumerate(train_loader): 116 | optimizer.param_groups[0]['lr'] = 2 * train_args['lr'] * (1 - float(curr_iter) / train_args['max_iter'] 117 | ) ** train_args['lr_decay'] 118 | optimizer.param_groups[1]['lr'] = train_args['lr'] * (1 - float(curr_iter) / train_args['max_iter'] 119 | ) ** train_args['lr_decay'] 120 | 121 | inputs, gts, _ = data 122 | assert len(inputs.size()) == 5 and len(gts.size()) == 4 123 | inputs.transpose_(0, 1) 124 | gts.transpose_(0, 1) 125 | 126 | assert inputs.size()[3:] == gts.size()[2:] 127 | slice_batch_pixel_size = inputs.size(1) * inputs.size(3) * inputs.size(4) 128 | 129 | for inputs_slice, gts_slice in zip(inputs, gts): 130 | inputs_slice = Variable(inputs_slice).cuda() 131 | gts_slice = Variable(gts_slice).cuda() 132 | 133 | optimizer.zero_grad() 134 | outputs, aux = net(inputs_slice) 135 | assert outputs.size()[2:] == gts_slice.size()[1:] 136 | assert outputs.size()[1] == cityscapes.num_classes 137 | 138 | main_loss = criterion(outputs, gts_slice) 139 | aux_loss = criterion(aux, gts_slice) 140 | loss = main_loss + 0.4 * aux_loss 141 | loss.backward() 142 | optimizer.step() 143 | 144 | train_main_loss.update(main_loss.data[0], slice_batch_pixel_size) 145 | train_aux_loss.update(aux_loss.data[0], slice_batch_pixel_size) 146 | 147 | curr_iter += 1 148 | writer.add_scalar('train_main_loss', train_main_loss.avg, curr_iter) 149 | writer.add_scalar('train_aux_loss', train_aux_loss.avg, curr_iter) 150 | writer.add_scalar('lr', optimizer.param_groups[1]['lr'], curr_iter) 151 | 152 | if (i + 1) % train_args['print_freq'] == 0: 153 | print('[epoch %d], [iter %d / %d], [train main loss %.5f], [train aux loss %.5f]. [lr %.10f]' % ( 154 | curr_epoch, i + 1, len(train_loader), train_main_loss.avg, train_aux_loss.avg, 155 | optimizer.param_groups[1]['lr'])) 156 | if curr_iter >= train_args['max_iter']: 157 | return 158 | if curr_iter % train_args['val_freq'] == 0: 159 | validate(val_loader, net, criterion, optimizer, curr_epoch, i + 1, train_args, visualize) 160 | curr_epoch += 1 161 | 162 | 163 | def validate(val_loader, net, criterion, optimizer, epoch, iter_num, train_args, visualize): 164 | # the following code is written assuming that batch size is 1 165 | net.eval() 166 | 167 | val_loss = AverageMeter() 168 | 169 | gts_all = np.zeros((len(val_loader), args['longer_size'] / 2, args['longer_size']), dtype=int) 170 | predictions_all = np.zeros((len(val_loader), args['longer_size'] / 2, args['longer_size']), dtype=int) 171 | for vi, data in enumerate(val_loader): 172 | input, gt, slices_info = data 173 | assert len(input.size()) == 5 and len(gt.size()) == 4 and len(slices_info.size()) == 3 174 | input.transpose_(0, 1) 175 | gt.transpose_(0, 1) 176 | slices_info.squeeze_(0) 177 | assert input.size()[3:] == gt.size()[2:] 178 | 179 | count = torch.zeros(args['longer_size'] / 2, args['longer_size']).cuda() 180 | output = torch.zeros(cityscapes.num_classes, args['longer_size'] / 2, args['longer_size']).cuda() 181 | 182 | slice_batch_pixel_size = input.size(1) * input.size(3) * input.size(4) 183 | 184 | for input_slice, gt_slice, info in zip(input, gt, slices_info): 185 | input_slice = Variable(input_slice).cuda() 186 | gt_slice = Variable(gt_slice).cuda() 187 | 188 | output_slice = net(input_slice) 189 | assert output_slice.size()[2:] == gt_slice.size()[1:] 190 | assert output_slice.size()[1] == cityscapes.num_classes 191 | output[:, info[0]: info[1], info[2]: info[3]] += output_slice[0, :, :info[4], :info[5]].data 192 | gts_all[vi, info[0]: info[1], info[2]: info[3]] += gt_slice[0, :info[4], :info[5]].data.cpu().numpy() 193 | 194 | count[info[0]: info[1], info[2]: info[3]] += 1 195 | 196 | val_loss.update(criterion(output_slice, gt_slice).data[0], slice_batch_pixel_size) 197 | 198 | output /= count 199 | gts_all[vi, :, :] /= count.cpu().numpy().astype(int) 200 | predictions_all[vi, :, :] = output.max(0)[1].squeeze_(0).cpu().numpy() 201 | 202 | print('validating: %d / %d' % (vi + 1, len(val_loader))) 203 | 204 | acc, acc_cls, mean_iu, fwavacc = evaluate(predictions_all, gts_all, cityscapes.num_classes) 205 | if val_loss.avg < train_args['best_record']['val_loss']: 206 | train_args['best_record']['val_loss'] = val_loss.avg 207 | train_args['best_record']['epoch'] = epoch 208 | train_args['best_record']['iter'] = iter_num 209 | train_args['best_record']['acc'] = acc 210 | train_args['best_record']['acc_cls'] = acc_cls 211 | train_args['best_record']['mean_iu'] = mean_iu 212 | train_args['best_record']['fwavacc'] = fwavacc 213 | snapshot_name = 'epoch_%d_iter_%d_loss_%.5f_acc_%.5f_acc-cls_%.5f_mean-iu_%.5f_fwavacc_%.5f_lr_%.10f' % ( 214 | epoch, iter_num, val_loss.avg, acc, acc_cls, mean_iu, fwavacc, optimizer.param_groups[1]['lr']) 215 | torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, snapshot_name + '.pth')) 216 | torch.save(optimizer.state_dict(), os.path.join(ckpt_path, exp_name, 'opt_' + snapshot_name + '.pth')) 217 | 218 | if train_args['val_save_to_img_file']: 219 | to_save_dir = os.path.join(ckpt_path, exp_name, '%d_%d' % (epoch, iter_num)) 220 | check_mkdir(to_save_dir) 221 | 222 | val_visual = [] 223 | for idx, data in enumerate(zip(gts_all, predictions_all)): 224 | gt_pil = cityscapes.colorize_mask(data[0]) 225 | predictions_pil = cityscapes.colorize_mask(data[1]) 226 | if train_args['val_save_to_img_file']: 227 | predictions_pil.save(os.path.join(to_save_dir, '%d_prediction.png' % idx)) 228 | gt_pil.save(os.path.join(to_save_dir, '%d_gt.png' % idx)) 229 | val_visual.extend([visualize(gt_pil.convert('RGB')), 230 | visualize(predictions_pil.convert('RGB'))]) 231 | val_visual = torch.stack(val_visual, 0) 232 | val_visual = vutils.make_grid(val_visual, nrow=2, padding=5) 233 | writer.add_image(snapshot_name, val_visual) 234 | 235 | print('-----------------------------------------------------------------------------------------------------------') 236 | print('[epoch %d], [iter %d], [val loss %.5f], [acc %.5f], [acc_cls %.5f], [mean_iu %.5f], [fwavacc %.5f]' % ( 237 | epoch, iter_num, val_loss.avg, acc, acc_cls, mean_iu, fwavacc)) 238 | 239 | print('best record: [val loss %.5f], [acc %.5f], [acc_cls %.5f], [mean_iu %.5f], [fwavacc %.5f], [epoch %d], ' 240 | '[iter %d]' % (train_args['best_record']['val_loss'], train_args['best_record']['acc'], 241 | train_args['best_record']['acc_cls'], train_args['best_record']['mean_iu'], 242 | train_args['best_record']['fwavacc'], train_args['best_record']['epoch'], 243 | train_args['best_record']['iter'])) 244 | 245 | print('-----------------------------------------------------------------------------------------------------------') 246 | 247 | writer.add_scalar('val_loss', val_loss.avg, epoch) 248 | writer.add_scalar('acc', acc, epoch) 249 | writer.add_scalar('acc_cls', acc_cls, epoch) 250 | writer.add_scalar('mean_iu', mean_iu, epoch) 251 | writer.add_scalar('fwavacc', fwavacc, epoch) 252 | 253 | net.train() 254 | return val_loss.avg 255 | 256 | 257 | if __name__ == '__main__': 258 | main() 259 | -------------------------------------------------------------------------------- /train/voc-fcn (caffe vgg)/README.md: -------------------------------------------------------------------------------- 1 | # Results 2 | 3 | ## Metrics 4 | training batch size: 1, iter num per epoch: 8k, lr: 1e-10, sum the pixel loss 5 | ![](static/fcn8s-train_loss.jpg) 6 | 7 | validate the loss and mean_iu after training of one epoch 8 | ![](static/fcn8s-val_loss.jpg) 9 | 10 | ![](static/fcn8s-mean_iu.jpg) 11 | 12 | ## Visualization 13 | ![](static/fcn8s-epoch9.jpg) 14 | -------------------------------------------------------------------------------- /train/voc-fcn (caffe vgg)/static/fcn8s-epoch9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijundeng/pytorch-semantic-segmentation/4a1721f9a3284788336430efb140288096c6dd09/train/voc-fcn (caffe vgg)/static/fcn8s-epoch9.jpg -------------------------------------------------------------------------------- /train/voc-fcn (caffe vgg)/static/fcn8s-mean_iu.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijundeng/pytorch-semantic-segmentation/4a1721f9a3284788336430efb140288096c6dd09/train/voc-fcn (caffe vgg)/static/fcn8s-mean_iu.jpg -------------------------------------------------------------------------------- /train/voc-fcn (caffe vgg)/static/fcn8s-train_loss.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijundeng/pytorch-semantic-segmentation/4a1721f9a3284788336430efb140288096c6dd09/train/voc-fcn (caffe vgg)/static/fcn8s-train_loss.jpg -------------------------------------------------------------------------------- /train/voc-fcn (caffe vgg)/static/fcn8s-val_loss.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijundeng/pytorch-semantic-segmentation/4a1721f9a3284788336430efb140288096c6dd09/train/voc-fcn (caffe vgg)/static/fcn8s-val_loss.jpg -------------------------------------------------------------------------------- /train/voc-fcn (caffe vgg)/train.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import random 4 | 5 | import torchvision.transforms as standard_transforms 6 | import torchvision.utils as vutils 7 | from tensorboard import SummaryWriter 8 | from torch import optim 9 | from torch.autograd import Variable 10 | from torch.backends import cudnn 11 | from torch.optim.lr_scheduler import ReduceLROnPlateau 12 | from torch.utils.data import DataLoader 13 | 14 | import utils.transforms as extended_transforms 15 | from datasets import voc 16 | from models import * 17 | from utils import check_mkdir, evaluate, AverageMeter, CrossEntropyLoss2d 18 | 19 | cudnn.benchmark = True 20 | 21 | ckpt_path = '../../ckpt' 22 | exp_name = 'voc-fcn8s (caffe vgg)' 23 | writer = SummaryWriter(os.path.join(ckpt_path, 'exp', exp_name)) 24 | 25 | args = { 26 | 'epoch_num': 30, 27 | 'lr': 1e-10, 28 | 'weight_decay': 5e-4, 29 | 'momentum': 0.99, 30 | 'lr_patience': 100, # large patience denotes fixed lr 31 | 'snapshot': '', # empty string denotes learning from scratch 32 | 'print_freq': 20, 33 | 'val_save_to_img_file': False, 34 | 'val_img_sample_rate': 0.1 # randomly sample some validation results to display 35 | } 36 | 37 | 38 | def main(train_args): 39 | net = FCN8s(num_classes=voc.num_classes, caffe=True).cuda() 40 | 41 | if len(train_args['snapshot']) == 0: 42 | curr_epoch = 1 43 | train_args['best_record'] = {'epoch': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0} 44 | else: 45 | print('training resumes from ' + train_args['snapshot']) 46 | net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, train_args['snapshot']))) 47 | split_snapshot = train_args['snapshot'].split('_') 48 | curr_epoch = int(split_snapshot[1]) + 1 49 | train_args['best_record'] = {'epoch': int(split_snapshot[1]), 'val_loss': float(split_snapshot[3]), 50 | 'acc': float(split_snapshot[5]), 'acc_cls': float(split_snapshot[7]), 51 | 'mean_iu': float(split_snapshot[9]), 'fwavacc': float(split_snapshot[11])} 52 | 53 | net.train() 54 | 55 | mean_std = ([103.939, 116.779, 123.68], [1.0, 1.0, 1.0]) 56 | 57 | input_transform = standard_transforms.Compose([ 58 | extended_transforms.FlipChannels(), 59 | standard_transforms.ToTensor(), 60 | standard_transforms.Lambda(lambda x: x.mul_(255)), 61 | standard_transforms.Normalize(*mean_std) 62 | ]) 63 | target_transform = extended_transforms.MaskToTensor() 64 | restore_transform = standard_transforms.Compose([ 65 | extended_transforms.DeNormalize(*mean_std), 66 | standard_transforms.Lambda(lambda x: x.div_(255)), 67 | standard_transforms.ToPILImage(), 68 | extended_transforms.FlipChannels() 69 | ]) 70 | visualize = standard_transforms.Compose([ 71 | standard_transforms.Scale(400), 72 | standard_transforms.CenterCrop(400), 73 | standard_transforms.ToTensor() 74 | ]) 75 | 76 | train_set = voc.VOC('train', transform=input_transform, target_transform=target_transform) 77 | train_loader = DataLoader(train_set, batch_size=1, num_workers=4, shuffle=True) 78 | val_set = voc.VOC('val', transform=input_transform, target_transform=target_transform) 79 | val_loader = DataLoader(val_set, batch_size=1, num_workers=4, shuffle=False) 80 | 81 | criterion = CrossEntropyLoss2d(size_average=False, ignore_index=voc.ignore_label).cuda() 82 | 83 | optimizer = optim.SGD([ 84 | {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'], 85 | 'lr': 2 * train_args['lr']}, 86 | {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'], 87 | 'lr': train_args['lr'], 'weight_decay': train_args['weight_decay']} 88 | ], momentum=train_args['momentum']) 89 | 90 | if len(train_args['snapshot']) > 0: 91 | optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, 'opt_' + train_args['snapshot']))) 92 | optimizer.param_groups[0]['lr'] = 2 * train_args['lr'] 93 | optimizer.param_groups[1]['lr'] = train_args['lr'] 94 | 95 | check_mkdir(ckpt_path) 96 | check_mkdir(os.path.join(ckpt_path, exp_name)) 97 | open(os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt'), 'w').write(str(train_args) + '\n\n') 98 | 99 | scheduler = ReduceLROnPlateau(optimizer, 'min', patience=train_args['lr_patience'], min_lr=1e-10, verbose=True) 100 | for epoch in range(curr_epoch, train_args['epoch_num'] + 1): 101 | train(train_loader, net, criterion, optimizer, epoch, train_args) 102 | val_loss = validate(val_loader, net, criterion, optimizer, epoch, train_args, restore_transform, visualize) 103 | scheduler.step(val_loss) 104 | 105 | 106 | def train(train_loader, net, criterion, optimizer, epoch, train_args): 107 | train_loss = AverageMeter() 108 | curr_iter = (epoch - 1) * len(train_loader) 109 | for i, data in enumerate(train_loader): 110 | inputs, labels = data 111 | assert inputs.size()[2:] == labels.size()[1:] 112 | N = inputs.size(0) 113 | inputs = Variable(inputs).cuda() 114 | labels = Variable(labels).cuda() 115 | 116 | optimizer.zero_grad() 117 | outputs = net(inputs) 118 | assert outputs.size()[2:] == labels.size()[1:] 119 | assert outputs.size()[1] == voc.num_classes 120 | 121 | loss = criterion(outputs, labels) / N 122 | loss.backward() 123 | optimizer.step() 124 | 125 | train_loss.update(loss.data[0], N) 126 | 127 | curr_iter += 1 128 | writer.add_scalar('train_loss', train_loss.avg, curr_iter) 129 | 130 | if (i + 1) % train_args['print_freq'] == 0: 131 | print('[epoch %d], [iter %d / %d], [train loss %.5f]' % ( 132 | epoch, i + 1, len(train_loader), train_loss.avg 133 | )) 134 | 135 | 136 | def validate(val_loader, net, criterion, optimizer, epoch, train_args, restore, visualize): 137 | net.eval() 138 | 139 | val_loss = AverageMeter() 140 | inputs_all, gts_all, predictions_all = [], [], [] 141 | 142 | for vi, data in enumerate(val_loader): 143 | inputs, gts = data 144 | N = inputs.size(0) 145 | inputs = Variable(inputs, volatile=True).cuda() 146 | gts = Variable(gts, volatile=True).cuda() 147 | 148 | outputs = net(inputs) 149 | predictions = outputs.data.max(1)[1].squeeze_(1).squeeze_(0).cpu().numpy() 150 | 151 | val_loss.update(criterion(outputs, gts).data[0] / N, N) 152 | 153 | if random.random() > train_args['val_img_sample_rate']: 154 | inputs_all.append(None) 155 | else: 156 | inputs_all.append(inputs.data.squeeze_(0).cpu()) 157 | gts_all.append(gts.data.squeeze_(0).cpu().numpy()) 158 | predictions_all.append(predictions) 159 | 160 | acc, acc_cls, mean_iu, fwavacc = evaluate(predictions_all, gts_all, voc.num_classes) 161 | 162 | if mean_iu > train_args['best_record']['mean_iu']: 163 | train_args['best_record']['val_loss'] = val_loss.avg 164 | train_args['best_record']['epoch'] = epoch 165 | train_args['best_record']['acc'] = acc 166 | train_args['best_record']['acc_cls'] = acc_cls 167 | train_args['best_record']['mean_iu'] = mean_iu 168 | train_args['best_record']['fwavacc'] = fwavacc 169 | snapshot_name = 'epoch_%d_loss_%.5f_acc_%.5f_acc-cls_%.5f_mean-iu_%.5f_fwavacc_%.5f_lr_%.10f' % ( 170 | epoch, val_loss.avg, acc, acc_cls, mean_iu, fwavacc, optimizer.param_groups[1]['lr'] 171 | ) 172 | torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, snapshot_name + '.pth')) 173 | torch.save(optimizer.state_dict(), os.path.join(ckpt_path, exp_name, 'opt_' + snapshot_name + '.pth')) 174 | 175 | if train_args['val_save_to_img_file']: 176 | to_save_dir = os.path.join(ckpt_path, exp_name, str(epoch)) 177 | check_mkdir(to_save_dir) 178 | 179 | val_visual = [] 180 | for idx, data in enumerate(zip(inputs_all, gts_all, predictions_all)): 181 | if data[0] is None: 182 | continue 183 | input_pil = restore(data[0]) 184 | gt_pil = voc.colorize_mask(data[1]) 185 | predictions_pil = voc.colorize_mask(data[2]) 186 | if train_args['val_save_to_img_file']: 187 | input_pil.save(os.path.join(to_save_dir, '%d_input.png' % idx)) 188 | predictions_pil.save(os.path.join(to_save_dir, '%d_prediction.png' % idx)) 189 | gt_pil.save(os.path.join(to_save_dir, '%d_gt.png' % idx)) 190 | val_visual.extend([visualize(input_pil.convert('RGB')), visualize(gt_pil.convert('RGB')), 191 | visualize(predictions_pil.convert('RGB'))]) 192 | val_visual = torch.stack(val_visual, 0) 193 | val_visual = vutils.make_grid(val_visual, nrow=3, padding=5) 194 | writer.add_image(snapshot_name, val_visual) 195 | 196 | print('--------------------------------------------------------------------') 197 | print('[epoch %d], [val loss %.5f], [acc %.5f], [acc_cls %.5f], [mean_iu %.5f], [fwavacc %.5f]' % ( 198 | epoch, val_loss.avg, acc, acc_cls, mean_iu, fwavacc)) 199 | 200 | print('best record: [val loss %.5f], [acc %.5f], [acc_cls %.5f], [mean_iu %.5f], [fwavacc %.5f], [epoch %d]' % ( 201 | train_args['best_record']['val_loss'], train_args['best_record']['acc'], train_args['best_record']['acc_cls'], 202 | train_args['best_record']['mean_iu'], train_args['best_record']['fwavacc'], train_args['best_record']['epoch'])) 203 | 204 | print('--------------------------------------------------------------------') 205 | 206 | writer.add_scalar('val_loss', val_loss.avg, epoch) 207 | writer.add_scalar('acc', acc, epoch) 208 | writer.add_scalar('acc_cls', acc_cls, epoch) 209 | writer.add_scalar('mean_iu', mean_iu, epoch) 210 | writer.add_scalar('fwavacc', fwavacc, epoch) 211 | writer.add_scalar('lr', optimizer.param_groups[1]['lr'], epoch) 212 | 213 | net.train() 214 | return val_loss.avg 215 | 216 | 217 | if __name__ == '__main__': 218 | main(args) 219 | -------------------------------------------------------------------------------- /train/voc-fcn/train.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import random 4 | 5 | import torchvision.transforms as standard_transforms 6 | import torchvision.utils as vutils 7 | from tensorboard import SummaryWriter 8 | from torch import optim 9 | from torch.autograd import Variable 10 | from torch.backends import cudnn 11 | from torch.optim.lr_scheduler import ReduceLROnPlateau 12 | from torch.utils.data import DataLoader 13 | 14 | import utils.transforms as extended_transforms 15 | from datasets import voc 16 | from models import * 17 | from utils import check_mkdir, evaluate, AverageMeter, CrossEntropyLoss2d 18 | 19 | cudnn.benchmark = True 20 | 21 | ckpt_path = '../../ckpt' 22 | exp_name = 'voc-fcn8s' 23 | writer = SummaryWriter(os.path.join(ckpt_path, 'exp', exp_name)) 24 | 25 | args = { 26 | 'epoch_num': 300, 27 | 'lr': 1e-10, 28 | 'weight_decay': 1e-4, 29 | 'momentum': 0.95, 30 | 'lr_patience': 100, # large patience denotes fixed lr 31 | 'snapshot': '', # empty string denotes learning from scratch 32 | 'print_freq': 20, 33 | 'val_save_to_img_file': False, 34 | 'val_img_sample_rate': 0.1 # randomly sample some validation results to display 35 | } 36 | 37 | 38 | def main(train_args): 39 | net = FCN8s(num_classes=voc.num_classes).cuda() 40 | 41 | if len(train_args['snapshot']) == 0: 42 | curr_epoch = 1 43 | train_args['best_record'] = {'epoch': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0} 44 | else: 45 | print('training resumes from ' + train_args['snapshot']) 46 | net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, train_args['snapshot']))) 47 | split_snapshot = train_args['snapshot'].split('_') 48 | curr_epoch = int(split_snapshot[1]) + 1 49 | train_args['best_record'] = {'epoch': int(split_snapshot[1]), 'val_loss': float(split_snapshot[3]), 50 | 'acc': float(split_snapshot[5]), 'acc_cls': float(split_snapshot[7]), 51 | 'mean_iu': float(split_snapshot[9]), 'fwavacc': float(split_snapshot[11])} 52 | 53 | net.train() 54 | 55 | mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 56 | 57 | input_transform = standard_transforms.Compose([ 58 | standard_transforms.ToTensor(), 59 | standard_transforms.Normalize(*mean_std) 60 | ]) 61 | target_transform = extended_transforms.MaskToTensor() 62 | restore_transform = standard_transforms.Compose([ 63 | extended_transforms.DeNormalize(*mean_std), 64 | standard_transforms.ToPILImage(), 65 | ]) 66 | visualize = standard_transforms.Compose([ 67 | standard_transforms.Scale(400), 68 | standard_transforms.CenterCrop(400), 69 | standard_transforms.ToTensor() 70 | ]) 71 | 72 | train_set = voc.VOC('train', transform=input_transform, target_transform=target_transform) 73 | train_loader = DataLoader(train_set, batch_size=1, num_workers=4, shuffle=True) 74 | val_set = voc.VOC('val', transform=input_transform, target_transform=target_transform) 75 | val_loader = DataLoader(val_set, batch_size=1, num_workers=4, shuffle=False) 76 | 77 | criterion = CrossEntropyLoss2d(size_average=False, ignore_index=voc.ignore_label).cuda() 78 | 79 | optimizer = optim.Adam([ 80 | {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'], 81 | 'lr': 2 * train_args['lr']}, 82 | {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'], 83 | 'lr': train_args['lr'], 'weight_decay': train_args['weight_decay']} 84 | ], betas=(train_args['momentum'], 0.999)) 85 | 86 | if len(train_args['snapshot']) > 0: 87 | optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, 'opt_' + train_args['snapshot']))) 88 | optimizer.param_groups[0]['lr'] = 2 * train_args['lr'] 89 | optimizer.param_groups[1]['lr'] = train_args['lr'] 90 | 91 | check_mkdir(ckpt_path) 92 | check_mkdir(os.path.join(ckpt_path, exp_name)) 93 | open(os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt'), 'w').write(str(train_args) + '\n\n') 94 | 95 | scheduler = ReduceLROnPlateau(optimizer, 'min', patience=train_args['lr_patience'], min_lr=1e-10, verbose=True) 96 | for epoch in range(curr_epoch, train_args['epoch_num'] + 1): 97 | train(train_loader, net, criterion, optimizer, epoch, train_args) 98 | val_loss = validate(val_loader, net, criterion, optimizer, epoch, train_args, restore_transform, visualize) 99 | scheduler.step(val_loss) 100 | 101 | 102 | def train(train_loader, net, criterion, optimizer, epoch, train_args): 103 | train_loss = AverageMeter() 104 | curr_iter = (epoch - 1) * len(train_loader) 105 | for i, data in enumerate(train_loader): 106 | inputs, labels = data 107 | assert inputs.size()[2:] == labels.size()[1:] 108 | N = inputs.size(0) 109 | inputs = Variable(inputs).cuda() 110 | labels = Variable(labels).cuda() 111 | 112 | optimizer.zero_grad() 113 | outputs = net(inputs) 114 | assert outputs.size()[2:] == labels.size()[1:] 115 | assert outputs.size()[1] == voc.num_classes 116 | 117 | loss = criterion(outputs, labels) / N 118 | loss.backward() 119 | optimizer.step() 120 | 121 | train_loss.update(loss.data[0], N) 122 | 123 | curr_iter += 1 124 | writer.add_scalar('train_loss', train_loss.avg, curr_iter) 125 | 126 | if (i + 1) % train_args['print_freq'] == 0: 127 | print('[epoch %d], [iter %d / %d], [train loss %.5f]' % ( 128 | epoch, i + 1, len(train_loader), train_loss.avg 129 | )) 130 | 131 | 132 | def validate(val_loader, net, criterion, optimizer, epoch, train_args, restore, visualize): 133 | net.eval() 134 | 135 | val_loss = AverageMeter() 136 | inputs_all, gts_all, predictions_all = [], [], [] 137 | 138 | for vi, data in enumerate(val_loader): 139 | inputs, gts = data 140 | N = inputs.size(0) 141 | inputs = Variable(inputs, volatile=True).cuda() 142 | gts = Variable(gts, volatile=True).cuda() 143 | 144 | outputs = net(inputs) 145 | predictions = outputs.data.max(1)[1].squeeze_(1).squeeze_(0).cpu().numpy() 146 | 147 | val_loss.update(criterion(outputs, gts).data[0] / N, N) 148 | 149 | if random.random() > train_args['val_img_sample_rate']: 150 | inputs_all.append(None) 151 | else: 152 | inputs_all.append(inputs.data.squeeze_(0).cpu()) 153 | gts_all.append(gts.data.squeeze_(0).cpu().numpy()) 154 | predictions_all.append(predictions) 155 | 156 | acc, acc_cls, mean_iu, fwavacc = evaluate(predictions_all, gts_all, voc.num_classes) 157 | 158 | if mean_iu > train_args['best_record']['mean_iu']: 159 | train_args['best_record']['val_loss'] = val_loss.avg 160 | train_args['best_record']['epoch'] = epoch 161 | train_args['best_record']['acc'] = acc 162 | train_args['best_record']['acc_cls'] = acc_cls 163 | train_args['best_record']['mean_iu'] = mean_iu 164 | train_args['best_record']['fwavacc'] = fwavacc 165 | snapshot_name = 'epoch_%d_loss_%.5f_acc_%.5f_acc-cls_%.5f_mean-iu_%.5f_fwavacc_%.5f_lr_%.10f' % ( 166 | epoch, val_loss.avg, acc, acc_cls, mean_iu, fwavacc, optimizer.param_groups[1]['lr'] 167 | ) 168 | torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, snapshot_name + '.pth')) 169 | torch.save(optimizer.state_dict(), os.path.join(ckpt_path, exp_name, 'opt_' + snapshot_name + '.pth')) 170 | 171 | if train_args['val_save_to_img_file']: 172 | to_save_dir = os.path.join(ckpt_path, exp_name, str(epoch)) 173 | check_mkdir(to_save_dir) 174 | 175 | val_visual = [] 176 | for idx, data in enumerate(zip(inputs_all, gts_all, predictions_all)): 177 | if data[0] is None: 178 | continue 179 | input_pil = restore(data[0]) 180 | gt_pil = voc.colorize_mask(data[1]) 181 | predictions_pil = voc.colorize_mask(data[2]) 182 | if train_args['val_save_to_img_file']: 183 | input_pil.save(os.path.join(to_save_dir, '%d_input.png' % idx)) 184 | predictions_pil.save(os.path.join(to_save_dir, '%d_prediction.png' % idx)) 185 | gt_pil.save(os.path.join(to_save_dir, '%d_gt.png' % idx)) 186 | val_visual.extend([visualize(input_pil.convert('RGB')), visualize(gt_pil.convert('RGB')), 187 | visualize(predictions_pil.convert('RGB'))]) 188 | val_visual = torch.stack(val_visual, 0) 189 | val_visual = vutils.make_grid(val_visual, nrow=3, padding=5) 190 | writer.add_image(snapshot_name, val_visual) 191 | 192 | print('--------------------------------------------------------------------') 193 | print('[epoch %d], [val loss %.5f], [acc %.5f], [acc_cls %.5f], [mean_iu %.5f], [fwavacc %.5f]' % ( 194 | epoch, val_loss.avg, acc, acc_cls, mean_iu, fwavacc)) 195 | 196 | print('best record: [val loss %.5f], [acc %.5f], [acc_cls %.5f], [mean_iu %.5f], [fwavacc %.5f], [epoch %d]' % ( 197 | train_args['best_record']['val_loss'], train_args['best_record']['acc'], train_args['best_record']['acc_cls'], 198 | train_args['best_record']['mean_iu'], train_args['best_record']['fwavacc'], train_args['best_record']['epoch'])) 199 | 200 | print('--------------------------------------------------------------------') 201 | 202 | writer.add_scalar('val_loss', val_loss.avg, epoch) 203 | writer.add_scalar('acc', acc, epoch) 204 | writer.add_scalar('acc_cls', acc_cls, epoch) 205 | writer.add_scalar('mean_iu', mean_iu, epoch) 206 | writer.add_scalar('fwavacc', fwavacc, epoch) 207 | writer.add_scalar('lr', optimizer.param_groups[1]['lr'], epoch) 208 | 209 | net.train() 210 | return val_loss.avg 211 | 212 | 213 | if __name__ == '__main__': 214 | main(args) 215 | -------------------------------------------------------------------------------- /train/voc-psp_net/train.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | from math import sqrt 4 | 5 | import numpy as np 6 | import torchvision.transforms as standard_transforms 7 | import torchvision.utils as vutils 8 | from tensorboard import SummaryWriter 9 | from torch import optim 10 | from torch.autograd import Variable 11 | from torch.utils.data import DataLoader 12 | 13 | import utils.joint_transforms as joint_transforms 14 | import utils.transforms as extended_transforms 15 | from datasets import voc 16 | from models import * 17 | from utils import check_mkdir, evaluate, AverageMeter, CrossEntropyLoss2d 18 | 19 | ckpt_path = '../../ckpt' 20 | exp_name = 'voc-psp_net' 21 | writer = SummaryWriter(os.path.join(ckpt_path, 'exp', exp_name)) 22 | 23 | args = { 24 | 'train_batch_size': 1, 25 | 'lr': 1e-2 / sqrt(16 / 4), 26 | 'lr_decay': 0.9, 27 | 'max_iter': 3e4, 28 | 'longer_size': 512, 29 | 'crop_size': 473, 30 | 'stride_rate': 2 / 3., 31 | 'weight_decay': 1e-4, 32 | 'momentum': 0.9, 33 | 'snapshot': '', 34 | 'print_freq': 10, 35 | 'val_save_to_img_file': True, 36 | 'val_img_sample_rate': 0.01, # randomly sample some validation results to display, 37 | 'val_img_display_size': 384, 38 | } 39 | 40 | 41 | def main(): 42 | net = PSPNet(num_classes=voc.num_classes).cuda() 43 | 44 | if len(args['snapshot']) == 0: 45 | net.load_state_dict(torch.load(os.path.join(ckpt_path, 'cityscapes (coarse)-psp_net', 'xx.pth'))) 46 | curr_epoch = 1 47 | args['best_record'] = {'epoch': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0} 48 | else: 49 | print('training resumes from ' + args['snapshot']) 50 | net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot']))) 51 | split_snapshot = args['snapshot'].split('_') 52 | curr_epoch = int(split_snapshot[1]) + 1 53 | args['best_record'] = {'epoch': int(split_snapshot[1]), 'val_loss': float(split_snapshot[3]), 54 | 'acc': float(split_snapshot[5]), 'acc_cls': float(split_snapshot[7]), 55 | 'mean_iu': float(split_snapshot[9]), 'fwavacc': float(split_snapshot[11])} 56 | net.train() 57 | 58 | mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 59 | 60 | train_joint_transform = joint_transforms.Compose([ 61 | joint_transforms.Scale(args['longer_size']), 62 | joint_transforms.RandomRotate(10), 63 | joint_transforms.RandomHorizontallyFlip() 64 | ]) 65 | sliding_crop = joint_transforms.SlidingCrop(args['crop_size'], args['stride_rate'], voc.ignore_label) 66 | train_input_transform = standard_transforms.Compose([ 67 | standard_transforms.ToTensor(), 68 | standard_transforms.Normalize(*mean_std) 69 | ]) 70 | val_input_transform = standard_transforms.Compose([ 71 | standard_transforms.ToTensor(), 72 | standard_transforms.Normalize(*mean_std) 73 | ]) 74 | target_transform = extended_transforms.MaskToTensor() 75 | visualize = standard_transforms.Compose([ 76 | standard_transforms.Scale(args['val_img_display_size']), 77 | standard_transforms.ToTensor() 78 | ]) 79 | 80 | train_set = voc.VOC('train', joint_transform=train_joint_transform, sliding_crop=sliding_crop, 81 | transform=train_input_transform, target_transform=target_transform) 82 | train_loader = DataLoader(train_set, batch_size=args['train_batch_size'], num_workers=8, shuffle=True) 83 | val_set = voc.VOC('val', transform=val_input_transform, sliding_crop=sliding_crop, 84 | target_transform=target_transform) 85 | val_loader = DataLoader(val_set, batch_size=1, num_workers=8, shuffle=False) 86 | 87 | criterion = CrossEntropyLoss2d(size_average=True, ignore_index=voc.ignore_label).cuda() 88 | 89 | optimizer = optim.SGD([ 90 | {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'], 91 | 'lr': 2 * args['lr']}, 92 | {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'], 93 | 'lr': args['lr'], 'weight_decay': args['weight_decay']} 94 | ], momentum=args['momentum'], nesterov=True) 95 | 96 | if len(args['snapshot']) > 0: 97 | optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, 'opt_' + args['snapshot']))) 98 | optimizer.param_groups[0]['lr'] = 2 * args['lr'] 99 | optimizer.param_groups[1]['lr'] = args['lr'] 100 | 101 | check_mkdir(ckpt_path) 102 | check_mkdir(os.path.join(ckpt_path, exp_name)) 103 | open(os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt'), 'w').write(str(args) + '\n\n') 104 | 105 | train(train_loader, net, criterion, optimizer, curr_epoch, args, val_loader, visualize) 106 | 107 | 108 | def train(train_loader, net, criterion, optimizer, curr_epoch, train_args, val_loader, visualize): 109 | while True: 110 | train_main_loss = AverageMeter() 111 | train_aux_loss = AverageMeter() 112 | curr_iter = (curr_epoch - 1) * len(train_loader) 113 | for i, data in enumerate(train_loader): 114 | optimizer.param_groups[0]['lr'] = 2 * train_args['lr'] * (1 - float(curr_iter) / train_args['max_iter'] 115 | ) ** train_args['lr_decay'] 116 | optimizer.param_groups[1]['lr'] = train_args['lr'] * (1 - float(curr_iter) / train_args['max_iter'] 117 | ) ** train_args['lr_decay'] 118 | 119 | inputs, gts, _ = data 120 | assert len(inputs.size()) == 5 and len(gts.size()) == 4 121 | inputs.transpose_(0, 1) 122 | gts.transpose_(0, 1) 123 | 124 | assert inputs.size()[3:] == gts.size()[2:] 125 | slice_batch_pixel_size = inputs.size(1) * inputs.size(3) * inputs.size(4) 126 | 127 | for inputs_slice, gts_slice in zip(inputs, gts): 128 | inputs_slice = Variable(inputs_slice).cuda() 129 | gts_slice = Variable(gts_slice).cuda() 130 | 131 | optimizer.zero_grad() 132 | outputs, aux = net(inputs_slice) 133 | assert outputs.size()[2:] == gts_slice.size()[1:] 134 | assert outputs.size()[1] == voc.num_classes 135 | 136 | main_loss = criterion(outputs, gts_slice) 137 | aux_loss = criterion(aux, gts_slice) 138 | loss = main_loss + 0.4 * aux_loss 139 | loss.backward() 140 | optimizer.step() 141 | 142 | train_main_loss.update(main_loss.data[0], slice_batch_pixel_size) 143 | train_aux_loss.update(aux_loss.data[0], slice_batch_pixel_size) 144 | 145 | curr_iter += 1 146 | writer.add_scalar('train_main_loss', train_main_loss.avg, curr_iter) 147 | writer.add_scalar('train_aux_loss', train_aux_loss.avg, curr_iter) 148 | writer.add_scalar('lr', optimizer.param_groups[1]['lr'], curr_iter) 149 | 150 | if (i + 1) % train_args['print_freq'] == 0: 151 | print('[epoch %d], [iter %d / %d], [train main loss %.5f], [train aux loss %.5f]. [lr %.10f]' % ( 152 | curr_epoch, i + 1, len(train_loader), train_main_loss.avg, train_aux_loss.avg, 153 | optimizer.param_groups[1]['lr'])) 154 | if curr_iter >= train_args['max_iter']: 155 | return 156 | validate(val_loader, net, criterion, optimizer, curr_epoch, train_args, visualize) 157 | curr_epoch += 1 158 | 159 | 160 | def validate(val_loader, net, criterion, optimizer, epoch, train_args, visualize): 161 | # the following code is written assuming that batch size is 1 162 | net.eval() 163 | 164 | val_loss = AverageMeter() 165 | 166 | gts_all = np.zeros((len(val_loader), args['shorter_size'], 2 * args['shorter_size']), dtype=int) 167 | predictions_all = np.zeros((len(val_loader), args['shorter_size'], 2 * args['shorter_size']), dtype=int) 168 | for vi, data in enumerate(val_loader): 169 | input, gt, slices_info = data 170 | assert len(input.size()) == 5 and len(gt.size()) == 4 and len(slices_info.size()) == 3 171 | input.transpose_(0, 1) 172 | gt.transpose_(0, 1) 173 | slices_info.squeeze_(0) 174 | assert input.size()[3:] == gt.size()[2:] 175 | 176 | count = torch.zeros(args['shorter_size'], 2 * args['shorter_size']).cuda() 177 | output = torch.zeros(voc.num_classes, args['shorter_size'], 2 * args['shorter_size']).cuda() 178 | 179 | slice_batch_pixel_size = input.size(1) * input.size(3) * input.size(4) 180 | 181 | for input_slice, gt_slice, info in zip(input, gt, slices_info): 182 | input_slice = Variable(input_slice).cuda() 183 | gt_slice = Variable(gt_slice).cuda() 184 | 185 | output_slice = net(input_slice) 186 | assert output_slice.size()[2:] == gt_slice.size()[1:] 187 | assert output_slice.size()[1] == voc.num_classes 188 | output[:, info[0]: info[1], info[2]: info[3]] += output_slice[0, :, :info[4], :info[5]].data 189 | gts_all[vi, info[0]: info[1], info[2]: info[3]] += gt_slice[0, :info[4], :info[5]].data.cpu().numpy() 190 | 191 | count[info[0]: info[1], info[2]: info[3]] += 1 192 | 193 | val_loss.update(criterion(output_slice, gt_slice).data[0], slice_batch_pixel_size) 194 | 195 | output /= count 196 | gts_all[vi, :, :] /= count.cpu().numpy().astype(int) 197 | predictions_all[vi, :, :] = output.max(0)[1].squeeze_(0).cpu().numpy() 198 | 199 | print('validating: %d / %d' % (vi + 1, len(val_loader))) 200 | 201 | acc, acc_cls, mean_iu, fwavacc = evaluate(predictions_all, gts_all, voc.num_classes) 202 | 203 | train_args['best_record']['val_loss'] = val_loss.avg 204 | train_args['best_record']['epoch'] = epoch 205 | train_args['best_record']['acc'] = acc 206 | train_args['best_record']['acc_cls'] = acc_cls 207 | train_args['best_record']['mean_iu'] = mean_iu 208 | train_args['best_record']['fwavacc'] = fwavacc 209 | snapshot_name = 'epoch_%d_loss_%.5f_acc_%.5f_acc-cls_%.5f_mean-iu_%.5f_fwavacc_%.5f_lr_%.10f' % ( 210 | epoch, val_loss.avg, acc, acc_cls, mean_iu, fwavacc, optimizer.param_groups[1]['lr']) 211 | torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, snapshot_name + '.pth')) 212 | torch.save(optimizer.state_dict(), os.path.join(ckpt_path, exp_name, 'opt_' + snapshot_name + '.pth')) 213 | 214 | if train_args['val_save_to_img_file']: 215 | to_save_dir = os.path.join(ckpt_path, exp_name, str(epoch)) 216 | check_mkdir(to_save_dir) 217 | 218 | val_visual = [] 219 | for idx, data in enumerate(zip(gts_all, predictions_all)): 220 | gt_pil = voc.colorize_mask(data[0]) 221 | predictions_pil = voc.colorize_mask(data[1]) 222 | if train_args['val_save_to_img_file']: 223 | predictions_pil.save(os.path.join(to_save_dir, '%d_prediction.png' % idx)) 224 | gt_pil.save(os.path.join(to_save_dir, '%d_gt.png' % idx)) 225 | val_visual.extend([visualize(gt_pil.convert('RGB')), 226 | visualize(predictions_pil.convert('RGB'))]) 227 | val_visual = torch.stack(val_visual, 0) 228 | val_visual = vutils.make_grid(val_visual, nrow=2, padding=5) 229 | writer.add_image(snapshot_name, val_visual) 230 | 231 | print('-----------------------------------------------------------------------------------------------------------') 232 | print('[epoch %d], [val loss %.5f], [acc %.5f], [acc_cls %.5f], [mean_iu %.5f], [fwavacc %.5f]' % ( 233 | epoch, val_loss.avg, acc, acc_cls, mean_iu, fwavacc)) 234 | 235 | print('best record: [val loss %.5f], [acc %.5f], [acc_cls %.5f], [mean_iu %.5f], [fwavacc %.5f], [epoch %d]' % ( 236 | train_args['best_record']['val_loss'], train_args['best_record']['acc'], train_args['best_record']['acc_cls'], 237 | train_args['best_record']['mean_iu'], train_args['best_record']['fwavacc'], train_args['best_record']['epoch'])) 238 | 239 | print('-----------------------------------------------------------------------------------------------------------') 240 | 241 | writer.add_scalar('val_loss', val_loss.avg, epoch) 242 | writer.add_scalar('acc', acc, epoch) 243 | writer.add_scalar('acc_cls', acc_cls, epoch) 244 | writer.add_scalar('mean_iu', mean_iu, epoch) 245 | writer.add_scalar('fwavacc', fwavacc, epoch) 246 | 247 | net.train() 248 | return val_loss.avg 249 | 250 | 251 | if __name__ == '__main__': 252 | main() 253 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .misc import * 2 | from .joint_transforms import * 3 | from .transforms import * 4 | -------------------------------------------------------------------------------- /utils/joint_transforms.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numbers 3 | import random 4 | 5 | from PIL import Image, ImageOps 6 | import numpy as np 7 | 8 | 9 | class Compose(object): 10 | def __init__(self, transforms): 11 | self.transforms = transforms 12 | 13 | def __call__(self, img, mask): 14 | assert img.size == mask.size 15 | for t in self.transforms: 16 | img, mask = t(img, mask) 17 | return img, mask 18 | 19 | 20 | class RandomCrop(object): 21 | def __init__(self, size, padding=0): 22 | if isinstance(size, numbers.Number): 23 | self.size = (int(size), int(size)) 24 | else: 25 | self.size = size 26 | self.padding = padding 27 | 28 | def __call__(self, img, mask): 29 | if self.padding > 0: 30 | img = ImageOps.expand(img, border=self.padding, fill=0) 31 | mask = ImageOps.expand(mask, border=self.padding, fill=0) 32 | 33 | assert img.size == mask.size 34 | w, h = img.size 35 | th, tw = self.size 36 | if w == tw and h == th: 37 | return img, mask 38 | if w < tw or h < th: 39 | return img.resize((tw, th), Image.BILINEAR), mask.resize((tw, th), Image.NEAREST) 40 | 41 | x1 = random.randint(0, w - tw) 42 | y1 = random.randint(0, h - th) 43 | return img.crop((x1, y1, x1 + tw, y1 + th)), mask.crop((x1, y1, x1 + tw, y1 + th)) 44 | 45 | 46 | class CenterCrop(object): 47 | def __init__(self, size): 48 | if isinstance(size, numbers.Number): 49 | self.size = (int(size), int(size)) 50 | else: 51 | self.size = size 52 | 53 | def __call__(self, img, mask): 54 | assert img.size == mask.size 55 | w, h = img.size 56 | th, tw = self.size 57 | x1 = int(round((w - tw) / 2.)) 58 | y1 = int(round((h - th) / 2.)) 59 | return img.crop((x1, y1, x1 + tw, y1 + th)), mask.crop((x1, y1, x1 + tw, y1 + th)) 60 | 61 | 62 | class RandomHorizontallyFlip(object): 63 | def __call__(self, img, mask): 64 | if random.random() < 0.5: 65 | return img.transpose(Image.FLIP_LEFT_RIGHT), mask.transpose(Image.FLIP_LEFT_RIGHT) 66 | return img, mask 67 | 68 | 69 | class FreeScale(object): 70 | def __init__(self, size): 71 | self.size = tuple(reversed(size)) # size: (h, w) 72 | 73 | def __call__(self, img, mask): 74 | assert img.size == mask.size 75 | return img.resize(self.size, Image.BILINEAR), mask.resize(self.size, Image.NEAREST) 76 | 77 | 78 | class Scale(object): 79 | def __init__(self, size): 80 | self.size = size 81 | 82 | def __call__(self, img, mask): 83 | assert img.size == mask.size 84 | w, h = img.size 85 | if (w >= h and w == self.size) or (h >= w and h == self.size): 86 | return img, mask 87 | if w > h: 88 | ow = self.size 89 | oh = int(self.size * h / w) 90 | return img.resize((ow, oh), Image.BILINEAR), mask.resize((ow, oh), Image.NEAREST) 91 | else: 92 | oh = self.size 93 | ow = int(self.size * w / h) 94 | return img.resize((ow, oh), Image.BILINEAR), mask.resize((ow, oh), Image.NEAREST) 95 | 96 | 97 | class RandomSizedCrop(object): 98 | def __init__(self, size): 99 | self.size = size 100 | 101 | def __call__(self, img, mask): 102 | assert img.size == mask.size 103 | for attempt in range(10): 104 | area = img.size[0] * img.size[1] 105 | target_area = random.uniform(0.45, 1.0) * area 106 | aspect_ratio = random.uniform(0.5, 2) 107 | 108 | w = int(round(math.sqrt(target_area * aspect_ratio))) 109 | h = int(round(math.sqrt(target_area / aspect_ratio))) 110 | 111 | if random.random() < 0.5: 112 | w, h = h, w 113 | 114 | if w <= img.size[0] and h <= img.size[1]: 115 | x1 = random.randint(0, img.size[0] - w) 116 | y1 = random.randint(0, img.size[1] - h) 117 | 118 | img = img.crop((x1, y1, x1 + w, y1 + h)) 119 | mask = mask.crop((x1, y1, x1 + w, y1 + h)) 120 | assert (img.size == (w, h)) 121 | 122 | return img.resize((self.size, self.size), Image.BILINEAR), mask.resize((self.size, self.size), 123 | Image.NEAREST) 124 | 125 | # Fallback 126 | scale = Scale(self.size) 127 | crop = CenterCrop(self.size) 128 | return crop(*scale(img, mask)) 129 | 130 | 131 | class RandomRotate(object): 132 | def __init__(self, degree): 133 | self.degree = degree 134 | 135 | def __call__(self, img, mask): 136 | rotate_degree = random.random() * 2 * self.degree - self.degree 137 | return img.rotate(rotate_degree, Image.BILINEAR), mask.rotate(rotate_degree, Image.NEAREST) 138 | 139 | 140 | class RandomSized(object): 141 | def __init__(self, size): 142 | self.size = size 143 | self.scale = Scale(self.size) 144 | self.crop = RandomCrop(self.size) 145 | 146 | def __call__(self, img, mask): 147 | assert img.size == mask.size 148 | 149 | w = int(random.uniform(0.5, 2) * img.size[0]) 150 | h = int(random.uniform(0.5, 2) * img.size[1]) 151 | 152 | img, mask = img.resize((w, h), Image.BILINEAR), mask.resize((w, h), Image.NEAREST) 153 | 154 | return self.crop(*self.scale(img, mask)) 155 | 156 | 157 | class SlidingCropOld(object): 158 | def __init__(self, crop_size, stride_rate, ignore_label): 159 | self.crop_size = crop_size 160 | self.stride_rate = stride_rate 161 | self.ignore_label = ignore_label 162 | 163 | def _pad(self, img, mask): 164 | h, w = img.shape[: 2] 165 | pad_h = max(self.crop_size - h, 0) 166 | pad_w = max(self.crop_size - w, 0) 167 | img = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), 'constant') 168 | mask = np.pad(mask, ((0, pad_h), (0, pad_w)), 'constant', constant_values=self.ignore_label) 169 | return img, mask 170 | 171 | def __call__(self, img, mask): 172 | assert img.size == mask.size 173 | 174 | w, h = img.size 175 | long_size = max(h, w) 176 | 177 | img = np.array(img) 178 | mask = np.array(mask) 179 | 180 | if long_size > self.crop_size: 181 | stride = int(math.ceil(self.crop_size * self.stride_rate)) 182 | h_step_num = int(math.ceil((h - self.crop_size) / float(stride))) + 1 183 | w_step_num = int(math.ceil((w - self.crop_size) / float(stride))) + 1 184 | img_sublist, mask_sublist = [], [] 185 | for yy in xrange(h_step_num): 186 | for xx in xrange(w_step_num): 187 | sy, sx = yy * stride, xx * stride 188 | ey, ex = sy + self.crop_size, sx + self.crop_size 189 | img_sub = img[sy: ey, sx: ex, :] 190 | mask_sub = mask[sy: ey, sx: ex] 191 | img_sub, mask_sub = self._pad(img_sub, mask_sub) 192 | img_sublist.append(Image.fromarray(img_sub.astype(np.uint8)).convert('RGB')) 193 | mask_sublist.append(Image.fromarray(mask_sub.astype(np.uint8)).convert('P')) 194 | return img_sublist, mask_sublist 195 | else: 196 | img, mask = self._pad(img, mask) 197 | img = Image.fromarray(img.astype(np.uint8)).convert('RGB') 198 | mask = Image.fromarray(mask.astype(np.uint8)).convert('P') 199 | return img, mask 200 | 201 | 202 | class SlidingCrop(object): 203 | def __init__(self, crop_size, stride_rate, ignore_label): 204 | self.crop_size = crop_size 205 | self.stride_rate = stride_rate 206 | self.ignore_label = ignore_label 207 | 208 | def _pad(self, img, mask): 209 | h, w = img.shape[: 2] 210 | pad_h = max(self.crop_size - h, 0) 211 | pad_w = max(self.crop_size - w, 0) 212 | img = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), 'constant') 213 | mask = np.pad(mask, ((0, pad_h), (0, pad_w)), 'constant', constant_values=self.ignore_label) 214 | return img, mask, h, w 215 | 216 | def __call__(self, img, mask): 217 | assert img.size == mask.size 218 | 219 | w, h = img.size 220 | long_size = max(h, w) 221 | 222 | img = np.array(img) 223 | mask = np.array(mask) 224 | 225 | if long_size > self.crop_size: 226 | stride = int(math.ceil(self.crop_size * self.stride_rate)) 227 | h_step_num = int(math.ceil((h - self.crop_size) / float(stride))) + 1 228 | w_step_num = int(math.ceil((w - self.crop_size) / float(stride))) + 1 229 | img_slices, mask_slices, slices_info = [], [], [] 230 | for yy in xrange(h_step_num): 231 | for xx in xrange(w_step_num): 232 | sy, sx = yy * stride, xx * stride 233 | ey, ex = sy + self.crop_size, sx + self.crop_size 234 | img_sub = img[sy: ey, sx: ex, :] 235 | mask_sub = mask[sy: ey, sx: ex] 236 | img_sub, mask_sub, sub_h, sub_w = self._pad(img_sub, mask_sub) 237 | img_slices.append(Image.fromarray(img_sub.astype(np.uint8)).convert('RGB')) 238 | mask_slices.append(Image.fromarray(mask_sub.astype(np.uint8)).convert('P')) 239 | slices_info.append([sy, ey, sx, ex, sub_h, sub_w]) 240 | return img_slices, mask_slices, slices_info 241 | else: 242 | img, mask, sub_h, sub_w = self._pad(img, mask) 243 | img = Image.fromarray(img.astype(np.uint8)).convert('RGB') 244 | mask = Image.fromarray(mask.astype(np.uint8)).convert('P') 245 | return [img], [mask], [[0, sub_h, 0, sub_w, sub_h, sub_w]] 246 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | from math import ceil 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | from torch.autograd import Variable 9 | 10 | 11 | def check_mkdir(dir_name): 12 | if not os.path.exists(dir_name): 13 | os.mkdir(dir_name) 14 | 15 | 16 | def initialize_weights(*models): 17 | for model in models: 18 | for module in model.modules(): 19 | if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear): 20 | nn.init.kaiming_normal(module.weight) 21 | if module.bias is not None: 22 | module.bias.data.zero_() 23 | elif isinstance(module, nn.BatchNorm2d): 24 | module.weight.data.fill_(1) 25 | module.bias.data.zero_() 26 | 27 | 28 | def get_upsampling_weight(in_channels, out_channels, kernel_size): 29 | factor = (kernel_size + 1) // 2 30 | if kernel_size % 2 == 1: 31 | center = factor - 1 32 | else: 33 | center = factor - 0.5 34 | og = np.ogrid[:kernel_size, :kernel_size] 35 | filt = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor) 36 | weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), dtype=np.float64) 37 | weight[list(range(in_channels)), list(range(out_channels)), :, :] = filt 38 | return torch.from_numpy(weight).float() 39 | 40 | 41 | class CrossEntropyLoss2d(nn.Module): 42 | def __init__(self, weight=None, size_average=True, ignore_index=255): 43 | super(CrossEntropyLoss2d, self).__init__() 44 | self.nll_loss = nn.NLLLoss2d(weight, size_average, ignore_index) 45 | 46 | def forward(self, inputs, targets): 47 | return self.nll_loss(F.log_softmax(inputs), targets) 48 | 49 | 50 | class FocalLoss2d(nn.Module): 51 | def __init__(self, gamma=2, weight=None, size_average=True, ignore_index=255): 52 | super(FocalLoss2d, self).__init__() 53 | self.gamma = gamma 54 | self.nll_loss = nn.NLLLoss2d(weight, size_average, ignore_index) 55 | 56 | def forward(self, inputs, targets): 57 | return self.nll_loss((1 - F.softmax(inputs)) ** self.gamma * F.log_softmax(inputs), targets) 58 | 59 | 60 | def _fast_hist(label_pred, label_true, num_classes): 61 | mask = (label_true >= 0) & (label_true < num_classes) 62 | hist = np.bincount( 63 | num_classes * label_true[mask].astype(int) + 64 | label_pred[mask], minlength=num_classes ** 2).reshape(num_classes, num_classes) 65 | return hist 66 | 67 | 68 | def evaluate(predictions, gts, num_classes): 69 | hist = np.zeros((num_classes, num_classes)) 70 | for lp, lt in zip(predictions, gts): 71 | hist += _fast_hist(lp.flatten(), lt.flatten(), num_classes) 72 | # axis 0: gt, axis 1: prediction 73 | acc = np.diag(hist).sum() / hist.sum() 74 | acc_cls = np.diag(hist) / hist.sum(axis=1) 75 | acc_cls = np.nanmean(acc_cls) 76 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) 77 | mean_iu = np.nanmean(iu) 78 | freq = hist.sum(axis=1) / hist.sum() 79 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 80 | return acc, acc_cls, mean_iu, fwavacc 81 | 82 | 83 | class AverageMeter(object): 84 | def __init__(self): 85 | self.reset() 86 | 87 | def reset(self): 88 | self.val = 0 89 | self.avg = 0 90 | self.sum = 0 91 | self.count = 0 92 | 93 | def update(self, val, n=1): 94 | self.val = val 95 | self.sum += val * n 96 | self.count += n 97 | self.avg = self.sum / self.count 98 | 99 | 100 | class PolyLR(object): 101 | def __init__(self, optimizer, curr_iter, max_iter, lr_decay): 102 | self.max_iter = float(max_iter) 103 | self.init_lr_groups = [] 104 | for p in optimizer.param_groups: 105 | self.init_lr_groups.append(p['lr']) 106 | self.param_groups = optimizer.param_groups 107 | self.curr_iter = curr_iter 108 | self.lr_decay = lr_decay 109 | 110 | def step(self): 111 | for idx, p in enumerate(self.param_groups): 112 | p['lr'] = self.init_lr_groups[idx] * (1 - self.curr_iter / self.max_iter) ** self.lr_decay 113 | 114 | 115 | # just a try, not recommend to use 116 | class Conv2dDeformable(nn.Module): 117 | def __init__(self, regular_filter, cuda=True): 118 | super(Conv2dDeformable, self).__init__() 119 | assert isinstance(regular_filter, nn.Conv2d) 120 | self.regular_filter = regular_filter 121 | self.offset_filter = nn.Conv2d(regular_filter.in_channels, 2 * regular_filter.in_channels, kernel_size=3, 122 | padding=1, bias=False) 123 | self.offset_filter.weight.data.normal_(0, 0.0005) 124 | self.input_shape = None 125 | self.grid_w = None 126 | self.grid_h = None 127 | self.cuda = cuda 128 | 129 | def forward(self, x): 130 | x_shape = x.size() # (b, c, h, w) 131 | offset = self.offset_filter(x) # (b, 2*c, h, w) 132 | offset_w, offset_h = torch.split(offset, self.regular_filter.in_channels, 1) # (b, c, h, w) 133 | offset_w = offset_w.contiguous().view(-1, int(x_shape[2]), int(x_shape[3])) # (b*c, h, w) 134 | offset_h = offset_h.contiguous().view(-1, int(x_shape[2]), int(x_shape[3])) # (b*c, h, w) 135 | if not self.input_shape or self.input_shape != x_shape: 136 | self.input_shape = x_shape 137 | grid_w, grid_h = np.meshgrid(np.linspace(-1, 1, x_shape[3]), np.linspace(-1, 1, x_shape[2])) # (h, w) 138 | grid_w = torch.Tensor(grid_w) 139 | grid_h = torch.Tensor(grid_h) 140 | if self.cuda: 141 | grid_w = grid_w.cuda() 142 | grid_h = grid_h.cuda() 143 | self.grid_w = nn.Parameter(grid_w) 144 | self.grid_h = nn.Parameter(grid_h) 145 | offset_w = offset_w + self.grid_w # (b*c, h, w) 146 | offset_h = offset_h + self.grid_h # (b*c, h, w) 147 | x = x.contiguous().view(-1, int(x_shape[2]), int(x_shape[3])).unsqueeze(1) # (b*c, 1, h, w) 148 | x = F.grid_sample(x, torch.stack((offset_h, offset_w), 3)) # (b*c, h, w) 149 | x = x.contiguous().view(-1, int(x_shape[1]), int(x_shape[2]), int(x_shape[3])) # (b, c, h, w) 150 | x = self.regular_filter(x) 151 | return x 152 | 153 | 154 | def sliced_forward(single_forward): 155 | def _pad(x, crop_size): 156 | h, w = x.size()[2:] 157 | pad_h = max(crop_size - h, 0) 158 | pad_w = max(crop_size - w, 0) 159 | x = F.pad(x, (0, pad_w, 0, pad_h)) 160 | return x, pad_h, pad_w 161 | 162 | def wrapper(self, x): 163 | batch_size, _, ori_h, ori_w = x.size() 164 | if self.training and self.use_aux: 165 | outputs_all_scales = Variable(torch.zeros((batch_size, self.num_classes, ori_h, ori_w))).cuda() 166 | aux_all_scales = Variable(torch.zeros((batch_size, self.num_classes, ori_h, ori_w))).cuda() 167 | for s in self.scales: 168 | new_size = (int(ori_h * s), int(ori_w * s)) 169 | scaled_x = F.upsample(x, size=new_size, mode='bilinear') 170 | scaled_x = Variable(scaled_x).cuda() 171 | scaled_h, scaled_w = scaled_x.size()[2:] 172 | long_size = max(scaled_h, scaled_w) 173 | print(scaled_x.size()) 174 | 175 | if long_size > self.crop_size: 176 | count = torch.zeros((scaled_h, scaled_w)) 177 | outputs = Variable(torch.zeros((batch_size, self.num_classes, scaled_h, scaled_w))).cuda() 178 | aux_outputs = Variable(torch.zeros((batch_size, self.num_classes, scaled_h, scaled_w))).cuda() 179 | stride = int(ceil(self.crop_size * self.stride_rate)) 180 | h_step_num = int(ceil((scaled_h - self.crop_size) / stride)) + 1 181 | w_step_num = int(ceil((scaled_w - self.crop_size) / stride)) + 1 182 | for yy in range(h_step_num): 183 | for xx in range(w_step_num): 184 | sy, sx = yy * stride, xx * stride 185 | ey, ex = sy + self.crop_size, sx + self.crop_size 186 | x_sub = scaled_x[:, :, sy: ey, sx: ex] 187 | x_sub, pad_h, pad_w = _pad(x_sub, self.crop_size) 188 | print(x_sub.size()) 189 | outputs_sub, aux_sub = single_forward(self, x_sub) 190 | 191 | if sy + self.crop_size > scaled_h: 192 | outputs_sub = outputs_sub[:, :, : -pad_h, :] 193 | aux_sub = aux_sub[:, :, : -pad_h, :] 194 | 195 | if sx + self.crop_size > scaled_w: 196 | outputs_sub = outputs_sub[:, :, :, : -pad_w] 197 | aux_sub = aux_sub[:, :, :, : -pad_w] 198 | 199 | outputs[:, :, sy: ey, sx: ex] = outputs_sub 200 | aux_outputs[:, :, sy: ey, sx: ex] = aux_sub 201 | 202 | count[sy: ey, sx: ex] += 1 203 | count = Variable(count).cuda() 204 | outputs = (outputs / count) 205 | aux_outputs = (outputs / count) 206 | else: 207 | scaled_x, pad_h, pad_w = _pad(scaled_x, self.crop_size) 208 | outputs, aux_outputs = single_forward(self, scaled_x) 209 | outputs = outputs[:, :, : -pad_h, : -pad_w] 210 | aux_outputs = aux_outputs[:, :, : -pad_h, : -pad_w] 211 | outputs_all_scales += outputs 212 | aux_all_scales += aux_outputs 213 | return outputs_all_scales / len(self.scales), aux_all_scales 214 | else: 215 | outputs_all_scales = Variable(torch.zeros((batch_size, self.num_classes, ori_h, ori_w))).cuda() 216 | for s in self.scales: 217 | new_size = (int(ori_h * s), int(ori_w * s)) 218 | scaled_x = F.upsample(x, size=new_size, mode='bilinear') 219 | scaled_h, scaled_w = scaled_x.size()[2:] 220 | long_size = max(scaled_h, scaled_w) 221 | 222 | if long_size > self.crop_size: 223 | count = torch.zeros((scaled_h, scaled_w)) 224 | outputs = Variable(torch.zeros((batch_size, self.num_classes, scaled_h, scaled_w))).cuda() 225 | stride = int(ceil(self.crop_size * self.stride_rate)) 226 | h_step_num = int(ceil((scaled_h - self.crop_size) / stride)) + 1 227 | w_step_num = int(ceil((scaled_w - self.crop_size) / stride)) + 1 228 | for yy in range(h_step_num): 229 | for xx in range(w_step_num): 230 | sy, sx = yy * stride, xx * stride 231 | ey, ex = sy + self.crop_size, sx + self.crop_size 232 | x_sub = scaled_x[:, :, sy: ey, sx: ex] 233 | x_sub, pad_h, pad_w = _pad(x_sub, self.crop_size) 234 | 235 | outputs_sub = single_forward(self, x_sub) 236 | 237 | if sy + self.crop_size > scaled_h: 238 | outputs_sub = outputs_sub[:, :, : -pad_h, :] 239 | 240 | if sx + self.crop_size > scaled_w: 241 | outputs_sub = outputs_sub[:, :, :, : -pad_w] 242 | 243 | outputs[:, :, sy: ey, sx: ex] = outputs_sub 244 | 245 | count[sy: ey, sx: ex] += 1 246 | count = Variable(count).cuda() 247 | outputs = (outputs / count) 248 | else: 249 | scaled_x, pad_h, pad_w = _pad(scaled_x, self.crop_size) 250 | outputs = single_forward(self, scaled_x) 251 | outputs = outputs[:, :, : -pad_h, : -pad_w] 252 | outputs_all_scales += outputs 253 | return outputs_all_scales 254 | 255 | return wrapper 256 | -------------------------------------------------------------------------------- /utils/transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | from skimage.filters import gaussian 5 | import torch 6 | from PIL import Image, ImageFilter 7 | 8 | 9 | class RandomVerticalFlip(object): 10 | def __call__(self, img): 11 | if random.random() < 0.5: 12 | return img.transpose(Image.FLIP_TOP_BOTTOM) 13 | return img 14 | 15 | 16 | class DeNormalize(object): 17 | def __init__(self, mean, std): 18 | self.mean = mean 19 | self.std = std 20 | 21 | def __call__(self, tensor): 22 | for t, m, s in zip(tensor, self.mean, self.std): 23 | t.mul_(s).add_(m) 24 | return tensor 25 | 26 | 27 | class MaskToTensor(object): 28 | def __call__(self, img): 29 | return torch.from_numpy(np.array(img, dtype=np.int32)).long() 30 | 31 | 32 | class FreeScale(object): 33 | def __init__(self, size, interpolation=Image.BILINEAR): 34 | self.size = tuple(reversed(size)) # size: (h, w) 35 | self.interpolation = interpolation 36 | 37 | def __call__(self, img): 38 | return img.resize(self.size, self.interpolation) 39 | 40 | 41 | class FlipChannels(object): 42 | def __call__(self, img): 43 | img = np.array(img)[:, :, ::-1] 44 | return Image.fromarray(img.astype(np.uint8)) 45 | 46 | 47 | class RandomGaussianBlur(object): 48 | def __call__(self, img): 49 | sigma = 0.15 + random.random() * 1.15 50 | blurred_img = gaussian(np.array(img), sigma=sigma, multichannel=True) 51 | blurred_img *= 255 52 | return Image.fromarray(blurred_img.astype(np.uint8)) 53 | --------------------------------------------------------------------------------