├── data.py ├── models.py ├── models_initial.py ├── modules ├── .txt ├── bwn.py ├── quantize.py └── rnlu.py ├── readme.md ├── requirements.txt ├── supplementary_material.pdf ├── train_base.py ├── train_sp_integrate_dynamic_quantization.py └── train_sp_integrate_dynamic_quantization_initial.py /data.py: -------------------------------------------------------------------------------- 1 | """prepare CIFAR and SVHN 2 | """ 3 | 4 | from __future__ import print_function 5 | 6 | import torch 7 | import torchvision 8 | import torchvision.transforms as transforms 9 | import numpy as np 10 | 11 | 12 | crop_size = 32 13 | padding = 4 14 | 15 | 16 | def prepare_train_data(dataset='cifar10', batch_size=128, 17 | shuffle=True, num_workers=4): 18 | 19 | if 'cifar' in dataset: 20 | transform_train = transforms.Compose([ 21 | transforms.RandomCrop(crop_size, padding=padding), 22 | transforms.RandomHorizontalFlip(), 23 | transforms.ToTensor(), 24 | transforms.Normalize((0.4914, 0.4822, 0.4465), 25 | (0.2023, 0.1994, 0.2010)), 26 | ]) 27 | 28 | trainset = torchvision.datasets.__dict__[dataset.upper()]( 29 | root='/tmp/data', train=True, download=True, transform=transform_train) 30 | train_loader = torch.utils.data.DataLoader(trainset, 31 | batch_size=batch_size, 32 | shuffle=shuffle, 33 | num_workers=num_workers) 34 | elif 'svhn' in dataset: 35 | transform_train =transforms.Compose([ 36 | transforms.ToTensor(), 37 | transforms.Normalize((0.4377, 0.4438, 0.4728), 38 | (0.1980, 0.2010, 0.1970)), 39 | ]) 40 | trainset = torchvision.datasets.__dict__[dataset.upper()]( 41 | root='/tmp/data', 42 | split='train', 43 | download=True, 44 | transform=transform_train 45 | ) 46 | 47 | transform_extra = transforms.Compose([ 48 | transforms.ToTensor(), 49 | transforms.Normalize((0.4300, 0.4284, 0.4427), 50 | (0.1963, 0.1979, 0.1995)) 51 | 52 | ]) 53 | 54 | extraset = torchvision.datasets.__dict__[dataset.upper()]( 55 | root='/tmp/data', 56 | split='extra', 57 | download=True, 58 | transform = transform_extra 59 | ) 60 | 61 | total_data = torch.utils.data.ConcatDataset([trainset, extraset]) 62 | 63 | train_loader = torch.utils.data.DataLoader(total_data, 64 | batch_size=batch_size, 65 | shuffle=shuffle, 66 | num_workers=num_workers) 67 | else: 68 | train_loader = None 69 | return train_loader 70 | 71 | 72 | def prepare_test_data(dataset='cifar10', batch_size=128, 73 | shuffle=False, num_workers=4): 74 | 75 | if 'cifar' in dataset: 76 | transform_test = transforms.Compose([ 77 | transforms.ToTensor(), 78 | transforms.Normalize((0.4914, 0.4822, 0.4465), 79 | (0.2023, 0.1994, 0.2010)), 80 | ]) 81 | 82 | testset = torchvision.datasets.__dict__[dataset.upper()](root='/tmp/data', 83 | train=False, 84 | download=True, 85 | transform=transform_test) 86 | test_loader = torch.utils.data.DataLoader(testset, 87 | batch_size=batch_size, 88 | shuffle=shuffle, 89 | num_workers=num_workers) 90 | elif 'svhn' in dataset: 91 | transform_test = transforms.Compose([ 92 | transforms.ToTensor(), 93 | transforms.Normalize((0.4524, 0.4525, 0.4690), 94 | (0.2194, 0.2266, 0.2285)), 95 | ]) 96 | testset = torchvision.datasets.__dict__[dataset.upper()]( 97 | root='/tmp/data', 98 | split='test', 99 | download=True, 100 | transform=transform_test) 101 | np.place(testset.labels, testset.labels == 10, 0) 102 | test_loader = torch.utils.data.DataLoader(testset, 103 | batch_size=batch_size, 104 | shuffle=shuffle, 105 | num_workers=num_workers) 106 | else: 107 | test_loader = None 108 | return test_loader 109 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | """ This file contains the model definitions for both original ResNet (6n+2 2 | layers) and SkipNets. 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import math 8 | from torch.autograd import Variable 9 | import torch.autograd as autograd 10 | from modules.quantize import quantize, quantize_grad, QConv2d, QLinear, RangeBN 11 | import torch.nn.functional as F 12 | 13 | 14 | NUM_BITS = 8 15 | NUM_BITS_WEIGHT = None 16 | NUM_BITS_GRAD = None 17 | BIPRECISION = False 18 | 19 | 20 | def Conv3x3(in_planes, out_planes, stride=1): 21 | "3x3 convolution with padding" 22 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 23 | padding=1, bias=False) 24 | 25 | 26 | def conv3x3(in_planes, out_planes, stride=1, pool_size = None): 27 | "3x3 convolution with padding" 28 | return QConv2d(in_planes, out_planes, kernel_size=3, stride=stride, 29 | padding=1, bias=False, num_bits=NUM_BITS, num_bits_weight=NUM_BITS_WEIGHT, num_bits_grad=NUM_BITS_GRAD, biprecision=BIPRECISION) 30 | 31 | class BasicBlock(nn.Module): 32 | expansion = 1 33 | 34 | def __init__(self, inplanes, planes, stride=1, downsample=None): 35 | super(BasicBlock, self).__init__() 36 | self.conv1 = conv3x3(inplanes, planes, stride) 37 | self.bn1 = nn.BatchNorm2d(planes) 38 | self.relu = nn.ReLU(inplace=True) 39 | self.conv2 = conv3x3(planes, planes) 40 | self.bn2 = nn.BatchNorm2d(planes) 41 | self.bn3 = nn.BatchNorm2d(planes) 42 | self.downsample = downsample 43 | self.stride = stride 44 | 45 | def forward(self, x, bit): 46 | residual = x 47 | 48 | out = self.conv1(x, bit) 49 | out = self.bn1(out) 50 | out = self.relu(out) 51 | 52 | out = self.conv2(out, bit) 53 | out = self.bn2(out) 54 | 55 | if self.downsample is not None: 56 | residual = self.downsample(x, bit) 57 | residual = self.bn3(residual) 58 | 59 | out += residual 60 | out = self.relu(out) 61 | return out 62 | 63 | 64 | ######################################## 65 | # Original ResNet # 66 | ######################################## 67 | 68 | 69 | class _ResNet(nn.Module): 70 | """Original ResNet without routing modules""" 71 | def __init__(self, block, layers, num_classes=10): 72 | self.inplanes = 16 73 | super(ResNet, self).__init__() 74 | self.conv1 = conv3x3(3, 16) 75 | self.bn1 = nn.BatchNorm2d(16) 76 | self.relu = nn.ReLU(inplace=True) 77 | self.layer1 = self._make_layer(block, 16, layers[0]) 78 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2) 79 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2) 80 | self.avgpool = nn.AvgPool2d(8) 81 | self.fc = nn.Linear(64 * block.expansion, num_classes) 82 | 83 | for m in self.modules(): 84 | if isinstance(m, nn.Conv2d): 85 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 86 | m.weight.data.normal_(0, math.sqrt(2. / n)) 87 | elif isinstance(m, nn.BatchNorm2d): 88 | m.weight.data.fill_(1) 89 | m.bias.data.zero_() 90 | 91 | def _make_layer(self, block, planes, blocks, stride=1): 92 | downsample = None 93 | if stride != 1 or self.inplanes != planes * block.expansion: 94 | downsample = QConv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False, num_bits=NUM_BITS, num_bits_weight=NUM_BITS_WEIGHT, num_bits_grad=NUM_BITS_GRAD, biprecision=BIPRECISION) 95 | 96 | layers = [] 97 | layers.append(block(self.inplanes, planes, stride, downsample)) 98 | self.inplanes = planes * block.expansion 99 | for i in range(1, blocks): 100 | layers.append(block(self.inplanes, planes)) 101 | 102 | return nn.Sequential(*layers) 103 | 104 | def forward(self, x): 105 | x = self.conv1(x, 0) 106 | x = self.bn1(x) 107 | x = self.relu(x) 108 | 109 | x = self.layer1(x) 110 | x = self.layer2(x) 111 | x = self.layer3(x) 112 | 113 | x = self.avgpool(x) 114 | x = x.view(x.size(0), -1) 115 | x = self.fc(x) 116 | return x 117 | 118 | ######################################## 119 | # Original ResNet # 120 | ######################################## 121 | 122 | class ResNet(nn.Module): 123 | """Original ResNet without routing modules""" 124 | def __init__(self, block, layers, num_classes=10): 125 | self.inplanes = 16 126 | super(ResNet, self).__init__() 127 | self.conv1 = conv3x3(3, 16) 128 | self.bn1 = nn.BatchNorm2d(16) 129 | self.relu = nn.ReLU(inplace=True) 130 | 131 | self.num_layers = layers 132 | 133 | self._make_group(block, 16, layers[0], group_id=1, 134 | ) 135 | self._make_group(block, 32, layers[1], group_id=2, 136 | ) 137 | self._make_group(block, 64, layers[2], group_id=3, 138 | ) 139 | 140 | # self.layer1 = self._make_layer(block, 16, layers[0]) 141 | # self.layer2 = self._make_layer(block, 32, layers[1], stride=2) 142 | # self.layer3 = self._make_layer(block, 64, layers[2], stride=2) 143 | self.avgpool = nn.AvgPool2d(8) 144 | self.fc = nn.Linear(64 * block.expansion, num_classes) 145 | 146 | for m in self.modules(): 147 | if isinstance(m, nn.Conv2d): 148 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 149 | m.weight.data.normal_(0, math.sqrt(2. / n)) 150 | elif isinstance(m, nn.BatchNorm2d): 151 | m.weight.data.fill_(1) 152 | m.bias.data.zero_() 153 | elif isinstance(m, nn.Linear): 154 | n = m.weight.size(0) * m.weight.size(1) 155 | m.weight.data.normal_(0, math.sqrt(2. / n)) 156 | 157 | # for m in self.modules(): 158 | # if isinstance(m, nn.Conv2d): 159 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 160 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 161 | # elif isinstance(m, nn.BatchNorm2d): 162 | # m.weight.data.fill_(1) 163 | # m.bias.data.zero_() 164 | 165 | def _make_group(self, block, planes, layers, group_id=1 166 | ): 167 | """ Create the whole group""" 168 | for i in range(layers): 169 | if group_id > 1 and i == 0: 170 | stride = 2 171 | else: 172 | stride = 1 173 | 174 | layer = self._make_layer_v2(block, planes, stride=stride, 175 | ) 176 | 177 | # setattr(self, 'group{}_ds{}'.format(group_id, i), meta[0]) 178 | setattr(self, 'group{}_layer{}'.format(group_id, i), layer) 179 | # setattr(self, 'group{}_gate{}'.format(group_id, i), meta[2]) 180 | 181 | 182 | def _make_layer_v2(self, block, planes, stride=1, 183 | ): 184 | """ create one block and optional a gate module """ 185 | downsample = None 186 | if stride != 1 or self.inplanes != planes * block.expansion: 187 | 188 | downsample = QConv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False, num_bits=NUM_BITS, num_bits_weight=NUM_BITS_WEIGHT, num_bits_grad=NUM_BITS_GRAD, biprecision=BIPRECISION) 189 | 190 | # downsample = nn.Sequential( 191 | # nn.Conv2d(self.inplanes, planes * block.expansion, 192 | # kernel_size=1, stride=stride, bias=False), 193 | # nn.BatchNorm2d(planes * block.expansion), 194 | 195 | # ) 196 | layer = block(self.inplanes, planes, stride, downsample) 197 | self.inplanes = planes * block.expansion 198 | 199 | # if gate_type == 'ffgate1': 200 | # gate_layer = FeedforwardGateI(pool_size=pool_size, 201 | # channel=planes*block.expansion) 202 | # elif gate_type == 'ffgate2': 203 | # gate_layer = FeedforwardGateII(pool_size=pool_size, 204 | # channel=planes*block.expansion) 205 | # elif gate_type == 'softgate1': 206 | # gate_layer = SoftGateI(pool_size=pool_size, 207 | # channel=planes*block.expansion) 208 | # elif gate_type == 'softgate2': 209 | # gate_layer = SoftGateII(pool_size=pool_size, 210 | # channel=planes*block.expansion) 211 | # else: 212 | # gate_layer = None 213 | 214 | # if downsample: 215 | # return downsample, layer, gate_layer 216 | # else: 217 | # return None, layer, gate_layer 218 | 219 | return layer 220 | 221 | def _make_layer(self, block, planes, blocks, stride=1): 222 | downsample = None 223 | if stride != 1 or self.inplanes != planes * block.expansion: 224 | downsample = QConv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False, num_bits=NUM_BITS, num_bits_weight=NUM_BITS_WEIGHT, num_bits_grad=NUM_BITS_GRAD, biprecision=BIPRECISION) 225 | 226 | layers = [] 227 | layers.append(block(self.inplanes, planes, stride, downsample)) 228 | self.inplanes = planes * block.expansion 229 | for i in range(1, blocks): 230 | layers.append(block(self.inplanes, planes)) 231 | 232 | return nn.Sequential(*layers) 233 | 234 | def forward(self, x, bit): 235 | x = self.conv1(x, 0) 236 | x = self.bn1(x) 237 | x = self.relu(x) 238 | 239 | for g in range(3): 240 | for i in range(self.num_layers[g]): 241 | x = getattr(self, 'group{}_layer{}'.format(g+1, i))(x, bit) 242 | 243 | # x = self.layer1(x) 244 | # x = self.layer2(x) 245 | # x = self.layer3(x) 246 | 247 | x = self.avgpool(x) 248 | x = x.view(x.size(0), -1) 249 | x = self.fc(x) 250 | return x 251 | 252 | 253 | 254 | # For CIFAR-10 255 | # ResNet-38 256 | 257 | def cifar10_resnet_20(pretrained=False, **kwargs): 258 | model = ResNet(BasicBlock, [3, 3, 3], **kwargs) 259 | return model 260 | 261 | def cifar10_resnet_31(pretrained=False, **kwargs): 262 | # n = 5 263 | model = ResNet(BasicBlock, [5, 5, 5], **kwargs) 264 | return model 265 | 266 | 267 | 268 | def cifar10_resnet_38(pretrained=False, **kwargs): 269 | # n = 6 270 | model = ResNet(BasicBlock, [6, 6, 6], **kwargs) 271 | return model 272 | 273 | 274 | # ResNet-74 275 | def cifar10_resnet_74(pretrained=False, **kwargs): 276 | # n = 12 277 | model = ResNet(BasicBlock, [12, 12, 12], **kwargs) 278 | return model 279 | 280 | 281 | # ResNet-110 282 | def cifar10_resnet_110(pretrained=False, **kwargs): 283 | # n = 18 284 | model = ResNet(BasicBlock, [18, 18, 18], **kwargs) 285 | return model 286 | 287 | 288 | # ResNet-152 289 | def cifar10_resnet_152(pretrained=False, **kwargs): 290 | # n = 25 291 | model = ResNet(BasicBlock, [25, 25, 25], **kwargs) 292 | return model 293 | 294 | 295 | # For CIFAR-100 296 | # ResNet-38 297 | def cifar100_resnet_38(pretrained=False, **kwargs): 298 | # n = 6 299 | model = ResNet(BasicBlock, [6, 6, 6], num_classes=100) 300 | return model 301 | 302 | 303 | # ResNet-74 304 | def cifar100_resnet_74(pretrained=False, **kwargs): 305 | # n = 12 306 | model = ResNet(BasicBlock, [12, 12, 12], num_classes=100) 307 | return model 308 | 309 | 310 | # ResNet-110 311 | def cifar100_resnet_110(pretrained=False, **kwargs): 312 | # n = 18 313 | model = ResNet(BasicBlock, [18, 18, 18], num_classes=100) 314 | return model 315 | 316 | 317 | # ResNet-152 318 | def cifar100_resnet_152(pretrained=False, **kwargs): 319 | # n = 25 320 | model = ResNet(BasicBlock, [25, 25, 25], num_classes=100) 321 | return model 322 | 323 | 324 | ######################################## 325 | # SkipNet+SP with Feedforward Gate # 326 | ######################################## 327 | 328 | 329 | # Feedforward-Gate (FFGate-I) 330 | class FeedforwardGateI(nn.Module): 331 | """ Use Max Pooling First and then apply to multiple 2 conv layers. 332 | The first conv has stride = 1 and second has stride = 2""" 333 | def __init__(self, pool_size=5, channel=10): 334 | super(FeedforwardGateI, self).__init__() 335 | self.pool_size = pool_size 336 | self.channel = channel 337 | 338 | self.maxpool = nn.MaxPool2d(2) 339 | self.conv1 = conv3x3(channel, channel) 340 | self.bn1 = nn.BatchNorm2d(channel) 341 | self.relu1 = nn.ReLU(inplace=True) 342 | 343 | # adding another conv layer 344 | self.conv2 = conv3x3(channel, channel, stride=2) 345 | self.bn2 = nn.BatchNorm2d(channel) 346 | self.relu2 = nn.ReLU(inplace=True) 347 | 348 | pool_size = math.floor(pool_size/2) # for max pooling 349 | pool_size = math.floor(pool_size/2 + 0.5) # for conv stride = 2 350 | 351 | self.avg_layer = nn.AvgPool2d(pool_size) 352 | self.linear_layer = nn.Conv2d(in_channels=channel, out_channels=2, 353 | kernel_size=1, stride=1) 354 | self.prob_layer = nn.Softmax() 355 | self.logprob = nn.LogSoftmax() 356 | 357 | def forward(self, x): 358 | x = self.maxpool(x) 359 | x = self.conv1(x) 360 | x = self.bn1(x) 361 | x = self.relu1(x) 362 | 363 | x = self.conv2(x) 364 | x = self.bn2(x) 365 | x = self.relu2(x) 366 | 367 | x = self.avg_layer(x) 368 | x = self.linear_layer(x).squeeze() 369 | softmax = self.prob_layer(x) 370 | logprob = self.logprob(x) 371 | 372 | # discretize output in forward pass. 373 | # use softmax gradients in backward pass 374 | x = (softmax[:, 1] > 0.5).float().detach() - \ 375 | softmax[:, 1].detach() + softmax[:, 1] 376 | 377 | x = x.view(x.size(0), 1, 1, 1) 378 | return x, logprob 379 | 380 | 381 | # soft gate v3 (matching FFGate-I) 382 | class SoftGateI(nn.Module): 383 | """This module has the same structure as FFGate-I. 384 | In training, adopt continuous gate output. In inference phase, 385 | use discrete gate outputs""" 386 | def __init__(self, pool_size=5, channel=10): 387 | super(SoftGateI, self).__init__() 388 | self.pool_size = pool_size 389 | self.channel = channel 390 | 391 | self.maxpool = nn.MaxPool2d(2) 392 | self.conv1 = conv3x3(channel, channel) 393 | self.bn1 = nn.BatchNorm2d(channel) 394 | self.relu1 = nn.ReLU(inplace=True) 395 | 396 | # adding another conv layer 397 | self.conv2 = conv3x3(channel, channel, stride=2) 398 | self.bn2 = nn.BatchNorm2d(channel) 399 | self.relu2 = nn.ReLU(inplace=True) 400 | 401 | pool_size = math.floor(pool_size/2) # for max pooling 402 | pool_size = math.floor(pool_size/2 + 0.5) # for conv stride = 2 403 | 404 | self.avg_layer = nn.AvgPool2d(pool_size) 405 | self.linear_layer = nn.Conv2d(in_channels=channel, out_channels=2, 406 | kernel_size=1, stride=1) 407 | self.prob_layer = nn.Softmax() 408 | self.logprob = nn.LogSoftmax() 409 | 410 | def forward(self, x): 411 | x = self.maxpool(x) 412 | x = self.conv1(x) 413 | x = self.bn1(x) 414 | x = self.relu1(x) 415 | 416 | x = self.conv2(x) 417 | x = self.bn2(x) 418 | x = self.relu2(x) 419 | 420 | x = self.avg_layer(x) 421 | x = self.linear_layer(x).squeeze() 422 | softmax = self.prob_layer(x) 423 | logprob = self.logprob(x) 424 | 425 | x = softmax[:, 1].contiguous() 426 | x = x.view(x.size(0), 1, 1, 1) 427 | 428 | if not self.training: 429 | x = (x > 0.5).float() 430 | return x, logprob 431 | 432 | 433 | # FFGate-II 434 | class FeedforwardGateII(nn.Module): 435 | """ use single conv (stride=2) layer only""" 436 | def __init__(self, pool_size=5, channel=10): 437 | super(FeedforwardGateII, self).__init__() 438 | self.pool_size = pool_size 439 | self.channel = channel 440 | 441 | self.conv1 = conv3x3(channel, channel, stride=2) 442 | self.bn1 = nn.BatchNorm2d(channel) 443 | self.relu1 = nn.ReLU(inplace=True) 444 | 445 | pool_size = math.floor(pool_size/2 + 0.5) # for conv stride = 2 446 | 447 | self.avg_layer = nn.AvgPool2d(pool_size) 448 | self.linear_layer = nn.Conv2d(in_channels=channel, out_channels=2, 449 | kernel_size=1, stride=1) 450 | self.prob_layer = nn.Softmax() 451 | self.logprob = nn.LogSoftmax() 452 | 453 | def forward(self, x): 454 | x = self.conv1(x) 455 | x = self.bn1(x) 456 | x = self.relu1(x) 457 | 458 | x = self.avg_layer(x) 459 | x = self.linear_layer(x).squeeze() 460 | softmax = self.prob_layer(x) 461 | logprob = self.logprob(x) 462 | 463 | # discretize 464 | x = (softmax[:, 1] > 0.5).float().detach() - \ 465 | softmax[:, 1].detach() + softmax[:, 1] 466 | 467 | x = x.view(x.size(0), 1, 1, 1) 468 | return x, logprob 469 | 470 | 471 | class SoftGateII(nn.Module): 472 | """ Soft gating version of FFGate-II""" 473 | def __init__(self, pool_size=5, channel=10): 474 | super(SoftGateII, self).__init__() 475 | self.pool_size = pool_size 476 | self.channel = channel 477 | 478 | self.conv1 = conv3x3(channel, channel, stride=2) 479 | self.bn1 = nn.BatchNorm2d(channel) 480 | self.relu1 = nn.ReLU(inplace=True) 481 | 482 | pool_size = math.floor(pool_size / 2 + 0.5) # for conv stride = 2 483 | 484 | self.avg_layer = nn.AvgPool2d(pool_size) 485 | self.linear_layer = nn.Conv2d(in_channels=channel, out_channels=2, 486 | kernel_size=1, stride=1) 487 | self.prob_layer = nn.Softmax() 488 | self.logprob = nn.LogSoftmax() 489 | 490 | def forward(self, x): 491 | x = self.conv1(x) 492 | x = self.bn1(x) 493 | x = self.relu1(x) 494 | 495 | x = self.avg_layer(x) 496 | x = self.linear_layer(x).squeeze() 497 | softmax = self.prob_layer(x) 498 | logprob = self.logprob(x) 499 | 500 | x = softmax[:, 1].contiguous() 501 | x = x.view(x.size(0), 1, 1, 1) 502 | if not self.training: 503 | x = (x > 0.5).float() 504 | return x, logprob 505 | 506 | 507 | class ResNetFeedForwardSP(nn.Module): 508 | """ SkipNets with Feed-forward Gates for Supervised Pre-training stage. 509 | Adding one routing module after each basic block.""" 510 | 511 | def __init__(self, block, layers, num_classes=10, 512 | gate_type='fisher', **kwargs): 513 | self.inplanes = 16 514 | super(ResNetFeedForwardSP, self).__init__() 515 | 516 | self.num_layers = layers 517 | self.conv1 = conv3x3(3, 16) 518 | self.bn1 = nn.BatchNorm2d(16) 519 | self.relu = nn.ReLU(inplace=True) 520 | 521 | # going to have 3 groups of layers. For the easiness of skipping, 522 | # We are going to break the sequential of layers into a list of layers. 523 | 524 | self.gate_type = gate_type 525 | self._make_group(block, 16, layers[0], group_id=1, 526 | gate_type=gate_type, pool_size=32) 527 | self._make_group(block, 32, layers[1], group_id=2, 528 | gate_type=gate_type, pool_size=16) 529 | self._make_group(block, 64, layers[2], group_id=3, 530 | gate_type=gate_type, pool_size=8) 531 | 532 | self.avgpool = nn.AvgPool2d(8) 533 | self.fc = nn.Linear(64 * block.expansion, num_classes) 534 | 535 | for m in self.modules(): 536 | if isinstance(m, nn.Conv2d): 537 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 538 | m.weight.data.normal_(0, math.sqrt(2. / n)) 539 | elif isinstance(m, nn.BatchNorm2d): 540 | m.weight.data.fill_(1) 541 | m.bias.data.zero_() 542 | elif isinstance(m, nn.Linear): 543 | n = m.weight.size(0) * m.weight.size(1) 544 | m.weight.data.normal_(0, math.sqrt(2. / n)) 545 | 546 | def _make_group(self, block, planes, layers, group_id=1, 547 | gate_type='fisher', pool_size=16): 548 | """ Create the whole group""" 549 | for i in range(layers): 550 | if group_id > 1 and i == 0: 551 | stride = 2 552 | else: 553 | stride = 1 554 | 555 | meta = self._make_layer_v2(block, planes, stride=stride, 556 | gate_type=gate_type, 557 | pool_size=pool_size) 558 | 559 | setattr(self, 'group{}_ds{}'.format(group_id, i), meta[0]) 560 | setattr(self, 'group{}_layer{}'.format(group_id, i), meta[1]) 561 | setattr(self, 'group{}_gate{}'.format(group_id, i), meta[2]) 562 | 563 | def _make_layer_v2(self, block, planes, stride=1, 564 | gate_type='fisher', pool_size=16): 565 | """ create one block and optional a gate module """ 566 | downsample = None 567 | if stride != 1 or self.inplanes != planes * block.expansion: 568 | downsample = nn.Sequential( 569 | nn.Conv2d(self.inplanes, planes * block.expansion, 570 | kernel_size=1, stride=stride, bias=False), 571 | nn.BatchNorm2d(planes * block.expansion), 572 | 573 | ) 574 | layer = block(self.inplanes, planes, stride, downsample) 575 | self.inplanes = planes * block.expansion 576 | 577 | if gate_type == 'ffgate1': 578 | gate_layer = FeedforwardGateI(pool_size=pool_size, 579 | channel=planes*block.expansion) 580 | elif gate_type == 'ffgate2': 581 | gate_layer = FeedforwardGateII(pool_size=pool_size, 582 | channel=planes*block.expansion) 583 | elif gate_type == 'softgate1': 584 | gate_layer = SoftGateI(pool_size=pool_size, 585 | channel=planes*block.expansion) 586 | elif gate_type == 'softgate2': 587 | gate_layer = SoftGateII(pool_size=pool_size, 588 | channel=planes*block.expansion) 589 | else: 590 | gate_layer = None 591 | 592 | if downsample: 593 | return downsample, layer, gate_layer 594 | else: 595 | return None, layer, gate_layer 596 | 597 | def forward(self, x): 598 | """Return output logits, masks(gate ouputs) and probabilities 599 | associated to each gate.""" 600 | 601 | x = self.conv1(x) 602 | x = self.bn1(x) 603 | x = self.relu(x) 604 | 605 | masks = [] 606 | gprobs = [] 607 | # must pass through the first layer in first group 608 | x = getattr(self, 'group1_layer0')(x) 609 | # gate takes the output of the current layer 610 | 611 | mask, gprob = getattr(self, 'group1_gate0')(x) 612 | gprobs.append(gprob) 613 | masks.append(mask.squeeze()) 614 | prev = x # input of next layer 615 | 616 | for g in range(3): 617 | for i in range(0 + int(g == 0), self.num_layers[g]): 618 | if getattr(self, 'group{}_ds{}'.format(g+1, i)) is not None: 619 | prev = getattr(self, 'group{}_ds{}'.format(g+1, i))(prev) 620 | x = getattr(self, 'group{}_layer{}'.format(g+1, i))(x) 621 | prev = x = mask.expand_as(x) * x \ 622 | + (1 - mask).expand_as(prev) * prev 623 | mask, gprob = getattr(self, 'group{}_gate{}'.format(g+1, i))(x) 624 | gprobs.append(gprob) 625 | masks.append(mask.squeeze()) 626 | 627 | del masks[-1] 628 | 629 | x = self.avgpool(x) 630 | x = x.view(x.size(0), -1) 631 | x = self.fc(x) 632 | 633 | return x, masks, gprobs 634 | 635 | 636 | # FeeforwardGate-I 637 | # For CIFAR-10 638 | def cifar10_feedforward_38(pretrained=False, **kwargs): 639 | """SkipNet-38 with FFGate-I""" 640 | model = ResNetFeedForwardSP(BasicBlock, [6, 6, 6], gate_type='ffgate1') 641 | return model 642 | 643 | 644 | def cifar10_feedforward_74(pretrained=False, **kwargs): 645 | """SkipNet-74 with FFGate-I""" 646 | model = ResNetFeedForwardSP(BasicBlock, [12, 12, 12], gate_type='ffgate1') 647 | return model 648 | 649 | 650 | def cifar10_feedforward_110(pretrained=False, **kwargs): 651 | """SkipNet-110 with FFGate-II""" 652 | model = ResNetFeedForwardSP(BasicBlock, [18, 18, 18], gate_type='ffgate2') 653 | return model 654 | 655 | 656 | # For CIFAR-100 657 | def cifar100_feeforward_38(pretrained=False, **kwargs): 658 | """SkipNet-38 with FFGate-I""" 659 | model = ResNetFeedForwardSP(BasicBlock, [6, 6, 6], num_classes=100, 660 | gate_type='ffgate1') 661 | return model 662 | 663 | 664 | def cifar100_feedforward_74(pretrained=False, **kwargs): 665 | """SkipNet-74 with FFGate-I""" 666 | model = ResNetFeedForwardSP(BasicBlock, [12, 12, 12], num_classes=100, 667 | gate_type='ffgate1') 668 | return model 669 | 670 | 671 | def cifar100_feedforward_110(pretrained=False, **kwargs): 672 | """SkipNet-110 with FFGate-II""" 673 | model = ResNetFeedForwardSP(BasicBlock, [18, 18, 18], num_classes=100, 674 | gate_type='ffgate2') 675 | return model 676 | 677 | 678 | ######################################## 679 | # SkipNet+SP with Recurrent Gate # 680 | ######################################## 681 | 682 | 683 | # For Recurrent Gate 684 | def repackage_hidden(h): 685 | """ to reduce memory usage""" 686 | if type(h) == Variable: 687 | return Variable(h.data) 688 | else: 689 | return tuple(repackage_hidden(v) for v in h) 690 | 691 | 692 | class RNNGate(nn.Module): 693 | """Recurrent Gate definition. 694 | Input is already passed through average pooling and embedding.""" 695 | def __init__(self, input_dim, hidden_dim, rnn_type='lstm'): 696 | super(RNNGate, self).__init__() 697 | self.rnn_type = rnn_type 698 | self.input_dim = input_dim 699 | self.hidden_dim = hidden_dim 700 | 701 | if self.rnn_type == 'lstm': 702 | self.rnn_one = nn.LSTM(input_dim, hidden_dim) 703 | # self.rnn_two = nn.LSTM(hidden_dim, hidden_dim) 704 | else: 705 | self.rnn = None 706 | self.hidden_one = None 707 | # self.hidden_two = None 708 | 709 | # reduce dim 710 | self.proj = nn.Linear(hidden_dim, 7) 711 | # self.proj_two = nn.Linear(hidden_dim, 4) 712 | self.prob = nn.Sigmoid() 713 | self.prob_layer = nn.Softmax() 714 | 715 | def init_hidden(self, batch_size): 716 | # The axes semantics are (num_layers, minibatch_size, hidden_dim) 717 | return (autograd.Variable(torch.zeros(1, batch_size, 718 | self.hidden_dim).cuda()), 719 | autograd.Variable(torch.zeros(1, batch_size, 720 | self.hidden_dim).cuda())) 721 | 722 | def repackage_hidden(self): 723 | self.hidden_one = repackage_hidden(self.hidden_one) 724 | # self.hidden_two = repackage_hidden(self.hidden_two) 725 | def forward(self, x): 726 | # Take the convolution output of each step 727 | batch_size = x.size(0) 728 | self.rnn_one.flatten_parameters() 729 | # self.rnn_two.flatten_parameters() 730 | 731 | out_one, self.hidden_one = self.rnn_one(x.view(1, batch_size, -1), self.hidden_one) 732 | 733 | # out_one = F.dropout(out_one, p = 0.1, training=True) 734 | 735 | # out_two, self.hidden_two = self.rnn_two(out_one.view(1, batch_size, -1), self.hidden_two) 736 | 737 | x_one = self.proj(out_one.squeeze()) 738 | # x_two = self.proj_two(out_two.squeeze()) 739 | 740 | # proj = self.proj(out.squeeze()) 741 | prob = self.prob_layer(x_one) 742 | # prob_two = self.prob_layer(x_two) 743 | 744 | # x_one = (prob > 0.5).float().detach() - \ 745 | # prob.detach() + prob 746 | 747 | # x_two = prob_two.detach().cpu().numpy() 748 | 749 | x_one = prob.detach().cpu().numpy() 750 | 751 | hard = (x_one == x_one.max(axis=1)[:,None]).astype(int) 752 | hard = torch.from_numpy(hard) 753 | hard = hard.cuda() 754 | 755 | # x_two = hard.float().detach() - \ 756 | # prob_two.detach() + prob_two 757 | 758 | x_one = hard.float().detach() - \ 759 | prob.detach() + prob 760 | 761 | # print(x_one) 762 | 763 | x_one = x_one.view(x_one.size(0),x_one.size(1), 1, 1, 1) 764 | 765 | # x_two = x_two.view(x_two.size(0), x_two.size(1), 1, 1, 1) 766 | 767 | return x_one # , x_two 768 | 769 | 770 | class SoftRNNGate(nn.Module): 771 | def __init__(self, input_dim, hidden_dim, rnn_type='lstm'): 772 | super(SoftRNNGate, self).__init__() 773 | self.rnn_type = rnn_type 774 | self.input_dim = input_dim 775 | self.hidden_dim = hidden_dim 776 | 777 | if self.rnn_type == 'lstm': 778 | self.rnn = nn.LSTM(input_dim, hidden_dim) 779 | else: 780 | self.rnn = None 781 | self.hidden = None 782 | 783 | # reduce dim 784 | self.proj = nn.Linear(hidden_dim, 1) 785 | self.prob = nn.Sigmoid() 786 | 787 | def init_hidden(self, batch_size): 788 | return (autograd.Variable(torch.zeros(1, batch_size, 789 | self.hidden_dim).cuda()), 790 | autograd.Variable(torch.zeros(1, batch_size, 791 | self.hidden_dim).cuda())) 792 | 793 | def repackage_hidden(self): 794 | self.hidden = repackage_hidden(self.hidden) 795 | 796 | def forward(self, x): 797 | # Take the convolution output of each step 798 | batch_size = x.size(0) 799 | self.rnn.flatten_parameters() 800 | out, self.hidden = self.rnn(x.view(1, batch_size, -1), self.hidden) 801 | 802 | proj = self.proj(out.squeeze()) 803 | prob = self.prob(proj) 804 | 805 | x = prob.view(batch_size, 1, 1, 1) 806 | if not self.training: 807 | x = (x > 0.5).float() 808 | return x, prob 809 | 810 | 811 | class ResNetRecurrentGateSP(nn.Module): 812 | """SkipNet with Recurrent Gate Model""" 813 | def __init__(self, block, layers, num_classes=10, embed_dim=10, 814 | hidden_dim=10, gate_type='rnn'): 815 | self.inplanes = 16 816 | super(ResNetRecurrentGateSP, self).__init__() 817 | 818 | self.num_layers = layers 819 | self.conv1 = conv3x3(3, 16) 820 | self.bn1 = nn.BatchNorm2d(16) 821 | self.relu = nn.ReLU(inplace=True) 822 | 823 | self.embed_dim = embed_dim 824 | self.hidden_dim = hidden_dim 825 | 826 | self._make_group(block, 16, layers[0], group_id=1, pool_size=32) 827 | self._make_group(block, 32, layers[1], group_id=2, pool_size=16) 828 | self._make_group(block, 64, layers[2], group_id=3, pool_size=8) 829 | 830 | # define recurrent gating module 831 | if gate_type == 'rnn': 832 | self.control = RNNGate(embed_dim, hidden_dim, rnn_type='lstm') 833 | elif gate_type == 'soft': 834 | self.control = SoftRNNGate(embed_dim, hidden_dim, rnn_type='lstm') 835 | else: 836 | print('gate type {} not implemented'.format(gate_type)) 837 | self.control = None 838 | 839 | self.avgpool = nn.AvgPool2d(8) 840 | self.fc = nn.Linear(64 * block.expansion, num_classes) 841 | 842 | for m in self.modules(): 843 | if isinstance(m, nn.Conv2d): 844 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 845 | m.weight.data.normal_(0, math.sqrt(2. / n)) 846 | elif isinstance(m, nn.BatchNorm2d): 847 | m.weight.data.fill_(1) 848 | m.bias.data.zero_() 849 | elif isinstance(m, nn.Linear): 850 | n = m.weight.size(0) * m.weight.size(1) 851 | m.weight.data.normal_(0, math.sqrt(2. / n)) 852 | 853 | def _make_group(self, block, planes, layers, group_id=1, pool_size=16): 854 | """ Create the whole group""" 855 | for i in range(layers): 856 | if group_id > 1 and i == 0: 857 | stride = 2 858 | else: 859 | stride = 1 860 | 861 | meta = self._make_layer_v2(block, planes, stride=stride, 862 | pool_size=pool_size) 863 | 864 | setattr(self, 'group{}_ds{}'.format(group_id, i), meta[0]) 865 | setattr(self, 'group{}_layer{}'.format(group_id, i), meta[1]) 866 | setattr(self, 'group{}_gate{}'.format(group_id, i), meta[2]) 867 | setattr(self, 'group{}_bn{}'.format(group_id, i), meta[3]) 868 | 869 | def _make_layer_v2(self, block, planes, stride=1, pool_size=16): 870 | """ create one block and optional a gate module """ 871 | downsample = None 872 | if stride != 1 or self.inplanes != planes * block.expansion: 873 | # downsample = nn.Sequential( 874 | # nn.Conv2d(self.inplanes, planes * block.expansion, 875 | # kernel_size=1, stride=stride, bias=False), 876 | # nn.BatchNorm2d(planes * block.expansion), 877 | 878 | # ) 879 | 880 | downsample = QConv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False, num_bits=NUM_BITS, num_bits_weight=NUM_BITS_WEIGHT, num_bits_grad=NUM_BITS_GRAD, biprecision=BIPRECISION) 881 | 882 | 883 | layer = block(self.inplanes, planes, stride, downsample) 884 | self.inplanes = planes * block.expansion 885 | 886 | 887 | bn = layer.bn3 888 | 889 | gate_layer = nn.Sequential( 890 | nn.AvgPool2d(pool_size), 891 | nn.Conv2d(in_channels=planes * block.expansion, 892 | out_channels=self.embed_dim, 893 | kernel_size=1, 894 | stride=1)) 895 | if downsample: 896 | return downsample, layer, gate_layer, bn 897 | else: 898 | return None, layer, gate_layer, None 899 | 900 | def forward(self, x, bits): 901 | 902 | batch_size = x.size(0) 903 | x = self.conv1(x, 0) 904 | x = self.bn1(x) 905 | x = self.relu(x) 906 | 907 | # reinitialize hidden units 908 | self.control.hidden_one = self.control.init_hidden(batch_size) 909 | # self.control.hidden_two = self.control.init_hidden(batch_size) 910 | 911 | masks = [] 912 | # gprobs = [] 913 | # must pass through the first layer in first group 914 | x = getattr(self, 'group1_layer0')(x, 0) 915 | # gate takes the output of the current layer 916 | 917 | gate_feature = getattr(self, 'group1_gate0')(x) 918 | mask_one = self.control(gate_feature) 919 | 920 | # bits = [4,8,16,0] 921 | # bits = [8,16,0] 922 | 923 | # gprobs.append(gprob) 924 | # masks.append(mask.squeeze()) 925 | # prev = x # input of next layer 926 | 927 | prev = x 928 | 929 | for g in range(3): 930 | for i in range(0 + int(g == 0), self.num_layers[g]): 931 | if getattr(self, 'group{}_ds{}'.format(g+1, i)) is not None: 932 | prev = getattr(self, 'group{}_ds{}'.format(g+1, i))(prev, 0) 933 | prev = getattr(self, 'group{}_bn{}'.format(g+1, i))(prev) 934 | 935 | output_candidates = [] 936 | 937 | output_candidates.append(prev) 938 | 939 | # output_candidates.append(prev) 940 | 941 | for k in range(len(bits)): 942 | out = getattr(self, 'group{}_layer{}'.format(g+1, i))(x, bits[k]) 943 | output_candidates.append(out) 944 | 945 | # x = getattr(self, 'group{}_layer{}'.format(g+1, i))(x, 0) 946 | 947 | mask_list = [] 948 | 949 | # mask_list.append(1 - mask_one) 950 | 951 | # for j in range(4): 952 | # mask_list.append(mask_one * mask_two[:,j,:,:,:]) 953 | 954 | for j in range(len(bits) + 1): 955 | mask_list.append(mask_one[:,j,:,:,:]) 956 | 957 | # prev = x = sum([mask_list[k].expand_as(out) * output_candidates[k] for k in range(5)]) 958 | 959 | prev = x = sum([mask_list[k].expand_as(out) * output_candidates[k] for k in range(len(bits) + 1)]) 960 | 961 | mask_list = [mask.squeeze() for mask in mask_list] 962 | 963 | masks.append(mask_list) 964 | 965 | # new mask is taking the current output 966 | # prev = x = mask.expand_as(x) * x \ 967 | # + (1 - mask).expand_as(prev) * prev 968 | gate_feature = getattr(self, 'group{}_gate{}'.format(g+1, i))(x) 969 | mask_one = self.control(gate_feature) 970 | # gprobs.append(gprob) 971 | 972 | 973 | 974 | 975 | # masks.append(mask.squeeze()) 976 | 977 | # last block doesn't have gate module 978 | # del masks[-1] 979 | 980 | x = self.avgpool(x) 981 | x = x.view(x.size(0), -1) 982 | x = self.fc(x) 983 | 984 | return x, masks 985 | 986 | 987 | # For CIFAR-10 988 | 989 | 990 | def cifar10_rnn_gate_20(pretrained=False, **kwargs): 991 | model = ResNetRecurrentGateSP(BasickBlock, [3, 3, 3], num_classes=10, 992 | embed_dim=10, hidden_dim = 10) 993 | return model 994 | 995 | def cifar10_rnn_gate_31(pretrained=False, **kwargs): 996 | model = ResNetRecurrentGateSP(BasicBlock, [5, 5, 5], num_classes=10, 997 | embed_dim=10, hidden_dim=10) 998 | return model 999 | 1000 | 1001 | 1002 | 1003 | def cifar10_rnn_gate_38(pretrained=False, **kwargs): 1004 | """SkipNet-38 with Recurrent Gate""" 1005 | model = ResNetRecurrentGateSP(BasicBlock, [6, 6, 6], num_classes=10, 1006 | embed_dim=10, hidden_dim=10) 1007 | return model 1008 | 1009 | 1010 | def cifar10_rnn_gate_74(pretrained=False, **kwargs): 1011 | """SkipNet-74 with Recurrent Gate""" 1012 | model = ResNetRecurrentGateSP(BasicBlock, [12, 12, 12], num_classes=10, 1013 | embed_dim=10, hidden_dim=10) 1014 | return model 1015 | 1016 | 1017 | def cifar10_rnn_gate_110(pretrained=False, **kwargs): 1018 | """SkipNet-110 with Recurrent Gate""" 1019 | model = ResNetRecurrentGateSP(BasicBlock, [18, 18, 18], num_classes=10, 1020 | embed_dim=10, hidden_dim=10) 1021 | return model 1022 | 1023 | 1024 | def cifar10_rnn_gate_152(pretrained=False, **kwargs): 1025 | """SkipNet-152 with Recurrent Gate""" 1026 | model = ResNetRecurrentGateSP(BasicBlock, [25, 25, 25], num_classes=10, 1027 | embed_dim=10, hidden_dim=10) 1028 | return model 1029 | 1030 | 1031 | # For CIFAR-100 1032 | def cifar100_rnn_gate_38(pretrained=False, **kwargs): 1033 | """SkipNet-38 with Recurrent Gate""" 1034 | model = ResNetRecurrentGateSP(BasicBlock, [6, 6, 6], num_classes=100, 1035 | embed_dim=10, hidden_dim=10) 1036 | return model 1037 | 1038 | 1039 | def cifar100_rnn_gate_74(pretrained=False, **kwargs): 1040 | """SkipNet-74 with Recurrent Gate""" 1041 | model = ResNetRecurrentGateSP(BasicBlock, [12, 12, 12], num_classes=100, 1042 | embed_dim=10, hidden_dim=10) 1043 | return model 1044 | 1045 | 1046 | def cifar100_rnn_gate_110(pretrained=False, **kwargs): 1047 | """SkipNet-110 with Recurrent Gate """ 1048 | model = ResNetRecurrentGateSP(BasicBlock, [18, 18, 18], num_classes=100, 1049 | embed_dim=10, hidden_dim=10) 1050 | return model 1051 | 1052 | 1053 | def cifar100_rnn_gate_152(pretrained=False, **kwargs): 1054 | """SkipNet-152 with Recurrent Gate""" 1055 | model = ResNetRecurrentGateSP(BasicBlock, [25, 25, 25], num_classes=100, 1056 | embed_dim=10, hidden_dim=10) 1057 | return model 1058 | 1059 | 1060 | ######################################## 1061 | # SkipNet+RL with Feedforward Gate # 1062 | ######################################## 1063 | 1064 | class RLFeedforwardGateI(nn.Module): 1065 | """ FFGate-I with sampling. Use Pytorch 2.0""" 1066 | def __init__(self, pool_size=5, channel=10): 1067 | super(RLFeedforwardGateI, self).__init__() 1068 | self.pool_size = pool_size 1069 | self.channel = channel 1070 | 1071 | self.maxpool = nn.MaxPool2d(2) 1072 | self.conv1 = conv3x3(channel, channel) 1073 | self.bn1 = nn.BatchNorm2d(channel) 1074 | self.relu1 = nn.ReLU(inplace=True) 1075 | 1076 | # adding another conv layer 1077 | self.conv2 = conv3x3(channel, channel, stride=2) 1078 | self.bn2 = nn.BatchNorm2d(channel) 1079 | self.relu2 = nn.ReLU(inplace=True) 1080 | 1081 | pool_size = math.floor(pool_size/2) # for max pooling 1082 | pool_size = math.floor(pool_size/2 + 0.5) # for conv stride = 2 1083 | 1084 | self.avg_layer = nn.AvgPool2d(pool_size) 1085 | self.linear_layer = nn.Conv2d(in_channels=channel, out_channels=2, 1086 | kernel_size=1, stride=1) 1087 | self.prob_layer = nn.Softmax() 1088 | 1089 | # saved actions and rewards 1090 | self.saved_action = [] 1091 | self.rewards = [] 1092 | 1093 | def forward(self, x): 1094 | x = self.maxpool(x) 1095 | x = self.conv1(x) 1096 | x = self.bn1(x) 1097 | x = self.relu1(x) 1098 | 1099 | x = self.conv2(x) 1100 | x = self.bn2(x) 1101 | x = self.relu2(x) 1102 | 1103 | x = self.avg_layer(x) 1104 | x = self.linear_layer(x).squeeze() 1105 | softmax = self.prob_layer(x) 1106 | 1107 | if self.training: 1108 | action = softmax.multinomial() 1109 | self.saved_action = action 1110 | else: 1111 | action = (softmax[:, 1] > 0.5).float() 1112 | self.saved_action = action 1113 | 1114 | action = action.view(action.size(0), 1, 1, 1).float() 1115 | return action, softmax 1116 | 1117 | 1118 | class RLFeedforwardGateII(nn.Module): 1119 | def __init__(self, pool_size=5, channel=10): 1120 | super(RLFeedforwardGateII, self).__init__() 1121 | self.pool_size = pool_size 1122 | self.channel = channel 1123 | 1124 | self.conv1 = conv3x3(channel, channel, stride=2) 1125 | self.bn1 = nn.BatchNorm2d(channel) 1126 | self.relu1 = nn.ReLU(inplace=True) 1127 | 1128 | pool_size = math.floor(pool_size/2 + 0.5) # for conv stride = 2 1129 | 1130 | self.avg_layer = nn.AvgPool2d(pool_size) 1131 | self.linear_layer = nn.Conv2d(in_channels=channel, out_channels=2, 1132 | kernel_size=1, stride=1) 1133 | self.prob_layer = nn.Softmax() 1134 | 1135 | # saved actions and rewards 1136 | self.saved_action = None 1137 | self.rewards = [] 1138 | 1139 | def forward(self, x): 1140 | x = self.conv1(x) 1141 | x = self.bn1(x) 1142 | x = self.relu1(x) 1143 | 1144 | x = self.avg_layer(x) 1145 | x = self.linear_layer(x).squeeze() 1146 | softmax = self.prob_layer(x) 1147 | 1148 | if self.training: 1149 | action = softmax.multinomial() 1150 | self.saved_action = action 1151 | else: 1152 | action = (softmax[:, 1] > 0.5).float() 1153 | self.saved_action = action 1154 | 1155 | action = action.view(action.size(0), 1, 1, 1).float() 1156 | return action, softmax 1157 | 1158 | 1159 | class ResNetFeedForwardRL(nn.Module): 1160 | """Adding gating module on every basic block""" 1161 | 1162 | def __init__(self, block, layers, num_classes=10, 1163 | gate_type='ffgate1', **kwargs): 1164 | self.inplanes = 16 1165 | super(ResNetFeedForwardRL, self).__init__() 1166 | 1167 | self.num_layers = layers 1168 | self.conv1 = conv3x3(3, 16) 1169 | self.bn1 = nn.BatchNorm2d(16) 1170 | self.relu = nn.ReLU(inplace=True) 1171 | 1172 | self.gate_instances = [] 1173 | self.gate_type = gate_type 1174 | self._make_group(block, 16, layers[0], group_id=1, 1175 | gate_type=gate_type, pool_size=32) 1176 | self._make_group(block, 32, layers[1], group_id=2, 1177 | gate_type=gate_type, pool_size=16) 1178 | self._make_group(block, 64, layers[2], group_id=3, 1179 | gate_type=gate_type, pool_size=8) 1180 | 1181 | # remove the last gate instance, (not optimized) 1182 | del self.gate_instances[-1] 1183 | 1184 | self.avgpool = nn.AvgPool2d(8) 1185 | self.fc = nn.Linear(64 * block.expansion, num_classes) 1186 | 1187 | self.softmax = nn.Softmax() 1188 | self.saved_actions = [] 1189 | self.rewards = [] 1190 | 1191 | for m in self.modules(): 1192 | if isinstance(m, nn.Conv2d): 1193 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 1194 | m.weight.data.normal_(0, math.sqrt(2. / n)) 1195 | elif isinstance(m, nn.BatchNorm2d): 1196 | m.weight.data.fill_(1) 1197 | m.bias.data.zero_() 1198 | elif isinstance(m, nn.Linear): 1199 | n = m.weight.size(0) * m.weight.size(1) 1200 | m.weight.data.normal_(0, math.sqrt(2. / n)) 1201 | 1202 | def _make_group(self, block, planes, layers, group_id=1, 1203 | gate_type='fisher', pool_size=16): 1204 | """ Create the whole group""" 1205 | for i in range(layers): 1206 | if group_id > 1 and i == 0: 1207 | stride = 2 1208 | else: 1209 | stride = 1 1210 | 1211 | meta = self._make_layer_v2(block, planes, stride=stride, 1212 | gate_type=gate_type, 1213 | pool_size=pool_size) 1214 | 1215 | setattr(self, 'group{}_ds{}'.format(group_id, i), meta[0]) 1216 | setattr(self, 'group{}_layer{}'.format(group_id, i), meta[1]) 1217 | setattr(self, 'group{}_gate{}'.format(group_id, i), meta[2]) 1218 | 1219 | # add into gate instance collection 1220 | self.gate_instances.append(meta[2]) 1221 | 1222 | def _make_layer_v2(self, block, planes, stride=1, 1223 | gate_type='fisher', pool_size=16): 1224 | """ create one block and optional a gate module """ 1225 | downsample = None 1226 | if stride != 1 or self.inplanes != planes * block.expansion: 1227 | downsample = nn.Sequential( 1228 | nn.Conv2d(self.inplanes, planes * block.expansion, 1229 | kernel_size=1, stride=stride, bias=False), 1230 | nn.BatchNorm2d(planes * block.expansion), 1231 | 1232 | ) 1233 | layer = block(self.inplanes, planes, stride, downsample) 1234 | self.inplanes = planes * block.expansion 1235 | 1236 | if gate_type == 'ffgate1': 1237 | gate_layer = RLFeedforwardGateI(pool_size=pool_size, 1238 | channel=planes*block.expansion) 1239 | elif gate_type == 'ffgate2': 1240 | gate_layer = RLFeedforwardGateII(pool_size=pool_size, 1241 | channel=planes*block.expansion) 1242 | else: 1243 | gate_layer = None 1244 | 1245 | if downsample: 1246 | return downsample, layer, gate_layer 1247 | else: 1248 | return None, layer, gate_layer 1249 | 1250 | def repackage_vars(self): 1251 | self.saved_actions = repackage_hidden(self.saved_actions) 1252 | 1253 | def forward(self, x, reinforce=False): 1254 | x = self.conv1(x) 1255 | x = self.bn1(x) 1256 | x = self.relu(x) 1257 | 1258 | masks = [] 1259 | gprobs = [] 1260 | # must pass through the first layer in first group 1261 | x = getattr(self, 'group1_layer0')(x) 1262 | # gate takes the output of the current layer 1263 | mask, gprob = getattr(self, 'group1_gate0')(x) 1264 | gprobs.append(gprob) 1265 | masks.append(mask.squeeze()) 1266 | prev = x # input of next layer 1267 | 1268 | for g in range(3): 1269 | for i in range(0 + int(g == 0), self.num_layers[g]): 1270 | if getattr(self, 'group{}_ds{}'.format(g+1, i)) is not None: 1271 | prev = getattr(self, 'group{}_ds{}'.format(g+1, i))(prev) 1272 | x = getattr(self, 'group{}_layer{}'.format(g+1, i))(x) 1273 | # new mask is taking the current output 1274 | prev = x = mask.expand_as(x) * x \ 1275 | + (1 - mask).expand_as(prev) * prev 1276 | mask, gprob = getattr(self, 'group{}_gate{}'.format(g+1, i))(x) 1277 | gprobs.append(gprob) 1278 | masks.append(mask.squeeze()) 1279 | 1280 | del masks[-1] 1281 | 1282 | x = self.avgpool(x) 1283 | x = x.view(x.size(0), -1) 1284 | x = self.fc(x) 1285 | 1286 | # collect all actions 1287 | for inst in self.gate_instances: 1288 | self.saved_actions.append(inst.saved_action) 1289 | 1290 | if reinforce: # for pure RL 1291 | softmax = self.softmax(x) 1292 | action = softmax.multinomial() 1293 | self.saved_actions.append(action) 1294 | 1295 | return x, masks, gprobs 1296 | 1297 | 1298 | # FFGate-I 1299 | # For CIFAR-10 1300 | def cifar10_feedfoward_rl_38(pretrained=False, **kwargs): 1301 | """SkipNet-38 + RL with FFGate-I""" 1302 | model = ResNetFeedForwardRL(BasicBlock, [6, 6, 6], 1303 | num_classes=10, gate_type='ffgate1') 1304 | return model 1305 | 1306 | 1307 | def cifar10_feedforward_rl_74(pretrained=False, **kwargs): 1308 | """SkipNet-74 + RL with FFGate-I""" 1309 | model = ResNetFeedForwardRL(BasicBlock, [12, 12, 12], 1310 | num_classes=10, gate_type='ffgate1') 1311 | return model 1312 | 1313 | 1314 | def cifar10_feedforward_rl_110(pretrained=False, **kwargs): 1315 | """SkipNet-110 + RL with FFGate-II""" 1316 | model = ResNetFeedForwardRL(BasicBlock, [18, 18, 18], 1317 | num_classes=10, gate_type='ffgate2') 1318 | return model 1319 | 1320 | 1321 | # For CIFAR-100 1322 | def cifar100_feedford_rl_38(pretrained=False, **kwargs): 1323 | """SkipNet-38 + RL with FFGate-I""" 1324 | model = ResNetFeedForwardRL(BasicBlock, [6, 6, 6], 1325 | num_classes=100, gate_type='ffgate1') 1326 | return model 1327 | 1328 | 1329 | def cifar100_feedforward_rl_74(pretrained=False, **kwargs): 1330 | """SkipNet-74 + RL with FFGate-I""" 1331 | model = ResNetFeedForwardRL(BasicBlock, [12, 12, 12], 1332 | num_classes=100, gate_type='ffgate1') 1333 | return model 1334 | 1335 | 1336 | def cifar100_feedforward_rl_110(pretrained=False, **kwargs): 1337 | """SkipNet-110 + RL with FFGate-II""" 1338 | model = ResNetFeedForwardRL(BasicBlock, [18, 18, 18], 1339 | num_classes=100, gate_type='ffgate2') 1340 | return model 1341 | 1342 | 1343 | ######################################## 1344 | # SkipNet+RL with Recurrent Gate # 1345 | ######################################## 1346 | 1347 | class RNNGatePolicy(nn.Module): 1348 | def __init__(self, input_dim, hidden_dim, rnn_type='lstm'): 1349 | super(RNNGatePolicy, self).__init__() 1350 | 1351 | self.rnn_type = rnn_type 1352 | self.input_dim = input_dim 1353 | self.hidden_dim = hidden_dim 1354 | 1355 | if self.rnn_type == 'lstm': 1356 | self.rnn = nn.LSTM(input_dim, hidden_dim) 1357 | else: 1358 | self.rnn = None 1359 | self.hidden = None 1360 | 1361 | # reduce dim. use softmax here for two actions. 1362 | self.proj = nn.Linear(hidden_dim, 1) 1363 | self.prob = nn.Sigmoid() 1364 | 1365 | # saved actions and rewards 1366 | self.saved_actions = [] 1367 | self.rewards = [] 1368 | 1369 | def hotter(self, t): 1370 | self.proj.weight.data /= t 1371 | self.proj.bias.data /= t 1372 | 1373 | def init_hidden(self, batch_size): 1374 | # The axes semantics are (num_layers, minibatch_size, hidden_dim) 1375 | return (autograd.Variable(torch.zeros(1, batch_size, 1376 | self.hidden_dim).cuda()), 1377 | autograd.Variable(torch.zeros(1, batch_size, 1378 | self.hidden_dim).cuda())) 1379 | 1380 | def repackage_hidden(self): 1381 | self.hidden = repackage_hidden(self.hidden) 1382 | 1383 | def forward(self, x): 1384 | batch_size = x.size(0) 1385 | self.rnn.flatten_parameters() 1386 | out, self.hidden = self.rnn(x.view(1, batch_size, -1), self.hidden) 1387 | 1388 | # do action selection in the forward pass 1389 | if self.training: 1390 | proj = self.proj(out.squeeze()) 1391 | prob = self.prob(proj) 1392 | bi_prob = torch.cat([1 - prob, prob], dim=1) 1393 | action = bi_prob.multinomial() 1394 | self.saved_actions.append(action) 1395 | else: 1396 | proj = self.proj(out.squeeze()) 1397 | prob = self.prob(proj) 1398 | bi_prob = torch.cat([1 - prob, prob], dim=1) 1399 | action = (prob > 0.5).float() 1400 | self.saved_actions.append(action) 1401 | action = action.view(action.size(0), 1, 1, 1).float() 1402 | return action, bi_prob 1403 | 1404 | 1405 | class ResNetRecurrentGateRL(nn.Module): 1406 | """Adding gating module on every basic block""" 1407 | 1408 | def __init__(self, block, layers, num_classes=10, 1409 | embed_dim=64, hidden_dim=64): 1410 | self.inplanes = 16 1411 | super(ResNetRecurrentGateRL, self).__init__() 1412 | 1413 | self.num_layers = layers 1414 | self.conv1 = conv3x3(3, 16) 1415 | self.bn1 = nn.BatchNorm2d(16) 1416 | self.relu = nn.ReLU(inplace=True) 1417 | 1418 | self.embed_dim = embed_dim 1419 | self.hidden_dim = hidden_dim 1420 | 1421 | self._make_group(block, 16, layers[0], group_id=1, pool_size=32) 1422 | self._make_group(block, 32, layers[1], group_id=2, pool_size=16) 1423 | self._make_group(block, 64, layers[2], group_id=3, pool_size=8) 1424 | 1425 | self.control = RNNGatePolicy(embed_dim, hidden_dim) 1426 | 1427 | self.avgpool = nn.AvgPool2d(8) 1428 | self.fc = nn.Linear(64 * block.expansion, num_classes) 1429 | 1430 | self.softmax = nn.Softmax() 1431 | 1432 | self.saved_actions = [] 1433 | self.rewards = [] 1434 | 1435 | for m in self.modules(): 1436 | if isinstance(m, nn.Conv2d): 1437 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 1438 | m.weight.data.normal_(0, math.sqrt(2. / n)) 1439 | elif isinstance(m, nn.BatchNorm2d): 1440 | m.weight.data.fill_(1) 1441 | m.bias.data.zero_() 1442 | elif isinstance(m, nn.Linear): 1443 | n = m.weight.size(0) * m.weight.size(1) 1444 | m.weight.data.normal_(0, math.sqrt(2. / n)) 1445 | m.bias.data.zero_() 1446 | 1447 | def _make_group(self, block, planes, layers, group_id=1, pool_size=16): 1448 | """ Create the whole group""" 1449 | for i in range(layers): 1450 | if group_id > 1 and i == 0: 1451 | stride = 2 1452 | else: 1453 | stride = 1 1454 | 1455 | meta = self._make_layer_v2(block, planes, stride=stride, 1456 | pool_size=pool_size) 1457 | 1458 | setattr(self, 'group{}_ds{}'.format(group_id, i), meta[0]) 1459 | setattr(self, 'group{}_layer{}'.format(group_id, i), meta[1]) 1460 | setattr(self, 'group{}_gate{}'.format(group_id, i), meta[2]) 1461 | 1462 | def _make_layer_v2(self, block, planes, stride=1, pool_size=16): 1463 | """ create one block and optional a gate module """ 1464 | downsample = None 1465 | if stride != 1 or self.inplanes != planes * block.expansion: 1466 | downsample = nn.Sequential( 1467 | nn.Conv2d(self.inplanes, planes * block.expansion, 1468 | kernel_size=1, stride=stride, bias=False), 1469 | nn.BatchNorm2d(planes * block.expansion), 1470 | 1471 | ) 1472 | layer = block(self.inplanes, planes, stride, downsample) 1473 | self.inplanes = planes * block.expansion 1474 | 1475 | gate_layer = nn.Sequential( 1476 | nn.AvgPool2d(pool_size), 1477 | nn.Conv2d(in_channels=planes * block.expansion, 1478 | out_channels=self.embed_dim, 1479 | kernel_size=1, 1480 | stride=1)) 1481 | 1482 | return downsample, layer, gate_layer 1483 | 1484 | def forward(self, x): 1485 | batch_size = x.size(0) 1486 | x = self.conv1(x) 1487 | x = self.bn1(x) 1488 | x = self.relu(x) 1489 | 1490 | # reinitialize hidden units 1491 | self.control.hidden = self.control.init_hidden(batch_size) 1492 | 1493 | masks = [] 1494 | gprobs = [] 1495 | # must pass through the first layer in first group 1496 | x = getattr(self, 'group1_layer0')(x) 1497 | # gate takes the output of the current layer 1498 | gate_feature = getattr(self, 'group1_gate0')(x) 1499 | 1500 | mask, gprob = self.control(gate_feature) 1501 | gprobs.append(gprob) 1502 | masks.append(mask.squeeze()) 1503 | prev = x 1504 | 1505 | for g in range(3): 1506 | for i in range(0 + int(g == 0), self.num_layers[g]): 1507 | if getattr(self, 'group{}_ds{}'.format(g+1, i)) is not None: 1508 | prev = getattr(self, 'group{}_ds{}'.format(g+1, i))(prev) 1509 | x = getattr(self, 'group{}_layer{}'.format(g+1, i))(x) 1510 | prev = x = mask.expand_as(x) * x + \ 1511 | (1 - mask).expand_as(prev)*prev 1512 | if not (g == 2 and (i == self.num_layers[g] -1)): 1513 | gate_feature = getattr(self, 1514 | 'group{}_gate{}'.format(g+1, i))(x) 1515 | mask, gprob = self.control(gate_feature) 1516 | gprobs.append(gprob) 1517 | masks.append(mask.squeeze()) 1518 | 1519 | x = self.avgpool(x) 1520 | x = x.view(x.size(0), -1) 1521 | 1522 | if self.training: 1523 | x = self.fc(x) 1524 | softmax = self.softmax(x) 1525 | pred = softmax.multinomial() 1526 | else: 1527 | x = self.fc(x) 1528 | pred = x.max(1)[1] 1529 | self.saved_actions.append(pred) 1530 | 1531 | return x, masks, gprobs 1532 | 1533 | 1534 | # for CIFAR-10 1535 | def cifar10_rnn_gate_rl_38(pretrained=False, **kwargs): 1536 | """SkipNet-38 + RL with Recurrent Gate""" 1537 | model = ResNetRecurrentGateRL(BasicBlock, [6, 6, 6], num_classes=10, 1538 | embed_dim=10, hidden_dim=10) 1539 | return model 1540 | 1541 | 1542 | def cifar10_rnn_gate_rl_74(pretrained=False, **kwargs): 1543 | """SkipNet-74 + RL with Recurrent Gate""" 1544 | model = ResNetRecurrentGateRL(BasicBlock, [12, 12, 12], num_classes=10, 1545 | embed_dim=10, hidden_dim=10) 1546 | return model 1547 | 1548 | 1549 | def cifar10_rnn_gate_rl_110(pretrained=False, **kwargs): 1550 | """SkipNet-110 + RL with Recurrent Gate""" 1551 | model = ResNetRecurrentGateRL(BasicBlock, [18, 18, 18], num_classes=10, 1552 | embed_dim=10, hidden_dim=10) 1553 | return model 1554 | 1555 | 1556 | # for CIFAR-100 1557 | def cifar100_rnn_gate_rl_38(pretrained=False, **kwargs): 1558 | """SkipNet-38 + RL with Recurrent Gate""" 1559 | model = ResNetRecurrentGateRL(BasicBlock, [6, 6, 6], num_classes=100, 1560 | embed_dim=10, hidden_dim=10) 1561 | return model 1562 | 1563 | 1564 | def cifar100_rnn_gate_rl_74(pretrained=False, **kwargs): 1565 | """SkipNet-74 + RL with Recurrent Gate""" 1566 | model = ResNetRecurrentGateRL(BasicBlock, [12, 12, 12], num_classes=100, 1567 | embed_dim=10, hidden_dim=10) 1568 | return model 1569 | 1570 | 1571 | def cifar100_rnn_gate_rl_110(pretrained=False, **kwargs): 1572 | """SkipNet-110 + RL with Recurrent Gate""" 1573 | model = ResNetRecurrentGateRL(BasicBlock, [18, 18, 18], num_classes=100, 1574 | embed_dim=10, hidden_dim=10) 1575 | return model 1576 | 1577 | 1578 | -------------------------------------------------------------------------------- /models_initial.py: -------------------------------------------------------------------------------- 1 | """ This file contains the model definitions for both original ResNet (6n+2 2 | layers) and SkipNets. 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import math 8 | from torch.autograd import Variable 9 | import torch.autograd as autograd 10 | from modules.quantize import quantize, quantize_grad, QConv2d, QLinear, RangeBN 11 | import torch.nn.functional as F 12 | 13 | 14 | NUM_BITS = 8 15 | NUM_BITS_WEIGHT = None 16 | NUM_BITS_GRAD = None 17 | BIPRECISION = False 18 | 19 | 20 | def Conv3x3(in_planes, out_planes, stride=1): 21 | "3x3 convolution with padding" 22 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 23 | padding=1, bias=False) 24 | 25 | 26 | def conv3x3(in_planes, out_planes, stride=1, pool_size = None): 27 | "3x3 convolution with padding" 28 | return QConv2d(in_planes, out_planes, kernel_size=3, stride=stride, 29 | padding=1, bias=False, num_bits=NUM_BITS, num_bits_weight=NUM_BITS_WEIGHT, num_bits_grad=NUM_BITS_GRAD, biprecision=BIPRECISION) 30 | 31 | class BasicBlock(nn.Module): 32 | expansion = 1 33 | 34 | def __init__(self, inplanes, planes, stride=1, downsample=None): 35 | super(BasicBlock, self).__init__() 36 | self.conv1 = conv3x3(inplanes, planes, stride) 37 | self.bn1 = nn.BatchNorm2d(planes) 38 | self.relu = nn.ReLU(inplace=True) 39 | self.conv2 = conv3x3(planes, planes) 40 | self.bn2 = nn.BatchNorm2d(planes) 41 | self.bn3 = nn.BatchNorm2d(planes) 42 | self.downsample = downsample 43 | self.stride = stride 44 | 45 | def forward(self, x, bit): 46 | residual = x 47 | 48 | out = self.conv1(x, bit) 49 | out = self.bn1(out) 50 | out = self.relu(out) 51 | 52 | out = self.conv2(out, bit) 53 | out = self.bn2(out) 54 | 55 | if self.downsample is not None: 56 | residual = self.downsample(x, bit) 57 | residual = self.bn3(residual) 58 | 59 | out += residual 60 | out = self.relu(out) 61 | return out 62 | 63 | 64 | ######################################## 65 | # Original ResNet # 66 | ######################################## 67 | 68 | 69 | class ResNet(nn.Module): 70 | """Original ResNet without routing modules""" 71 | def __init__(self, block, layers, num_classes=10): 72 | self.inplanes = 16 73 | super(ResNet, self).__init__() 74 | self.conv1 = conv3x3(3, 16) 75 | self.bn1 = nn.BatchNorm2d(16) 76 | self.relu = nn.ReLU(inplace=True) 77 | self.layer1 = self._make_layer(block, 16, layers[0]) 78 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2) 79 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2) 80 | self.avgpool = nn.AvgPool2d(8) 81 | self.fc = nn.Linear(64 * block.expansion, num_classes) 82 | 83 | for m in self.modules(): 84 | if isinstance(m, nn.Conv2d): 85 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 86 | m.weight.data.normal_(0, math.sqrt(2. / n)) 87 | elif isinstance(m, nn.BatchNorm2d): 88 | m.weight.data.fill_(1) 89 | m.bias.data.zero_() 90 | 91 | def _make_layer(self, block, planes, blocks, stride=1): 92 | downsample = None 93 | if stride != 1 or self.inplanes != planes * block.expansion: 94 | downsample = QConv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False, num_bits=NUM_BITS, num_bits_weight=NUM_BITS_WEIGHT, num_bits_grad=NUM_BITS_GRAD, biprecision=BIPRECISION) 95 | 96 | layers = [] 97 | layers.append(block(self.inplanes, planes, stride, downsample)) 98 | self.inplanes = planes * block.expansion 99 | for i in range(1, blocks): 100 | layers.append(block(self.inplanes, planes)) 101 | 102 | return nn.Sequential(*layers) 103 | 104 | def forward(self, x): 105 | x = self.conv1(x, 0) 106 | x = self.bn1(x) 107 | x = self.relu(x) 108 | 109 | x = self.layer1(x) 110 | x = self.layer2(x) 111 | x = self.layer3(x) 112 | 113 | x = self.avgpool(x) 114 | x = x.view(x.size(0), -1) 115 | x = self.fc(x) 116 | return x 117 | 118 | # For CIFAR-10 119 | # ResNet-38 120 | def cifar10_resnet_38(pretrained=False, **kwargs): 121 | # n = 6 122 | model = ResNet(BasicBlock, [6, 6, 6], **kwargs) 123 | return model 124 | 125 | 126 | # ResNet-74 127 | def cifar10_resnet_74(pretrained=False, **kwargs): 128 | # n = 12 129 | model = ResNet(BasicBlock, [12, 12, 12], **kwargs) 130 | return model 131 | 132 | 133 | # ResNet-110 134 | def cifar10_resnet_110(pretrained=False, **kwargs): 135 | # n = 18 136 | model = ResNet(BasicBlock, [18, 18, 18], **kwargs) 137 | return model 138 | 139 | 140 | # ResNet-152 141 | def cifar10_resnet_152(pretrained=False, **kwargs): 142 | # n = 25 143 | model = ResNet(BasicBlock, [25, 25, 25], **kwargs) 144 | return model 145 | 146 | 147 | # For CIFAR-100 148 | # ResNet-38 149 | def cifar100_resnet_38(pretrained=False, **kwargs): 150 | # n = 6 151 | model = ResNet(BasicBlock, [6, 6, 6], num_classes=100) 152 | return model 153 | 154 | 155 | # ResNet-74 156 | def cifar100_resnet_74(pretrained=False, **kwargs): 157 | # n = 12 158 | model = ResNet(BasicBlock, [12, 12, 12], num_classes=100) 159 | return model 160 | 161 | 162 | # ResNet-110 163 | def cifar100_resnet_110(pretrained=False, **kwargs): 164 | # n = 18 165 | model = ResNet(BasicBlock, [18, 18, 18], num_classes=100) 166 | return model 167 | 168 | 169 | # ResNet-152 170 | def cifar100_resnet_152(pretrained=False, **kwargs): 171 | # n = 25 172 | model = ResNet(BasicBlock, [25, 25, 25], num_classes=100) 173 | return model 174 | 175 | 176 | ######################################## 177 | # SkipNet+SP with Feedforward Gate # 178 | ######################################## 179 | 180 | 181 | # Feedforward-Gate (FFGate-I) 182 | class FeedforwardGateI(nn.Module): 183 | """ Use Max Pooling First and then apply to multiple 2 conv layers. 184 | The first conv has stride = 1 and second has stride = 2""" 185 | def __init__(self, pool_size=5, channel=10): 186 | super(FeedforwardGateI, self).__init__() 187 | self.pool_size = pool_size 188 | self.channel = channel 189 | 190 | self.maxpool = nn.MaxPool2d(2) 191 | self.conv1 = conv3x3(channel, channel) 192 | self.bn1 = nn.BatchNorm2d(channel) 193 | self.relu1 = nn.ReLU(inplace=True) 194 | 195 | # adding another conv layer 196 | self.conv2 = conv3x3(channel, channel, stride=2) 197 | self.bn2 = nn.BatchNorm2d(channel) 198 | self.relu2 = nn.ReLU(inplace=True) 199 | 200 | pool_size = math.floor(pool_size/2) # for max pooling 201 | pool_size = math.floor(pool_size/2 + 0.5) # for conv stride = 2 202 | 203 | self.avg_layer = nn.AvgPool2d(pool_size) 204 | self.linear_layer = nn.Conv2d(in_channels=channel, out_channels=2, 205 | kernel_size=1, stride=1) 206 | self.prob_layer = nn.Softmax() 207 | self.logprob = nn.LogSoftmax() 208 | 209 | def forward(self, x): 210 | x = self.maxpool(x) 211 | x = self.conv1(x) 212 | x = self.bn1(x) 213 | x = self.relu1(x) 214 | 215 | x = self.conv2(x) 216 | x = self.bn2(x) 217 | x = self.relu2(x) 218 | 219 | x = self.avg_layer(x) 220 | x = self.linear_layer(x).squeeze() 221 | softmax = self.prob_layer(x) 222 | logprob = self.logprob(x) 223 | 224 | # discretize output in forward pass. 225 | # use softmax gradients in backward pass 226 | x = (softmax[:, 1] > 0.5).float().detach() - \ 227 | softmax[:, 1].detach() + softmax[:, 1] 228 | 229 | x = x.view(x.size(0), 1, 1, 1) 230 | return x, logprob 231 | 232 | 233 | # soft gate v3 (matching FFGate-I) 234 | class SoftGateI(nn.Module): 235 | """This module has the same structure as FFGate-I. 236 | In training, adopt continuous gate output. In inference phase, 237 | use discrete gate outputs""" 238 | def __init__(self, pool_size=5, channel=10): 239 | super(SoftGateI, self).__init__() 240 | self.pool_size = pool_size 241 | self.channel = channel 242 | 243 | self.maxpool = nn.MaxPool2d(2) 244 | self.conv1 = conv3x3(channel, channel) 245 | self.bn1 = nn.BatchNorm2d(channel) 246 | self.relu1 = nn.ReLU(inplace=True) 247 | 248 | # adding another conv layer 249 | self.conv2 = conv3x3(channel, channel, stride=2) 250 | self.bn2 = nn.BatchNorm2d(channel) 251 | self.relu2 = nn.ReLU(inplace=True) 252 | 253 | pool_size = math.floor(pool_size/2) # for max pooling 254 | pool_size = math.floor(pool_size/2 + 0.5) # for conv stride = 2 255 | 256 | self.avg_layer = nn.AvgPool2d(pool_size) 257 | self.linear_layer = nn.Conv2d(in_channels=channel, out_channels=2, 258 | kernel_size=1, stride=1) 259 | self.prob_layer = nn.Softmax() 260 | self.logprob = nn.LogSoftmax() 261 | 262 | def forward(self, x): 263 | x = self.maxpool(x) 264 | x = self.conv1(x) 265 | x = self.bn1(x) 266 | x = self.relu1(x) 267 | 268 | x = self.conv2(x) 269 | x = self.bn2(x) 270 | x = self.relu2(x) 271 | 272 | x = self.avg_layer(x) 273 | x = self.linear_layer(x).squeeze() 274 | softmax = self.prob_layer(x) 275 | logprob = self.logprob(x) 276 | 277 | x = softmax[:, 1].contiguous() 278 | x = x.view(x.size(0), 1, 1, 1) 279 | 280 | if not self.training: 281 | x = (x > 0.5).float() 282 | return x, logprob 283 | 284 | 285 | # FFGate-II 286 | class FeedforwardGateII(nn.Module): 287 | """ use single conv (stride=2) layer only""" 288 | def __init__(self, pool_size=5, channel=10): 289 | super(FeedforwardGateII, self).__init__() 290 | self.pool_size = pool_size 291 | self.channel = channel 292 | 293 | self.conv1 = conv3x3(channel, channel, stride=2) 294 | self.bn1 = nn.BatchNorm2d(channel) 295 | self.relu1 = nn.ReLU(inplace=True) 296 | 297 | pool_size = math.floor(pool_size/2 + 0.5) # for conv stride = 2 298 | 299 | self.avg_layer = nn.AvgPool2d(pool_size) 300 | self.linear_layer = nn.Conv2d(in_channels=channel, out_channels=2, 301 | kernel_size=1, stride=1) 302 | self.prob_layer = nn.Softmax() 303 | self.logprob = nn.LogSoftmax() 304 | 305 | def forward(self, x): 306 | x = self.conv1(x) 307 | x = self.bn1(x) 308 | x = self.relu1(x) 309 | 310 | x = self.avg_layer(x) 311 | x = self.linear_layer(x).squeeze() 312 | softmax = self.prob_layer(x) 313 | logprob = self.logprob(x) 314 | 315 | # discretize 316 | x = (softmax[:, 1] > 0.5).float().detach() - \ 317 | softmax[:, 1].detach() + softmax[:, 1] 318 | 319 | x = x.view(x.size(0), 1, 1, 1) 320 | return x, logprob 321 | 322 | 323 | class SoftGateII(nn.Module): 324 | """ Soft gating version of FFGate-II""" 325 | def __init__(self, pool_size=5, channel=10): 326 | super(SoftGateII, self).__init__() 327 | self.pool_size = pool_size 328 | self.channel = channel 329 | 330 | self.conv1 = conv3x3(channel, channel, stride=2) 331 | self.bn1 = nn.BatchNorm2d(channel) 332 | self.relu1 = nn.ReLU(inplace=True) 333 | 334 | pool_size = math.floor(pool_size / 2 + 0.5) # for conv stride = 2 335 | 336 | self.avg_layer = nn.AvgPool2d(pool_size) 337 | self.linear_layer = nn.Conv2d(in_channels=channel, out_channels=2, 338 | kernel_size=1, stride=1) 339 | self.prob_layer = nn.Softmax() 340 | self.logprob = nn.LogSoftmax() 341 | 342 | def forward(self, x): 343 | x = self.conv1(x) 344 | x = self.bn1(x) 345 | x = self.relu1(x) 346 | 347 | x = self.avg_layer(x) 348 | x = self.linear_layer(x).squeeze() 349 | softmax = self.prob_layer(x) 350 | logprob = self.logprob(x) 351 | 352 | x = softmax[:, 1].contiguous() 353 | x = x.view(x.size(0), 1, 1, 1) 354 | if not self.training: 355 | x = (x > 0.5).float() 356 | return x, logprob 357 | 358 | 359 | class ResNetFeedForwardSP(nn.Module): 360 | """ SkipNets with Feed-forward Gates for Supervised Pre-training stage. 361 | Adding one routing module after each basic block.""" 362 | 363 | def __init__(self, block, layers, num_classes=10, 364 | gate_type='fisher', **kwargs): 365 | self.inplanes = 16 366 | super(ResNetFeedForwardSP, self).__init__() 367 | 368 | self.num_layers = layers 369 | self.conv1 = conv3x3(3, 16) 370 | self.bn1 = nn.BatchNorm2d(16) 371 | self.relu = nn.ReLU(inplace=True) 372 | 373 | # going to have 3 groups of layers. For the easiness of skipping, 374 | # We are going to break the sequential of layers into a list of layers. 375 | 376 | self.gate_type = gate_type 377 | self._make_group(block, 16, layers[0], group_id=1, 378 | gate_type=gate_type, pool_size=32) 379 | self._make_group(block, 32, layers[1], group_id=2, 380 | gate_type=gate_type, pool_size=16) 381 | self._make_group(block, 64, layers[2], group_id=3, 382 | gate_type=gate_type, pool_size=8) 383 | 384 | self.avgpool = nn.AvgPool2d(8) 385 | self.fc = nn.Linear(64 * block.expansion, num_classes) 386 | 387 | for m in self.modules(): 388 | if isinstance(m, nn.Conv2d): 389 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 390 | m.weight.data.normal_(0, math.sqrt(2. / n)) 391 | elif isinstance(m, nn.BatchNorm2d): 392 | m.weight.data.fill_(1) 393 | m.bias.data.zero_() 394 | elif isinstance(m, nn.Linear): 395 | n = m.weight.size(0) * m.weight.size(1) 396 | m.weight.data.normal_(0, math.sqrt(2. / n)) 397 | 398 | def _make_group(self, block, planes, layers, group_id=1, 399 | gate_type='fisher', pool_size=16): 400 | """ Create the whole group""" 401 | for i in range(layers): 402 | if group_id > 1 and i == 0: 403 | stride = 2 404 | else: 405 | stride = 1 406 | 407 | meta = self._make_layer_v2(block, planes, stride=stride, 408 | gate_type=gate_type, 409 | pool_size=pool_size) 410 | 411 | setattr(self, 'group{}_ds{}'.format(group_id, i), meta[0]) 412 | setattr(self, 'group{}_layer{}'.format(group_id, i), meta[1]) 413 | setattr(self, 'group{}_gate{}'.format(group_id, i), meta[2]) 414 | 415 | def _make_layer_v2(self, block, planes, stride=1, 416 | gate_type='fisher', pool_size=16): 417 | """ create one block and optional a gate module """ 418 | downsample = None 419 | if stride != 1 or self.inplanes != planes * block.expansion: 420 | downsample = nn.Sequential( 421 | nn.Conv2d(self.inplanes, planes * block.expansion, 422 | kernel_size=1, stride=stride, bias=False), 423 | nn.BatchNorm2d(planes * block.expansion), 424 | 425 | ) 426 | layer = block(self.inplanes, planes, stride, downsample) 427 | self.inplanes = planes * block.expansion 428 | 429 | if gate_type == 'ffgate1': 430 | gate_layer = FeedforwardGateI(pool_size=pool_size, 431 | channel=planes*block.expansion) 432 | elif gate_type == 'ffgate2': 433 | gate_layer = FeedforwardGateII(pool_size=pool_size, 434 | channel=planes*block.expansion) 435 | elif gate_type == 'softgate1': 436 | gate_layer = SoftGateI(pool_size=pool_size, 437 | channel=planes*block.expansion) 438 | elif gate_type == 'softgate2': 439 | gate_layer = SoftGateII(pool_size=pool_size, 440 | channel=planes*block.expansion) 441 | else: 442 | gate_layer = None 443 | 444 | if downsample: 445 | return downsample, layer, gate_layer 446 | else: 447 | return None, layer, gate_layer 448 | 449 | def forward(self, x): 450 | """Return output logits, masks(gate ouputs) and probabilities 451 | associated to each gate.""" 452 | 453 | x = self.conv1(x) 454 | x = self.bn1(x) 455 | x = self.relu(x) 456 | 457 | masks = [] 458 | gprobs = [] 459 | # must pass through the first layer in first group 460 | x = getattr(self, 'group1_layer0')(x) 461 | # gate takes the output of the current layer 462 | 463 | mask, gprob = getattr(self, 'group1_gate0')(x) 464 | gprobs.append(gprob) 465 | masks.append(mask.squeeze()) 466 | prev = x # input of next layer 467 | 468 | for g in range(3): 469 | for i in range(0 + int(g == 0), self.num_layers[g]): 470 | if getattr(self, 'group{}_ds{}'.format(g+1, i)) is not None: 471 | prev = getattr(self, 'group{}_ds{}'.format(g+1, i))(prev) 472 | x = getattr(self, 'group{}_layer{}'.format(g+1, i))(x) 473 | prev = x = mask.expand_as(x) * x \ 474 | + (1 - mask).expand_as(prev) * prev 475 | mask, gprob = getattr(self, 'group{}_gate{}'.format(g+1, i))(x) 476 | gprobs.append(gprob) 477 | masks.append(mask.squeeze()) 478 | 479 | del masks[-1] 480 | 481 | x = self.avgpool(x) 482 | x = x.view(x.size(0), -1) 483 | x = self.fc(x) 484 | 485 | return x, masks, gprobs 486 | 487 | 488 | # FeeforwardGate-I 489 | # For CIFAR-10 490 | def cifar10_feedforward_38(pretrained=False, **kwargs): 491 | """SkipNet-38 with FFGate-I""" 492 | model = ResNetFeedForwardSP(BasicBlock, [6, 6, 6], gate_type='ffgate1') 493 | return model 494 | 495 | 496 | def cifar10_feedforward_74(pretrained=False, **kwargs): 497 | """SkipNet-74 with FFGate-I""" 498 | model = ResNetFeedForwardSP(BasicBlock, [12, 12, 12], gate_type='ffgate1') 499 | return model 500 | 501 | 502 | def cifar10_feedforward_110(pretrained=False, **kwargs): 503 | """SkipNet-110 with FFGate-II""" 504 | model = ResNetFeedForwardSP(BasicBlock, [18, 18, 18], gate_type='ffgate2') 505 | return model 506 | 507 | 508 | # For CIFAR-100 509 | def cifar100_feeforward_38(pretrained=False, **kwargs): 510 | """SkipNet-38 with FFGate-I""" 511 | model = ResNetFeedForwardSP(BasicBlock, [6, 6, 6], num_classes=100, 512 | gate_type='ffgate1') 513 | return model 514 | 515 | 516 | def cifar100_feedforward_74(pretrained=False, **kwargs): 517 | """SkipNet-74 with FFGate-I""" 518 | model = ResNetFeedForwardSP(BasicBlock, [12, 12, 12], num_classes=100, 519 | gate_type='ffgate1') 520 | return model 521 | 522 | 523 | def cifar100_feedforward_110(pretrained=False, **kwargs): 524 | """SkipNet-110 with FFGate-II""" 525 | model = ResNetFeedForwardSP(BasicBlock, [18, 18, 18], num_classes=100, 526 | gate_type='ffgate2') 527 | return model 528 | 529 | 530 | ######################################## 531 | # SkipNet+SP with Recurrent Gate # 532 | ######################################## 533 | 534 | 535 | # For Recurrent Gate 536 | def repackage_hidden(h): 537 | """ to reduce memory usage""" 538 | if type(h) == Variable: 539 | return Variable(h.data) 540 | else: 541 | return tuple(repackage_hidden(v) for v in h) 542 | 543 | 544 | class RNNGate(nn.Module): 545 | """Recurrent Gate definition. 546 | Input is already passed through average pooling and embedding.""" 547 | def __init__(self, input_dim, hidden_dim, rnn_type='lstm'): 548 | super(RNNGate, self).__init__() 549 | self.rnn_type = rnn_type 550 | self.input_dim = input_dim 551 | self.hidden_dim = hidden_dim 552 | 553 | if self.rnn_type == 'lstm': 554 | self.rnn_one = nn.LSTM(input_dim, hidden_dim) 555 | # self.rnn_two = nn.LSTM(hidden_dim, hidden_dim) 556 | else: 557 | self.rnn = None 558 | self.hidden_one = None 559 | # self.hidden_two = None 560 | 561 | # reduce dim 562 | self.proj = nn.Linear(hidden_dim, 7) 563 | # self.proj_two = nn.Linear(hidden_dim, 4) 564 | self.prob = nn.Sigmoid() 565 | self.prob_layer = nn.Softmax() 566 | 567 | def init_hidden(self, batch_size): 568 | # The axes semantics are (num_layers, minibatch_size, hidden_dim) 569 | return (autograd.Variable(torch.zeros(1, batch_size, 570 | self.hidden_dim).cuda()), 571 | autograd.Variable(torch.zeros(1, batch_size, 572 | self.hidden_dim).cuda())) 573 | 574 | def repackage_hidden(self): 575 | self.hidden_one = repackage_hidden(self.hidden_one) 576 | # self.hidden_two = repackage_hidden(self.hidden_two) 577 | def forward(self, x): 578 | # Take the convolution output of each step 579 | batch_size = x.size(0) 580 | self.rnn_one.flatten_parameters() 581 | # self.rnn_two.flatten_parameters() 582 | 583 | out_one, self.hidden_one = self.rnn_one(x.view(1, batch_size, -1), self.hidden_one) 584 | 585 | # out_one = F.dropout(out_one, p = 0.1, training=True) 586 | 587 | # out_two, self.hidden_two = self.rnn_two(out_one.view(1, batch_size, -1), self.hidden_two) 588 | 589 | x_one = self.proj(out_one.squeeze()) 590 | # x_two = self.proj_two(out_two.squeeze()) 591 | 592 | # proj = self.proj(out.squeeze()) 593 | prob = self.prob_layer(x_one) 594 | # prob_two = self.prob_layer(x_two) 595 | 596 | # x_one = (prob > 0.5).float().detach() - \ 597 | # prob.detach() + prob 598 | 599 | # x_two = prob_two.detach().cpu().numpy() 600 | 601 | x_one = prob.detach().cpu().numpy() 602 | 603 | hard = (x_one == x_one.max(axis=1)[:,None]).astype(int) 604 | hard = torch.from_numpy(hard) 605 | hard = hard.cuda() 606 | 607 | # x_two = hard.float().detach() - \ 608 | # prob_two.detach() + prob_two 609 | 610 | x_one = hard.float().detach() - \ 611 | prob.detach() + prob 612 | 613 | # print(x_one) 614 | 615 | x_one = x_one.view(x_one.size(0),x_one.size(1), 1, 1, 1) 616 | 617 | # x_two = x_two.view(x_two.size(0), x_two.size(1), 1, 1, 1) 618 | 619 | return x_one # , x_two 620 | 621 | 622 | class SoftRNNGate(nn.Module): 623 | def __init__(self, input_dim, hidden_dim, rnn_type='lstm'): 624 | super(SoftRNNGate, self).__init__() 625 | self.rnn_type = rnn_type 626 | self.input_dim = input_dim 627 | self.hidden_dim = hidden_dim 628 | 629 | if self.rnn_type == 'lstm': 630 | self.rnn = nn.LSTM(input_dim, hidden_dim) 631 | else: 632 | self.rnn = None 633 | self.hidden = None 634 | 635 | # reduce dim 636 | self.proj = nn.Linear(hidden_dim, 1) 637 | self.prob = nn.Sigmoid() 638 | 639 | def init_hidden(self, batch_size): 640 | return (autograd.Variable(torch.zeros(1, batch_size, 641 | self.hidden_dim).cuda()), 642 | autograd.Variable(torch.zeros(1, batch_size, 643 | self.hidden_dim).cuda())) 644 | 645 | def repackage_hidden(self): 646 | self.hidden = repackage_hidden(self.hidden) 647 | 648 | def forward(self, x): 649 | # Take the convolution output of each step 650 | batch_size = x.size(0) 651 | self.rnn.flatten_parameters() 652 | out, self.hidden = self.rnn(x.view(1, batch_size, -1), self.hidden) 653 | 654 | proj = self.proj(out.squeeze()) 655 | prob = self.prob(proj) 656 | 657 | x = prob.view(batch_size, 1, 1, 1) 658 | if not self.training: 659 | x = (x > 0.5).float() 660 | return x, prob 661 | 662 | 663 | class ResNetRecurrentGateSP(nn.Module): 664 | """SkipNet with Recurrent Gate Model""" 665 | def __init__(self, block, layers, num_classes=10, embed_dim=10, 666 | hidden_dim=10, gate_type='rnn'): 667 | self.inplanes = 16 668 | super(ResNetRecurrentGateSP, self).__init__() 669 | 670 | self.num_layers = layers 671 | self.conv1 = conv3x3(3, 16) 672 | self.bn1 = nn.BatchNorm2d(16) 673 | self.relu = nn.ReLU(inplace=True) 674 | 675 | self.embed_dim = embed_dim 676 | self.hidden_dim = hidden_dim 677 | 678 | self._make_group(block, 16, layers[0], group_id=1, pool_size=32) 679 | self._make_group(block, 32, layers[1], group_id=2, pool_size=16) 680 | self._make_group(block, 64, layers[2], group_id=3, pool_size=8) 681 | 682 | # define recurrent gating module 683 | if gate_type == 'rnn': 684 | self.control = RNNGate(embed_dim, hidden_dim, rnn_type='lstm') 685 | elif gate_type == 'soft': 686 | self.control = SoftRNNGate(embed_dim, hidden_dim, rnn_type='lstm') 687 | else: 688 | print('gate type {} not implemented'.format(gate_type)) 689 | self.control = None 690 | 691 | self.avgpool = nn.AvgPool2d(8) 692 | self.fc = nn.Linear(64 * block.expansion, num_classes) 693 | 694 | for m in self.modules(): 695 | if isinstance(m, nn.Conv2d): 696 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 697 | m.weight.data.normal_(0, math.sqrt(2. / n)) 698 | elif isinstance(m, nn.BatchNorm2d): 699 | m.weight.data.fill_(1) 700 | m.bias.data.zero_() 701 | elif isinstance(m, nn.Linear): 702 | n = m.weight.size(0) * m.weight.size(1) 703 | m.weight.data.normal_(0, math.sqrt(2. / n)) 704 | 705 | def _make_group(self, block, planes, layers, group_id=1, pool_size=16): 706 | """ Create the whole group""" 707 | for i in range(layers): 708 | if group_id > 1 and i == 0: 709 | stride = 2 710 | else: 711 | stride = 1 712 | 713 | meta = self._make_layer_v2(block, planes, stride=stride, 714 | pool_size=pool_size) 715 | 716 | setattr(self, 'group{}_ds{}'.format(group_id, i), meta[0]) 717 | setattr(self, 'group{}_layer{}'.format(group_id, i), meta[1]) 718 | setattr(self, 'group{}_gate{}'.format(group_id, i), meta[2]) 719 | setattr(self, 'group{}_bn{}'.format(group_id, i), meta[3]) 720 | 721 | def _make_layer_v2(self, block, planes, stride=1, pool_size=16): 722 | """ create one block and optional a gate module """ 723 | downsample = None 724 | if stride != 1 or self.inplanes != planes * block.expansion: 725 | # downsample = nn.Sequential( 726 | # nn.Conv2d(self.inplanes, planes * block.expansion, 727 | # kernel_size=1, stride=stride, bias=False), 728 | # nn.BatchNorm2d(planes * block.expansion), 729 | 730 | # ) 731 | 732 | downsample = QConv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False, num_bits=NUM_BITS, num_bits_weight=NUM_BITS_WEIGHT, num_bits_grad=NUM_BITS_GRAD, biprecision=BIPRECISION) 733 | 734 | 735 | layer = block(self.inplanes, planes, stride, downsample) 736 | self.inplanes = planes * block.expansion 737 | 738 | 739 | bn = layer.bn3 740 | 741 | gate_layer = nn.Sequential( 742 | nn.AvgPool2d(pool_size), 743 | nn.Conv2d(in_channels=planes * block.expansion, 744 | out_channels=self.embed_dim, 745 | kernel_size=1, 746 | stride=1)) 747 | if downsample: 748 | return downsample, layer, gate_layer, bn 749 | else: 750 | return None, layer, gate_layer, None 751 | 752 | def forward(self, x, bits): 753 | 754 | batch_size = x.size(0) 755 | x = self.conv1(x, 0) 756 | x = self.bn1(x) 757 | x = self.relu(x) 758 | 759 | # reinitialize hidden units 760 | self.control.hidden_one = self.control.init_hidden(batch_size) 761 | # self.control.hidden_two = self.control.init_hidden(batch_size) 762 | 763 | masks = [] 764 | # gprobs = [] 765 | # must pass through the first layer in first group 766 | x = getattr(self, 'group1_layer0')(x, 0) 767 | # gate takes the output of the current layer 768 | 769 | gate_feature = getattr(self, 'group1_gate0')(x) 770 | mask_one = self.control(gate_feature) 771 | 772 | # bits = [4,8,16,0] 773 | # bits = [8,16,0] 774 | 775 | # gprobs.append(gprob) 776 | # masks.append(mask.squeeze()) 777 | # prev = x # input of next layer 778 | 779 | for g in range(3): 780 | for i in range(0 + int(g == 0), self.num_layers[g]): 781 | # if getattr(self, 'group{}_ds{}'.format(g+1, i)) is not None: 782 | # prev = getattr(self, 'group{}_ds{}'.format(g+1, i))(prev, 0) 783 | # prev = getattr(self, 'group{}_bn{}'.format(g+1, i))(prev) 784 | 785 | output_candidates = [] 786 | 787 | # output_candidates.append(prev) 788 | 789 | # for k in range(2): 790 | # out = getattr(self, 'group{}_layer{}'.format(g+1, i))(x, bits[k]) 791 | # output_candidates.append(out) 792 | 793 | x = getattr(self, 'group{}_layer{}'.format(g+1, i))(x, 0) 794 | 795 | mask_list = [] 796 | 797 | # mask_list.append(1 - mask_one) 798 | 799 | # for j in range(4): 800 | # mask_list.append(mask_one * mask_two[:,j,:,:,:]) 801 | 802 | for j in range(len(bits) + 1): 803 | mask_list.append(mask_one[:,j,:,:,:]) 804 | 805 | # prev = x = sum([mask_list[k].expand_as(out) * output_candidates[k] for k in range(5)]) 806 | 807 | # x = sum([mask_list[k].expand_as(out) * output_candidates[k] for k in range(2)]) 808 | 809 | mask_list = [mask.squeeze() for mask in mask_list] 810 | 811 | masks.append(mask_list) 812 | 813 | # new mask is taking the current output 814 | # prev = x = mask.expand_as(x) * x \ 815 | # + (1 - mask).expand_as(prev) * prev 816 | gate_feature = getattr(self, 'group{}_gate{}'.format(g+1, i))(x) 817 | mask_one = self.control(gate_feature) 818 | # gprobs.append(gprob) 819 | 820 | 821 | 822 | 823 | # masks.append(mask.squeeze()) 824 | 825 | # last block doesn't have gate module 826 | # del masks[-1] 827 | 828 | x = self.avgpool(x) 829 | x = x.view(x.size(0), -1) 830 | x = self.fc(x) 831 | 832 | return x, masks 833 | 834 | 835 | # For CIFAR-10 836 | 837 | def cifar10_rnn_gate_20(pretrained=False, **kwargs): 838 | model = ResNetRecurrentGateSP(BasicBlock, [3, 3, 3], num_classes=10, 839 | embed_dim=10, hidden_dim=10) 840 | return model 841 | 842 | 843 | def cifar10_rnn_gate_31(pretrained=False, **kwargs): 844 | 845 | model = ResNetRecurrentGateSP(BasicBlock, [5, 5, 5], num_classes=10, 846 | embed_dim=10, hidden_dim=10) 847 | return model 848 | 849 | 850 | 851 | 852 | def cifar10_rnn_gate_38(pretrained=False, **kwargs): 853 | """SkipNet-38 with Recurrent Gate""" 854 | model = ResNetRecurrentGateSP(BasicBlock, [6, 6, 6], num_classes=10, 855 | embed_dim=10, hidden_dim=10) 856 | return model 857 | 858 | 859 | def cifar10_rnn_gate_74(pretrained=False, **kwargs): 860 | """SkipNet-74 with Recurrent Gate""" 861 | model = ResNetRecurrentGateSP(BasicBlock, [12, 12, 12], num_classes=10, 862 | embed_dim=10, hidden_dim=10) 863 | return model 864 | 865 | 866 | def cifar10_rnn_gate_110(pretrained=False, **kwargs): 867 | """SkipNet-110 with Recurrent Gate""" 868 | model = ResNetRecurrentGateSP(BasicBlock, [18, 18, 18], num_classes=10, 869 | embed_dim=10, hidden_dim=10) 870 | return model 871 | 872 | 873 | def cifar10_rnn_gate_152(pretrained=False, **kwargs): 874 | """SkipNet-152 with Recurrent Gate""" 875 | model = ResNetRecurrentGateSP(BasicBlock, [25, 25, 25], num_classes=10, 876 | embed_dim=10, hidden_dim=10) 877 | return model 878 | 879 | 880 | # For CIFAR-100 881 | def cifar100_rnn_gate_38(pretrained=False, **kwargs): 882 | """SkipNet-38 with Recurrent Gate""" 883 | model = ResNetRecurrentGateSP(BasicBlock, [6, 6, 6], num_classes=100, 884 | embed_dim=10, hidden_dim=10) 885 | return model 886 | 887 | 888 | def cifar100_rnn_gate_74(pretrained=False, **kwargs): 889 | """SkipNet-74 with Recurrent Gate""" 890 | model = ResNetRecurrentGateSP(BasicBlock, [12, 12, 12], num_classes=100, 891 | embed_dim=10, hidden_dim=10) 892 | return model 893 | 894 | 895 | def cifar100_rnn_gate_110(pretrained=False, **kwargs): 896 | """SkipNet-110 with Recurrent Gate """ 897 | model = ResNetRecurrentGateSP(BasicBlock, [18, 18, 18], num_classes=100, 898 | embed_dim=10, hidden_dim=10) 899 | return model 900 | 901 | 902 | def cifar100_rnn_gate_152(pretrained=False, **kwargs): 903 | """SkipNet-152 with Recurrent Gate""" 904 | model = ResNetRecurrentGateSP(BasicBlock, [25, 25, 25], num_classes=100, 905 | embed_dim=10, hidden_dim=10) 906 | return model 907 | 908 | 909 | ######################################## 910 | # SkipNet+RL with Feedforward Gate # 911 | ######################################## 912 | 913 | class RLFeedforwardGateI(nn.Module): 914 | """ FFGate-I with sampling. Use Pytorch 2.0""" 915 | def __init__(self, pool_size=5, channel=10): 916 | super(RLFeedforwardGateI, self).__init__() 917 | self.pool_size = pool_size 918 | self.channel = channel 919 | 920 | self.maxpool = nn.MaxPool2d(2) 921 | self.conv1 = conv3x3(channel, channel) 922 | self.bn1 = nn.BatchNorm2d(channel) 923 | self.relu1 = nn.ReLU(inplace=True) 924 | 925 | # adding another conv layer 926 | self.conv2 = conv3x3(channel, channel, stride=2) 927 | self.bn2 = nn.BatchNorm2d(channel) 928 | self.relu2 = nn.ReLU(inplace=True) 929 | 930 | pool_size = math.floor(pool_size/2) # for max pooling 931 | pool_size = math.floor(pool_size/2 + 0.5) # for conv stride = 2 932 | 933 | self.avg_layer = nn.AvgPool2d(pool_size) 934 | self.linear_layer = nn.Conv2d(in_channels=channel, out_channels=2, 935 | kernel_size=1, stride=1) 936 | self.prob_layer = nn.Softmax() 937 | 938 | # saved actions and rewards 939 | self.saved_action = [] 940 | self.rewards = [] 941 | 942 | def forward(self, x): 943 | x = self.maxpool(x) 944 | x = self.conv1(x) 945 | x = self.bn1(x) 946 | x = self.relu1(x) 947 | 948 | x = self.conv2(x) 949 | x = self.bn2(x) 950 | x = self.relu2(x) 951 | 952 | x = self.avg_layer(x) 953 | x = self.linear_layer(x).squeeze() 954 | softmax = self.prob_layer(x) 955 | 956 | if self.training: 957 | action = softmax.multinomial() 958 | self.saved_action = action 959 | else: 960 | action = (softmax[:, 1] > 0.5).float() 961 | self.saved_action = action 962 | 963 | action = action.view(action.size(0), 1, 1, 1).float() 964 | return action, softmax 965 | 966 | 967 | class RLFeedforwardGateII(nn.Module): 968 | def __init__(self, pool_size=5, channel=10): 969 | super(RLFeedforwardGateII, self).__init__() 970 | self.pool_size = pool_size 971 | self.channel = channel 972 | 973 | self.conv1 = conv3x3(channel, channel, stride=2) 974 | self.bn1 = nn.BatchNorm2d(channel) 975 | self.relu1 = nn.ReLU(inplace=True) 976 | 977 | pool_size = math.floor(pool_size/2 + 0.5) # for conv stride = 2 978 | 979 | self.avg_layer = nn.AvgPool2d(pool_size) 980 | self.linear_layer = nn.Conv2d(in_channels=channel, out_channels=2, 981 | kernel_size=1, stride=1) 982 | self.prob_layer = nn.Softmax() 983 | 984 | # saved actions and rewards 985 | self.saved_action = None 986 | self.rewards = [] 987 | 988 | def forward(self, x): 989 | x = self.conv1(x) 990 | x = self.bn1(x) 991 | x = self.relu1(x) 992 | 993 | x = self.avg_layer(x) 994 | x = self.linear_layer(x).squeeze() 995 | softmax = self.prob_layer(x) 996 | 997 | if self.training: 998 | action = softmax.multinomial() 999 | self.saved_action = action 1000 | else: 1001 | action = (softmax[:, 1] > 0.5).float() 1002 | self.saved_action = action 1003 | 1004 | action = action.view(action.size(0), 1, 1, 1).float() 1005 | return action, softmax 1006 | 1007 | 1008 | class ResNetFeedForwardRL(nn.Module): 1009 | """Adding gating module on every basic block""" 1010 | 1011 | def __init__(self, block, layers, num_classes=10, 1012 | gate_type='ffgate1', **kwargs): 1013 | self.inplanes = 16 1014 | super(ResNetFeedForwardRL, self).__init__() 1015 | 1016 | self.num_layers = layers 1017 | self.conv1 = conv3x3(3, 16) 1018 | self.bn1 = nn.BatchNorm2d(16) 1019 | self.relu = nn.ReLU(inplace=True) 1020 | 1021 | self.gate_instances = [] 1022 | self.gate_type = gate_type 1023 | self._make_group(block, 16, layers[0], group_id=1, 1024 | gate_type=gate_type, pool_size=32) 1025 | self._make_group(block, 32, layers[1], group_id=2, 1026 | gate_type=gate_type, pool_size=16) 1027 | self._make_group(block, 64, layers[2], group_id=3, 1028 | gate_type=gate_type, pool_size=8) 1029 | 1030 | # remove the last gate instance, (not optimized) 1031 | del self.gate_instances[-1] 1032 | 1033 | self.avgpool = nn.AvgPool2d(8) 1034 | self.fc = nn.Linear(64 * block.expansion, num_classes) 1035 | 1036 | self.softmax = nn.Softmax() 1037 | self.saved_actions = [] 1038 | self.rewards = [] 1039 | 1040 | for m in self.modules(): 1041 | if isinstance(m, nn.Conv2d): 1042 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 1043 | m.weight.data.normal_(0, math.sqrt(2. / n)) 1044 | elif isinstance(m, nn.BatchNorm2d): 1045 | m.weight.data.fill_(1) 1046 | m.bias.data.zero_() 1047 | elif isinstance(m, nn.Linear): 1048 | n = m.weight.size(0) * m.weight.size(1) 1049 | m.weight.data.normal_(0, math.sqrt(2. / n)) 1050 | 1051 | def _make_group(self, block, planes, layers, group_id=1, 1052 | gate_type='fisher', pool_size=16): 1053 | """ Create the whole group""" 1054 | for i in range(layers): 1055 | if group_id > 1 and i == 0: 1056 | stride = 2 1057 | else: 1058 | stride = 1 1059 | 1060 | meta = self._make_layer_v2(block, planes, stride=stride, 1061 | gate_type=gate_type, 1062 | pool_size=pool_size) 1063 | 1064 | setattr(self, 'group{}_ds{}'.format(group_id, i), meta[0]) 1065 | setattr(self, 'group{}_layer{}'.format(group_id, i), meta[1]) 1066 | setattr(self, 'group{}_gate{}'.format(group_id, i), meta[2]) 1067 | 1068 | # add into gate instance collection 1069 | self.gate_instances.append(meta[2]) 1070 | 1071 | def _make_layer_v2(self, block, planes, stride=1, 1072 | gate_type='fisher', pool_size=16): 1073 | """ create one block and optional a gate module """ 1074 | downsample = None 1075 | if stride != 1 or self.inplanes != planes * block.expansion: 1076 | downsample = nn.Sequential( 1077 | nn.Conv2d(self.inplanes, planes * block.expansion, 1078 | kernel_size=1, stride=stride, bias=False), 1079 | nn.BatchNorm2d(planes * block.expansion), 1080 | 1081 | ) 1082 | layer = block(self.inplanes, planes, stride, downsample) 1083 | self.inplanes = planes * block.expansion 1084 | 1085 | if gate_type == 'ffgate1': 1086 | gate_layer = RLFeedforwardGateI(pool_size=pool_size, 1087 | channel=planes*block.expansion) 1088 | elif gate_type == 'ffgate2': 1089 | gate_layer = RLFeedforwardGateII(pool_size=pool_size, 1090 | channel=planes*block.expansion) 1091 | else: 1092 | gate_layer = None 1093 | 1094 | if downsample: 1095 | return downsample, layer, gate_layer 1096 | else: 1097 | return None, layer, gate_layer 1098 | 1099 | def repackage_vars(self): 1100 | self.saved_actions = repackage_hidden(self.saved_actions) 1101 | 1102 | def forward(self, x, reinforce=False): 1103 | x = self.conv1(x) 1104 | x = self.bn1(x) 1105 | x = self.relu(x) 1106 | 1107 | masks = [] 1108 | gprobs = [] 1109 | # must pass through the first layer in first group 1110 | x = getattr(self, 'group1_layer0')(x) 1111 | # gate takes the output of the current layer 1112 | mask, gprob = getattr(self, 'group1_gate0')(x) 1113 | gprobs.append(gprob) 1114 | masks.append(mask.squeeze()) 1115 | prev = x # input of next layer 1116 | 1117 | for g in range(3): 1118 | for i in range(0 + int(g == 0), self.num_layers[g]): 1119 | if getattr(self, 'group{}_ds{}'.format(g+1, i)) is not None: 1120 | prev = getattr(self, 'group{}_ds{}'.format(g+1, i))(prev) 1121 | x = getattr(self, 'group{}_layer{}'.format(g+1, i))(x) 1122 | # new mask is taking the current output 1123 | prev = x = mask.expand_as(x) * x \ 1124 | + (1 - mask).expand_as(prev) * prev 1125 | mask, gprob = getattr(self, 'group{}_gate{}'.format(g+1, i))(x) 1126 | gprobs.append(gprob) 1127 | masks.append(mask.squeeze()) 1128 | 1129 | del masks[-1] 1130 | 1131 | x = self.avgpool(x) 1132 | x = x.view(x.size(0), -1) 1133 | x = self.fc(x) 1134 | 1135 | # collect all actions 1136 | for inst in self.gate_instances: 1137 | self.saved_actions.append(inst.saved_action) 1138 | 1139 | if reinforce: # for pure RL 1140 | softmax = self.softmax(x) 1141 | action = softmax.multinomial() 1142 | self.saved_actions.append(action) 1143 | 1144 | return x, masks, gprobs 1145 | 1146 | 1147 | # FFGate-I 1148 | # For CIFAR-10 1149 | def cifar10_feedfoward_rl_38(pretrained=False, **kwargs): 1150 | """SkipNet-38 + RL with FFGate-I""" 1151 | model = ResNetFeedForwardRL(BasicBlock, [6, 6, 6], 1152 | num_classes=10, gate_type='ffgate1') 1153 | return model 1154 | 1155 | 1156 | def cifar10_feedforward_rl_74(pretrained=False, **kwargs): 1157 | """SkipNet-74 + RL with FFGate-I""" 1158 | model = ResNetFeedForwardRL(BasicBlock, [12, 12, 12], 1159 | num_classes=10, gate_type='ffgate1') 1160 | return model 1161 | 1162 | 1163 | def cifar10_feedforward_rl_110(pretrained=False, **kwargs): 1164 | """SkipNet-110 + RL with FFGate-II""" 1165 | model = ResNetFeedForwardRL(BasicBlock, [18, 18, 18], 1166 | num_classes=10, gate_type='ffgate2') 1167 | return model 1168 | 1169 | 1170 | # For CIFAR-100 1171 | def cifar100_feedford_rl_38(pretrained=False, **kwargs): 1172 | """SkipNet-38 + RL with FFGate-I""" 1173 | model = ResNetFeedForwardRL(BasicBlock, [6, 6, 6], 1174 | num_classes=100, gate_type='ffgate1') 1175 | return model 1176 | 1177 | 1178 | def cifar100_feedforward_rl_74(pretrained=False, **kwargs): 1179 | """SkipNet-74 + RL with FFGate-I""" 1180 | model = ResNetFeedForwardRL(BasicBlock, [12, 12, 12], 1181 | num_classes=100, gate_type='ffgate1') 1182 | return model 1183 | 1184 | 1185 | def cifar100_feedforward_rl_110(pretrained=False, **kwargs): 1186 | """SkipNet-110 + RL with FFGate-II""" 1187 | model = ResNetFeedForwardRL(BasicBlock, [18, 18, 18], 1188 | num_classes=100, gate_type='ffgate2') 1189 | return model 1190 | 1191 | 1192 | ######################################## 1193 | # SkipNet+RL with Recurrent Gate # 1194 | ######################################## 1195 | 1196 | class RNNGatePolicy(nn.Module): 1197 | def __init__(self, input_dim, hidden_dim, rnn_type='lstm'): 1198 | super(RNNGatePolicy, self).__init__() 1199 | 1200 | self.rnn_type = rnn_type 1201 | self.input_dim = input_dim 1202 | self.hidden_dim = hidden_dim 1203 | 1204 | if self.rnn_type == 'lstm': 1205 | self.rnn = nn.LSTM(input_dim, hidden_dim) 1206 | else: 1207 | self.rnn = None 1208 | self.hidden = None 1209 | 1210 | # reduce dim. use softmax here for two actions. 1211 | self.proj = nn.Linear(hidden_dim, 1) 1212 | self.prob = nn.Sigmoid() 1213 | 1214 | # saved actions and rewards 1215 | self.saved_actions = [] 1216 | self.rewards = [] 1217 | 1218 | def hotter(self, t): 1219 | self.proj.weight.data /= t 1220 | self.proj.bias.data /= t 1221 | 1222 | def init_hidden(self, batch_size): 1223 | # The axes semantics are (num_layers, minibatch_size, hidden_dim) 1224 | return (autograd.Variable(torch.zeros(1, batch_size, 1225 | self.hidden_dim).cuda()), 1226 | autograd.Variable(torch.zeros(1, batch_size, 1227 | self.hidden_dim).cuda())) 1228 | 1229 | def repackage_hidden(self): 1230 | self.hidden = repackage_hidden(self.hidden) 1231 | 1232 | def forward(self, x): 1233 | batch_size = x.size(0) 1234 | self.rnn.flatten_parameters() 1235 | out, self.hidden = self.rnn(x.view(1, batch_size, -1), self.hidden) 1236 | 1237 | # do action selection in the forward pass 1238 | if self.training: 1239 | proj = self.proj(out.squeeze()) 1240 | prob = self.prob(proj) 1241 | bi_prob = torch.cat([1 - prob, prob], dim=1) 1242 | action = bi_prob.multinomial() 1243 | self.saved_actions.append(action) 1244 | else: 1245 | proj = self.proj(out.squeeze()) 1246 | prob = self.prob(proj) 1247 | bi_prob = torch.cat([1 - prob, prob], dim=1) 1248 | action = (prob > 0.5).float() 1249 | self.saved_actions.append(action) 1250 | action = action.view(action.size(0), 1, 1, 1).float() 1251 | return action, bi_prob 1252 | 1253 | 1254 | class ResNetRecurrentGateRL(nn.Module): 1255 | """Adding gating module on every basic block""" 1256 | 1257 | def __init__(self, block, layers, num_classes=10, 1258 | embed_dim=64, hidden_dim=64): 1259 | self.inplanes = 16 1260 | super(ResNetRecurrentGateRL, self).__init__() 1261 | 1262 | self.num_layers = layers 1263 | self.conv1 = conv3x3(3, 16) 1264 | self.bn1 = nn.BatchNorm2d(16) 1265 | self.relu = nn.ReLU(inplace=True) 1266 | 1267 | self.embed_dim = embed_dim 1268 | self.hidden_dim = hidden_dim 1269 | 1270 | self._make_group(block, 16, layers[0], group_id=1, pool_size=32) 1271 | self._make_group(block, 32, layers[1], group_id=2, pool_size=16) 1272 | self._make_group(block, 64, layers[2], group_id=3, pool_size=8) 1273 | 1274 | self.control = RNNGatePolicy(embed_dim, hidden_dim) 1275 | 1276 | self.avgpool = nn.AvgPool2d(8) 1277 | self.fc = nn.Linear(64 * block.expansion, num_classes) 1278 | 1279 | self.softmax = nn.Softmax() 1280 | 1281 | self.saved_actions = [] 1282 | self.rewards = [] 1283 | 1284 | for m in self.modules(): 1285 | if isinstance(m, nn.Conv2d): 1286 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 1287 | m.weight.data.normal_(0, math.sqrt(2. / n)) 1288 | elif isinstance(m, nn.BatchNorm2d): 1289 | m.weight.data.fill_(1) 1290 | m.bias.data.zero_() 1291 | elif isinstance(m, nn.Linear): 1292 | n = m.weight.size(0) * m.weight.size(1) 1293 | m.weight.data.normal_(0, math.sqrt(2. / n)) 1294 | m.bias.data.zero_() 1295 | 1296 | def _make_group(self, block, planes, layers, group_id=1, pool_size=16): 1297 | """ Create the whole group""" 1298 | for i in range(layers): 1299 | if group_id > 1 and i == 0: 1300 | stride = 2 1301 | else: 1302 | stride = 1 1303 | 1304 | meta = self._make_layer_v2(block, planes, stride=stride, 1305 | pool_size=pool_size) 1306 | 1307 | setattr(self, 'group{}_ds{}'.format(group_id, i), meta[0]) 1308 | setattr(self, 'group{}_layer{}'.format(group_id, i), meta[1]) 1309 | setattr(self, 'group{}_gate{}'.format(group_id, i), meta[2]) 1310 | 1311 | def _make_layer_v2(self, block, planes, stride=1, pool_size=16): 1312 | """ create one block and optional a gate module """ 1313 | downsample = None 1314 | if stride != 1 or self.inplanes != planes * block.expansion: 1315 | downsample = nn.Sequential( 1316 | nn.Conv2d(self.inplanes, planes * block.expansion, 1317 | kernel_size=1, stride=stride, bias=False), 1318 | nn.BatchNorm2d(planes * block.expansion), 1319 | 1320 | ) 1321 | layer = block(self.inplanes, planes, stride, downsample) 1322 | self.inplanes = planes * block.expansion 1323 | 1324 | gate_layer = nn.Sequential( 1325 | nn.AvgPool2d(pool_size), 1326 | nn.Conv2d(in_channels=planes * block.expansion, 1327 | out_channels=self.embed_dim, 1328 | kernel_size=1, 1329 | stride=1)) 1330 | 1331 | return downsample, layer, gate_layer 1332 | 1333 | def forward(self, x): 1334 | batch_size = x.size(0) 1335 | x = self.conv1(x) 1336 | x = self.bn1(x) 1337 | x = self.relu(x) 1338 | 1339 | # reinitialize hidden units 1340 | self.control.hidden = self.control.init_hidden(batch_size) 1341 | 1342 | masks = [] 1343 | gprobs = [] 1344 | # must pass through the first layer in first group 1345 | x = getattr(self, 'group1_layer0')(x) 1346 | # gate takes the output of the current layer 1347 | gate_feature = getattr(self, 'group1_gate0')(x) 1348 | 1349 | mask, gprob = self.control(gate_feature) 1350 | gprobs.append(gprob) 1351 | masks.append(mask.squeeze()) 1352 | prev = x 1353 | 1354 | for g in range(3): 1355 | for i in range(0 + int(g == 0), self.num_layers[g]): 1356 | if getattr(self, 'group{}_ds{}'.format(g+1, i)) is not None: 1357 | prev = getattr(self, 'group{}_ds{}'.format(g+1, i))(prev) 1358 | x = getattr(self, 'group{}_layer{}'.format(g+1, i))(x) 1359 | prev = x = mask.expand_as(x) * x + \ 1360 | (1 - mask).expand_as(prev)*prev 1361 | if not (g == 2 and (i == self.num_layers[g] -1)): 1362 | gate_feature = getattr(self, 1363 | 'group{}_gate{}'.format(g+1, i))(x) 1364 | mask, gprob = self.control(gate_feature) 1365 | gprobs.append(gprob) 1366 | masks.append(mask.squeeze()) 1367 | 1368 | x = self.avgpool(x) 1369 | x = x.view(x.size(0), -1) 1370 | 1371 | if self.training: 1372 | x = self.fc(x) 1373 | softmax = self.softmax(x) 1374 | pred = softmax.multinomial() 1375 | else: 1376 | x = self.fc(x) 1377 | pred = x.max(1)[1] 1378 | self.saved_actions.append(pred) 1379 | 1380 | return x, masks, gprobs 1381 | 1382 | 1383 | # for CIFAR-10 1384 | def cifar10_rnn_gate_rl_38(pretrained=False, **kwargs): 1385 | """SkipNet-38 + RL with Recurrent Gate""" 1386 | model = ResNetRecurrentGateRL(BasicBlock, [6, 6, 6], num_classes=10, 1387 | embed_dim=10, hidden_dim=10) 1388 | return model 1389 | 1390 | 1391 | def cifar10_rnn_gate_rl_74(pretrained=False, **kwargs): 1392 | """SkipNet-74 + RL with Recurrent Gate""" 1393 | model = ResNetRecurrentGateRL(BasicBlock, [12, 12, 12], num_classes=10, 1394 | embed_dim=10, hidden_dim=10) 1395 | return model 1396 | 1397 | 1398 | def cifar10_rnn_gate_rl_110(pretrained=False, **kwargs): 1399 | """SkipNet-110 + RL with Recurrent Gate""" 1400 | model = ResNetRecurrentGateRL(BasicBlock, [18, 18, 18], num_classes=10, 1401 | embed_dim=10, hidden_dim=10) 1402 | return model 1403 | 1404 | 1405 | # for CIFAR-100 1406 | def cifar100_rnn_gate_rl_38(pretrained=False, **kwargs): 1407 | """SkipNet-38 + RL with Recurrent Gate""" 1408 | model = ResNetRecurrentGateRL(BasicBlock, [6, 6, 6], num_classes=100, 1409 | embed_dim=10, hidden_dim=10) 1410 | return model 1411 | 1412 | 1413 | def cifar100_rnn_gate_rl_74(pretrained=False, **kwargs): 1414 | """SkipNet-74 + RL with Recurrent Gate""" 1415 | model = ResNetRecurrentGateRL(BasicBlock, [12, 12, 12], num_classes=100, 1416 | embed_dim=10, hidden_dim=10) 1417 | return model 1418 | 1419 | 1420 | def cifar100_rnn_gate_rl_110(pretrained=False, **kwargs): 1421 | """SkipNet-110 + RL with Recurrent Gate""" 1422 | model = ResNetRecurrentGateRL(BasicBlock, [18, 18, 18], num_classes=100, 1423 | embed_dim=10, hidden_dim=10) 1424 | return model 1425 | 1426 | 1427 | -------------------------------------------------------------------------------- /modules/.txt: -------------------------------------------------------------------------------- 1 | new 2 | -------------------------------------------------------------------------------- /modules/bwn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Bounded weight norm 3 | Weight Normalization from https://arxiv.org/abs/1602.07868 4 | taken and adapted from https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/weight_norm.py 5 | """ 6 | import torch 7 | from torch.nn.parameter import Parameter 8 | from torch.autograd import Variable, Function 9 | import torch.nn as nn 10 | 11 | 12 | def gather_params(self, memo=None, param_func=lambda s: s._parameters.values()): 13 | if memo is None: 14 | memo = set() 15 | for p in param_func(self): 16 | if p is not None and p not in memo: 17 | memo.add(p) 18 | yield p 19 | for m in self.children(): 20 | for p in gather_params(m, memo, param_func): 21 | yield p 22 | 23 | nn.Module.gather_params = gather_params 24 | 25 | 26 | def _norm(x, dim, p=2): 27 | """Computes the norm over all dimensions except dim""" 28 | if p == float('inf'): # infinity norm 29 | func = lambda x, dim: x.abs().max(dim=dim)[0] 30 | else: 31 | func = lambda x, dim: torch.norm(x, dim=dim, p=p) 32 | if dim is None: 33 | return x.norm(p=p) 34 | elif dim == 0: 35 | output_size = (x.size(0),) + (1,) * (x.dim() - 1) 36 | return func(x.contiguous().view(x.size(0), -1), 1).view(*output_size) 37 | elif dim == x.dim() - 1: 38 | output_size = (1,) * (x.dim() - 1) + (x.size(-1),) 39 | return func(x.contiguous().view(-1, x.size(-1)), 0).view(*output_size) 40 | else: 41 | return _norm(x.transpose(0, dim), 0).transpose(0, dim) 42 | 43 | 44 | def _mean(p, dim): 45 | """Computes the mean over all dimensions except dim""" 46 | if dim is None: 47 | return p.mean() 48 | elif dim == 0: 49 | output_size = (p.size(0),) + (1,) * (p.dim() - 1) 50 | return p.contiguous().view(p.size(0), -1).mean(dim=1).view(*output_size) 51 | elif dim == p.dim() - 1: 52 | output_size = (1,) * (p.dim() - 1) + (p.size(-1),) 53 | return p.contiguous().view(-1, p.size(-1)).mean(dim=0).view(*output_size) 54 | else: 55 | return _mean(p.transpose(0, dim), 0).transpose(0, dim) 56 | 57 | 58 | class BoundedWeighNorm(object): 59 | 60 | def __init__(self, name, dim, p): 61 | self.name = name 62 | self.dim = dim 63 | self.p = p 64 | 65 | def compute_weight(self, module): 66 | v = getattr(module, self.name + '_v') 67 | pre_norm = getattr(module, self.name + '_prenorm') 68 | return v * (pre_norm / _norm(v, self.dim, p=self.p)) 69 | 70 | @staticmethod 71 | def apply(module, name, dim, p): 72 | fn = BoundedWeighNorm(name, dim, p) 73 | 74 | weight = getattr(module, name) 75 | 76 | # remove w from parameter list 77 | del module._parameters[name] 78 | 79 | prenorm = _norm(weight, dim, p=p).mean() 80 | module.register_buffer(name + '_prenorm', prenorm.detach()) 81 | pre_norm = getattr(module, name + '_prenorm') 82 | print(pre_norm) 83 | module.register_parameter(name + '_v', Parameter(weight.data)) 84 | setattr(module, name, fn.compute_weight(module)) 85 | 86 | # recompute weight before every forward() 87 | module.register_forward_pre_hook(fn) 88 | 89 | def gather_normed_params(self, memo=None, param_func=lambda s: fn.compute_weight(s)): 90 | return gather_params(self, memo, param_func) 91 | module.gather_params = gather_normed_params 92 | return fn 93 | 94 | def remove(self, module): 95 | weight = self.compute_weight(module) 96 | delattr(module, self.name) 97 | del module._parameters[self.name + '_prenorm'] 98 | del module._parameters[self.name + '_v'] 99 | module.register_parameter(self.name, Parameter(weight.data)) 100 | 101 | def __call__(self, module, inputs): 102 | setattr(module, self.name, self.compute_weight(module)) 103 | 104 | 105 | def weight_norm(module, name='weight', dim=0, p=2): 106 | r"""Applies weight normalization to a parameter in the given module. 107 | 108 | .. math:: 109 | \mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|} 110 | 111 | Weight normalization is a reparameterization that decouples the magnitude 112 | of a weight tensor from its direction. This replaces the parameter specified 113 | by `name` (e.g. "weight") with two parameters: one specifying the magnitude 114 | (e.g. "weight_g") and one specifying the direction (e.g. "weight_v"). 115 | Weight normalization is implemented via a hook that recomputes the weight 116 | tensor from the magnitude and direction before every :meth:`~Module.forward` 117 | call. 118 | 119 | By default, with `dim=0`, the norm is computed independently per output 120 | channel/plane. To compute a norm over the entire weight tensor, use 121 | `dim=None`. 122 | 123 | See https://arxiv.org/abs/1602.07868 124 | 125 | Args: 126 | module (nn.Module): containing module 127 | name (str, optional): name of weight parameter 128 | dim (int, optional): dimension over which to compute the norm 129 | 130 | Returns: 131 | The original module with the weight norm hook 132 | 133 | Example:: 134 | 135 | >>> m = weight_norm(nn.Linear(20, 40), name='weight') 136 | Linear (20 -> 40) 137 | >>> m.weight_g.size() 138 | torch.Size([40, 1]) 139 | >>> m.weight_v.size() 140 | torch.Size([40, 20]) 141 | 142 | """ 143 | BoundedWeighNorm.apply(module, name, dim, p) 144 | return module 145 | 146 | 147 | def remove_weight_norm(module, name='weight'): 148 | r"""Removes the weight normalization reparameterization from a module. 149 | 150 | Args: 151 | module (nn.Module): containing module 152 | name (str, optional): name of weight parameter 153 | 154 | Example: 155 | >>> m = weight_norm(nn.Linear(20, 40)) 156 | >>> remove_weight_norm(m) 157 | """ 158 | for k, hook in module._forward_pre_hooks.items(): 159 | if isinstance(hook, BoundedWeighNorm) and hook.name == name: 160 | hook.remove(module) 161 | del module._forward_pre_hooks[k] 162 | return module 163 | 164 | raise ValueError("weight_norm of '{}' not found in {}" 165 | .format(name, module)) 166 | -------------------------------------------------------------------------------- /modules/quantize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd.function import InplaceFunction, Function 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import math 6 | 7 | 8 | def _mean(p, dim): 9 | """Computes the mean over all dimensions except dim""" 10 | if dim is None: 11 | return p.mean() 12 | elif dim == 0: 13 | output_size = (p.size(0),) + (1,) * (p.dim() - 1) 14 | return p.contiguous().view(p.size(0), -1).mean(dim=1).view(*output_size) 15 | elif dim == p.dim() - 1: 16 | output_size = (1,) * (p.dim() - 1) + (p.size(-1),) 17 | return p.contiguous().view(-1, p.size(-1)).mean(dim=0).view(*output_size) 18 | else: 19 | return _mean(p.transpose(0, dim), 0).transpose(0, dim) 20 | 21 | 22 | class UniformQuantize(InplaceFunction): 23 | 24 | @classmethod 25 | def forward(cls, ctx, input, num_bits=8, min_value=None, max_value=None, 26 | stochastic=False, inplace=False, enforce_true_zero=False, num_chunks=None, out_half=False): 27 | 28 | num_chunks = num_chunks = input.shape[ 29 | 0] if num_chunks is None else num_chunks 30 | if min_value is None or max_value is None: 31 | B = input.shape[0] 32 | y = input.view(B // num_chunks, -1) 33 | if min_value is None: 34 | min_value = y.min(-1)[0].mean(-1) # C 35 | #min_value = float(input.view(input.size(0), -1).min(-1)[0].mean()) 36 | if max_value is None: 37 | #max_value = float(input.view(input.size(0), -1).max(-1)[0].mean()) 38 | max_value = y.max(-1)[0].mean(-1) # C 39 | ctx.inplace = inplace 40 | ctx.num_bits = num_bits 41 | ctx.min_value = min_value 42 | ctx.max_value = max_value 43 | ctx.stochastic = stochastic 44 | 45 | if ctx.inplace: 46 | ctx.mark_dirty(input) 47 | output = input 48 | else: 49 | output = input.clone() 50 | 51 | qmin = 0. 52 | qmax = 2.**num_bits - 1. 53 | #import pdb; pdb.set_trace() 54 | scale = (max_value - min_value) / (qmax - qmin) 55 | 56 | scale = max(scale, 1e-8) 57 | 58 | if enforce_true_zero: 59 | initial_zero_point = qmin - min_value / scale 60 | zero_point = 0. 61 | # make zero exactly represented 62 | if initial_zero_point < qmin: 63 | zero_point = qmin 64 | elif initial_zero_point > qmax: 65 | zero_point = qmax 66 | else: 67 | zero_point = initial_zero_point 68 | zero_point = int(zero_point) 69 | output.div_(scale).add_(zero_point) 70 | else: 71 | output.add_(-min_value).div_(scale).add_(qmin) 72 | 73 | if ctx.stochastic: 74 | noise = output.new(output.shape).uniform_(-0.5, 0.5) 75 | output.add_(noise) 76 | output.clamp_(qmin, qmax).round_() # quantize 77 | 78 | if enforce_true_zero: 79 | output.add_(-zero_point).mul_(scale) # dequantize 80 | else: 81 | output.add_(-qmin).mul_(scale).add_(min_value) # dequantize 82 | if out_half and num_bits <= 16: 83 | output = output.half() 84 | return output 85 | 86 | @staticmethod 87 | def backward(ctx, grad_output): 88 | # straight-through estimator 89 | grad_input = grad_output 90 | return grad_input, None, None, None, None, None, None 91 | 92 | 93 | class UniformQuantizeGrad(InplaceFunction): 94 | 95 | @classmethod 96 | def forward(cls, ctx, input, num_bits=8, min_value=None, max_value=None, stochastic=True, inplace=False): 97 | ctx.inplace = inplace 98 | ctx.num_bits = num_bits 99 | ctx.min_value = min_value 100 | ctx.max_value = max_value 101 | ctx.stochastic = stochastic 102 | return input 103 | 104 | @staticmethod 105 | def backward(ctx, grad_output): 106 | if ctx.min_value is None: 107 | min_value = float(grad_output.min()) 108 | # min_value = float(grad_output.view( 109 | # grad_output.size(0), -1).min(-1)[0].mean()) 110 | else: 111 | min_value = ctx.min_value 112 | if ctx.max_value is None: 113 | max_value = float(grad_output.max()) 114 | # max_value = float(grad_output.view( 115 | # grad_output.size(0), -1).max(-1)[0].mean()) 116 | else: 117 | max_value = ctx.max_value 118 | grad_input = UniformQuantize().apply(grad_output, ctx.num_bits, 119 | min_value, max_value, ctx.stochastic, ctx.inplace) 120 | return grad_input, None, None, None, None, None 121 | 122 | 123 | def conv2d_biprec(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, num_bits_grad=None): 124 | out1 = F.conv2d(input.detach(), weight, bias, 125 | stride, padding, dilation, groups) 126 | out2 = F.conv2d(input, weight.detach(), bias.detach() if bias is not None else None, 127 | stride, padding, dilation, groups) 128 | out2 = quantize_grad(out2, num_bits=num_bits_grad) 129 | return out1 + out2 - out1.detach() 130 | 131 | 132 | def linear_biprec(input, weight, bias=None, num_bits_grad=None): 133 | out1 = F.linear(input.detach(), weight, bias) 134 | out2 = F.linear(input, weight.detach(), bias.detach() 135 | if bias is not None else None) 136 | out2 = quantize_grad(out2, num_bits=num_bits_grad) 137 | return out1 + out2 - out1.detach() 138 | 139 | 140 | def quantize(x, num_bits=8, min_value=None, max_value=None, num_chunks=None, stochastic=False, inplace=False): 141 | return UniformQuantize().apply(x, num_bits, min_value, max_value, num_chunks, stochastic, inplace) 142 | 143 | 144 | def quantize_grad(x, num_bits=8, min_value=None, max_value=None, stochastic=True, inplace=False): 145 | return UniformQuantizeGrad().apply(x, num_bits, min_value, max_value, stochastic, inplace) 146 | 147 | 148 | class QuantMeasure(nn.Module): 149 | """docstring for QuantMeasure.""" 150 | 151 | def __init__(self, num_bits=8, momentum=0.1): 152 | super(QuantMeasure, self).__init__() 153 | self.register_buffer('running_min', torch.zeros(1)) 154 | self.register_buffer('running_max', torch.zeros(1)) 155 | self.momentum = momentum 156 | # self.num_bits = num_bits 157 | 158 | def forward(self, input, num_bits): 159 | if self.training: 160 | min_value = input.detach().view( 161 | input.size(0), -1).min(-1)[0].mean() 162 | max_value = input.detach().view( 163 | input.size(0), -1).max(-1)[0].mean() 164 | self.running_min.mul_(self.momentum).add_( 165 | min_value * (1 - self.momentum)) 166 | self.running_max.mul_(self.momentum).add_( 167 | max_value * (1 - self.momentum)) 168 | else: 169 | min_value = self.running_min 170 | max_value = self.running_max 171 | return quantize(input, num_bits, min_value=float(min_value), max_value=float(max_value), num_chunks=16) 172 | 173 | 174 | class QConv2d(nn.Conv2d): 175 | """docstring for QConv2d.""" 176 | 177 | def __init__(self, in_channels, out_channels, kernel_size, 178 | stride=1, padding=0, dilation=1, groups=1, bias=True, num_bits=8, num_bits_weight=None, num_bits_grad=None, biprecision=False): 179 | super(QConv2d, self).__init__(in_channels, out_channels, kernel_size, 180 | stride, padding, dilation, groups, bias) 181 | self.num_bits = num_bits 182 | self.num_bits_weight = num_bits_weight or num_bits 183 | self.num_bits_grad = num_bits_grad 184 | self.quantize_input = QuantMeasure(self.num_bits) 185 | self.biprecision = biprecision 186 | 187 | self.stride = stride 188 | 189 | def forward(self, input, num_bits): 190 | if num_bits != 0: 191 | qinput = self.quantize_input(input, num_bits) 192 | qweight = quantize(self.weight, num_bits=num_bits, 193 | min_value=float(self.weight.min()), 194 | max_value=float(self.weight.max())) 195 | else: 196 | qinput = input 197 | qweight = self.weight 198 | if self.bias is not None: 199 | if num_bits != 0: 200 | qbias = quantize(self.bias, num_bits=num_bits) 201 | else: 202 | qbias = self.bias 203 | else: 204 | qbias = None 205 | if not self.biprecision or self.num_bits_grad is None: 206 | output = F.conv2d(qinput, qweight, qbias, self.stride, 207 | self.padding, self.dilation, self.groups) 208 | if self.num_bits_grad is not None: 209 | output = quantize_grad(output, num_bits=self.num_bits_grad) 210 | else: 211 | output = conv2d_biprec(qinput, qweight, qbias, self.stride, 212 | self.padding, self.dilation, self.groups, num_bits_grad=self.num_bits_grad) 213 | 214 | return output 215 | 216 | 217 | class QLinear(nn.Linear): 218 | """docstring for QConv2d.""" 219 | 220 | def __init__(self, in_features, out_features, bias=True, num_bits=8, num_bits_weight=None, num_bits_grad=None, biprecision=False): 221 | super(QLinear, self).__init__(in_features, out_features, bias) 222 | self.num_bits = num_bits 223 | self.num_bits_weight = num_bits_weight or num_bits 224 | self.num_bits_grad = num_bits_grad 225 | self.biprecision = biprecision 226 | self.quantize_input = QuantMeasure(self.num_bits) 227 | 228 | def forward(self, input): 229 | qinput = self.quantize_input(input) 230 | qweight = quantize(self.weight, num_bits=self.num_bits_weight, 231 | min_value=float(self.weight.min()), 232 | max_value=float(self.weight.max())) 233 | if self.bias is not None: 234 | qbias = quantize(self.bias, num_bits=self.num_bits_weight) 235 | else: 236 | qbias = None 237 | 238 | if not self.biprecision or self.num_bits_grad is None: 239 | output = F.linear(qinput, qweight, qbias) 240 | if self.num_bits_grad is not None: 241 | output = quantize_grad(output, num_bits=self.num_bits_grad) 242 | else: 243 | output = linear_biprec(qinput, qweight, qbias, self.num_bits_grad) 244 | return output 245 | 246 | 247 | class RangeBN(nn.Module): 248 | # this is normalized RangeBN 249 | 250 | def __init__(self, num_features, dim=1, momentum=0.1, affine=True, num_chunks=16, eps=1e-5, num_bits=8, num_bits_grad=8): 251 | super(RangeBN, self).__init__() 252 | self.register_buffer('running_mean', torch.zeros(num_features)) 253 | self.register_buffer('running_var', torch.zeros(num_features)) 254 | 255 | self.momentum = momentum 256 | self.dim = dim 257 | if affine: 258 | self.bias = nn.Parameter(torch.Tensor(num_features)) 259 | self.weight = nn.Parameter(torch.Tensor(num_features)) 260 | self.num_bits = num_bits 261 | self.num_bits_grad = num_bits_grad 262 | self.quantize_input = QuantMeasure(self.num_bits) 263 | self.eps = eps 264 | self.num_chunks = num_chunks 265 | self.reset_params() 266 | 267 | def reset_params(self): 268 | if self.weight is not None: 269 | self.weight.data.uniform_() 270 | if self.bias is not None: 271 | self.bias.data.zero_() 272 | 273 | def forward(self, x): 274 | x = self.quantize_input(x) 275 | if x.dim() == 2: # 1d 276 | x = x.unsqueeze(-1,).unsqueeze(-1) 277 | 278 | if self.training: 279 | B, C, H, W = x.shape 280 | y = x.transpose(0, 1).contiguous() # C x B x H x W 281 | y = y.view(C, self.num_chunks, B * H * W // self.num_chunks) 282 | mean_max = y.max(-1)[0].mean(-1) # C 283 | mean_min = y.min(-1)[0].mean(-1) # C 284 | mean = y.view(C, -1).mean(-1) # C 285 | scale_fix = (0.5 * 0.35) * (1 + (math.pi * math.log(4)) ** 286 | 0.5) / ((2 * math.log(y.size(-1))) ** 0.5) 287 | 288 | scale = 1 / ((mean_max - mean_min) * scale_fix + self.eps) 289 | 290 | self.running_mean.detach().mul_(self.momentum).add_( 291 | mean * (1 - self.momentum)) 292 | 293 | self.running_var.detach().mul_(self.momentum).add_( 294 | scale * (1 - self.momentum)) 295 | else: 296 | mean = self.running_mean 297 | scale = self.running_var 298 | scale = quantize(scale, num_bits=self.num_bits, min_value=float( 299 | scale.min()), max_value=float(scale.max())) 300 | out = (x - mean.view(1, mean.size(0), 1, 1)) * \ 301 | scale.view(1, scale.size(0), 1, 1) 302 | 303 | if self.weight is not None: 304 | qweight = quantize(self.weight, num_bits=self.num_bits, 305 | min_value=float(self.weight.min()), 306 | max_value=float(self.weight.max())) 307 | out = out * qweight.view(1, qweight.size(0), 1, 1) 308 | 309 | if self.bias is not None: 310 | qbias = quantize(self.bias, num_bits=self.num_bits) 311 | out = out + qbias.view(1, qbias.size(0), 1, 1) 312 | if self.num_bits_grad is not None: 313 | out = quantize_grad(out, num_bits=self.num_bits_grad) 314 | 315 | if out.size(3) == 1 and out.size(2) == 1: 316 | out = out.squeeze(-1).squeeze(-1) 317 | return out 318 | -------------------------------------------------------------------------------- /modules/rnlu.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch.autograd.function import InplaceFunction 4 | from torch.autograd import Variable 5 | import torch.nn as nn 6 | import math 7 | 8 | 9 | class BiReLUFunction(InplaceFunction): 10 | 11 | @classmethod 12 | def forward(cls, ctx, input, inplace=False): 13 | if input.size(1) % 2 != 0: 14 | raise RuntimeError("dimension 1 of input must be multiple of 2, " 15 | "but got {}".format(input.size(1))) 16 | ctx.inplace = inplace 17 | 18 | if ctx.inplace: 19 | ctx.mark_dirty(input) 20 | output = input 21 | else: 22 | output = input.clone() 23 | 24 | pos, neg = output.chunk(2, dim=1) 25 | pos.clamp_(min=0) 26 | neg.clamp_(max=0) 27 | # scale = (pos - neg).view(pos.size(0), -1).mean(1).div_(2) 28 | # output. 29 | ctx.save_for_backward(output) 30 | return output 31 | 32 | @staticmethod 33 | def backward(ctx, grad_output): 34 | output, = ctx.saved_variables 35 | grad_input = grad_output.masked_fill(output.eq(0), 0) 36 | return grad_input, None 37 | 38 | 39 | def birelu(x, inplace=False): 40 | return BiReLUFunction().apply(x, inplace) 41 | 42 | 43 | class BiReLU(nn.Module): 44 | """docstring for BiReLU.""" 45 | 46 | def __init__(self, inplace=False): 47 | super(BiReLU, self).__init__() 48 | self.inplace = inplace 49 | 50 | def forward(self, inputs): 51 | return birelu(inputs, inplace=self.inplace) 52 | 53 | 54 | def binorm(x, shift=0, scale_fix=(2 / math.pi) ** 0.5): 55 | pos, neg = (x + shift).split(2, dim=1) 56 | scale = (pos - neg).view(pos.size(0), -1).mean(1).div_(2) * scale_fix 57 | return x / scale 58 | 59 | 60 | def _mean(p, dim): 61 | """Computes the mean over all dimensions except dim""" 62 | if dim is None: 63 | return p.mean() 64 | elif dim == 0: 65 | output_size = (p.size(0),) + (1,) * (p.dim() - 1) 66 | return p.contiguous().view(p.size(0), -1).mean(dim=1).view(*output_size) 67 | elif dim == p.dim() - 1: 68 | output_size = (1,) * (p.dim() - 1) + (p.size(-1),) 69 | return p.contiguous().view(-1, p.size(-1)).mean(dim=0).view(*output_size) 70 | else: 71 | return _mean(p.transpose(0, dim), 0).transpose(0, dim) 72 | 73 | 74 | def rnlu(x, inplace=False, shift=0, scale_fix=(math.pi / 2) ** 0.5): 75 | x = birelu(x, inplace=inplace) 76 | pos, neg = (x + shift).chunk(2, dim=1) 77 | # scale = torch.cat((_mean(pos, 1), -_mean(neg, 1)), 1) * scale_fix + 1e-5 78 | scale = (pos - neg).view(pos.size(0), -1).mean(1) * scale_fix + 1e-8 79 | return x / scale.view(scale.size(0), *([1] * (x.dim() - 1))) 80 | 81 | 82 | class RnLU(nn.Module): 83 | """docstring for RnLU.""" 84 | 85 | def __init__(self, inplace=False): 86 | super(RnLU, self).__init__() 87 | self.inplace = inplace 88 | 89 | def forward(self, x): 90 | return rnlu(x, inplace=self.inplace) 91 | 92 | # output. 93 | if __name__ == "__main__": 94 | x = Variable(torch.randn(2, 16, 5, 5).cuda(), requires_grad=True) 95 | output = rnlu(x) 96 | 97 | output.sum().backward() 98 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Fractional Skipping: Towards Finer-Grained Dynamic CNN inference [[PDF]](https://arxiv.org/abs/2001.00705) 2 | 3 | Jianghao Shen, Yonggan Fu, Yue Wang, Pengfei Xu, Zhangyang Wang, Yingyan Lin 4 | 5 | In AAAI 2020. 6 | 7 | ## Overview 8 | We present DFS (Dynamic Fractional Skipping), a dynamic inference framework that extends binary layer skipping options with "fractional skipping" ability - by quantizing the layer weights and activations into different bitwidths. 9 | 10 | Highlights: 11 | 12 | - **Novel integration** of two CNN inference mindsets: _dynamic_ _layer_ _skipping_ and _static_ _quantization_ 13 | - Introduced _input_-_adaptive_ _quantization_ at inference for the **first time** 14 | - **Better performance and computational cost tradeoff** than SkipNet and other relevant competitors 15 | 16 | ![performance_skipnet](https://i.ibb.co/kH5cghN/CIFAR10-DFS-Res-Net74-vs-Skip-Net74-1.png ) 17 | 18 | Figure 6: Comparing the accuracy vs. computation percentage of DFS-ResNet74 and SkipNet74 on CIFAR10. 19 | 20 | 21 | ## Method 22 | ![DFS](https://i.ibb.co/yRdw0mL/ezgif-5-ebd7e26308-pdf-1.png) 23 | 24 | Figure1. An illustration of the DFS framework, where C1, C2, C3 denote three consecutive convolution layers, each of which consists of a column of filters as represented using cuboids. For each layer, the decision is computed by the corresponding gating network denoted with "Gx". In this example, the first conv layer is executed fractionally with a low bitwidth, the second layer is fully executed, while the third one is skipped. 25 | 26 | 27 | 28 | ![Gating](https://i.ibb.co/qkbv66X/ezgif-5-f5d1a89614-pdf-1.png) 29 | 30 | Figure 2. An illustration of the RNN gate used in DFS. The output is a skipping probability vector, where the green arrows denote the layer skip options (skip/keep), and the blue arrows represent the quantization options. During inference, the skip/keep/quantization options corresponding to the largest vector element will be selected and to be executed. 31 | 32 | ## Prerequisites 33 | - Ubuntu 34 | - Python 3 35 | - NVIDIA GPU + CUDA cuDNN 36 | 37 | ## Installation 38 | - Clone this repo: 39 | ```bash 40 | git clone https://github.com/Torment123/DFS.git 41 | cd DFS 42 | ``` 43 | - Install dependencies 44 | ```bash 45 | pip install requirements.txt 46 | ``` 47 | ## Usage 48 | - **Work flow:** pretrain the ResNet backbone → train gate → train DFS 49 | 50 | **0. Data Preparation** 51 | - `data.py` includes the data preparation for the CIFAR-10 and CIFAR-100 datasets. 52 | 53 | **1. Pretrain the ResNet backbone** 54 | We first train a base ResNet model in preparation for further DFS training stage. 55 | ```bash 56 | CUDA_VISIBLE_DEVICES=0 python3 train_base.py train cifar10_resnet_38 --dataset cifar10 --save-folder save_checkpoints/backbone 57 | ``` 58 | 59 | **2. Train gate** 60 | We then add RNN gate to the pretrained ResNet. Fix the parameters of ResNet, only train the RNN gate to reach zero skip ratio. set minimum = 100, lr = 0.01, iters=2000 61 | 62 | ```bash 63 | CUDA_VISIBLE_DEVICES=0 python3 train_sp_integrate_dynamic_quantization_initial.py train cifar10_rnn_gate_38 --minimum 100 --lr 0.01 --resume save_checkpoints/backbone/model_best.pth.tar --iters 2000--save-folder save_checkpoints/full_execution 64 | ``` 65 | 66 | **3. Train DFS** 67 | After the gate is trained to reach full execution, we then unfreeze the backbone's parameters and jointly train it with the gate for our specified skip ratio. Set minimum = _specified_ _computation_ _percentage_, lr = 0.01. 68 | ```bash 69 | CUDA_VISIBLE_DEVICES=0 python3 train_sp_integrate_dynamic_quantization.py train cifar10_rnn_gate_38 --minimum _specified_ _computation_ _percentage_ --lr 0.01 --resume save_checkpoints/full_execution/checkpoint_latest.pth.tar --save-folder save_checkpoints/DFS 70 | ``` 71 | 72 | ## Acknowledgement 73 | - The sequential formulation of dynamic inference problem from [SkipNet](https://github.com/ucbdrive/skipnet) 74 | - The quantization function from [Scalable Methods](https://github.com/eladhoffer/quantized.pytorch) 75 | 76 | ## License 77 | [MIT](https://choosealicense.com/licenses/mit/) 78 | 79 | 80 | 81 | 82 | 83 | 84 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision -------------------------------------------------------------------------------- /supplementary_material.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Torment123/DFS/9e05ac1e98dbcdf5385048f0b3cf158eb14698a7/supplementary_material.pdf -------------------------------------------------------------------------------- /train_base.py: -------------------------------------------------------------------------------- 1 | """ This file is for training original model without routing modules. 2 | """ 3 | 4 | from __future__ import print_function 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.backends.cudnn as cudnn 9 | from torch.autograd import Variable 10 | 11 | import os 12 | import shutil 13 | import argparse 14 | import time 15 | import logging 16 | 17 | import models 18 | from data import * 19 | 20 | 21 | model_names = sorted(name for name in models.__dict__ 22 | if name.islower() and not name.startswith('__') 23 | and callable(models.__dict__[name]) 24 | ) 25 | 26 | 27 | def parse_args(): 28 | # hyper-parameters are from ResNet paper 29 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 training') 30 | parser.add_argument('cmd', choices=['train', 'test']) 31 | parser.add_argument('arch', metavar='ARCH', default='cifar10_resnet_110', 32 | choices=model_names, 33 | help='model architecture: ' + 34 | ' | '.join(model_names) + 35 | ' (default: cifar10_resnet_110)') 36 | parser.add_argument('--dataset', '-d', type=str, default='cifar10', 37 | choices=['cifar10', 'cifar100'], 38 | help='dataset choice') 39 | parser.add_argument('--workers', default=8, type=int, metavar='N', 40 | help='number of data loading workers (default: 4 )') 41 | parser.add_argument('--iters', default=64000, type=int, 42 | help='number of total iterations (default: 64,000)') 43 | parser.add_argument('--start-iter', default=0, type=int, 44 | help='manual iter number (useful on restarts)') 45 | parser.add_argument('--batch-size', default=128, type=int, 46 | help='mini-batch size (default: 128)') 47 | parser.add_argument('--lr', default=0.1, type=float, 48 | help='initial learning rate') 49 | parser.add_argument('--momentum', default=0.9, type=float, 50 | help='momentum') 51 | parser.add_argument('--weight-decay', default=1e-4, type=float, 52 | help='weight decay (default: 1e-4)') 53 | parser.add_argument('--print-freq', default=10, type=int, 54 | help='print frequency (default: 10)') 55 | parser.add_argument('--resume', default='', type=str, 56 | help='path to latest checkpoint (default: None)') 57 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 58 | help='use pretrained model') 59 | parser.add_argument('--step-ratio', default=0.1, type=float, 60 | help='ratio for learning rate deduction') 61 | parser.add_argument('--warm-up', action='store_true', 62 | help='for n = 18, the model needs to warm up for 400 ' 63 | 'iterations') 64 | parser.add_argument('--save-folder', default='save_checkpoints/', type=str, 65 | help='folder to save the checkpoints') 66 | parser.add_argument('--eval-every', default=500, type=int, 67 | help='evaluate model every (default: 1000) iterations') 68 | parser.add_argument('--precision',default=0,type=int, 69 | help='bitwidth for quantization') 70 | args = parser.parse_args() 71 | return args 72 | 73 | 74 | def main(): 75 | args = parse_args() 76 | save_path = args.save_path = os.path.join(args.save_folder, args.arch) 77 | if not os.path.exists(save_path): 78 | os.makedirs(save_path) 79 | 80 | # config logging file 81 | args.logger_file = os.path.join(save_path, 'log_{}.txt'.format(args.cmd)) 82 | handlers = [logging.FileHandler(args.logger_file, mode='w'), 83 | logging.StreamHandler()] 84 | logging.basicConfig(level=logging.INFO, 85 | datefmt='%m-%d-%y %H:%M', 86 | format='%(asctime)s:%(message)s', 87 | handlers=handlers) 88 | 89 | if args.cmd == 'train': 90 | logging.info('start training {}'.format(args.arch)) 91 | run_training(args) 92 | 93 | elif args.cmd == 'test': 94 | logging.info('start evaluating {} with checkpoints from {}'.format( 95 | args.arch, args.resume)) 96 | test_model(args) 97 | 98 | 99 | def run_training(args): 100 | # create model 101 | model = models.__dict__[args.arch](args.pretrained) 102 | model = torch.nn.DataParallel(model).cuda() 103 | 104 | best_prec1 = 0 105 | 106 | # optionally resume from a checkpoint 107 | if args.resume: 108 | if os.path.isfile(args.resume): 109 | logging.info('=> loading checkpoint `{}`'.format(args.resume)) 110 | checkpoint = torch.load(args.resume) 111 | args.start_iter = checkpoint['iter'] 112 | best_prec1 = checkpoint['best_prec1'] 113 | model.load_state_dict(checkpoint['state_dict']) 114 | logging.info('=> loaded checkpoint `{}` (iter: {})'.format( 115 | args.resume, checkpoint['iter'] 116 | )) 117 | else: 118 | logging.info('=> no checkpoint found at `{}`'.format(args.resume)) 119 | 120 | cudnn.benchmark = False 121 | 122 | train_loader = prepare_train_data(dataset=args.dataset, 123 | batch_size=args.batch_size, 124 | shuffle=True, 125 | num_workers=args.workers) 126 | test_loader = prepare_test_data(dataset=args.dataset, 127 | batch_size=args.batch_size, 128 | shuffle=False, 129 | num_workers=args.workers) 130 | 131 | # define loss function (criterion) and optimizer 132 | criterion = nn.CrossEntropyLoss().cuda() 133 | 134 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 135 | momentum=args.momentum, 136 | weight_decay=args.weight_decay) 137 | 138 | batch_time = AverageMeter() 139 | data_time = AverageMeter() 140 | losses = AverageMeter() 141 | top1 = AverageMeter() 142 | 143 | end = time.time() 144 | for i in range(args.start_iter, args.iters): 145 | model.train() 146 | adjust_learning_rate(args, optimizer, i) 147 | 148 | input, target = next(iter(train_loader)) 149 | # measuring data loading time 150 | data_time.update(time.time() - end) 151 | 152 | target = target.squeeze().long().cuda(async=True) 153 | input_var = Variable(input) 154 | target_var = Variable(target) 155 | 156 | # compute output 157 | output = model(input_var,args.precision) 158 | loss = criterion(output, target_var) 159 | 160 | # measure accuracy and record loss 161 | prec1, = accuracy(output.data, target, topk=(1,)) 162 | losses.update(loss.item(), input.size(0)) 163 | top1.update(prec1.item(), input.size(0)) 164 | 165 | # compute gradient and do SGD step 166 | optimizer.zero_grad() 167 | loss.backward() 168 | optimizer.step() 169 | 170 | # measure elapsed time 171 | batch_time.update(time.time() - end) 172 | end = time.time() 173 | 174 | # print log 175 | if i % args.print_freq == 0: 176 | logging.info("Iter: [{0}/{1}]\t" 177 | "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" 178 | "Data {data_time.val:.3f} ({data_time.avg:.3f})\t" 179 | "Loss {loss.val:.3f} ({loss.avg:.3f})\t" 180 | "Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t".format( 181 | i, 182 | args.iters, 183 | batch_time=batch_time, 184 | data_time=data_time, 185 | loss=losses, 186 | top1=top1) 187 | ) 188 | 189 | # evaluate every 1000 steps 190 | if (i % args.eval_every == 0 and i > 0) or (i == args.iters - 1): 191 | with torch.no_grad(): 192 | prec1 = validate(args, test_loader, model, criterion) 193 | # prec1 = validate(args, test_loader, model, criterion) 194 | is_best = prec1 > best_prec1 195 | best_prec1 = max(prec1, best_prec1) 196 | checkpoint_path = os.path.join(args.save_path, 197 | 'checkpoint_{:05d}.pth.tar'.format( 198 | i)) 199 | save_checkpoint({ 200 | 'iter': i, 201 | 'arch': args.arch, 202 | 'state_dict': model.state_dict(), 203 | 'best_prec1': best_prec1, 204 | }, 205 | is_best, filename=checkpoint_path) 206 | shutil.copyfile(checkpoint_path, os.path.join(args.save_path, 207 | 'checkpoint_latest' 208 | '.pth.tar')) 209 | 210 | 211 | def validate(args, test_loader, model, criterion): 212 | batch_time = AverageMeter() 213 | losses = AverageMeter() 214 | top1 = AverageMeter() 215 | 216 | # switch to evaluation mode 217 | model.eval() 218 | end = time.time() 219 | for i, (input, target) in enumerate(test_loader): 220 | target = target.squeeze().long().cuda(async=True) 221 | input_var = Variable(input, volatile=True) 222 | target_var = Variable(target, volatile=True) 223 | 224 | # compute output 225 | output = model(input_var,args.precision) 226 | loss = criterion(output, target_var) 227 | 228 | # measure accuracy and record loss 229 | prec1, = accuracy(output.data, target, topk=(1,)) 230 | top1.update(prec1.item(), input.size(0)) 231 | losses.update(loss.item(), input.size(0)) 232 | batch_time.update(time.time() - end) 233 | end = time.time() 234 | 235 | if (i % args.print_freq == 0) or (i == len(test_loader) - 1): 236 | logging.info( 237 | 'Test: [{}/{}]\t' 238 | 'Time: {batch_time.val:.4f}({batch_time.avg:.4f})\t' 239 | 'Loss: {loss.val:.3f}({loss.avg:.3f})\t' 240 | 'Prec@1: {top1.val:.3f}({top1.avg:.3f})\t'.format( 241 | i, len(test_loader), batch_time=batch_time, 242 | loss=losses, top1=top1 243 | ) 244 | ) 245 | 246 | logging.info(' * Prec@1 {top1.avg:.3f}'.format(top1=top1)) 247 | return top1.avg 248 | 249 | 250 | def test_model(args): 251 | # create model 252 | model = models.__dict__[args.arch](args.pretrained) 253 | model = torch.nn.DataParallel(model).cuda() 254 | 255 | if args.resume: 256 | if os.path.isfile(args.resume): 257 | logging.info('=> loading checkpoint `{}`'.format(args.resume)) 258 | checkpoint = torch.load(args.resume) 259 | args.start_iter = checkpoint['iter'] 260 | best_prec1 = checkpoint['best_prec1'] 261 | model.load_state_dict(checkpoint['state_dict']) 262 | logging.info('=> loaded checkpoint `{}` (iter: {})'.format( 263 | args.resume, checkpoint['iter'] 264 | )) 265 | else: 266 | logging.info('=> no checkpoint found at `{}`'.format(args.resume)) 267 | 268 | cudnn.benchmark = False 269 | test_loader = prepare_test_data(dataset=args.dataset, 270 | batch_size=args.batch_size, 271 | shuffle=False, 272 | num_workers=args.workers) 273 | criterion = nn.CrossEntropyLoss().cuda() 274 | 275 | # validate(args, test_loader, model, criterion) 276 | 277 | with torch.no_grad(): 278 | prec1 = validate(args, test_loader, model, criterion) 279 | 280 | 281 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 282 | torch.save(state, filename) 283 | if is_best: 284 | save_path = os.path.dirname(filename) 285 | shutil.copyfile(filename, os.path.join(save_path, 286 | 'model_best.pth.tar')) 287 | 288 | 289 | class AverageMeter(object): 290 | """Computes and stores the average and current value""" 291 | 292 | def __init__(self): 293 | self.reset() 294 | 295 | def reset(self): 296 | self.val = 0 297 | self.avg = 0 298 | self.sum = 0 299 | self.count = 0 300 | 301 | def update(self, val, n=1): 302 | self.val = val 303 | self.sum += val * n 304 | self.count += n 305 | self.avg = self.sum / self.count 306 | 307 | 308 | def adjust_learning_rate(args, optimizer, _iter): 309 | """divide lr by 10 at 32k and 48k """ 310 | if args.warm_up and (_iter < 400): 311 | lr = 0.01 312 | elif 32000 <= _iter < 48000: 313 | lr = args.lr * (args.step_ratio ** 1) 314 | elif _iter >= 48000: 315 | lr = args.lr * (args.step_ratio ** 2) 316 | else: 317 | lr = args.lr 318 | 319 | if _iter % args.eval_every == 0: 320 | logging.info('Iter [{}] learning rate = {}'.format(_iter, lr)) 321 | 322 | for param_group in optimizer.param_groups: 323 | param_group['lr'] = lr 324 | 325 | 326 | def accuracy(output, target, topk=(1,)): 327 | """Computes the precision@k for the specified values of k""" 328 | maxk = max(topk) 329 | batch_size = target.size(0) 330 | 331 | _, pred = output.topk(maxk, 1, True, True) 332 | pred = pred.t() 333 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 334 | 335 | res = [] 336 | for k in topk: 337 | correct_k = correct[:k].view(-1).float().sum(0) 338 | res.append(correct_k.mul_(100.0 / batch_size)) 339 | return res 340 | 341 | 342 | if __name__ == '__main__': 343 | main() 344 | -------------------------------------------------------------------------------- /train_sp_integrate_dynamic_quantization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Training file for training SkipNets for supervised pre-training stage 3 | """ 4 | 5 | from __future__ import print_function 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.backends.cudnn as cudnn 10 | from torch.autograd import Variable 11 | 12 | import os 13 | import shutil 14 | import argparse 15 | import time 16 | import logging 17 | 18 | import models 19 | from data import * 20 | 21 | from functools import reduce 22 | 23 | 24 | model_names = sorted(name for name in models.__dict__ 25 | if name.islower() and not name.startswith('__') 26 | and callable(models.__dict__[name]) 27 | ) 28 | 29 | 30 | def parse_args(): 31 | # hyper-parameters are from ResNet paper 32 | parser = argparse.ArgumentParser( 33 | description='PyTorch CIFAR10 training with gating') 34 | parser.add_argument('cmd', choices=['train', 'test']) 35 | parser.add_argument('arch', metavar='ARCH', 36 | default='cifar10_feedforward_38', 37 | choices=model_names, 38 | help='model architecture: ' + 39 | ' | '.join(model_names) + 40 | ' (default: cifar10_feedforward_38)') 41 | parser.add_argument('--gate-type', type=str, default='ff', 42 | choices=['ff', 'rnn'], help='gate type') 43 | parser.add_argument('--dataset', '-d', default='cifar10', type=str, 44 | choices=['cifar10', 'cifar100'], 45 | help='dataset type') 46 | parser.add_argument('--workers', default=1, type=int, metavar='N', 47 | help='number of data loading workers (default: 4 )') 48 | parser.add_argument('--iters', default=64000, type=int, 49 | help='number of total iterations (default: 64,000)') 50 | parser.add_argument('--start-iter', default=0, type=int, 51 | help='manual iter number (useful on restarts)') 52 | parser.add_argument('--batch-size', default=128, type=int, 53 | help='mini-batch size (default: 128)') 54 | parser.add_argument('--lr', default=0.1, type=float, 55 | help='initial learning rate') 56 | parser.add_argument('--momentum', default=0.9, type=float, 57 | help='momentum') 58 | parser.add_argument('--weight-decay', default=1e-4, type=float, 59 | help='weight decay (default: 1e-4)') 60 | parser.add_argument('--print-freq', default=10, type=int, 61 | help='print frequency (default: 10)') 62 | parser.add_argument('--resume', default='', type=str, 63 | help='path to latest checkpoint (default: None)') 64 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 65 | help='use pretrained model') 66 | parser.add_argument('--step-ratio', default=0.1, type=float, 67 | help='ratio for learning rate deduction') 68 | parser.add_argument('--warm-up', action='store_true', 69 | help='for n = 18, the model needs to warm up for 400 ' 70 | 'iterations') 71 | parser.add_argument('--save-folder', default='save_checkpoints', 72 | type=str, 73 | help='folder to save the checkpoints') 74 | parser.add_argument('--eval-every', default=390, type=int, 75 | help='evaluate model every (default: 1000) iterations') 76 | parser.add_argument('--verbose', action="store_true", 77 | help='print layer skipping ratio at training') 78 | parser.add_argument('--minimum', default=100, type=float, 79 | help='minimum') 80 | parser.add_argument('--computation_cost', default=True, type=bool, 81 | help='using computation cost as regularization term') 82 | parser.add_argument('--proceed', default='False', 83 | help='whether this experiment continues from a checkpoint') 84 | parser.add_argument('--beta', default=1e-5, type=float, 85 | help='coefficient') 86 | 87 | args = parser.parse_args() 88 | return args 89 | 90 | 91 | def main(): 92 | args = parse_args() 93 | 94 | save_path = args.save_path = os.path.join(args.save_folder, args.arch) 95 | os.makedirs(save_path, exist_ok=True) 96 | 97 | # config logger file 98 | args.logger_file = os.path.join(save_path, 'log_{}.txt'.format(args.cmd)) 99 | handlers = [logging.FileHandler(args.logger_file, mode='w'), 100 | logging.StreamHandler()] 101 | logging.basicConfig(level=logging.INFO, 102 | datefmt='%m-%d-%y %H:%M', 103 | format='%(asctime)s:%(message)s', 104 | handlers=handlers) 105 | 106 | if args.cmd == 'train': 107 | logging.info('start training {}'.format(args.arch)) 108 | run_training(args) 109 | 110 | elif args.cmd == 'test': 111 | logging.info('start evaluating {} with checkpoints from {}'.format( 112 | args.arch, args.resume)) 113 | test_model(args) 114 | 115 | 116 | def run_training(args): 117 | # create model 118 | model = models.__dict__[args.arch](args.pretrained) 119 | model = torch.nn.DataParallel(model).cuda() 120 | 121 | best_prec1 = 0 122 | 123 | # optionally resume from a checkpoint 124 | if args.resume: 125 | if os.path.isfile(args.resume): 126 | logging.info('=> loading checkpoint `{}`'.format(args.resume)) 127 | checkpoint = torch.load(args.resume) 128 | if args.proceed == 'True': 129 | args.start_iter = checkpoint['iter'] 130 | else: 131 | args.start_iter = 0 132 | best_prec1 = checkpoint['best_prec1'] 133 | model.load_state_dict(checkpoint['state_dict'],strict=True) 134 | logging.info('=> loaded checkpoint `{}` (iter: {})'.format( 135 | args.resume, checkpoint['iter'] 136 | )) 137 | else: 138 | logging.info('=> no checkpoint found at `{}`'.format(args.resume)) 139 | 140 | cudnn.benchmark = True 141 | 142 | train_loader = prepare_train_data(dataset=args.dataset, 143 | batch_size=args.batch_size, 144 | shuffle=True, 145 | num_workers=args.workers) 146 | test_loader = prepare_test_data(dataset=args.dataset, 147 | batch_size=args.batch_size, 148 | shuffle=False, 149 | num_workers=args.workers) 150 | 151 | # define loss function (criterion) and optimizer 152 | criterion = nn.CrossEntropyLoss().cuda() 153 | 154 | optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, 155 | model.parameters()), 156 | args.lr, 157 | momentum=args.momentum, 158 | weight_decay=args.weight_decay) 159 | 160 | batch_time = AverageMeter() 161 | data_time = AverageMeter() 162 | losses = AverageMeter() 163 | top1 = AverageMeter() 164 | skip_ratios = ListAverageMeter() 165 | cp_record = AverageMeter() 166 | 167 | network_depth = sum(model.module.num_layers) 168 | 169 | # costs = [1/64, 1/16, 1/4, 1] 170 | # costs = [0, 1/16, 1/4, 1] 171 | costs = [0,1/16, 25/256, 9/64, 49/256, 1/4, 1] 172 | # bits = [8,16,0] 173 | bits = [8,10,12,14,16,0] 174 | # bits = [2,4,6,8,10,12,14,16,0] 175 | # costs = [0, 1/256, 1/64, 9/256, 1/16, 25/256, 9/64, 49/256, 1/4, 1] 176 | 177 | # bits = [2,4,6,8,10,12,14,16,0] 178 | 179 | layerwise_decision_statistics = [] 180 | 181 | for k in range(network_depth - 1): 182 | layerwise_decision_statistics.append([]) 183 | for j in range(len(costs)): 184 | ratio = AverageMeter() 185 | layerwise_decision_statistics[k].append(ratio) 186 | 187 | 188 | end = time.time() 189 | for i in range(args.start_iter, args.iters): 190 | model.train() 191 | adjust_learning_rate(args, optimizer, i) 192 | 193 | input, target = next(iter(train_loader)) 194 | # measuring data loading time 195 | data_time.update(time.time() - end) 196 | 197 | target = target.cuda(async = False) 198 | input_var = Variable(input).cuda() 199 | target_var = Variable(target).cuda() 200 | 201 | # compute output 202 | output, masks = model(input_var,bits) 203 | 204 | computation_cost = 0 205 | computation_all = 0 206 | 207 | for layer in range(network_depth - 1): 208 | 209 | full_layer = reduce((lambda x, y: x * y), masks[layer][0].shape) 210 | 211 | computation_all += full_layer 212 | 213 | for k in range(len(costs)): 214 | 215 | dynamic_choice = masks[layer][k].sum() 216 | 217 | ratio = dynamic_choice / full_layer 218 | 219 | layerwise_decision_statistics[layer][k].update(ratio, 1) 220 | 221 | computation_cost += masks[layer][k].sum() * costs[k] 222 | 223 | 224 | # for layer in range(network_depth - 1): 225 | # computation_cost += masks[layer].sum() 226 | 227 | # computation_all += reduce((lambda x, y: x * y), masks[layer].shape) 228 | 229 | cp_ratio = (float(computation_cost) / float(computation_all)) * 100 230 | 231 | # collect skip ratio of each layer 232 | # skips = [mask.data.le(0.5).float().mean() for mask in masks] 233 | # if skip_ratios.len != len(skips): 234 | # skip_ratios.set_len(len(skips)) 235 | 236 | computation_cost *= args.beta 237 | 238 | if cp_ratio <= args.minimum: 239 | reg = -1 240 | else: 241 | reg = 1 242 | 243 | if args.computation_cost: 244 | loss = criterion(output, target_var) + computation_cost * reg 245 | else: 246 | loss = criterion(output, target_var) 247 | 248 | # measure accuracy and record loss 249 | prec1, = accuracy(output.data, target, topk=(1,)) 250 | losses.update(loss.item(), input.size(0)) 251 | top1.update(prec1.item(), input.size(0)) 252 | # skip_ratios.update(skips, input.size(0)) 253 | cp_record.update(cp_ratio,1) 254 | 255 | # compute gradient and do SGD step 256 | optimizer.zero_grad() 257 | loss.backward() 258 | optimizer.step() 259 | 260 | # repackage hidden units for RNN Gate 261 | if args.gate_type == 'rnn': 262 | model.module.control.repackage_hidden() 263 | 264 | batch_time.update(time.time() - end) 265 | end = time.time() 266 | 267 | # print log 268 | if i % args.print_freq == 0 or i == (args.iters - 1): 269 | logging.info("Iter: [{0}/{1}]\t" 270 | "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" 271 | "Data {data_time.val:.3f} ({data_time.avg:.3f})\t" 272 | "Loss {loss.val:.3f} ({loss.avg:.3f})\t" 273 | "Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t" 274 | "Computation_Percentage: {cp_record.val:.3f}({cp_record.avg:.3f})\t".format( 275 | i, 276 | args.iters, 277 | batch_time=batch_time, 278 | data_time=data_time, 279 | loss=losses, 280 | top1=top1, 281 | cp_record=cp_record) 282 | ) 283 | 284 | 285 | 286 | # for idx in range(skip_ratios.len): 287 | # logging.info( 288 | # "{} layer skipping = {:.3f}({:.3f})".format( 289 | # idx, 290 | # skip_ratios.val[idx], 291 | # skip_ratios.avg[idx], 292 | # ) 293 | # ) 294 | 295 | # evaluate every 1000 steps 296 | if (i % args.eval_every == 0 and i > 0) or (i == (args.iters-1)): 297 | 298 | with torch.no_grad(): 299 | prec1 = validate(args, test_loader, model, criterion) 300 | is_best = prec1 > best_prec1 301 | best_prec1 = max(prec1, best_prec1) 302 | checkpoint_path = os.path.join(args.save_path, 303 | 'checkpoint_{:05d}.pth.tar'.format( 304 | i)) 305 | save_checkpoint({ 306 | 'iter': i, 307 | 'arch': args.arch, 308 | 'state_dict': model.state_dict(), 309 | 'best_prec1': best_prec1, 310 | }, 311 | is_best = is_best, filename=checkpoint_path) 312 | shutil.copyfile(checkpoint_path, os.path.join(args.save_path, 313 | 'checkpoint_latest' 314 | '.pth.tar')) 315 | 316 | 317 | def validate(args, test_loader, model, criterion): 318 | batch_time = AverageMeter() 319 | losses = AverageMeter() 320 | top1 = AverageMeter() 321 | skip_ratios = ListAverageMeter() 322 | cp_record = AverageMeter() 323 | 324 | # costs = [1/64, 1/16, 1/4, 1] 325 | # costs = [0,1/16,1/4,1] 326 | # bits = [8,16,0] 327 | bits = [8,10,12,14,16,0] 328 | costs = [0,1/16, 25/256, 9/64, 49/256, 1/4, 1] 329 | # bits = [2,4,6,8,10,12,14,16,0] 330 | # costs = [0, 1/256, 1/64, 9/256, 1/16, 25/256, 9/64, 49/256, 1/4, 1] 331 | # bits = [2,4,6,8,10,12,14,16,0] 332 | 333 | network_depth = sum(model.module.num_layers) 334 | 335 | layerwise_decision_statistics = [] 336 | 337 | for k in range(network_depth - 1): 338 | layerwise_decision_statistics.append([]) 339 | for j in range(len(costs)): 340 | ratio = AverageMeter() 341 | layerwise_decision_statistics[k].append(ratio) 342 | 343 | # switch to evaluation mode 344 | model.eval() 345 | end = time.time() 346 | for i, (input, target) in enumerate(test_loader): 347 | target = target.cuda(async = True) 348 | input_var = Variable(input, volatile=True).cuda() 349 | target_var = Variable(target, volatile=True).cuda() 350 | 351 | 352 | 353 | # compute output 354 | output, masks = model(input_var,bits) 355 | 356 | computation_cost = 0 357 | computation_all = 0 358 | 359 | 360 | for layer in range(network_depth - 1): 361 | 362 | full_layer = reduce((lambda x, y: x * y), masks[layer][0].shape) 363 | 364 | computation_all += full_layer 365 | 366 | for k in range(len(costs)): 367 | 368 | dynamic_choice = masks[layer][k].sum() 369 | 370 | ratio = dynamic_choice / full_layer 371 | 372 | layerwise_decision_statistics[layer][k].update(ratio, 1) 373 | 374 | computation_cost += masks[layer][k].sum() * costs[k] 375 | 376 | # for layer in range(network_depth - 1): 377 | # computation_cost += masks[layer].sum() 378 | 379 | # computation_all += reduce((lambda x, y: x * y), masks[layer].shape) 380 | 381 | cp_ratio = (float(computation_cost) / float(computation_all)) * 100 382 | 383 | 384 | # skips = [mask.data.le(0.5).float().mean() for mask in masks] 385 | # if skip_ratios.len != len(skips): 386 | # skip_ratios.set_len(len(skips)) 387 | loss = criterion(output, target_var) 388 | 389 | # measure accuracy and record loss 390 | prec1, = accuracy(output.data, target, topk=(1,)) 391 | top1.update(prec1.item(), input.size(0)) 392 | # skip_ratios.update(skips, input.size(0)) 393 | losses.update(loss.item(), input.size(0)) 394 | batch_time.update(time.time() - end) 395 | end = time.time() 396 | cp_record.update(cp_ratio,1) 397 | 398 | if i % args.print_freq == 0 or (i == (len(test_loader) - 1)): 399 | logging.info( 400 | 'Test: [{}/{}]\t' 401 | 'Time: {batch_time.val:.4f}({batch_time.avg:.4f})\t' 402 | 'Loss: {loss.val:.3f}({loss.avg:.3f})\t' 403 | 'Prec@1: {top1.val:.3f}({top1.avg:.3f})\t' 404 | 'Computation_Percentage:{cp_record.val:.3f}({cp_record.avg:.3f})\t'.format( 405 | i, len(test_loader), batch_time=batch_time, 406 | loss=losses, 407 | top1=top1, 408 | cp_record=cp_record 409 | ) 410 | ) 411 | 412 | if args.gate_type == 'rnn': 413 | model.module.control.repackage_hidden() 414 | 415 | logging.info(' * Prec@1 {top1.avg:.3f}, Loss {loss.avg:.3f}'.format( 416 | top1=top1, loss=losses)) 417 | 418 | for layer in range(network_depth - 1): 419 | print('layer{}_decision'.format(layer + 2)) 420 | 421 | for g in range(len(costs)): 422 | 423 | print('{}_ratio{}'.format(g,layerwise_decision_statistics[layer][g].avg)) 424 | 425 | # skip_summaries = [] 426 | # for idx in range(skip_ratios.len): 427 | # logging.info( 428 | # "{} layer skipping = {:.3f}".format( 429 | # idx, 430 | # skip_ratios.avg[idx], 431 | # ) 432 | # ) 433 | # skip_summaries.append(1-skip_ratios.avg[idx]) 434 | # compute `computational percentage` 435 | # cp = ((sum(skip_summaries) + 1) / (len(skip_summaries) + 1)) * 100 436 | # logging.info('*** Computation Percentage: {:.3f} %'.format(cp)) 437 | 438 | return top1.avg 439 | 440 | 441 | def test_model(args): 442 | # create model 443 | model = models.__dict__[args.arch](args.pretrained) 444 | model = torch.nn.DataParallel(model).cuda() 445 | 446 | if args.resume: 447 | if os.path.isfile(args.resume): 448 | logging.info('=> loading checkpoint `{}`'.format(args.resume)) 449 | checkpoint = torch.load(args.resume) 450 | args.start_iter = checkpoint['iter'] 451 | best_prec1 = checkpoint['best_prec1'] 452 | model.load_state_dict(checkpoint['state_dict'],strict=True) 453 | logging.info('=> loaded checkpoint `{}` (iter: {})'.format( 454 | args.resume, checkpoint['iter'] 455 | )) 456 | else: 457 | logging.info('=> no checkpoint found at `{}`'.format(args.resume)) 458 | 459 | cudnn.benchmark = False 460 | test_loader = prepare_test_data(dataset=args.dataset, 461 | batch_size=args.batch_size, 462 | shuffle=False, 463 | num_workers=args.workers) 464 | criterion = nn.CrossEntropyLoss().cuda() 465 | 466 | with torch.no_grad(): 467 | validate(args, test_loader, model, criterion) 468 | 469 | 470 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 471 | torch.save(state, filename) 472 | if is_best: 473 | save_path = os.path.dirname(filename) 474 | shutil.copyfile(filename, os.path.join(save_path, 475 | 'model_best.pth.tar')) 476 | 477 | 478 | class AverageMeter(object): 479 | """Computes and stores the average and current value""" 480 | 481 | def __init__(self): 482 | self.reset() 483 | 484 | def reset(self): 485 | self.val = 0 486 | self.avg = 0 487 | self.sum = 0 488 | self.count = 0 489 | 490 | def update(self, val, n=1): 491 | self.val = val 492 | self.sum += val * n 493 | self.count += n 494 | self.avg = self.sum / self.count 495 | 496 | 497 | class ListAverageMeter(object): 498 | """Computes and stores the average and current values of a list""" 499 | def __init__(self): 500 | self.len = 10000 # set up the maximum length 501 | self.reset() 502 | 503 | def reset(self): 504 | self.val = [0] * self.len 505 | self.avg = [0] * self.len 506 | self.sum = [0] * self.len 507 | self.count = 0 508 | 509 | def set_len(self, n): 510 | self.len = n 511 | self.reset() 512 | 513 | def update(self, vals, n=1): 514 | assert len(vals) == self.len, 'length of vals not equal to self.len' 515 | self.val = vals 516 | for i in range(self.len): 517 | self.sum[i] += self.val[i] * n 518 | self.count += n 519 | for i in range(self.len): 520 | self.avg[i] = self.sum[i] / self.count 521 | 522 | 523 | def adjust_learning_rate(args, optimizer, _iter): 524 | """ divide lr by 10 at 32k and 48k """ 525 | if args.warm_up and (_iter < 400): 526 | lr = 0.01 527 | elif 32000 <= _iter < 48000: 528 | lr = args.lr * (args.step_ratio ** 1) 529 | elif _iter >= 48000: 530 | lr = args.lr * (args.step_ratio ** 2) 531 | else: 532 | lr = args.lr 533 | 534 | if _iter % args.eval_every == 0: 535 | logging.info('Iter [{}] learning rate = {}'.format(_iter, lr)) 536 | 537 | for param_group in optimizer.param_groups: 538 | param_group['lr'] = lr 539 | 540 | 541 | def accuracy(output, target, topk=(1,)): 542 | """Computes the precision@k for the specified values of k""" 543 | maxk = max(topk) 544 | batch_size = target.size(0) 545 | 546 | _, pred = output.topk(maxk, 1, True, True) 547 | pred = pred.t() 548 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 549 | 550 | res = [] 551 | for k in topk: 552 | correct_k = correct[:k].view(-1).float().sum(0) 553 | res.append(correct_k.mul_(100.0 / batch_size)) 554 | return res 555 | 556 | 557 | if __name__ == '__main__': 558 | main() 559 | -------------------------------------------------------------------------------- /train_sp_integrate_dynamic_quantization_initial.py: -------------------------------------------------------------------------------- 1 | """ 2 | Training file for training SkipNets for supervised pre-training stage 3 | """ 4 | 5 | from __future__ import print_function 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.backends.cudnn as cudnn 10 | from torch.autograd import Variable 11 | 12 | import os 13 | import shutil 14 | import argparse 15 | import time 16 | import logging 17 | 18 | import models_initial 19 | from data import * 20 | 21 | from functools import reduce 22 | 23 | 24 | model_names = sorted(name for name in models_initial.__dict__ 25 | if name.islower() and not name.startswith('__') 26 | and callable(models_initial.__dict__[name]) 27 | ) 28 | 29 | 30 | def parse_args(): 31 | # hyper-parameters are from ResNet paper 32 | parser = argparse.ArgumentParser( 33 | description='PyTorch CIFAR10 training with gating') 34 | parser.add_argument('cmd', choices=['train', 'test']) 35 | parser.add_argument('arch', metavar='ARCH', 36 | default='cifar10_feedforward_38', 37 | choices=model_names, 38 | help='model architecture: ' + 39 | ' | '.join(model_names) + 40 | ' (default: cifar10_feedforward_38)') 41 | parser.add_argument('--gate-type', type=str, default='ff', 42 | choices=['ff', 'rnn'], help='gate type') 43 | parser.add_argument('--dataset', '-d', default='cifar10', type=str, 44 | choices=['cifar10', 'cifar100'], 45 | help='dataset type') 46 | parser.add_argument('--workers', default=1, type=int, metavar='N', 47 | help='number of data loading workers (default: 4 )') 48 | parser.add_argument('--iters', default=64000, type=int, 49 | help='number of total iterations (default: 64,000)') 50 | parser.add_argument('--start-iter', default=0, type=int, 51 | help='manual iter number (useful on restarts)') 52 | parser.add_argument('--batch-size', default=128, type=int, 53 | help='mini-batch size (default: 128)') 54 | parser.add_argument('--lr', default=0.1, type=float, 55 | help='initial learning rate') 56 | parser.add_argument('--momentum', default=0.9, type=float, 57 | help='momentum') 58 | parser.add_argument('--weight-decay', default=1e-4, type=float, 59 | help='weight decay (default: 1e-4)') 60 | parser.add_argument('--print-freq', default=10, type=int, 61 | help='print frequency (default: 10)') 62 | parser.add_argument('--resume', default='', type=str, 63 | help='path to latest checkpoint (default: None)') 64 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 65 | help='use pretrained model') 66 | parser.add_argument('--step-ratio', default=0.1, type=float, 67 | help='ratio for learning rate deduction') 68 | parser.add_argument('--warm-up', action='store_true', 69 | help='for n = 18, the model needs to warm up for 400 ' 70 | 'iterations') 71 | parser.add_argument('--save-folder', default='save_checkpoints', 72 | type=str, 73 | help='folder to save the checkpoints') 74 | parser.add_argument('--eval-every', default=390, type=int, 75 | help='evaluate model every (default: 1000) iterations') 76 | parser.add_argument('--verbose', action="store_true", 77 | help='print layer skipping ratio at training') 78 | parser.add_argument('--minimum', default=100, type=float, 79 | help='minimum') 80 | parser.add_argument('--computation_cost', default=True, type=bool, 81 | help='using computation cost as regularization term') 82 | parser.add_argument('--proceed', default='False', 83 | help='whether this experiment continues from a checkpoint') 84 | parser.add_argument('--beta', default=1e-5, type=float, 85 | help='coefficient') 86 | 87 | args = parser.parse_args() 88 | return args 89 | 90 | 91 | def main(): 92 | args = parse_args() 93 | 94 | save_path = args.save_path = os.path.join(args.save_folder, args.arch) 95 | os.makedirs(save_path, exist_ok=True) 96 | 97 | # config logger file 98 | args.logger_file = os.path.join(save_path, 'log_{}.txt'.format(args.cmd)) 99 | handlers = [logging.FileHandler(args.logger_file, mode='w'), 100 | logging.StreamHandler()] 101 | logging.basicConfig(level=logging.INFO, 102 | datefmt='%m-%d-%y %H:%M', 103 | format='%(asctime)s:%(message)s', 104 | handlers=handlers) 105 | 106 | if args.cmd == 'train': 107 | logging.info('start training {}'.format(args.arch)) 108 | run_training(args) 109 | 110 | elif args.cmd == 'test': 111 | logging.info('start evaluating {} with checkpoints from {}'.format( 112 | args.arch, args.resume)) 113 | test_model(args) 114 | 115 | 116 | def run_training(args): 117 | # create model 118 | model = models_initial.__dict__[args.arch](args.pretrained) 119 | 120 | for param in model.parameters(): 121 | param.requires_grad = False 122 | 123 | for param in model.control.parameters(): 124 | param.requires_grad = True 125 | 126 | for g in range(3): 127 | for i in range(model.num_layers[g]): 128 | gate_layer = getattr(model,'group{}_gate{}'.format(g + 1,i)) 129 | for param in gate_layer.parameters(): 130 | param.requires_grad = True 131 | 132 | model = torch.nn.DataParallel(model).cuda() 133 | 134 | best_prec1 = 0 135 | 136 | # optionally resume from a checkpoint 137 | if args.resume: 138 | if os.path.isfile(args.resume): 139 | logging.info('=> loading checkpoint `{}`'.format(args.resume)) 140 | checkpoint = torch.load(args.resume) 141 | if args.proceed == 'True': 142 | args.start_iter = checkpoint['iter'] 143 | else: 144 | args.start_iter = 0 145 | best_prec1 = checkpoint['best_prec1'] 146 | model.load_state_dict(checkpoint['state_dict'],strict=False) 147 | logging.info('=> loaded checkpoint `{}` (iter: {})'.format( 148 | args.resume, checkpoint['iter'] 149 | )) 150 | else: 151 | logging.info('=> no checkpoint found at `{}`'.format(args.resume)) 152 | 153 | cudnn.benchmark = True 154 | 155 | train_loader = prepare_train_data(dataset=args.dataset, 156 | batch_size=args.batch_size, 157 | shuffle=True, 158 | num_workers=args.workers) 159 | test_loader = prepare_test_data(dataset=args.dataset, 160 | batch_size=args.batch_size, 161 | shuffle=False, 162 | num_workers=args.workers) 163 | 164 | # define loss function (criterion) and optimizer 165 | criterion = nn.CrossEntropyLoss().cuda() 166 | 167 | optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, 168 | model.parameters()), 169 | args.lr, 170 | momentum=args.momentum, 171 | weight_decay=args.weight_decay) 172 | 173 | batch_time = AverageMeter() 174 | data_time = AverageMeter() 175 | losses = AverageMeter() 176 | top1 = AverageMeter() 177 | skip_ratios = ListAverageMeter() 178 | cp_record = AverageMeter() 179 | 180 | network_depth = sum(model.module.num_layers) 181 | 182 | 183 | 184 | # costs = [1/64, 1/16, 1/4, 1] 185 | # costs = [0,1/16,1/4,1] 186 | 187 | # bits = [8,16,0] 188 | costs = [0, 1/16, 25/256, 9/64, 49/256, 1/4, 1] 189 | bits = [8,10,12,14,16,0] 190 | # costs = [0, 1/256, 1/64, 9/256, 1/16, 25/256, 9/64, 49/256, 1/4, 1] 191 | 192 | # bits = [2,4,6,8,10,12,14,16,0] 193 | 194 | layerwise_decision_statistics = [] 195 | 196 | for k in range(network_depth - 1): 197 | layerwise_decision_statistics.append([]) 198 | for j in range(len(costs)): 199 | ratio = AverageMeter() 200 | layerwise_decision_statistics[k].append(ratio) 201 | 202 | 203 | end = time.time() 204 | for i in range(args.start_iter, args.iters): 205 | model.train() 206 | adjust_learning_rate(args, optimizer, i) 207 | 208 | input, target = next(iter(train_loader)) 209 | # measuring data loading time 210 | data_time.update(time.time() - end) 211 | 212 | target = target.cuda(async=False) 213 | input_var = Variable(input).cuda() 214 | target_var = Variable(target).cuda() 215 | 216 | # compute output 217 | output, masks = model(input_var, bits) 218 | 219 | computation_cost = 0 220 | computation_all = 0 221 | 222 | for layer in range(network_depth - 1): 223 | 224 | full_layer = reduce((lambda x, y: x * y), masks[layer][0].shape) 225 | 226 | computation_all += full_layer 227 | 228 | for k in range(len(costs)): 229 | 230 | dynamic_choice = masks[layer][k].sum() 231 | 232 | ratio = dynamic_choice / full_layer 233 | 234 | layerwise_decision_statistics[layer][k].update(ratio, 1) 235 | 236 | computation_cost += masks[layer][k].sum() * costs[k] 237 | 238 | 239 | # for layer in range(network_depth - 1): 240 | # computation_cost += masks[layer].sum() 241 | 242 | # computation_all += reduce((lambda x, y: x * y), masks[layer].shape) 243 | 244 | cp_ratio = (float(computation_cost) / float(computation_all)) * 100 245 | 246 | # collect skip ratio of each layer 247 | # skips = [mask.data.le(0.5).float().mean() for mask in masks] 248 | # if skip_ratios.len != len(skips): 249 | # skip_ratios.set_len(len(skips)) 250 | 251 | computation_cost *= args.beta 252 | 253 | if cp_ratio <= args.minimum: 254 | reg = -1 255 | else: 256 | reg = 1 257 | 258 | if args.computation_cost: 259 | loss = criterion(output, target_var) + computation_cost * reg 260 | else: 261 | loss = criterion(output, target_var) 262 | 263 | # measure accuracy and record loss 264 | prec1, = accuracy(output.data, target, topk=(1,)) 265 | losses.update(loss.item(), input.size(0)) 266 | top1.update(prec1.item(), input.size(0)) 267 | # skip_ratios.update(skips, input.size(0)) 268 | cp_record.update(cp_ratio,1) 269 | 270 | # compute gradient and do SGD step 271 | optimizer.zero_grad() 272 | loss.backward() 273 | optimizer.step() 274 | 275 | # repackage hidden units for RNN Gate 276 | if args.gate_type == 'rnn': 277 | model.module.control.repackage_hidden() 278 | 279 | batch_time.update(time.time() - end) 280 | end = time.time() 281 | 282 | # print log 283 | if i % args.print_freq == 0 or i == (args.iters - 1): 284 | logging.info("Iter: [{0}/{1}]\t" 285 | "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" 286 | "Data {data_time.val:.3f} ({data_time.avg:.3f})\t" 287 | "Loss {loss.val:.3f} ({loss.avg:.3f})\t" 288 | "Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t" 289 | "Computation_Percentage: {cp_record.val:.3f}({cp_record.avg:.3f})\t".format( 290 | i, 291 | args.iters, 292 | batch_time=batch_time, 293 | data_time=data_time, 294 | loss=losses, 295 | top1=top1, 296 | cp_record=cp_record) 297 | ) 298 | 299 | 300 | 301 | # for idx in range(skip_ratios.len): 302 | # logging.info( 303 | # "{} layer skipping = {:.3f}({:.3f})".format( 304 | # idx, 305 | # skip_ratios.val[idx], 306 | # skip_ratios.avg[idx], 307 | # ) 308 | # ) 309 | 310 | # evaluate every 1000 steps 311 | if (i % args.eval_every == 0 and i > 0) or (i == (args.iters-1)): 312 | 313 | with torch.no_grad(): 314 | prec1 = validate(args, test_loader, model, criterion) 315 | is_best = prec1 > best_prec1 316 | best_prec1 = max(prec1, best_prec1) 317 | checkpoint_path = os.path.join(args.save_path, 318 | 'checkpoint_{:05d}.pth.tar'.format( 319 | i)) 320 | save_checkpoint({ 321 | 'iter': i, 322 | 'arch': args.arch, 323 | 'state_dict': model.state_dict(), 324 | 'best_prec1': best_prec1, 325 | }, 326 | is_best = is_best, filename=checkpoint_path) 327 | shutil.copyfile(checkpoint_path, os.path.join(args.save_path, 328 | 'checkpoint_latest' 329 | '.pth.tar')) 330 | 331 | 332 | def validate(args, test_loader, model, criterion): 333 | batch_time = AverageMeter() 334 | losses = AverageMeter() 335 | top1 = AverageMeter() 336 | skip_ratios = ListAverageMeter() 337 | cp_record = AverageMeter() 338 | 339 | # costs = [1/64, 1/16, 1/4, 1] 340 | # costs = [0,1/16,1/4,1] 341 | # bits = [8,16,0] 342 | 343 | costs = [0, 1/16, 25/256, 9/64, 49/256, 1/4, 1] 344 | bits = [8,10,12,14,16,0] 345 | # costs = [0, 1/256, 1/64, 9/256, 1/16, 25/256, 9/64, 49/256, 1/4, 1] 346 | 347 | # bits = [2,4,6,8,10,12,14,16,0] 348 | 349 | network_depth = sum(model.module.num_layers) 350 | 351 | layerwise_decision_statistics = [] 352 | 353 | for k in range(network_depth - 1): 354 | layerwise_decision_statistics.append([]) 355 | for j in range(len(costs)): 356 | ratio = AverageMeter() 357 | layerwise_decision_statistics[k].append(ratio) 358 | 359 | # switch to evaluation mode 360 | model.eval() 361 | end = time.time() 362 | for i, (input, target) in enumerate(test_loader): 363 | target = target.cuda(async=True) 364 | input_var = Variable(input, volatile=True).cuda() 365 | target_var = Variable(target, volatile=True).cuda() 366 | 367 | 368 | 369 | # compute output 370 | output, masks = model(input_var,bits) 371 | 372 | computation_cost = 0 373 | computation_all = 0 374 | 375 | 376 | for layer in range(network_depth - 1): 377 | 378 | full_layer = reduce((lambda x, y: x * y), masks[layer][0].shape) 379 | 380 | computation_all += full_layer 381 | 382 | for k in range(len(costs)): 383 | 384 | dynamic_choice = masks[layer][k].sum() 385 | 386 | ratio = dynamic_choice / full_layer 387 | 388 | layerwise_decision_statistics[layer][k].update(ratio, 1) 389 | 390 | computation_cost += masks[layer][k].sum() * costs[k] 391 | 392 | # for layer in range(network_depth - 1): 393 | # computation_cost += masks[layer].sum() 394 | 395 | # computation_all += reduce((lambda x, y: x * y), masks[layer].shape) 396 | 397 | cp_ratio = (float(computation_cost) / float(computation_all)) * 100 398 | 399 | 400 | # skips = [mask.data.le(0.5).float().mean() for mask in masks] 401 | # if skip_ratios.len != len(skips): 402 | # skip_ratios.set_len(len(skips)) 403 | loss = criterion(output, target_var) 404 | 405 | # measure accuracy and record loss 406 | prec1, = accuracy(output.data, target, topk=(1,)) 407 | top1.update(prec1.item(), input.size(0)) 408 | # skip_ratios.update(skips, input.size(0)) 409 | losses.update(loss.item(), input.size(0)) 410 | batch_time.update(time.time() - end) 411 | end = time.time() 412 | cp_record.update(cp_ratio,1) 413 | 414 | if i % args.print_freq == 0 or (i == (len(test_loader) - 1)): 415 | logging.info( 416 | 'Test: [{}/{}]\t' 417 | 'Time: {batch_time.val:.4f}({batch_time.avg:.4f})\t' 418 | 'Loss: {loss.val:.3f}({loss.avg:.3f})\t' 419 | 'Prec@1: {top1.val:.3f}({top1.avg:.3f})\t' 420 | 'Computation_Percentage:{cp_record.val:.3f}({cp_record.avg:.3f})\t'.format( 421 | i, len(test_loader), batch_time=batch_time, 422 | loss=losses, 423 | top1=top1, 424 | cp_record=cp_record 425 | ) 426 | ) 427 | 428 | if args.gate_type == 'rnn': 429 | model.module.control.repackage_hidden() 430 | 431 | logging.info(' * Prec@1 {top1.avg:.3f}, Loss {loss.avg:.3f}'.format( 432 | top1=top1, loss=losses)) 433 | 434 | for layer in range(network_depth - 1): 435 | print('layer{}_decision'.format(layer + 2)) 436 | 437 | for g in range(len(costs)): 438 | 439 | print('{}_ratio{}'.format(g,layerwise_decision_statistics[layer][g].avg)) 440 | 441 | # skip_summaries = [] 442 | # for idx in range(skip_ratios.len): 443 | # logging.info( 444 | # "{} layer skipping = {:.3f}".format( 445 | # idx, 446 | # skip_ratios.avg[idx], 447 | # ) 448 | # ) 449 | # skip_summaries.append(1-skip_ratios.avg[idx]) 450 | # compute `computational percentage` 451 | # cp = ((sum(skip_summaries) + 1) / (len(skip_summaries) + 1)) * 100 452 | # logging.info('*** Computation Percentage: {:.3f} %'.format(cp)) 453 | 454 | return top1.avg 455 | 456 | 457 | def test_model(args): 458 | # create model 459 | model = models_initial.__dict__[args.arch](args.pretrained) 460 | model = torch.nn.DataParallel(model).cuda() 461 | 462 | if args.resume: 463 | if os.path.isfile(args.resume): 464 | logging.info('=> loading checkpoint `{}`'.format(args.resume)) 465 | checkpoint = torch.load(args.resume) 466 | args.start_iter = checkpoint['iter'] 467 | best_prec1 = checkpoint['best_prec1'] 468 | model.load_state_dict(checkpoint['state_dict'],strict=False) 469 | logging.info('=> loaded checkpoint `{}` (iter: {})'.format( 470 | args.resume, checkpoint['iter'] 471 | )) 472 | else: 473 | logging.info('=> no checkpoint found at `{}`'.format(args.resume)) 474 | 475 | cudnn.benchmark = False 476 | test_loader = prepare_test_data(dataset=args.dataset, 477 | batch_size=args.batch_size, 478 | shuffle=False, 479 | num_workers=args.workers) 480 | criterion = nn.CrossEntropyLoss().cuda() 481 | 482 | with torch.no_grad(): 483 | validate(args, test_loader, model, criterion) 484 | 485 | 486 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 487 | torch.save(state, filename) 488 | if is_best: 489 | save_path = os.path.dirname(filename) 490 | shutil.copyfile(filename, os.path.join(save_path, 491 | 'model_best.pth.tar')) 492 | 493 | 494 | class AverageMeter(object): 495 | """Computes and stores the average and current value""" 496 | 497 | def __init__(self): 498 | self.reset() 499 | 500 | def reset(self): 501 | self.val = 0 502 | self.avg = 0 503 | self.sum = 0 504 | self.count = 0 505 | 506 | def update(self, val, n=1): 507 | self.val = val 508 | self.sum += val * n 509 | self.count += n 510 | self.avg = self.sum / self.count 511 | 512 | 513 | class ListAverageMeter(object): 514 | """Computes and stores the average and current values of a list""" 515 | def __init__(self): 516 | self.len = 10000 # set up the maximum length 517 | self.reset() 518 | 519 | def reset(self): 520 | self.val = [0] * self.len 521 | self.avg = [0] * self.len 522 | self.sum = [0] * self.len 523 | self.count = 0 524 | 525 | def set_len(self, n): 526 | self.len = n 527 | self.reset() 528 | 529 | def update(self, vals, n=1): 530 | assert len(vals) == self.len, 'length of vals not equal to self.len' 531 | self.val = vals 532 | for i in range(self.len): 533 | self.sum[i] += self.val[i] * n 534 | self.count += n 535 | for i in range(self.len): 536 | self.avg[i] = self.sum[i] / self.count 537 | 538 | 539 | def adjust_learning_rate(args, optimizer, _iter): 540 | """ divide lr by 10 at 32k and 48k """ 541 | if args.warm_up and (_iter < 400): 542 | lr = 0.01 543 | elif 32000 <= _iter < 48000: 544 | lr = args.lr * (args.step_ratio ** 1) 545 | elif _iter >= 48000: 546 | lr = args.lr * (args.step_ratio ** 2) 547 | else: 548 | lr = args.lr 549 | 550 | if _iter % args.eval_every == 0: 551 | logging.info('Iter [{}] learning rate = {}'.format(_iter, lr)) 552 | 553 | for param_group in optimizer.param_groups: 554 | param_group['lr'] = lr 555 | 556 | 557 | def accuracy(output, target, topk=(1,)): 558 | """Computes the precision@k for the specified values of k""" 559 | maxk = max(topk) 560 | batch_size = target.size(0) 561 | 562 | _, pred = output.topk(maxk, 1, True, True) 563 | pred = pred.t() 564 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 565 | 566 | res = [] 567 | for k in topk: 568 | correct_k = correct[:k].view(-1).float().sum(0) 569 | res.append(correct_k.mul_(100.0 / batch_size)) 570 | return res 571 | 572 | 573 | if __name__ == '__main__': 574 | main() 575 | --------------------------------------------------------------------------------