├── LICENSE ├── README.md ├── cityscapes.py ├── data └── pascal_seg_colormap.mat ├── deeplab.py ├── main.py ├── pascal.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2018, Chenxi Liu 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | 31 | 32 | --------------------------- LICENSE FOR torchvision -------------------------------- 33 | BSD 3-Clause License 34 | 35 | Copyright (c) Soumith Chintala 2016, 36 | All rights reserved. 37 | 38 | Redistribution and use in source and binary forms, with or without 39 | modification, are permitted provided that the following conditions are met: 40 | 41 | * Redistributions of source code must retain the above copyright notice, this 42 | list of conditions and the following disclaimer. 43 | 44 | * Redistributions in binary form must reproduce the above copyright notice, 45 | this list of conditions and the following disclaimer in the documentation 46 | and/or other materials provided with the distribution. 47 | 48 | * Neither the name of the copyright holder nor the names of its 49 | contributors may be used to endorse or promote products derived from 50 | this software without specific prior written permission. 51 | 52 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 53 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 54 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 55 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 56 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 57 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 58 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 59 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 60 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 61 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 62 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepLabv3.pytorch 2 | 3 | This is a PyTorch implementation of [DeepLabv3](https://arxiv.org/abs/1706.05587) that aims to reuse the [resnet implementation in torchvision](https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py) as much as possible. This means we use the [PyTorch model checkpoint](https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py#L13) when finetuning from ImageNet, instead of [the one provided in TensorFlow](https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/model_zoo.md). 4 | 5 | We try to match every detail in DeepLabv3, except that Multi-Grid other than (1, 1, 1) is not yet supported. On PASCAL VOC 2012 validation set, using the same hyperparameters, we reproduce the performance reported in the paper (GPU with 16GB memory is required). We also support the combination of Group Normalization + Weight Standardization: 6 | 7 | Implementation | Normalization | Multi-Grid | ASPP | Image Pooling | mIOU 8 | --- | --- | --- | --- | --- | --- 9 | Paper | BN | (1, 2, 4) | (6, 12, 18) | Yes | 77.21 10 | Ours | BN | (1, 1, 1) | (6, 12, 18) | Yes | 76.49 11 | Ours | GN+WS | (1, 1, 1) | (6, 12, 18) | Yes | 77.20 12 | 13 | To run the BN experiment, after preparing the dataset as follows, simply run: 14 | ```bash 15 | python main.py --train --exp bn_lr7e-3 --epochs 50 --base_lr 0.007 16 | ``` 17 | To test the trained model, use the same command except delete `--train`. To use our trained model (76.49): 18 | ```bash 19 | wget https://cs.jhu.edu/~cxliu/data/deeplab_resnet101_pascal_v3_bn_lr7e-3_epoch50.pth -P data/ 20 | ``` 21 | 22 | To run the GN+WS experiment, begin by downloading the GN+WS ResNet101 trained on ImageNet: 23 | ```bash 24 | wget https://cs.jhu.edu/~syqiao/WeightStandardization/R-101-GN-WS.pth.tar -P data/ 25 | python main.py --train --exp gn_ws_lr7e-3 --epochs 50 --base_lr 0.007 --groups 32 --weight_std 26 | ``` 27 | Again, to test the trained model, use the same command except delete `--train`. To use our trained model (77.20): 28 | ```bash 29 | wget https://cs.jhu.edu/~cxliu/data/deeplab_resnet101_pascal_v3_gn_ws_lr7e-3_epoch50.pth -P data/ 30 | ``` 31 | 32 | 33 | ## Prepare PASCAL VOC 2012 Dataset 34 | ```bash 35 | mkdir data 36 | cd data 37 | wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar 38 | tar -xf VOCtrainval_11-May-2012.tar 39 | cd VOCdevkit/VOC2012/ 40 | wget http://cs.jhu.edu/~cxliu/data/SegmentationClassAug.zip 41 | wget http://cs.jhu.edu/~cxliu/data/SegmentationClassAug_Visualization.zip 42 | wget http://cs.jhu.edu/~cxliu/data/list.zip 43 | unzip SegmentationClassAug.zip 44 | unzip SegmentationClassAug_Visualization.zip 45 | unzip list.zip 46 | ``` 47 | 48 | ## Prepare Cityscapes Dataset 49 | ```bash 50 | unzip leftImg8bit_trainvaltest.zip 51 | unzip gtFine_trainvaltest.zip 52 | git clone https://github.com/mcordts/cityscapesScripts.git 53 | mv cityscapesScripts/cityscapesscripts ./ 54 | rm -rf cityscapesScripts 55 | python cityscapesscripts/preparation/createTrainIdLabelImgs.py 56 | ``` 57 | -------------------------------------------------------------------------------- /cityscapes.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch.utils.data as data 4 | import os 5 | import random 6 | import glob 7 | from PIL import Image 8 | from utils import preprocess 9 | 10 | _FOLDERS_MAP = { 11 | 'image': 'leftImg8bit', 12 | 'label': 'gtFine', 13 | } 14 | 15 | _POSTFIX_MAP = { 16 | 'image': '_leftImg8bit', 17 | 'label': '_gtFine_labelTrainIds', 18 | } 19 | 20 | _DATA_FORMAT_MAP = { 21 | 'image': 'png', 22 | 'label': 'png', 23 | } 24 | 25 | 26 | class Cityscapes(data.Dataset): 27 | CLASSES = [ 28 | 'road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light', 29 | 'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', 30 | 'truck', 'bus', 'train', 'motorcycle', 'bicycle' 31 | ] 32 | 33 | def __init__(self, root, train=True, transform=None, target_transform=None, download=False, crop_size=None): 34 | self.root = root 35 | self.transform = transform 36 | self.target_transform = target_transform 37 | self.train = train 38 | self.crop_size = crop_size 39 | 40 | if download: 41 | self.download() 42 | 43 | dataset_split = 'train' if self.train else 'val' 44 | self.images = self._get_files('image', dataset_split) 45 | self.masks = self._get_files('label', dataset_split) 46 | 47 | def __getitem__(self, index): 48 | _img = Image.open(self.images[index]).convert('RGB') 49 | _target = Image.open(self.masks[index]) 50 | 51 | _img, _target = preprocess(_img, _target, 52 | flip=True if self.train else False, 53 | scale=(0.5, 2.0) if self.train else None, 54 | crop=(self.crop_size, self.crop_size) if self.train else (1025, 2049)) 55 | 56 | if self.transform is not None: 57 | _img = self.transform(_img) 58 | 59 | if self.target_transform is not None: 60 | _target = self.target_transform(_target) 61 | 62 | return _img, _target 63 | 64 | def _get_files(self, data, dataset_split): 65 | pattern = '*%s.%s' % (_POSTFIX_MAP[data], _DATA_FORMAT_MAP[data]) 66 | search_files = os.path.join( 67 | self.root, _FOLDERS_MAP[data], dataset_split, '*', pattern) 68 | filenames = glob.glob(search_files) 69 | return sorted(filenames) 70 | 71 | def __len__(self): 72 | return len(self.images) 73 | 74 | def download(self): 75 | raise NotImplementedError('Automatic download not yet implemented.') 76 | -------------------------------------------------------------------------------- /data/pascal_seg_colormap.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxi116/DeepLabv3.pytorch/046818d755f91169dbad141362b98178dd685447/data/pascal_seg_colormap.mat -------------------------------------------------------------------------------- /deeplab.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import torch.utils.model_zoo as model_zoo 5 | from torch.nn import functional as F 6 | 7 | 8 | __all__ = ['ResNet', 'resnet50', 'resnet101', 'resnet152'] 9 | 10 | 11 | model_urls = { 12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 15 | } 16 | 17 | 18 | class Conv2d(nn.Conv2d): 19 | 20 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 21 | padding=0, dilation=1, groups=1, bias=True): 22 | super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride, 23 | padding, dilation, groups, bias) 24 | 25 | def forward(self, x): 26 | # return super(Conv2d, self).forward(x) 27 | weight = self.weight 28 | weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, 29 | keepdim=True).mean(dim=3, keepdim=True) 30 | weight = weight - weight_mean 31 | std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5 32 | weight = weight / std.expand_as(weight) 33 | return F.conv2d(x, weight, self.bias, self.stride, 34 | self.padding, self.dilation, self.groups) 35 | 36 | 37 | class ASPP(nn.Module): 38 | 39 | def __init__(self, C, depth, num_classes, conv=nn.Conv2d, norm=nn.BatchNorm2d, momentum=0.0003, mult=1): 40 | super(ASPP, self).__init__() 41 | self._C = C 42 | self._depth = depth 43 | self._num_classes = num_classes 44 | 45 | self.global_pooling = nn.AdaptiveAvgPool2d(1) 46 | self.relu = nn.ReLU(inplace=True) 47 | self.aspp1 = conv(C, depth, kernel_size=1, stride=1, bias=False) 48 | self.aspp2 = conv(C, depth, kernel_size=3, stride=1, 49 | dilation=int(6*mult), padding=int(6*mult), 50 | bias=False) 51 | self.aspp3 = conv(C, depth, kernel_size=3, stride=1, 52 | dilation=int(12*mult), padding=int(12*mult), 53 | bias=False) 54 | self.aspp4 = conv(C, depth, kernel_size=3, stride=1, 55 | dilation=int(18*mult), padding=int(18*mult), 56 | bias=False) 57 | self.aspp5 = conv(C, depth, kernel_size=1, stride=1, bias=False) 58 | self.aspp1_bn = norm(depth, momentum) 59 | self.aspp2_bn = norm(depth, momentum) 60 | self.aspp3_bn = norm(depth, momentum) 61 | self.aspp4_bn = norm(depth, momentum) 62 | self.aspp5_bn = norm(depth, momentum) 63 | self.conv2 = conv(depth * 5, depth, kernel_size=1, stride=1, 64 | bias=False) 65 | self.bn2 = norm(depth, momentum) 66 | self.conv3 = nn.Conv2d(depth, num_classes, kernel_size=1, stride=1) 67 | 68 | def forward(self, x): 69 | x1 = self.aspp1(x) 70 | x1 = self.aspp1_bn(x1) 71 | x1 = self.relu(x1) 72 | x2 = self.aspp2(x) 73 | x2 = self.aspp2_bn(x2) 74 | x2 = self.relu(x2) 75 | x3 = self.aspp3(x) 76 | x3 = self.aspp3_bn(x3) 77 | x3 = self.relu(x3) 78 | x4 = self.aspp4(x) 79 | x4 = self.aspp4_bn(x4) 80 | x4 = self.relu(x4) 81 | x5 = self.global_pooling(x) 82 | x5 = self.aspp5(x5) 83 | x5 = self.aspp5_bn(x5) 84 | x5 = self.relu(x5) 85 | x5 = nn.Upsample((x.shape[2], x.shape[3]), mode='bilinear', 86 | align_corners=True)(x5) 87 | x = torch.cat((x1, x2, x3, x4, x5), 1) 88 | x = self.conv2(x) 89 | x = self.bn2(x) 90 | x = self.relu(x) 91 | x = self.conv3(x) 92 | 93 | return x 94 | 95 | 96 | class Bottleneck(nn.Module): 97 | expansion = 4 98 | 99 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1, conv=None, norm=None): 100 | super(Bottleneck, self).__init__() 101 | self.conv1 = conv(inplanes, planes, kernel_size=1, bias=False) 102 | self.bn1 = norm(planes) 103 | self.conv2 = conv(planes, planes, kernel_size=3, stride=stride, 104 | dilation=dilation, padding=dilation, bias=False) 105 | self.bn2 = norm(planes) 106 | self.conv3 = conv(planes, planes * self.expansion, kernel_size=1, bias=False) 107 | self.bn3 = norm(planes * self.expansion) 108 | self.relu = nn.ReLU(inplace=True) 109 | self.downsample = downsample 110 | self.stride = stride 111 | 112 | def forward(self, x): 113 | residual = x 114 | 115 | out = self.conv1(x) 116 | out = self.bn1(out) 117 | out = self.relu(out) 118 | 119 | out = self.conv2(out) 120 | out = self.bn2(out) 121 | out = self.relu(out) 122 | 123 | out = self.conv3(out) 124 | out = self.bn3(out) 125 | 126 | if self.downsample is not None: 127 | residual = self.downsample(x) 128 | 129 | out += residual 130 | out = self.relu(out) 131 | 132 | return out 133 | 134 | 135 | class ResNet(nn.Module): 136 | 137 | def __init__(self, block, layers, num_classes, num_groups=None, weight_std=False, beta=False): 138 | self.inplanes = 64 139 | self.norm = lambda planes, momentum=0.05: nn.BatchNorm2d(planes, momentum=momentum) if num_groups is None else nn.GroupNorm(num_groups, planes) 140 | self.conv = Conv2d if weight_std else nn.Conv2d 141 | 142 | super(ResNet, self).__init__() 143 | if not beta: 144 | self.conv1 = self.conv(3, 64, kernel_size=7, stride=2, padding=3, 145 | bias=False) 146 | else: 147 | self.conv1 = nn.Sequential( 148 | self.conv(3, 64, 3, stride=2, padding=1, bias=False), 149 | self.conv(64, 64, 3, stride=1, padding=1, bias=False), 150 | self.conv(64, 64, 3, stride=1, padding=1, bias=False)) 151 | self.bn1 = self.norm(64) 152 | self.relu = nn.ReLU(inplace=True) 153 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 154 | self.layer1 = self._make_layer(block, 64, layers[0]) 155 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 156 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 157 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, 158 | dilation=2) 159 | self.aspp = ASPP(512 * block.expansion, 256, num_classes, conv=self.conv, norm=self.norm) 160 | 161 | for m in self.modules(): 162 | if isinstance(m, self.conv): 163 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 164 | m.weight.data.normal_(0, math.sqrt(2. / n)) 165 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.GroupNorm): 166 | m.weight.data.fill_(1) 167 | m.bias.data.zero_() 168 | 169 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 170 | downsample = None 171 | if stride != 1 or dilation != 1 or self.inplanes != planes * block.expansion: 172 | downsample = nn.Sequential( 173 | self.conv(self.inplanes, planes * block.expansion, 174 | kernel_size=1, stride=stride, dilation=max(1, dilation/2), bias=False), 175 | self.norm(planes * block.expansion), 176 | ) 177 | 178 | layers = [] 179 | layers.append(block(self.inplanes, planes, stride, downsample, dilation=max(1, dilation/2), conv=self.conv, norm=self.norm)) 180 | self.inplanes = planes * block.expansion 181 | for i in range(1, blocks): 182 | layers.append(block(self.inplanes, planes, dilation=dilation, conv=self.conv, norm=self.norm)) 183 | 184 | return nn.Sequential(*layers) 185 | 186 | def forward(self, x): 187 | size = (x.shape[2], x.shape[3]) 188 | x = self.conv1(x) 189 | x = self.bn1(x) 190 | x = self.relu(x) 191 | x = self.maxpool(x) 192 | 193 | x = self.layer1(x) 194 | x = self.layer2(x) 195 | x = self.layer3(x) 196 | x = self.layer4(x) 197 | 198 | x = self.aspp(x) 199 | x = nn.Upsample(size, mode='bilinear', align_corners=True)(x) 200 | return x 201 | 202 | 203 | def resnet50(pretrained=False, **kwargs): 204 | """Constructs a ResNet-50 model. 205 | 206 | Args: 207 | pretrained (bool): If True, returns a model pre-trained on ImageNet 208 | """ 209 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 210 | if pretrained: 211 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 212 | return model 213 | 214 | 215 | def resnet101(pretrained=False, num_groups=None, weight_std=False, **kwargs): 216 | """Constructs a ResNet-101 model. 217 | 218 | Args: 219 | pretrained (bool): If True, returns a model pre-trained on ImageNet 220 | """ 221 | model = ResNet(Bottleneck, [3, 4, 23, 3], num_groups=num_groups, weight_std=weight_std, **kwargs) 222 | if pretrained: 223 | model_dict = model.state_dict() 224 | if num_groups and weight_std: 225 | pretrained_dict = torch.load('data/R-101-GN-WS.pth.tar') 226 | overlap_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict} 227 | assert len(overlap_dict) == 312 228 | elif not num_groups and not weight_std: 229 | pretrained_dict = model_zoo.load_url(model_urls['resnet101']) 230 | overlap_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 231 | else: 232 | raise ValueError('Currently only support BN or GN+WS') 233 | model_dict.update(overlap_dict) 234 | model.load_state_dict(model_dict) 235 | return model 236 | 237 | 238 | def resnet152(pretrained=False, **kwargs): 239 | """Constructs a ResNet-152 model. 240 | 241 | Args: 242 | pretrained (bool): If True, returns a model pre-trained on ImageNet 243 | """ 244 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 245 | if pretrained: 246 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 247 | return model 248 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import pdb 8 | from PIL import Image 9 | from scipy.io import loadmat 10 | from torch.autograd import Variable 11 | from torchvision import transforms 12 | 13 | import deeplab 14 | from pascal import VOCSegmentation 15 | from cityscapes import Cityscapes 16 | from utils import AverageMeter, inter_and_union 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--train', action='store_true', default=False, 20 | help='training mode') 21 | parser.add_argument('--exp', type=str, required=True, 22 | help='name of experiment') 23 | parser.add_argument('--gpu', type=int, default=0, 24 | help='test time gpu device id') 25 | parser.add_argument('--backbone', type=str, default='resnet101', 26 | help='resnet101') 27 | parser.add_argument('--dataset', type=str, default='pascal', 28 | help='pascal or cityscapes') 29 | parser.add_argument('--groups', type=int, default=None, 30 | help='num of groups for group normalization') 31 | parser.add_argument('--epochs', type=int, default=30, 32 | help='num of training epochs') 33 | parser.add_argument('--batch_size', type=int, default=16, 34 | help='batch size') 35 | parser.add_argument('--base_lr', type=float, default=0.00025, 36 | help='base learning rate') 37 | parser.add_argument('--last_mult', type=float, default=1.0, 38 | help='learning rate multiplier for last layers') 39 | parser.add_argument('--scratch', action='store_true', default=False, 40 | help='train from scratch') 41 | parser.add_argument('--freeze_bn', action='store_true', default=False, 42 | help='freeze batch normalization parameters') 43 | parser.add_argument('--weight_std', action='store_true', default=False, 44 | help='weight standardization') 45 | parser.add_argument('--beta', action='store_true', default=False, 46 | help='resnet101 beta') 47 | parser.add_argument('--crop_size', type=int, default=513, 48 | help='image crop size') 49 | parser.add_argument('--resume', type=str, default=None, 50 | help='path to checkpoint to resume from') 51 | parser.add_argument('--workers', type=int, default=4, 52 | help='number of data loading workers') 53 | args = parser.parse_args() 54 | 55 | 56 | def main(): 57 | assert torch.cuda.is_available() 58 | torch.backends.cudnn.benchmark = True 59 | model_fname = 'data/deeplab_{0}_{1}_v3_{2}_epoch%d.pth'.format( 60 | args.backbone, args.dataset, args.exp) 61 | if args.dataset == 'pascal': 62 | dataset = VOCSegmentation('data/VOCdevkit', 63 | train=args.train, crop_size=args.crop_size) 64 | elif args.dataset == 'cityscapes': 65 | dataset = Cityscapes('data/cityscapes', 66 | train=args.train, crop_size=args.crop_size) 67 | else: 68 | raise ValueError('Unknown dataset: {}'.format(args.dataset)) 69 | if args.backbone == 'resnet101': 70 | model = getattr(deeplab, 'resnet101')( 71 | pretrained=(not args.scratch), 72 | num_classes=len(dataset.CLASSES), 73 | num_groups=args.groups, 74 | weight_std=args.weight_std, 75 | beta=args.beta) 76 | else: 77 | raise ValueError('Unknown backbone: {}'.format(args.backbone)) 78 | 79 | if args.train: 80 | criterion = nn.CrossEntropyLoss(ignore_index=255) 81 | model = nn.DataParallel(model).cuda() 82 | model.train() 83 | if args.freeze_bn: 84 | for m in model.modules(): 85 | if isinstance(m, nn.BatchNorm2d): 86 | m.eval() 87 | m.weight.requires_grad = False 88 | m.bias.requires_grad = False 89 | backbone_params = ( 90 | list(model.module.conv1.parameters()) + 91 | list(model.module.bn1.parameters()) + 92 | list(model.module.layer1.parameters()) + 93 | list(model.module.layer2.parameters()) + 94 | list(model.module.layer3.parameters()) + 95 | list(model.module.layer4.parameters())) 96 | last_params = list(model.module.aspp.parameters()) 97 | optimizer = optim.SGD([ 98 | {'params': filter(lambda p: p.requires_grad, backbone_params)}, 99 | {'params': filter(lambda p: p.requires_grad, last_params)}], 100 | lr=args.base_lr, momentum=0.9, weight_decay=0.0001) 101 | dataset_loader = torch.utils.data.DataLoader( 102 | dataset, batch_size=args.batch_size, shuffle=args.train, 103 | pin_memory=True, num_workers=args.workers) 104 | max_iter = args.epochs * len(dataset_loader) 105 | losses = AverageMeter() 106 | start_epoch = 0 107 | 108 | if args.resume: 109 | if os.path.isfile(args.resume): 110 | print('=> loading checkpoint {0}'.format(args.resume)) 111 | checkpoint = torch.load(args.resume) 112 | start_epoch = checkpoint['epoch'] 113 | model.load_state_dict(checkpoint['state_dict']) 114 | optimizer.load_state_dict(checkpoint['optimizer']) 115 | print('=> loaded checkpoint {0} (epoch {1})'.format( 116 | args.resume, checkpoint['epoch'])) 117 | else: 118 | print('=> no checkpoint found at {0}'.format(args.resume)) 119 | 120 | for epoch in range(start_epoch, args.epochs): 121 | for i, (inputs, target) in enumerate(dataset_loader): 122 | cur_iter = epoch * len(dataset_loader) + i 123 | lr = args.base_lr * (1 - float(cur_iter) / max_iter) ** 0.9 124 | optimizer.param_groups[0]['lr'] = lr 125 | optimizer.param_groups[1]['lr'] = lr * args.last_mult 126 | 127 | inputs = Variable(inputs.cuda()) 128 | target = Variable(target.cuda()) 129 | outputs = model(inputs) 130 | loss = criterion(outputs, target) 131 | if np.isnan(loss.item()) or np.isinf(loss.item()): 132 | pdb.set_trace() 133 | losses.update(loss.item(), args.batch_size) 134 | 135 | loss.backward() 136 | optimizer.step() 137 | optimizer.zero_grad() 138 | 139 | print('epoch: {0}\t' 140 | 'iter: {1}/{2}\t' 141 | 'lr: {3:.6f}\t' 142 | 'loss: {loss.val:.4f} ({loss.ema:.4f})'.format( 143 | epoch + 1, i + 1, len(dataset_loader), lr, loss=losses)) 144 | 145 | if epoch % 10 == 9: 146 | torch.save({ 147 | 'epoch': epoch + 1, 148 | 'state_dict': model.state_dict(), 149 | 'optimizer': optimizer.state_dict(), 150 | }, model_fname % (epoch + 1)) 151 | 152 | else: 153 | torch.cuda.set_device(args.gpu) 154 | model = model.cuda() 155 | model.eval() 156 | checkpoint = torch.load(model_fname % args.epochs) 157 | state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items() if 'tracked' not in k} 158 | model.load_state_dict(state_dict) 159 | cmap = loadmat('data/pascal_seg_colormap.mat')['colormap'] 160 | cmap = (cmap * 255).astype(np.uint8).flatten().tolist() 161 | 162 | inter_meter = AverageMeter() 163 | union_meter = AverageMeter() 164 | for i in range(len(dataset)): 165 | inputs, target = dataset[i] 166 | inputs = Variable(inputs.cuda()) 167 | outputs = model(inputs.unsqueeze(0)) 168 | _, pred = torch.max(outputs, 1) 169 | pred = pred.data.cpu().numpy().squeeze().astype(np.uint8) 170 | mask = target.numpy().astype(np.uint8) 171 | imname = dataset.masks[i].split('/')[-1] 172 | mask_pred = Image.fromarray(pred) 173 | mask_pred.putpalette(cmap) 174 | mask_pred.save(os.path.join('data/val', imname)) 175 | print('eval: {0}/{1}'.format(i + 1, len(dataset))) 176 | 177 | inter, union = inter_and_union(pred, mask, len(dataset.CLASSES)) 178 | inter_meter.update(inter) 179 | union_meter.update(union) 180 | 181 | iou = inter_meter.sum / (union_meter.sum + 1e-10) 182 | for i, val in enumerate(iou): 183 | print('IoU {0}: {1:.2f}'.format(dataset.CLASSES[i], val * 100)) 184 | print('Mean IoU: {0:.2f}'.format(iou.mean() * 100)) 185 | 186 | 187 | if __name__ == "__main__": 188 | main() 189 | -------------------------------------------------------------------------------- /pascal.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch.utils.data as data 4 | import os 5 | from PIL import Image 6 | from utils import preprocess 7 | 8 | 9 | class VOCSegmentation(data.Dataset): 10 | CLASSES = [ 11 | 'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 12 | 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 13 | 'motorbike', 'person', 'potted-plant', 'sheep', 'sofa', 'train', 14 | 'tv/monitor' 15 | ] 16 | 17 | def __init__(self, root, train=True, transform=None, target_transform=None, download=False, crop_size=None): 18 | self.root = root 19 | _voc_root = os.path.join(self.root, 'VOC2012') 20 | _list_dir = os.path.join(_voc_root, 'list') 21 | self.transform = transform 22 | self.target_transform = target_transform 23 | self.train = train 24 | self.crop_size = crop_size 25 | 26 | if download: 27 | self.download() 28 | 29 | if self.train: 30 | _list_f = os.path.join(_list_dir, 'train_aug.txt') 31 | else: 32 | _list_f = os.path.join(_list_dir, 'val.txt') 33 | self.images = [] 34 | self.masks = [] 35 | with open(_list_f, 'r') as lines: 36 | for line in lines: 37 | _image = _voc_root + line.split()[0] 38 | _mask = _voc_root + line.split()[1] 39 | assert os.path.isfile(_image) 40 | assert os.path.isfile(_mask) 41 | self.images.append(_image) 42 | self.masks.append(_mask) 43 | 44 | def __getitem__(self, index): 45 | _img = Image.open(self.images[index]).convert('RGB') 46 | _target = Image.open(self.masks[index]) 47 | 48 | _img, _target = preprocess(_img, _target, 49 | flip=True if self.train else False, 50 | scale=(0.5, 2.0) if self.train else None, 51 | crop=(self.crop_size, self.crop_size)) 52 | 53 | if self.transform is not None: 54 | _img = self.transform(_img) 55 | 56 | if self.target_transform is not None: 57 | _target = self.target_transform(_target) 58 | 59 | return _img, _target 60 | 61 | def __len__(self): 62 | return len(self.images) 63 | 64 | def download(self): 65 | raise NotImplementedError('Automatic download not yet implemented.') 66 | 67 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import numpy as np 4 | import torch 5 | import torchvision.transforms as transforms 6 | from PIL import Image 7 | 8 | 9 | class AverageMeter(object): 10 | def __init__(self): 11 | self.val = None 12 | self.sum = None 13 | self.cnt = None 14 | self.avg = None 15 | self.ema = None 16 | self.initialized = False 17 | 18 | def update(self, val, n=1): 19 | if not self.initialized: 20 | self.initialize(val, n) 21 | else: 22 | self.add(val, n) 23 | 24 | def initialize(self, val, n): 25 | self.val = val 26 | self.sum = val * n 27 | self.cnt = n 28 | self.avg = val 29 | self.ema = val 30 | self.initialized = True 31 | 32 | def add(self, val, n): 33 | self.val = val 34 | self.sum += val * n 35 | self.cnt += n 36 | self.avg = self.sum / self.cnt 37 | self.ema = self.ema * 0.99 + self.val * 0.01 38 | 39 | 40 | def inter_and_union(pred, mask, num_class): 41 | pred = np.asarray(pred, dtype=np.uint8).copy() 42 | mask = np.asarray(mask, dtype=np.uint8).copy() 43 | 44 | # 255 -> 0 45 | pred += 1 46 | mask += 1 47 | pred = pred * (mask > 0) 48 | 49 | inter = pred * (pred == mask) 50 | (area_inter, _) = np.histogram(inter, bins=num_class, range=(1, num_class)) 51 | (area_pred, _) = np.histogram(pred, bins=num_class, range=(1, num_class)) 52 | (area_mask, _) = np.histogram(mask, bins=num_class, range=(1, num_class)) 53 | area_union = area_pred + area_mask - area_inter 54 | 55 | return (area_inter, area_union) 56 | 57 | 58 | def preprocess(image, mask, flip=False, scale=None, crop=None): 59 | if flip: 60 | if random.random() < 0.5: 61 | image = image.transpose(Image.FLIP_LEFT_RIGHT) 62 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT) 63 | if scale: 64 | w, h = image.size 65 | rand_log_scale = math.log(scale[0], 2) + random.random() * (math.log(scale[1], 2) - math.log(scale[0], 2)) 66 | random_scale = math.pow(2, rand_log_scale) 67 | new_size = (int(round(w * random_scale)), int(round(h * random_scale))) 68 | image = image.resize(new_size, Image.ANTIALIAS) 69 | mask = mask.resize(new_size, Image.NEAREST) 70 | 71 | data_transforms = transforms.Compose([ 72 | transforms.ToTensor(), 73 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 74 | ]) 75 | image = data_transforms(image) 76 | mask = torch.LongTensor(np.array(mask).astype(np.int64)) 77 | 78 | if crop: 79 | h, w = image.shape[1], image.shape[2] 80 | pad_tb = max(0, crop[0] - h) 81 | pad_lr = max(0, crop[1] - w) 82 | image = torch.nn.ZeroPad2d((0, pad_lr, 0, pad_tb))(image) 83 | mask = torch.nn.ConstantPad2d((0, pad_lr, 0, pad_tb), 255)(mask) 84 | 85 | h, w = image.shape[1], image.shape[2] 86 | i = random.randint(0, h - crop[0]) 87 | j = random.randint(0, w - crop[1]) 88 | image = image[:, i:i + crop[0], j:j + crop[1]] 89 | mask = mask[i:i + crop[0], j:j + crop[1]] 90 | 91 | return image, mask 92 | 93 | --------------------------------------------------------------------------------