├── .gitignore ├── README.md ├── augmentation.py ├── extractors.py ├── pspnet.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | .idea 104 | .idea/ 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pspnet-pytorch 2 | PyTorch implementation of PSPNet segmentation network 3 | 4 | 5 | ### Original paper 6 | 7 | [Pyramid Scene Parsing Network](https://arxiv.org/abs/1612.01105) 8 | 9 | ### Details 10 | 11 | This is a slightly different version - instead of direct 8x upsampling at the end I use three consequitive upsamplings for stability. 12 | 13 | ### Feature extraction 14 | 15 | Using pretrained weights for extractors - improved quality and convergence dramatically. 16 | 17 | Currently supported: 18 | 19 | * SqueezeNet 20 | * DenseNet-121 21 | * ResNet-18 22 | * ResNet-34 23 | * ResNet-50 24 | * ResNet-101 25 | * ResNet-152 26 | 27 | Planned: 28 | 29 | * DenseNet-169 30 | * DenseNet-201 31 | 32 | ### Usage 33 | 34 | To follow the training routine in train.py you need a DataLoader that yields the tuples of the following format: 35 | 36 | (Bx3xHxW FloatTensor x, BxHxW LongTensor y, BxN LongTensor y\_cls) where 37 | 38 | x - batch of input images, 39 | 40 | y - batch of groung truth seg maps, 41 | 42 | y\_cls - batch of 1D tensors of dimensionality N: N total number of classes, 43 | 44 | y\_cls[i, T] = 1 if class T is present in image i, 0 otherwise 45 | -------------------------------------------------------------------------------- /augmentation.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numbers 3 | import math 4 | import collections 5 | 6 | from PIL import ImageOps, Image 7 | import numpy as np 8 | 9 | 10 | class Padding: 11 | def __init__(self, pad): 12 | self.pad = pad 13 | 14 | def __call__(self, img): 15 | return ImageOps.expand(img, border=self.pad, fill=0) 16 | 17 | 18 | class Scale: 19 | def __init__(self, size, interpolation=Image.NEAREST): 20 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) 21 | self.size = size 22 | self.interpolation = interpolation 23 | 24 | def __call__(self, imgmap): 25 | img, target = imgmap 26 | if isinstance(self.size, int): 27 | w, h = img.size 28 | if (w <= h and w == self.size) or (h <= w and h == self.size): 29 | return img, target 30 | if w < h: 31 | ow = self.size 32 | oh = int(self.size * h / w) 33 | return img.resize((ow, oh), self.interpolation), target.resize((ow, oh), self.interpolation) 34 | else: 35 | oh = self.size 36 | ow = int(self.size * w / h) 37 | return img.resize((ow, oh), self.interpolation), target.resize((ow, oh), self.interpolation) 38 | else: 39 | return img.resize(self.size, self.interpolation), target.resize(self.size, self.interpolation) 40 | 41 | 42 | class CenterCrop: 43 | def __init__(self, size): 44 | if isinstance(size, numbers.Number): 45 | self.size = (int(size), int(size)) 46 | else: 47 | self.size = size 48 | 49 | def __call__(self, imgmap): 50 | img, target = imgmap 51 | w, h = img.size 52 | th, tw = self.size 53 | x1 = int(round((w - tw) / 2.)) 54 | y1 = int(round((h - th) / 2.)) 55 | return img.crop((x1, y1, x1 + tw, y1 + th)), target.crop((x1, y1, x1 + tw, y1 + th)) 56 | 57 | 58 | class RandomCrop: 59 | def __init__(self, size): 60 | if isinstance(size, numbers.Number): 61 | self.size = (int(size), int(size)) 62 | else: 63 | self.size = size 64 | 65 | def __call__(self, imgmap): 66 | img, target = imgmap 67 | w, h = img.size 68 | if self.size is not None: 69 | th, tw = self.size 70 | if w == tw and h == th: 71 | return img, target 72 | else: 73 | x1 = random.randint(0, w - tw) 74 | y1 = random.randint(0, h - th) 75 | return img.crop((x1, y1, x1 + tw, y1 + th)), target.crop((x1, y1, x1 + tw, y1 + th)) 76 | else: 77 | return img, target 78 | 79 | 80 | class RandomSizedCrop: 81 | 82 | def __init__(self, size, interpolation=Image.NEAREST): 83 | self.size = size 84 | self.interpolation = interpolation 85 | 86 | def __call__(self, imgmap): 87 | img, target = imgmap 88 | for attempt in range(10): 89 | area = img.size[0] * img.size[1] 90 | target_area = random.uniform(0.5, 1.0) * area 91 | aspect_ratio = random.uniform(3. / 4, 4. / 3) 92 | 93 | w = int(round(math.sqrt(target_area * aspect_ratio))) 94 | h = int(round(math.sqrt(target_area / aspect_ratio))) 95 | 96 | if random.random() < 0.5: 97 | w, h = h, w 98 | 99 | if w <= img.size[0] and h <= img.size[1]: 100 | x1 = random.randint(0, img.size[0] - w) 101 | y1 = random.randint(0, img.size[1] - h) 102 | 103 | img = img.crop((x1, y1, x1 + w, y1 + h)) 104 | target = target.crop((x1, y1, x1 + w, y1 + h)) 105 | assert(img.size == (w, h)) 106 | assert(target.size == (w, h)) 107 | 108 | return img.resize((self.size, self.size), self.interpolation), \ 109 | target.resize((self.size, self.size), self.interpolation) 110 | 111 | # Fallback 112 | scale = Scale(self.size, interpolation=self.interpolation) 113 | crop = CenterCrop(self.size) 114 | return crop(scale((img, target))) 115 | 116 | 117 | class RandomHorizontalFlip: 118 | 119 | def __call__(self, imgmap): 120 | img, target = imgmap 121 | if random.random() < 0.5: 122 | return img.transpose(Image.FLIP_LEFT_RIGHT), target.transpose(Image.FLIP_LEFT_RIGHT) 123 | return img, target 124 | 125 | 126 | class RandomRotation: 127 | 128 | def __call__(self, imgmap, degree=10): 129 | img, target = imgmap 130 | deg = np.random.randint(-degree, degree, 1)[0] 131 | return img.rotate(deg), target.rotate(deg) 132 | -------------------------------------------------------------------------------- /extractors.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.utils import model_zoo 8 | from torchvision.models.densenet import densenet121, densenet161 9 | from torchvision.models.squeezenet import squeezenet1_1 10 | 11 | 12 | def load_weights_sequential(target, source_state): 13 | new_dict = OrderedDict() 14 | for (k1, v1), (k2, v2) in zip(target.state_dict().items(), source_state.items()): 15 | new_dict[k1] = v2 16 | target.load_state_dict(new_dict) 17 | 18 | ''' 19 | Implementation of dilated ResNet-101 with deep supervision. Downsampling is changed to 8x 20 | ''' 21 | model_urls = { 22 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 23 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 24 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 25 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 26 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 27 | } 28 | 29 | 30 | def conv3x3(in_planes, out_planes, stride=1, dilation=1): 31 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 32 | padding=dilation, dilation=dilation, bias=False) 33 | 34 | 35 | class BasicBlock(nn.Module): 36 | expansion = 1 37 | 38 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 39 | super(BasicBlock, self).__init__() 40 | self.conv1 = conv3x3(inplanes, planes, stride=stride, dilation=dilation) 41 | self.bn1 = nn.BatchNorm2d(planes) 42 | self.relu = nn.ReLU(inplace=True) 43 | self.conv2 = conv3x3(planes, planes, stride=1, dilation=dilation) 44 | self.bn2 = nn.BatchNorm2d(planes) 45 | self.downsample = downsample 46 | self.stride = stride 47 | 48 | def forward(self, x): 49 | residual = x 50 | 51 | out = self.conv1(x) 52 | out = self.bn1(out) 53 | out = self.relu(out) 54 | 55 | out = self.conv2(out) 56 | out = self.bn2(out) 57 | 58 | if self.downsample is not None: 59 | residual = self.downsample(x) 60 | 61 | out += residual 62 | out = self.relu(out) 63 | 64 | return out 65 | 66 | 67 | class Bottleneck(nn.Module): 68 | expansion = 4 69 | 70 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 71 | super(Bottleneck, self).__init__() 72 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 73 | self.bn1 = nn.BatchNorm2d(planes) 74 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, dilation=dilation, 75 | padding=dilation, bias=False) 76 | self.bn2 = nn.BatchNorm2d(planes) 77 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 78 | self.bn3 = nn.BatchNorm2d(planes * 4) 79 | self.relu = nn.ReLU(inplace=True) 80 | self.downsample = downsample 81 | self.stride = stride 82 | 83 | def forward(self, x): 84 | residual = x 85 | 86 | out = self.conv1(x) 87 | out = self.bn1(out) 88 | out = self.relu(out) 89 | 90 | out = self.conv2(out) 91 | out = self.bn2(out) 92 | out = self.relu(out) 93 | 94 | out = self.conv3(out) 95 | out = self.bn3(out) 96 | 97 | if self.downsample is not None: 98 | residual = self.downsample(x) 99 | 100 | out += residual 101 | out = self.relu(out) 102 | 103 | return out 104 | 105 | 106 | class ResNet(nn.Module): 107 | def __init__(self, block, layers=(3, 4, 23, 3)): 108 | self.inplanes = 64 109 | super(ResNet, self).__init__() 110 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 111 | bias=False) 112 | self.bn1 = nn.BatchNorm2d(64) 113 | self.relu = nn.ReLU(inplace=True) 114 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 115 | self.layer1 = self._make_layer(block, 64, layers[0]) 116 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 117 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2) 118 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4) 119 | 120 | for m in self.modules(): 121 | if isinstance(m, nn.Conv2d): 122 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 123 | m.weight.data.normal_(0, math.sqrt(2. / n)) 124 | elif isinstance(m, nn.BatchNorm2d): 125 | m.weight.data.fill_(1) 126 | m.bias.data.zero_() 127 | 128 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 129 | downsample = None 130 | if stride != 1 or self.inplanes != planes * block.expansion: 131 | downsample = nn.Sequential( 132 | nn.Conv2d(self.inplanes, planes * block.expansion, 133 | kernel_size=1, stride=stride, bias=False), 134 | nn.BatchNorm2d(planes * block.expansion), 135 | ) 136 | 137 | layers = [block(self.inplanes, planes, stride, downsample)] 138 | self.inplanes = planes * block.expansion 139 | for i in range(1, blocks): 140 | layers.append(block(self.inplanes, planes, dilation=dilation)) 141 | 142 | return nn.Sequential(*layers) 143 | 144 | def forward(self, x): 145 | x = self.conv1(x) 146 | x = self.bn1(x) 147 | x = self.relu(x) 148 | x = self.maxpool(x) 149 | 150 | x = self.layer1(x) 151 | x = self.layer2(x) 152 | x_3 = self.layer3(x) 153 | x = self.layer4(x_3) 154 | 155 | return x, x_3 156 | 157 | 158 | ''' 159 | Implementation of DenseNet with deep supervision. Downsampling is changed to 8x 160 | ''' 161 | 162 | 163 | class _DenseLayer(nn.Sequential): 164 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): 165 | super(_DenseLayer, self).__init__() 166 | self.add_module('norm.1', nn.BatchNorm2d(num_input_features)), 167 | self.add_module('relu.1', nn.ReLU(inplace=True)), 168 | self.add_module('conv.1', nn.Conv2d(num_input_features, bn_size * 169 | growth_rate, kernel_size=1, stride=1, bias=False)), 170 | self.add_module('norm.2', nn.BatchNorm2d(bn_size * growth_rate)), 171 | self.add_module('relu.2', nn.ReLU(inplace=True)), 172 | self.add_module('conv.2', nn.Conv2d(bn_size * growth_rate, growth_rate, 173 | kernel_size=3, stride=1, padding=1, bias=False)), 174 | self.drop_rate = drop_rate 175 | 176 | def forward(self, x): 177 | new_features = super(_DenseLayer, self).forward(x) 178 | if self.drop_rate > 0: 179 | new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) 180 | return torch.cat([x, new_features], 1) 181 | 182 | 183 | class _DenseBlock(nn.Sequential): 184 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): 185 | super(_DenseBlock, self).__init__() 186 | for i in range(num_layers): 187 | layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate) 188 | self.add_module('denselayer%d' % (i + 1), layer) 189 | 190 | 191 | class _Transition(nn.Sequential): 192 | def __init__(self, num_input_features, num_output_features, downsample=True): 193 | super(_Transition, self).__init__() 194 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 195 | self.add_module('relu', nn.ReLU(inplace=True)) 196 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 197 | kernel_size=1, stride=1, bias=False)) 198 | if downsample: 199 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 200 | else: 201 | self.add_module('pool', nn.AvgPool2d(kernel_size=1, stride=1)) # compatibility hack 202 | 203 | 204 | class DenseNet(nn.Module): 205 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), 206 | num_init_features=64, bn_size=4, drop_rate=0, pretrained=True): 207 | 208 | super(DenseNet, self).__init__() 209 | 210 | # First convolution 211 | self.start_features = nn.Sequential(OrderedDict([ 212 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), 213 | ('norm0', nn.BatchNorm2d(num_init_features)), 214 | ('relu0', nn.ReLU(inplace=True)), 215 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 216 | ])) 217 | 218 | # Each denseblock 219 | num_features = num_init_features 220 | 221 | init_weights = list(densenet121(pretrained=True).features.children()) 222 | start = 0 223 | for i, c in enumerate(self.start_features.children()): 224 | if pretrained: 225 | c.load_state_dict(init_weights[i].state_dict()) 226 | start += 1 227 | self.blocks = nn.ModuleList() 228 | for i, num_layers in enumerate(block_config): 229 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, 230 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate) 231 | if pretrained: 232 | block.load_state_dict(init_weights[start].state_dict()) 233 | start += 1 234 | self.blocks.append(block) 235 | setattr(self, 'denseblock%d' % (i + 1), block) 236 | 237 | num_features = num_features + num_layers * growth_rate 238 | if i != len(block_config) - 1: 239 | downsample = i < 1 240 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2, 241 | downsample=downsample) 242 | if pretrained: 243 | trans.load_state_dict(init_weights[start].state_dict()) 244 | start += 1 245 | self.blocks.append(trans) 246 | setattr(self, 'transition%d' % (i + 1), trans) 247 | num_features = num_features // 2 248 | 249 | def forward(self, x): 250 | out = self.start_features(x) 251 | deep_features = None 252 | for i, block in enumerate(self.blocks): 253 | out = block(out) 254 | if i == 5: 255 | deep_features = out 256 | 257 | return out, deep_features 258 | 259 | 260 | class Fire(nn.Module): 261 | 262 | def __init__(self, inplanes, squeeze_planes, 263 | expand1x1_planes, expand3x3_planes, dilation=1): 264 | super(Fire, self).__init__() 265 | self.inplanes = inplanes 266 | self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1) 267 | self.squeeze_activation = nn.ReLU(inplace=True) 268 | self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, 269 | kernel_size=1) 270 | self.expand1x1_activation = nn.ReLU(inplace=True) 271 | self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, 272 | kernel_size=3, padding=dilation, dilation=dilation) 273 | self.expand3x3_activation = nn.ReLU(inplace=True) 274 | 275 | def forward(self, x): 276 | x = self.squeeze_activation(self.squeeze(x)) 277 | return torch.cat([ 278 | self.expand1x1_activation(self.expand1x1(x)), 279 | self.expand3x3_activation(self.expand3x3(x)) 280 | ], 1) 281 | 282 | 283 | class SqueezeNet(nn.Module): 284 | 285 | def __init__(self, pretrained=False): 286 | super(SqueezeNet, self).__init__() 287 | 288 | self.feat_1 = nn.Sequential( 289 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), 290 | nn.ReLU(inplace=True) 291 | ) 292 | self.feat_2 = nn.Sequential( 293 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1), 294 | Fire(64, 16, 64, 64), 295 | Fire(128, 16, 64, 64) 296 | ) 297 | self.feat_3 = nn.Sequential( 298 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1), 299 | Fire(128, 32, 128, 128, 2), 300 | Fire(256, 32, 128, 128, 2) 301 | ) 302 | self.feat_4 = nn.Sequential( 303 | Fire(256, 48, 192, 192, 4), 304 | Fire(384, 48, 192, 192, 4), 305 | Fire(384, 64, 256, 256, 4), 306 | Fire(512, 64, 256, 256, 4) 307 | ) 308 | if pretrained: 309 | weights = squeezenet1_1(pretrained=True).features.state_dict() 310 | load_weights_sequential(self, weights) 311 | 312 | def forward(self, x): 313 | f1 = self.feat_1(x) 314 | f2 = self.feat_2(f1) 315 | f3 = self.feat_3(f2) 316 | f4 = self.feat_4(f3) 317 | return f4, f3 318 | 319 | 320 | ''' 321 | Handy methods for construction 322 | ''' 323 | 324 | 325 | def squeezenet(pretrained=True): 326 | return SqueezeNet(pretrained) 327 | 328 | 329 | def densenet(pretrained=True): 330 | return DenseNet(pretrained=pretrained) 331 | 332 | 333 | def resnet18(pretrained=True): 334 | model = ResNet(BasicBlock, [2, 2, 2, 2]) 335 | if pretrained: 336 | load_weights_sequential(model, model_zoo.load_url(model_urls['resnet18'])) 337 | return model 338 | 339 | 340 | def resnet34(pretrained=True): 341 | model = ResNet(BasicBlock, [3, 4, 6, 3]) 342 | if pretrained: 343 | load_weights_sequential(model, model_zoo.load_url(model_urls['resnet34'])) 344 | return model 345 | 346 | 347 | def resnet50(pretrained=True): 348 | model = ResNet(Bottleneck, [3, 4, 6, 3]) 349 | if pretrained: 350 | load_weights_sequential(model, model_zoo.load_url(model_urls['resnet50'])) 351 | return model 352 | 353 | 354 | def resnet101(pretrained=True): 355 | model = ResNet(Bottleneck, [3, 4, 23, 3]) 356 | if pretrained: 357 | load_weights_sequential(model, model_zoo.load_url(model_urls['resnet101'])) 358 | return model 359 | 360 | 361 | def resnet152(pretrained=True): 362 | model = ResNet(Bottleneck, [3, 8, 36, 3]) 363 | if pretrained: 364 | load_weights_sequential(model, model_zoo.load_url(model_urls['resnet152'])) 365 | return model 366 | -------------------------------------------------------------------------------- /pspnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | import extractors 6 | 7 | 8 | class PSPModule(nn.Module): 9 | def __init__(self, features, out_features=1024, sizes=(1, 2, 3, 6)): 10 | super().__init__() 11 | self.stages = [] 12 | self.stages = nn.ModuleList([self._make_stage(features, size) for size in sizes]) 13 | self.bottleneck = nn.Conv2d(features * (len(sizes) + 1), out_features, kernel_size=1) 14 | self.relu = nn.ReLU() 15 | 16 | def _make_stage(self, features, size): 17 | prior = nn.AdaptiveAvgPool2d(output_size=(size, size)) 18 | conv = nn.Conv2d(features, features, kernel_size=1, bias=False) 19 | return nn.Sequential(prior, conv) 20 | 21 | def forward(self, feats): 22 | h, w = feats.size(2), feats.size(3) 23 | priors = [F.upsample(input=stage(feats), size=(h, w), mode='bilinear') for stage in self.stages] + [feats] 24 | bottle = self.bottleneck(torch.cat(priors, 1)) 25 | return self.relu(bottle) 26 | 27 | 28 | class PSPUpsample(nn.Module): 29 | def __init__(self, in_channels, out_channels): 30 | super().__init__() 31 | self.conv = nn.Sequential( 32 | nn.Conv2d(in_channels, out_channels, 3, padding=1), 33 | nn.BatchNorm2d(out_channels), 34 | nn.PReLU() 35 | ) 36 | 37 | def forward(self, x): 38 | h, w = 2 * x.size(2), 2 * x.size(3) 39 | p = F.upsample(input=x, size=(h, w), mode='bilinear') 40 | return self.conv(p) 41 | 42 | 43 | class PSPNet(nn.Module): 44 | def __init__(self, n_classes=18, sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet34', 45 | pretrained=True): 46 | super().__init__() 47 | self.feats = getattr(extractors, backend)(pretrained) 48 | self.psp = PSPModule(psp_size, 1024, sizes) 49 | self.drop_1 = nn.Dropout2d(p=0.3) 50 | 51 | self.up_1 = PSPUpsample(1024, 256) 52 | self.up_2 = PSPUpsample(256, 64) 53 | self.up_3 = PSPUpsample(64, 64) 54 | 55 | self.drop_2 = nn.Dropout2d(p=0.15) 56 | self.final = nn.Sequential( 57 | nn.Conv2d(64, n_classes, kernel_size=1), 58 | nn.LogSoftmax() 59 | ) 60 | 61 | self.classifier = nn.Sequential( 62 | nn.Linear(deep_features_size, 256), 63 | nn.ReLU(), 64 | nn.Linear(256, n_classes) 65 | ) 66 | 67 | def forward(self, x): 68 | f, class_f = self.feats(x) 69 | p = self.psp(f) 70 | p = self.drop_1(p) 71 | 72 | p = self.up_1(p) 73 | p = self.drop_2(p) 74 | 75 | p = self.up_2(p) 76 | p = self.drop_2(p) 77 | 78 | p = self.up_3(p) 79 | p = self.drop_2(p) 80 | 81 | auxiliary = F.adaptive_max_pool2d(input=class_f, output_size=(1, 1)).view(-1, class_f.size(1)) 82 | 83 | return self.final(p), self.classifier(auxiliary) 84 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torch import nn 4 | from torch import optim 5 | from torch.optim.lr_scheduler import MultiStepLR 6 | from torch.autograd import Variable 7 | from torch.utils.data import DataLoader 8 | 9 | from tqdm import tqdm 10 | import click 11 | import numpy as np 12 | 13 | from pspnet import PSPNet 14 | 15 | 16 | models = { 17 | 'squeezenet': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=512, deep_features_size=256, backend='squeezenet'), 18 | 'densenet': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=1024, deep_features_size=512, backend='densenet'), 19 | 'resnet18': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=512, deep_features_size=256, backend='resnet18'), 20 | 'resnet34': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=512, deep_features_size=256, backend='resnet34'), 21 | 'resnet50': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet50'), 22 | 'resnet101': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet101'), 23 | 'resnet152': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet152') 24 | } 25 | 26 | 27 | def build_network(snapshot, backend): 28 | epoch = 0 29 | backend = backend.lower() 30 | net = models[backend]() 31 | net = nn.DataParallel(net) 32 | if snapshot is not None: 33 | _, epoch = os.path.basename(snapshot).split('_') 34 | epoch = int(epoch) 35 | net.load_state_dict(torch.load(snapshot)) 36 | logging.info("Snapshot for epoch {} loaded from {}".format(epoch, snapshot)) 37 | net = net.cuda() 38 | return net, epoch 39 | 40 | 41 | @click.command() 42 | @click.option('--data-path', type=str, help='Path to dataset folder') 43 | @click.option('--models-path', type=str, help='Path for storing model snapshots') 44 | @click.option('--backend', type=str, default='resnet34', help='Feature extractor') 45 | @click.option('--snapshot', type=str, default=None, help='Path to pretrained weights') 46 | @click.option('--crop_x', type=int, default=256, help='Horizontal random crop size') 47 | @click.option('--crop_y', type=int, default=256, help='Vertical random crop size') 48 | @click.option('--batch-size', type=int, default=16) 49 | @click.option('--alpha', type=float, default=1.0, help='Coefficient for classification loss term') 50 | @click.option('--epochs', type=int, default=20, help='Number of training epochs to run') 51 | @click.option('--gpu', type=str, default='0', help='List of GPUs for parallel training, e.g. 0,1,2,3') 52 | @click.option('--start-lr', type=float, default=0.001) 53 | @click.option('--milestones', type=str, default='10,20,30', help='Milestones for LR decreasing') 54 | def train(data_path, models_path, backend, snapshot, crop_x, crop_y, batch_size, alpha, epochs, start_lr, milestones, gpu): 55 | os.environ["CUDA_VISIBLE_DEVICES"] = gpu 56 | net, starting_epoch = build_network(snapshot, backend) 57 | data_path = os.path.abspath(os.path.expanduser(data_path)) 58 | models_path = os.path.abspath(os.path.expanduser(models_path)) 59 | os.makedirs(models_path, exist_ok=True) 60 | 61 | ''' 62 | To follow this training routine you need a DataLoader that yields the tuples of the following format: 63 | (Bx3xHxW FloatTensor x, BxHxW LongTensor y, BxN LongTensor y_cls) where 64 | x - batch of input images, 65 | y - batch of groung truth seg maps, 66 | y_cls - batch of 1D tensors of dimensionality N: N total number of classes, 67 | y_cls[i, T] = 1 if class T is present in image i, 0 otherwise 68 | ''' 69 | train_loader, class_weights, n_images = None, None, None 70 | 71 | optimizer = optim.Adam(net.parameters(), lr=start_lr) 72 | scheduler = MultiStepLR(optimizer, milestones=[int(x) for x in milestones.split(',')]) 73 | 74 | for epoch in range(starting_epoch, starting_epoch + epochs): 75 | seg_criterion = nn.NLLLoss2d(weight=class_weights) 76 | cls_criterion = nn.BCEWithLogitsLoss(weight=class_weights) 77 | epoch_losses = [] 78 | train_iterator = tqdm(loader, total=max_steps // batch_size + 1) 79 | net.train() 80 | for x, y, y_cls in train_iterator: 81 | steps += batch_size 82 | optimizer.zero_grad() 83 | x, y, y_cls = Variable(x).cuda(), Variable(y).cuda(), Variable(y_cls).cuda() 84 | out, out_cls = net(x) 85 | seg_loss, cls_loss = seg_criterion(out, y), cls_criterion(out_cls, y_cls) 86 | loss = seg_loss + alpha * cls_loss 87 | epoch_losses.append(loss.data[0]) 88 | status = '[{0}] loss = {1:0.5f} avg = {2:0.5f}, LR = {5:0.7f}'.format( 89 | epoch + 1, loss.data[0], np.mean(epoch_losses), scheduler.get_lr()[0]) 90 | train_iterator.set_description(status) 91 | loss.backward() 92 | optimizer.step() 93 | scheduler.step() 94 | torch.save(net.state_dict(), os.path.join(models_path, '_'.join(["PSPNet", str(epoch + 1)]))) 95 | train_loss = np.mean(epoch_losses) 96 | 97 | 98 | if __name__ == '__main__': 99 | train() 100 | --------------------------------------------------------------------------------