├── models ├── __init__.py ├── glouncv │ ├── __init__.py │ ├── mobilenetv2.py │ ├── alexnet.py │ ├── alexnet_bn.py │ ├── preresnet_cifar.py │ └── preresnet.py ├── imagenet_presnet.py └── cifar100_presnet.py ├── monitors ├── __init__.py └── metrics.py ├── quantizer ├── __init__.py └── uniq.py ├── requirements.txt ├── lr_scheduler.py ├── README.md ├── main.py └── utils.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /monitors/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /quantizer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/glouncv/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.10.0 2 | cachetools==4.1.1 3 | certifi==2020.6.20 4 | chardet==3.0.4 5 | future==0.18.2 6 | google-auth==1.21.2 7 | google-auth-oauthlib==0.4.1 8 | grpcio==1.32.0 9 | idna==2.10 10 | importlib-metadata==1.7.0 11 | Markdown==3.2.2 12 | numpy==1.18.5 13 | oauthlib==3.1.0 14 | Pillow==7.2.0 15 | pkg-resources==0.0.0 16 | protobuf==3.13.0 17 | pyasn1==0.4.8 18 | pyasn1-modules==0.2.8 19 | requests==2.24.0 20 | requests-oauthlib==1.3.0 21 | rsa==4.6 22 | six==1.15.0 23 | tensorboard==2.3.0 24 | tensorboard-plugin-wit==1.7.0 25 | torch==1.4.0 26 | torchvision==0.5.0 27 | urllib3==1.25.10 28 | Werkzeug==1.0.1 29 | zipp==1.2.0 30 | -------------------------------------------------------------------------------- /lr_scheduler.py: -------------------------------------------------------------------------------- 1 | class ConstantWarmupScheduler(object): 2 | 3 | def __init__(self, optimizer, min_lr=0.001, total_epoch=5, after_lr=0.01, after_scheduler=None): 4 | self.optimizer = optimizer 5 | self.total_epoch = total_epoch 6 | self.min_lr = min_lr 7 | self.after_lr = after_lr 8 | self.after_scheduler = after_scheduler 9 | self._current_epoch = 0 10 | super(ConstantWarmupScheduler, self).__init__() 11 | 12 | def step(self): 13 | if self._current_epoch < self.total_epoch: 14 | for param_group in self.optimizer.param_groups: 15 | param_group['lr'] = self.min_lr 16 | else: 17 | if self._current_epoch == self.total_epoch: 18 | for param_group in self.optimizer.param_groups: 19 | param_group['lr'] = self.after_lr 20 | 21 | self.after_scheduler.step() 22 | self._current_epoch += 1 23 | 24 | 25 | def state_dict(self): 26 | self.after_scheduler.state_dict() \ 27 | if self._current_epoch >= self.total_epoch else None -------------------------------------------------------------------------------- /monitors/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def write_metrics(writer, epoch, net, wt_optimizer, train_loss, train_acc1, test_loss, test_acc1, prefix="Train"): 6 | 7 | writer.add_scalar('%s_Train/Loss' % (prefix), train_loss, epoch) 8 | writer.add_scalar('%s_Train/Acc1'% (prefix), train_acc1, epoch) 9 | writer.add_scalar('%s_Test/Loss' % (prefix), test_loss, epoch) 10 | writer.add_scalar('%s_Test/Acc1' % (prefix), test_acc1, epoch) 11 | writer.add_scalar('%s_Train/LR' % (prefix), wt_optimizer.param_groups[0]['lr'], epoch) 12 | 13 | for n, param in net.named_parameters(): 14 | if ".delta" in n: 15 | if param.ndim == 0: 16 | writer.add_scalar('{}_Train/delta_{}'.format(prefix, n), param, epoch) 17 | else: 18 | writer.add_histogram('{}_Train/delta_{}'.format(prefix, n), param, epoch) 19 | 20 | # Weight Histogram 21 | for n, m in net.named_modules(): 22 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear) or isinstance(m, nn.BatchNorm2d): 23 | writer.add_histogram('{}_Train/{}.weight'.format(prefix, n), m.weight, epoch) 24 | writer.add_histogram('{}_Train/{}.weight.grad'.format(prefix, n), m.weight.grad, epoch) 25 | 26 | if m.bias != None: 27 | writer.add_histogram('{}_Train/{}.bias'.format(prefix, n), m.bias, epoch) 28 | writer.add_histogram('{}_Train/{}.bias.grad'.format(prefix, n), m.bias.grad, epoch) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UniQ 2 | 3 | This repo contains the code and data of the following paper: 4 | 5 | **Training Multi-bit Quantized and Binarized Networks with A Learnable Symmetric Quantizer** 6 | 7 | 8 | ## Prerequisites 9 | Use the package manager [pip](https://pip.pypa.io/en/stable/) to install the library dependencies 10 | 11 | 12 | ```bash 13 | pip install -r requirements.txt 14 | ``` 15 | 16 | ## Training 17 | 18 | ```bash 19 | export CUDA_VISIBLE_DEVICES=[GPU_IDs] && \ 20 | python main.py --train_id [training_id] \ 21 | --lr [learning_rate_value] --wd [weight_decay_value] --batch-size [batch_size] \ 22 | --dataset [dataset_name] --arch [architecture_name] \ 23 | --bit [bit-width] --epoch [training_epochs] \ 24 | --data_root [path_to_dataset] \ 25 | --init_from [path_to_pretrained_model] \ 26 | --train_scheme uniq --quant_mode [quantization_mode] \ 27 | --num_calibration_batches [number_of_batches_for_initialization] 28 | ``` 29 | 30 | 31 | 32 | ## Testing 33 | 34 | ```bash 35 | export CUDA_VISIBLE_DEVICES=[GPU_IDs] && \ 36 | python main.py --train_id [training_id] \ 37 | --batch-size [batch_size] \ 38 | --dataset [dataset_name] --arch [architecture_name] \ 39 | --bit [bit-width] 40 | --data_root [path_to_dataset] \ 41 | --init_from [path_to_trained_model] \ 42 | --train_scheme uniq --quant_mode [quantization_mode] \ 43 | -e 44 | ``` 45 | 46 | 47 | | Arguments | Description | 48 | | ------------- | ------------- | 49 | | `--train_id` | ID for experiment management (arbitrary). | 50 | | `--lr` | Learning rate | 51 | | `--wd` | Weight decay | 52 | | `--batch_size` | Batch size | 53 | | `--dataset` | Dataset name
Possible values: `cifar100`, `imagenet` | 54 | | `--data_root` | Path to the dataset directory | 55 | | `--arch` | Architecture name
Possible values: `presnet18`, `presnet32`, `glouncv-presnet34`, `glouncv-mobilenetv2_w1` | 56 | | `--bit` | Bit-width (W/A) | 57 | | `--epoch` | Number of training epochs | 58 | | `--init_from` | Path to the pretrained model. | 59 | | `--train_scheme` | Training scheme
Possible values: `fp32` (normal training), `uniq` (low-bit quantization training) | 60 | | `--quant_mode` | Quantization mode
Possible values: `layer_wise` (layer-wise quantization), `kernel-wise` (kernel-wise quantization) | 61 | | `--num_calibration_batches` | Number of batches used for initialization | 62 | 63 | 64 | For each experiment details and hyperparameter setting, we refer the readers to the paper and `main.py` file. 65 | 66 | ## Citation 67 | If you find RBNN useful in your research, please consider citing: 68 | ``` 69 | @ARTICLE{9383003, 70 | author={P. {Pham} and J. A. {Abraham} and J. {Chung}}, 71 | journal={IEEE Access}, 72 | title={Training Multi-Bit Quantized and Binarized Networks with a Learnable Symmetric Quantizer}, 73 | year={2021}, 74 | volume={9}, 75 | number={}, 76 | pages={47194-47203}, 77 | doi={10.1109/ACCESS.2021.3067889}} 78 | ``` 79 | 80 | ## Contributing 81 | Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change. 82 | 83 | Please make sure to update tests as appropriate. 84 | -------------------------------------------------------------------------------- /models/imagenet_presnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | 5 | # based on 6 | https://github.com/kuangliu/pytorch-cifar/blob/master/models/preact_resnet.py 7 | https://raw.githubusercontent.com/NVlabs/Taylor_pruning/b21ed61ac41cb59a9879a95350bd752ab26ffd91/models/preact_resnet.py 8 | """ 9 | 10 | '''Pre-activation ResNet in PyTorch. 11 | Reference: 12 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 13 | Identity Mappings in Deep Residual Networks. arXiv:1603.05027 14 | ''' 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | 20 | 21 | class PreActBlock(nn.Module): 22 | '''Pre-activation version of the BasicBlock.''' 23 | expansion = 1 24 | 25 | def __init__(self, in_planes, planes, stride=1): 26 | super(PreActBlock, self).__init__() 27 | self.bn1 = nn.BatchNorm2d(in_planes) 28 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 29 | self.relu1 = nn.ReLU(inplace=True) 30 | 31 | self.bn2 = nn.BatchNorm2d(planes) 32 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 33 | self.relu2 = nn.ReLU(inplace=True) 34 | 35 | if stride != 1 or in_planes != self.expansion*planes: 36 | self.shortcut = nn.Sequential( 37 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 38 | ) 39 | 40 | def forward(self, x): 41 | out = self.relu1(self.bn1(x)) 42 | 43 | if hasattr(self, 'shortcut'): 44 | shortcut = self.shortcut(out) 45 | else: 46 | shortcut = x 47 | 48 | out = self.conv1(out) 49 | out = self.bn2(out) 50 | out = self.relu2(out) 51 | out = self.conv2(out) 52 | out = out + shortcut 53 | return out 54 | 55 | 56 | class PreActBottleneck(nn.Module): 57 | '''Pre-activation version of the original Bottleneck module.''' 58 | expansion = 4 59 | 60 | def __init__(self, in_planes, planes, stride=1): 61 | super(PreActBottleneck, self).__init__() 62 | 63 | self.bn1 = nn.BatchNorm2d(in_planes) 64 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 65 | 66 | self.bn2 = nn.BatchNorm2d(planes) 67 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 68 | 69 | self.bn3 = nn.BatchNorm2d(planes) 70 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 71 | 72 | if stride != 1 or in_planes != self.expansion*planes: 73 | self.shortcut = nn.Sequential( 74 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)) 75 | 76 | def forward(self, x): 77 | out = F.relu(self.bn1(x)) 78 | input_out = out 79 | 80 | out = self.conv1(out) 81 | out = self.bn2(out) 82 | out = F.relu(out) 83 | 84 | out = self.conv2(out) 85 | out = self.bn3(out) 86 | 87 | out = F.relu(out) 88 | out = self.conv3(out) 89 | 90 | if hasattr(self, 'shortcut'): 91 | shortcut = self.shortcut(input_out) 92 | else: 93 | shortcut = x 94 | 95 | out = out + shortcut 96 | return out 97 | 98 | 99 | class PreActResNet(nn.Module): 100 | def __init__(self, block, num_blocks, num_classes=1000): 101 | super(PreActResNet, self).__init__() 102 | 103 | self.in_planes = 64 104 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 105 | bias=False) 106 | self.bn1 = nn.BatchNorm2d(64) 107 | self.relu1 = nn.ReLU(inplace=True) 108 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 109 | 110 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 111 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 112 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 113 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 114 | 115 | # Pre-activation 116 | self.bn2 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 117 | self.relu2 = nn.ReLU(inplace=True) 118 | 119 | self.avgpool = nn.AvgPool2d(7, stride=1) 120 | self.fc = nn.Linear(512 * block.expansion, num_classes) 121 | 122 | for m in self.modules(): 123 | if isinstance(m, nn.Conv2d): 124 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 125 | if isinstance(m, nn.BatchNorm2d): 126 | m.bias.data.zero_() 127 | 128 | def _make_layer(self, block, planes, num_blocks, stride): 129 | strides = [stride] + [1]*(num_blocks-1) 130 | layers = [] 131 | for stride in strides: 132 | layers.append(block(self.in_planes, planes, stride)) 133 | self.in_planes = planes * block.expansion 134 | return nn.Sequential(*layers) 135 | 136 | def forward(self, x): 137 | out = self.conv1(x) 138 | out = self.bn1(out) 139 | out = self.relu1(out) 140 | out = self.maxpool(out) 141 | 142 | out = self.layer1(out) 143 | out = self.layer2(out) 144 | out = self.layer3(out) 145 | out = self.layer4(out) 146 | 147 | out = self.bn2(out) 148 | out = self.relu2(out) 149 | 150 | out = self.avgpool(out) 151 | 152 | out = out.view(out.size(0), -1) 153 | out = self.fc(out) 154 | return out 155 | 156 | 157 | def PreActResNet18(): 158 | return PreActResNet(PreActBlock, [2,2,2,2]) 159 | 160 | def PreActResNet34(): 161 | return PreActResNet(PreActBlock, [3,4,6,3]) 162 | 163 | def PreActResNet50(): 164 | return PreActResNet(PreActBottleneck, [3,4,6,3]) 165 | 166 | def PreActResNet101(): 167 | return PreActResNet(PreActBottleneck, [3,4,23,3]) 168 | 169 | def PreActResNet152(): 170 | return PreActResNet(PreActBottleneck, [3,8,36,3]) 171 | 172 | 173 | def test(): 174 | net = PreActResNet18() 175 | y = net((torch.randn(1,3,32,32))) 176 | print(y.size()) 177 | 178 | # test() 179 | -------------------------------------------------------------------------------- /quantizer/uniq.py: -------------------------------------------------------------------------------- 1 | import torch as t 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import torch 5 | import math 6 | 7 | 8 | def grad_scale(x, scale): 9 | y = x 10 | y_grad = x * scale 11 | return (y - y_grad).detach() + y_grad 12 | 13 | 14 | def round_pass(x): 15 | y = x.round() 16 | y_grad = x 17 | return (y - y_grad).detach() + y_grad 18 | 19 | 20 | class STATUS(object): 21 | INIT_READY = 0 22 | INIT_DONE = 1 23 | NOT_READY = -1 24 | 25 | 26 | 27 | class UniQQuantizer(t.nn.Module): 28 | def __init__(self, bit, is_activation=False, **kwargs): 29 | super(UniQQuantizer,self).__init__() 30 | 31 | self.bit = bit 32 | self.is_activation = is_activation 33 | self.delta_normal = {1: 1.595769121605729, 2: 0.9956866859435065, 3: 0.5860194414434872, 4: 0.33520061219993685, 5: 0.18813879027991698, 6: 0.10406300944201481, 7: 0.05686767238235839, 8: 0.03076238758025524, 9: 0.016498958773102656} 34 | self.delta_positive_normal = {1: 1.22399153, 2: 0.65076985, 3: 0.35340955, 4: 0.19324868, 5: 0.10548752, 6: 0.0572659, 7: 0.03087133, 8: 0.01652923, 9: 0.00879047} 35 | self.quant_mode = kwargs.get('quant_mode', 'layer_wise') 36 | self.layer_type = kwargs.get('layer_type', 'conv') 37 | 38 | if self.quant_mode == 'layer_wise': 39 | self.delta = nn.Parameter(torch.tensor(0.0), requires_grad=True) 40 | 41 | elif self.quant_mode == 'kernel_wise': 42 | assert kwargs['num_channels'] > 1 43 | if self.layer_type == 'conv': 44 | shape = [1, kwargs['num_channels'], 1, 1] if self.is_activation else [kwargs['num_channels'], 1, 1, 1] 45 | self.delta = nn.Parameter(torch.Tensor(*shape), requires_grad=True) 46 | else: 47 | shape = [1, kwargs['num_channels']] if self.is_activation else [kwargs['num_channels'], 1] 48 | self.delta = nn.Parameter(torch.Tensor(*shape), requires_grad=True) 49 | 50 | self.kwargs = kwargs 51 | self.register_buffer('init_state', torch.tensor(STATUS.NOT_READY)) 52 | self.register_buffer('min_val', torch.tensor(0.0, dtype=torch.float)) 53 | self.register_buffer('max_val', torch.tensor(2**(self.bit) - 1, dtype=torch.float)) 54 | 55 | 56 | def set_init_state(self, value): 57 | self.init_state.fill_(value) 58 | 59 | def initialization(self, x): 60 | if self.is_activation: 61 | if self.quant_mode == 'kernel_wise': 62 | if self.layer_type == 'conv': 63 | _meanx = (x.detach()**2).view(x.shape[0], -1, x.shape[2] * x.shape[3]).mean(2, True).mean(0, True).view(1, -1, 1, 1) 64 | 65 | elif self.layer_type == 'linear': 66 | _meanx = (x.detach()**2).mean(1, True).mean(0, True).view(1, 1) 67 | 68 | _meanx[_meanx==0] = _meanx[_meanx!=0].min() 69 | pre_relu_std = ((2*_meanx))**0.5 70 | else: 71 | pre_relu_std = (2*((x.detach()**2).mean()))**0.5 72 | self.delta.data.copy_(torch.max(self.delta.data, pre_relu_std * self.delta_positive_normal[self.bit])) 73 | 74 | else: 75 | 76 | if self.quant_mode == 'kernel_wise': 77 | if self.layer_type == 'conv': 78 | std = x.detach().view(-1, x.shape[1] * x.shape[2] * x.shape[3]).std(1, True).view(-1, 1, 1, 1) 79 | if self.layer_type == 'linear': 80 | std = x.detach().view(-1, x.shape[1]).std(1, True).view(-1, 1) 81 | else: 82 | std = x.detach().std() 83 | self.delta.data.copy_( std * self.delta_normal[self.bit]) 84 | 85 | def forward(self, x): 86 | if self.training and self.init_state == STATUS.INIT_READY: 87 | self.initialization(x) 88 | 89 | # Quantization 90 | if self.is_activation: 91 | if self.quant_mode == 'kernel_wise': 92 | g = 1.0 / math.sqrt((x.numel() / x.shape[1]) * (2**self.bit -1)) 93 | else: 94 | g = 1.0 / math.sqrt(x.numel() * (2**self.bit -1)) 95 | 96 | step_size = grad_scale(self.delta, g) 97 | x = x / step_size 98 | x = round_pass(torch.min(torch.max(x, self.min_val), self.max_val)) * step_size 99 | else: 100 | 101 | if self.quant_mode== 'kernel_wise': 102 | g = 1.0 / math.sqrt((x.numel() / x.shape[0]) * max((2**(self.bit-1) -1),1)) 103 | else: 104 | g = 1.0 / math.sqrt(x.numel() * max((2**(self.bit-1) -1),1)) 105 | 106 | step_size = grad_scale(self.delta, g) 107 | alpha = step_size * self.max_val * 0.5 108 | x = (x + alpha) / step_size 109 | x = round_pass(torch.min(torch.max(x, self.min_val), self.max_val)) * step_size - alpha 110 | 111 | return x 112 | 113 | def extra_repr(self): 114 | return "bit=%s, is_activation=%s, quant_mode=%s" % \ 115 | (self.bit, self.is_activation, self.kwargs.get('quant_mode', 'layer_wise')) 116 | 117 | 118 | 119 | class UniQConv2d(nn.Conv2d): 120 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 121 | padding=0, dilation=1, groups=1, bias=True, bit=4, quant_mode='layer_wise'): 122 | 123 | 124 | super(UniQConv2d, self).__init__( 125 | in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 126 | stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) 127 | 128 | # use per-channel quantization (optinal) for weights only. 129 | self.quan_w = UniQQuantizer(bit=bit, is_activation=False, quant_mode=quant_mode, num_channels=out_channels) 130 | self.quan_a = UniQQuantizer(bit=bit, is_activation=True, quant_mode='layer_wise', num_channels=in_channels) 131 | self.bit = bit 132 | 133 | def forward(self, x): 134 | if self.bit == 32: 135 | return F.conv2d(x, self.weight, self.bias, self.stride, 136 | self.padding, self.dilation, self.groups) 137 | else: 138 | return F.conv2d(self.quan_a(x), self.quan_w(self.weight), self.bias, self.stride, 139 | self.padding, self.dilation, self.groups) 140 | 141 | class UniQInputConv2d(nn.Conv2d): 142 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 143 | padding=0, dilation=1, groups=1, bias=True, bit=4, quant_mode='layer_wise'): 144 | 145 | 146 | super(UniQInputConv2d, self).__init__( 147 | in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 148 | stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) 149 | 150 | #always use `layer_wise` for the first layer 151 | self.quan_w = UniQQuantizer(bit=bit, is_activation=False, quant_mode=quant_mode, num_channels=out_channels) 152 | self.quan_a = UniQQuantizer(bit=bit, is_activation=False, quant_mode='layer_wise', num_channels=in_channels) 153 | self.bit = bit 154 | 155 | def forward(self, x): 156 | if self.bit == 32: 157 | return F.conv2d(x, self.weight, self.bias, self.stride, 158 | self.padding, self.dilation, self.groups) 159 | else: 160 | return F.conv2d(self.quan_a(x), self.quan_w(self.weight), self.bias, self.stride, 161 | self.padding, self.dilation, self.groups) 162 | 163 | 164 | class UniQLinear(nn.Linear): 165 | def __init__(self, in_features, out_features, bias=True, bit=4, quant_mode='layer_wise'): 166 | 167 | super(UniQLinear, self).__init__(in_features=in_features, out_features=out_features, bias=bias) 168 | 169 | #always use `layer_wise` for the last layer 170 | self.quan_w = UniQQuantizer(bit=bit, is_activation=False, quant_mode=quant_mode, num_channels=out_features, layer_type='linear') 171 | self.quan_a = UniQQuantizer(bit=bit, is_activation=True, quant_mode='layer_wise', num_channels=in_features, layer_type='linear') 172 | self.bit = bit 173 | 174 | def forward(self, x): 175 | if self.bit == 32: 176 | return F.linear(x, self.weight, self.bias) 177 | else: 178 | return F.linear(self.quan_a(x), self.quan_w(self.weight), self.bias) 179 | -------------------------------------------------------------------------------- /models/cifar100_presnet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | resnet for cifar in pytorch 3 | 4 | Reference: 5 | [1] K. He, X. Zhang, S. Ren, and J. Sun. Deep residual learning for image recognition. In CVPR, 2016. 6 | [2] K. He, X. Zhang, S. Ren, and J. Sun. Identity mappings in deep residual networks. In ECCV, 2016. 7 | ''' 8 | 9 | import torch 10 | import torch.nn as nn 11 | import math 12 | 13 | 14 | def conv3x3(in_planes, out_planes, stride=1): 15 | " 3x3 convolution with padding " 16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 17 | 18 | 19 | class BasicBlock(nn.Module): 20 | expansion=1 21 | 22 | def __init__(self, inplanes, planes, stride=1, downsample=None): 23 | super(BasicBlock, self).__init__() 24 | self.conv1 = conv3x3(inplanes, planes, stride) 25 | self.bn1 = nn.BatchNorm2d(planes) 26 | self.relu = nn.ReLU(inplace=True) 27 | self.conv2 = conv3x3(planes, planes) 28 | self.bn2 = nn.BatchNorm2d(planes) 29 | self.downsample = downsample 30 | self.stride = stride 31 | 32 | def forward(self, x): 33 | residual = x 34 | 35 | out = self.conv1(x) 36 | out = self.bn1(out) 37 | out = self.relu(out) 38 | 39 | out = self.conv2(out) 40 | out = self.bn2(out) 41 | 42 | if self.downsample is not None: 43 | residual = self.downsample(x) 44 | 45 | out += residual 46 | out = self.relu(out) 47 | 48 | return out 49 | 50 | 51 | class Bottleneck(nn.Module): 52 | expansion=4 53 | 54 | def __init__(self, inplanes, planes, stride=1, downsample=None): 55 | super(Bottleneck, self).__init__() 56 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 57 | self.bn1 = nn.BatchNorm2d(planes) 58 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 59 | self.bn2 = nn.BatchNorm2d(planes) 60 | self.conv3 = nn.Conv2d(planes, planes*4, kernel_size=1, bias=False) 61 | self.bn3 = nn.BatchNorm2d(planes*4) 62 | self.relu = nn.ReLU(inplace=True) 63 | self.downsample = downsample 64 | self.stride = stride 65 | 66 | def forward(self, x): 67 | residual = x 68 | 69 | out = self.conv1(x) 70 | out = self.bn1(out) 71 | out = self.relu(out) 72 | 73 | out = self.conv2(out) 74 | out = self.bn2(out) 75 | out = self.relu(out) 76 | 77 | out = self.conv3(out) 78 | out = self.bn3(out) 79 | 80 | if self.downsample is not None: 81 | residual = self.downsample(x) 82 | 83 | out += residual 84 | out = self.relu(out) 85 | 86 | return out 87 | 88 | 89 | class PreActBasicBlock(nn.Module): 90 | expansion = 1 91 | 92 | def __init__(self, inplanes, planes, stride=1, downsample=None): 93 | super(PreActBasicBlock, self).__init__() 94 | self.bn1 = nn.BatchNorm2d(inplanes) 95 | self.relu = nn.ReLU(inplace=True) 96 | self.conv1 = conv3x3(inplanes, planes, stride) 97 | self.bn2 = nn.BatchNorm2d(planes) 98 | self.conv2 = conv3x3(planes, planes) 99 | self.downsample = downsample 100 | self.stride = stride 101 | 102 | def forward(self, x): 103 | residual = x 104 | 105 | out = self.bn1(x) 106 | out = self.relu(out) 107 | 108 | if self.downsample is not None: 109 | residual = self.downsample(out) 110 | 111 | out = self.conv1(out) 112 | 113 | out = self.bn2(out) 114 | out = self.relu(out) 115 | out = self.conv2(out) 116 | 117 | out += residual 118 | 119 | return out 120 | 121 | 122 | class PreActBottleneck(nn.Module): 123 | expansion = 4 124 | 125 | def __init__(self, inplanes, planes, stride=1, downsample=None): 126 | super(PreActBottleneck, self).__init__() 127 | self.bn1 = nn.BatchNorm2d(inplanes) 128 | self.relu = nn.ReLU(inplace=True) 129 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 130 | self.bn2 = nn.BatchNorm2d(planes) 131 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 132 | self.bn3 = nn.BatchNorm2d(planes) 133 | self.conv3 = nn.Conv2d(planes, planes*4, kernel_size=1, bias=False) 134 | self.downsample = downsample 135 | self.stride = stride 136 | 137 | def forward(self, x): 138 | residual = x 139 | 140 | out = self.bn1(x) 141 | out = self.relu(out) 142 | 143 | if self.downsample is not None: 144 | residual = self.downsample(out) 145 | 146 | out = self.conv1(out) 147 | 148 | out = self.bn2(out) 149 | out = self.relu(out) 150 | out = self.conv2(out) 151 | 152 | out = self.bn3(out) 153 | out = self.relu(out) 154 | out = self.conv3(out) 155 | 156 | out += residual 157 | 158 | return out 159 | 160 | 161 | class ResNet_Cifar(nn.Module): 162 | 163 | def __init__(self, block, layers, num_classes=10): 164 | super(ResNet_Cifar, self).__init__() 165 | self.inplanes = 16 166 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 167 | self.bn1 = nn.BatchNorm2d(16) 168 | self.relu = nn.ReLU(inplace=True) 169 | self.layer1 = self._make_layer(block, 16, layers[0]) 170 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2) 171 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2) 172 | self.avgpool = nn.AvgPool2d(8, stride=1) 173 | self.fc = nn.Linear(64 * block.expansion, num_classes) 174 | 175 | for m in self.modules(): 176 | if isinstance(m, nn.Conv2d): 177 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 178 | m.weight.data.normal_(0, math.sqrt(2. / n)) 179 | elif isinstance(m, nn.BatchNorm2d): 180 | m.weight.data.fill_(1) 181 | m.bias.data.zero_() 182 | 183 | def _make_layer(self, block, planes, blocks, stride=1): 184 | downsample = None 185 | if stride != 1 or self.inplanes != planes * block.expansion: 186 | downsample = nn.Sequential( 187 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), 188 | nn.BatchNorm2d(planes * block.expansion) 189 | ) 190 | 191 | layers = [] 192 | layers.append(block(self.inplanes, planes, stride, downsample)) 193 | self.inplanes = planes * block.expansion 194 | for _ in range(1, blocks): 195 | layers.append(block(self.inplanes, planes)) 196 | 197 | return nn.Sequential(*layers) 198 | 199 | def forward(self, x): 200 | x = self.conv1(x) 201 | x = self.bn1(x) 202 | x = self.relu(x) 203 | 204 | x = self.layer1(x) 205 | x = self.layer2(x) 206 | x = self.layer3(x) 207 | 208 | x = self.avgpool(x) 209 | x = x.view(x.size(0), -1) 210 | x = self.fc(x) 211 | 212 | return x 213 | 214 | 215 | class PreAct_ResNet_Cifar(nn.Module): 216 | 217 | def __init__(self, block, layers, num_classes=10): 218 | super(PreAct_ResNet_Cifar, self).__init__() 219 | self.inplanes = 16 220 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 221 | self.layer1 = self._make_layer(block, 16, layers[0]) 222 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2) 223 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2) 224 | self.bn = nn.BatchNorm2d(64*block.expansion) 225 | self.relu = nn.ReLU(inplace=True) 226 | self.avgpool = nn.AvgPool2d(8, stride=1) 227 | self.fc = nn.Linear(64*block.expansion, num_classes) 228 | 229 | for m in self.modules(): 230 | if isinstance(m, nn.Conv2d): 231 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 232 | m.weight.data.normal_(0, math.sqrt(2. / n)) 233 | elif isinstance(m, nn.BatchNorm2d): 234 | m.weight.data.fill_(1) 235 | m.bias.data.zero_() 236 | 237 | def _make_layer(self, block, planes, blocks, stride=1): 238 | downsample = None 239 | if stride != 1 or self.inplanes != planes*block.expansion: 240 | downsample = nn.Sequential( 241 | nn.Conv2d(self.inplanes, planes*block.expansion, kernel_size=1, stride=stride, bias=False) 242 | ) 243 | 244 | layers = [] 245 | layers.append(block(self.inplanes, planes, stride, downsample)) 246 | self.inplanes = planes*block.expansion 247 | for _ in range(1, blocks): 248 | layers.append(block(self.inplanes, planes)) 249 | return nn.Sequential(*layers) 250 | 251 | def forward(self, x): 252 | x = self.conv1(x) 253 | 254 | x = self.layer1(x) 255 | x = self.layer2(x) 256 | x = self.layer3(x) 257 | 258 | x = self.bn(x) 259 | x = self.relu(x) 260 | x = self.avgpool(x) 261 | x = x.view(x.size(0), -1) 262 | x = self.fc(x) 263 | 264 | return x 265 | 266 | 267 | 268 | def resnet20_cifar(**kwargs): 269 | model = ResNet_Cifar(BasicBlock, [3, 3, 3], **kwargs) 270 | return model 271 | 272 | 273 | def resnet32_cifar(**kwargs): 274 | model = ResNet_Cifar(BasicBlock, [5, 5, 5], **kwargs) 275 | return model 276 | 277 | 278 | def resnet44_cifar(**kwargs): 279 | model = ResNet_Cifar(BasicBlock, [7, 7, 7], **kwargs) 280 | return model 281 | 282 | 283 | def resnet56_cifar(**kwargs): 284 | model = ResNet_Cifar(BasicBlock, [9, 9, 9], **kwargs) 285 | return model 286 | 287 | 288 | def resnet110_cifar(**kwargs): 289 | model = ResNet_Cifar(BasicBlock, [18, 18, 18], **kwargs) 290 | return model 291 | 292 | 293 | def resnet1202_cifar(**kwargs): 294 | model = ResNet_Cifar(BasicBlock, [200, 200, 200], **kwargs) 295 | return model 296 | 297 | 298 | def resnet164_cifar(**kwargs): 299 | model = ResNet_Cifar(Bottleneck, [18, 18, 18], **kwargs) 300 | return model 301 | 302 | 303 | def resnet1001_cifar(**kwargs): 304 | model = ResNet_Cifar(Bottleneck, [111, 111, 111], **kwargs) 305 | return model 306 | 307 | def preact_resnet20_cifar(**kwargs): 308 | model = PreAct_ResNet_Cifar(PreActBasicBlock, [3, 3, 3], **kwargs) 309 | return model 310 | 311 | def preact_resnet32_cifar(**kwargs): 312 | model = PreAct_ResNet_Cifar(PreActBasicBlock, [5, 5, 5], **kwargs) 313 | return model 314 | 315 | def preact_resnet110_cifar(**kwargs): 316 | model = PreAct_ResNet_Cifar(PreActBasicBlock, [18, 18, 18], **kwargs) 317 | return model 318 | 319 | 320 | def preact_resnet164_cifar(**kwargs): 321 | model = PreAct_ResNet_Cifar(PreActBottleneck, [18, 18, 18], **kwargs) 322 | return model 323 | 324 | 325 | def preact_resnet1001_cifar(**kwargs): 326 | model = PreAct_ResNet_Cifar(PreActBottleneck, [111, 111, 111], **kwargs) 327 | return model 328 | 329 | 330 | if __name__ == '__main__': 331 | net = resnet20_cifar() 332 | y = net(torch.randn(1, 3, 64, 64)) 333 | print(net) 334 | print(y.size()) 335 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | import torch.backends.cudnn as cudnn 6 | 7 | import torchvision 8 | import torchvision.transforms as transforms 9 | 10 | import os 11 | import time 12 | import math 13 | import argparse 14 | import warnings 15 | import numpy as np 16 | 17 | from functools import partial 18 | from torch.utils.tensorboard import SummaryWriter 19 | from monitors.metrics import write_metrics 20 | 21 | import lr_scheduler 22 | import utils 23 | 24 | from models.imagenet_presnet import PreActResNet18 25 | from models.glouncv.alexnet import alexnet 26 | from models.glouncv.preresnet import preresnet34 27 | from models.glouncv.mobilenetv2 import mobilenetv2_w1 28 | from models.cifar100_presnet import preact_resnet32_cifar 29 | 30 | 31 | parser = argparse.ArgumentParser(description='PyTorch ImageNet/CIFAR Training') 32 | parser.add_argument('--lr', default=0.1, type=float, help='Main learning rate') 33 | parser.add_argument('--warmup_lr', default=0.001, type=float, help='Warmup learning rate') 34 | 35 | parser.add_argument('--wd', default=1e-4, type=float, help='weight decay') 36 | parser.add_argument('--bit', default=4, type=int, help='bit-width for UniQ quantizer') 37 | 38 | parser.add_argument('--dataset', default='imagenette', type=str, 39 | help='dataset name for training') 40 | parser.add_argument('--data_root', default = '/soc_local/data/pytorch/imagenet/', type=str, 41 | help='path to dataset') 42 | parser.add_argument('-b', '--batch-size', default=256, type=int, 43 | metavar='N', 44 | help='mini-batch size (default: 256), this is the total ' 45 | 'batch size of all GPUs on the current node when ' 46 | 'using Data Parallel or Distributed Data Parallel') 47 | parser.add_argument('--arch', default='resnet18', type=str, 48 | choices=['presnet18', 'presnet32', 'glouncv-presnet34', 'glouncv-mobilenetv2_w1'], 49 | help='network architecture') 50 | 51 | parser.add_argument('--init_from', type=str, 52 | help='init weights from from checkpoint') 53 | 54 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 55 | help='evaluate model on validation set') 56 | parser.add_argument('--epochs', default=120, type=int, help='number of training epochs') 57 | 58 | parser.add_argument('--train_id', type=str, default= 'train-01', 59 | help='training id, is used for collect experiment results') 60 | 61 | parser.add_argument('--train_scheme', type=str, default= 'fp32', choices=['fp32', 'uniq'], 62 | help='Training scheme') 63 | 64 | parser.add_argument('--optimizer', type=str, default= 'sgd', choices=['sgd', 'adam'], 65 | help='Optimizer selection.') 66 | 67 | parser.add_argument('--output_dir', type=str, default= 'outputs', 68 | help='output directory') 69 | 70 | parser.add_argument('--print_freq', default=10, type=int, help='log print frequency.') 71 | 72 | 73 | parser.add_argument('--quant_mode', type=str, default= 'layer_wise', choices=['layer_wise', 'kernel_wise'], 74 | help='Quantization mode') 75 | 76 | parser.add_argument('--num_calibration_batches', default=100, type=int, help='number of calibration training batches') 77 | 78 | parser.add_argument('--enable_warmup', dest='enable_warmup', action='store_true', 79 | help='Enable warm-up learning rate.') 80 | 81 | parser.add_argument('--warmup_epochs', default=5, type=int, help='number of epochs for warm-up') 82 | 83 | parser.add_argument('--dropout_ratio', default=0.1, type=float, help='dropout ratio for AlexNet.') 84 | 85 | 86 | args = parser.parse_args() 87 | print ("Script arguments:\n", args) 88 | 89 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 90 | best_acc = 0 91 | start_epoch = 0 92 | working_dir = os.path.join(args.output_dir, args.train_id) 93 | os.makedirs(working_dir, exist_ok=True) 94 | writer = SummaryWriter(working_dir) 95 | 96 | 97 | # Setup data. 98 | print('==> Preparing data..') 99 | trainloader, testloader = utils.get_dataloaders(dataset=args.dataset, batch_size=args.batch_size, data_root=args.data_root) 100 | 101 | # Setup model 102 | # ---------------------------------------- 103 | print('==> Building model..') 104 | if args.dataset == "imagenet": 105 | models = { 106 | 'presnet18': PreActResNet18, 107 | 'glouncv-alexnet': alexnet, 108 | 'glouncv-presnet34': preresnet34, 109 | 'glouncv-mobilenetv2_w1': mobilenetv2_w1 110 | } 111 | net = models.get(args.arch, None)() 112 | 113 | elif args.dataset == "cifar100": 114 | assert args.arch == "presnet32" 115 | net = preact_resnet32_cifar(num_classes=100) 116 | 117 | assert net != None 118 | 119 | 120 | 121 | # Module replacement 122 | # --------------------------------- 123 | if args.train_scheme.startswith("uniq"): 124 | from quantizer.uniq import UniQConv2d, UniQInputConv2d, UniQLinear 125 | if args.bit > 1: 126 | replacement_dict = { 127 | nn.Conv2d : partial(UniQConv2d, bit=args.bit, quant_mode=args.quant_mode), 128 | nn.Linear: partial(UniQLinear, bit=args.bit, quant_mode=args.quant_mode)} 129 | exception_dict = { 130 | '__first__': partial(UniQInputConv2d, bit=8), 131 | '__last__': partial(UniQLinear, bit=8), 132 | } 133 | 134 | if args.arch == "glouncv-mobilenetv2_w1": 135 | exception_dict['__last__'] = partial(UniQConv2d, bit=8) 136 | net = utils.replace_module(net, replacement_dict=replacement_dict, exception_dict=exception_dict, arch=args.arch) 137 | 138 | else: 139 | # All settings for binary neural networks. 140 | assert args.wd == 0 141 | replacement_dict = {nn.Conv2d : partial(UniQConv2d, bit=1, quant_mode=args.quant_mode), 142 | nn.Linear: partial(UniQLinear, bit=1, quant_mode=args.quant_mode) } 143 | exception_dict = { 144 | '__first__': partial(UniQInputConv2d, bit=32), 145 | '__last__': partial(UniQLinear, bit=32), 146 | '__downsampling__': partial(UniQConv2d, bit=32, quant_mode=args.quant_mode) 147 | } 148 | 149 | if args.arch == "glouncv-mobilenetv2_w1": 150 | exception_dict['__last__'] = partial(UniQConv2d, bit=32) 151 | net = utils.replace_module(net, replacement_dict=replacement_dict, exception_dict=exception_dict, arch=args.arch) 152 | 153 | # The following part is used for dropout ratio modification. 154 | if args.arch.startswith("glouncv-alexnet"): 155 | net.output.fc1.dropout = nn.Dropout(p=args.dropout_ratio, inplace=False) 156 | net.output.fc2.dropout = nn.Dropout(p=args.dropout_ratio, inplace=False) 157 | 158 | 159 | 160 | net = net.to(device) 161 | if device == 'cuda': 162 | net = torch.nn.DataParallel(net) 163 | cudnn.benchmark = True 164 | 165 | print (net) 166 | print ("Number of learnable parameters: ", sum(p.numel() for p in net.parameters() if p.requires_grad) / 1e6, "M") 167 | time.sleep(5) 168 | 169 | 170 | 171 | # Loading checkpoint 172 | # ----------------------------- 173 | if args.init_from and os.path.isfile(args.init_from): 174 | print('==> Initializing from checkpoint: ', args.init_from) 175 | checkpoint = torch.load(args.init_from) 176 | loaded_params = {} 177 | for k,v in checkpoint['net'].items(): 178 | if not k.startswith("module."): 179 | loaded_params["module." + k] = v 180 | else: 181 | loaded_params[k] = v 182 | 183 | net_state_dict = net.state_dict() 184 | net_state_dict.update(loaded_params) 185 | net.load_state_dict(net_state_dict) 186 | else: 187 | warnings.warn("No checkpoint file is provided !!!") 188 | 189 | 190 | 191 | params = utils.add_weight_decay(net, weight_decay=args.wd, skip_keys=['delta', 'alpha']) 192 | criterion = nn.CrossEntropyLoss() 193 | 194 | # Setup optimizer 195 | # ---------------------------- 196 | if args.optimizer == 'sgd': 197 | print ("==> Use SGD optimizer") 198 | optimizer = optim.SGD(params, lr=args.lr, 199 | momentum=0.9, weight_decay=args.wd) 200 | elif args.optimizer == 'adam': 201 | print ("==> Use Adam optimizer") 202 | optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.wd) 203 | 204 | 205 | # Setup LR scheduler 206 | # ---------------------------- 207 | if args.enable_warmup: 208 | lr_scheduler = lr_scheduler.ConstantWarmupScheduler(optimizer=optimizer, min_lr=args.warmup_lr, total_epoch=args.warmup_epochs, after_lr=args.lr, 209 | after_scheduler=optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs - args.warmup_epochs)) 210 | else: 211 | lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs ) 212 | 213 | 214 | 215 | def train(epoch, ): 216 | global args 217 | 218 | print('\nEpoch: %d' % epoch) 219 | net.train() 220 | train_loss = 0 221 | correct = 0 222 | total = 0 223 | 224 | for batch_idx, (inputs, targets) in enumerate(trainloader): 225 | inputs, targets = inputs.to(device), targets.to(device) 226 | optimizer.zero_grad() 227 | outputs = net(inputs) 228 | loss = criterion(outputs, targets) 229 | loss.backward() 230 | optimizer.step() 231 | train_loss += loss.item() 232 | _, predicted = outputs.max(1) 233 | total += targets.size(0) 234 | correct += predicted.eq(targets).sum().item() 235 | 236 | if batch_idx % args.print_freq == 0: 237 | print ("[Train] Epoch=", epoch, " BatchID=", batch_idx, 'Loss: %.3f | Acc: %.3f%% (%d/%d)' \ 238 | % (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) 239 | 240 | return (train_loss/batch_idx, correct/total) 241 | 242 | def test(epoch): 243 | global best_acc, args 244 | 245 | net.eval() 246 | test_loss = 0 247 | correct = 0 248 | total = 0 249 | with torch.no_grad(): 250 | for batch_idx, (inputs, targets) in enumerate(testloader): 251 | inputs, targets = inputs.to(device), targets.to(device) 252 | outputs = net(inputs) 253 | loss = criterion(outputs, targets) 254 | test_loss += loss.item() 255 | _, predicted = outputs.max(1) 256 | total += targets.size(0) 257 | correct += predicted.eq(targets).sum().item() 258 | 259 | if batch_idx % args.print_freq == 0: 260 | print ("[Test] Epoch=", epoch, " BatchID=", batch_idx, 'Loss: %.3f | Acc: %.3f%% (%d/%d)' \ 261 | % (test_loss/(batch_idx+1), 100.*correct/total, correct, total)) 262 | 263 | # Save checkpoint. 264 | acc = 100.*correct/total 265 | if acc > best_acc: 266 | best_acc = acc 267 | utils.save_checkpoint(net, lr_scheduler, optimizer, acc, epoch, 268 | filename=os.path.join(working_dir, 'ckpt_best.pth')) 269 | print('Saving..') 270 | print ('Best accuracy: ', best_acc) 271 | 272 | return (test_loss/batch_idx, correct/total) 273 | 274 | 275 | def simple_initialization(num_batches=100): 276 | net.train() 277 | from quantizer.uniq import STATUS, UniQConv2d, UniQInputConv2d, UniQLinear 278 | for n, m in net.named_modules(): 279 | if isinstance(m, UniQConv2d) or isinstance(m, UniQInputConv2d) or isinstance(m, UniQLinear): 280 | assert getattr(m, 'quan_a', None) != None 281 | assert getattr(m, 'quan_w', None) != None 282 | m.quan_a.set_init_state(STATUS.INIT_READY) 283 | m.quan_w.set_init_state(STATUS.INIT_READY) 284 | 285 | 286 | for batch_idx, (inputs, _) in enumerate(trainloader): 287 | inputs = inputs.to(device) 288 | net(inputs) 289 | if batch_idx + 1 == num_batches: break 290 | 291 | for n, m in net.named_modules(): 292 | if isinstance(m, UniQConv2d) or isinstance(m, UniQInputConv2d) or isinstance(m, UniQLinear): 293 | assert getattr(m, 'quan_a', None) != None 294 | assert getattr(m, 'quan_w', None) != None 295 | m.quan_a.set_init_state(STATUS.INIT_DONE) 296 | m.quan_w.set_init_state(STATUS.INIT_DONE) 297 | 298 | 299 | 300 | if args.evaluate: 301 | print ("==> Start evaluating ...") 302 | test(-1) 303 | exit() 304 | 305 | 306 | 307 | # Main training 308 | # ----------------------------------------------- 309 | # Reset to 'warmup_lr' if we are using warmup strategy. 310 | if args.enable_warmup: 311 | assert args.bit == 1 312 | for param_group in optimizer.param_groups: 313 | param_group['lr'] = args.warmup_lr 314 | 315 | # Initialization 316 | # ------------------------------------------------ 317 | if args.bit != 32 and args.train_scheme in ["uniq", ]: 318 | simple_initialization(num_batches=args.num_calibration_batches) 319 | 320 | # Training 321 | # ----------------------------------------------- 322 | for epoch in range(start_epoch, args.epochs): 323 | train_loss, train_acc1 = train(epoch) 324 | test_loss, test_acc1 = test(epoch) 325 | 326 | if lr_scheduler is not None: 327 | lr_scheduler.step() 328 | 329 | write_metrics(writer, epoch, net, \ 330 | optimizer, train_loss, train_acc1, test_loss, test_acc1, prefix="Standard_Training") 331 | -------------------------------------------------------------------------------- /models/glouncv/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | """ 2 | MobileNetV2 for ImageNet-1K, implemented in PyTorch. 3 | Original paper: 'MobileNetV2: Inverted Residuals and Linear Bottlenecks,' https://arxiv.org/abs/1801.04381. 4 | """ 5 | 6 | __all__ = ['MobileNetV2', 'mobilenetv2_w1', 'mobilenetv2_w3d4', 'mobilenetv2_wd2', 'mobilenetv2_wd4', 'mobilenetv2b_w1', 7 | 'mobilenetv2b_w3d4', 'mobilenetv2b_wd2', 'mobilenetv2b_wd4'] 8 | 9 | import os 10 | import torch.nn as nn 11 | import torch.nn.init as init 12 | from .common import conv1x1, conv1x1_block, conv3x3_block, dwconv3x3_block 13 | 14 | 15 | class LinearBottleneck(nn.Module): 16 | """ 17 | So-called 'Linear Bottleneck' layer. It is used as a MobileNetV2 unit. 18 | Parameters: 19 | ---------- 20 | in_channels : int 21 | Number of input channels. 22 | out_channels : int 23 | Number of output channels. 24 | stride : int or tuple/list of 2 int 25 | Strides of the second convolution layer. 26 | expansion : bool 27 | Whether do expansion of channels. 28 | remove_exp_conv : bool 29 | Whether to remove expansion convolution. 30 | """ 31 | def __init__(self, 32 | in_channels, 33 | out_channels, 34 | stride, 35 | expansion, 36 | remove_exp_conv): 37 | super(LinearBottleneck, self).__init__() 38 | self.residual = (in_channels == out_channels) and (stride == 1) 39 | mid_channels = in_channels * 6 if expansion else in_channels 40 | self.use_exp_conv = (expansion or (not remove_exp_conv)) 41 | 42 | if self.use_exp_conv: 43 | self.conv1 = conv1x1_block( 44 | in_channels=in_channels, 45 | out_channels=mid_channels, 46 | activation="relu6") 47 | self.conv2 = dwconv3x3_block( 48 | in_channels=mid_channels, 49 | out_channels=mid_channels, 50 | stride=stride, 51 | activation="relu6") 52 | self.conv3 = conv1x1_block( 53 | in_channels=mid_channels, 54 | out_channels=out_channels, 55 | activation=None) 56 | 57 | def forward(self, x): 58 | if self.residual: 59 | identity = x 60 | if self.use_exp_conv: 61 | x = self.conv1(x) 62 | x = self.conv2(x) 63 | x = self.conv3(x) 64 | if self.residual: 65 | x = x + identity 66 | return x 67 | 68 | 69 | class MobileNetV2(nn.Module): 70 | """ 71 | MobileNetV2 model from 'MobileNetV2: Inverted Residuals and Linear Bottlenecks,' https://arxiv.org/abs/1801.04381. 72 | Parameters: 73 | ---------- 74 | channels : list of list of int 75 | Number of output channels for each unit. 76 | init_block_channels : int 77 | Number of output channels for the initial unit. 78 | final_block_channels : int 79 | Number of output channels for the final block of the feature extractor. 80 | remove_exp_conv : bool 81 | Whether to remove expansion convolution. 82 | in_channels : int, default 3 83 | Number of input channels. 84 | in_size : tuple of two ints, default (224, 224) 85 | Spatial size of the expected input image. 86 | num_classes : int, default 1000 87 | Number of classification classes. 88 | """ 89 | def __init__(self, 90 | channels, 91 | init_block_channels, 92 | final_block_channels, 93 | remove_exp_conv, 94 | in_channels=3, 95 | in_size=(224, 224), 96 | num_classes=1000): 97 | super(MobileNetV2, self).__init__() 98 | self.in_size = in_size 99 | self.num_classes = num_classes 100 | 101 | self.features = nn.Sequential() 102 | self.features.add_module("init_block", conv3x3_block( 103 | in_channels=in_channels, 104 | out_channels=init_block_channels, 105 | stride=2, 106 | activation="relu6")) 107 | in_channels = init_block_channels 108 | for i, channels_per_stage in enumerate(channels): 109 | stage = nn.Sequential() 110 | for j, out_channels in enumerate(channels_per_stage): 111 | stride = 2 if (j == 0) and (i != 0) else 1 112 | expansion = (i != 0) or (j != 0) 113 | stage.add_module("unit{}".format(j + 1), LinearBottleneck( 114 | in_channels=in_channels, 115 | out_channels=out_channels, 116 | stride=stride, 117 | expansion=expansion, 118 | remove_exp_conv=remove_exp_conv)) 119 | in_channels = out_channels 120 | self.features.add_module("stage{}".format(i + 1), stage) 121 | self.features.add_module("final_block", conv1x1_block( 122 | in_channels=in_channels, 123 | out_channels=final_block_channels, 124 | activation="relu6")) 125 | in_channels = final_block_channels 126 | self.features.add_module("final_pool", nn.AvgPool2d( 127 | kernel_size=7, 128 | stride=1)) 129 | 130 | self.output = conv1x1( 131 | in_channels=in_channels, 132 | out_channels=num_classes, 133 | bias=False) 134 | 135 | self._init_params() 136 | 137 | def _init_params(self): 138 | for name, module in self.named_modules(): 139 | if isinstance(module, nn.Conv2d): 140 | init.kaiming_uniform_(module.weight) 141 | if module.bias is not None: 142 | init.constant_(module.bias, 0) 143 | 144 | def forward(self, x): 145 | x = self.features(x) 146 | x = self.output(x) 147 | x = x.view(x.size(0), -1) 148 | return x 149 | 150 | 151 | def get_mobilenetv2(width_scale, 152 | remove_exp_conv=False, 153 | model_name=None, 154 | pretrained=False, 155 | root=os.path.join("~", ".torch", "models"), 156 | **kwargs): 157 | """ 158 | Create MobileNetV2 model with specific parameters. 159 | Parameters: 160 | ---------- 161 | width_scale : float 162 | Scale factor for width of layers. 163 | remove_exp_conv : bool, default False 164 | Whether to remove expansion convolution. 165 | model_name : str or None, default None 166 | Model name for loading pretrained model. 167 | pretrained : bool, default False 168 | Whether to load the pretrained weights for model. 169 | root : str, default '~/.torch/models' 170 | Location for keeping the model parameters. 171 | """ 172 | 173 | init_block_channels = 32 174 | final_block_channels = 1280 175 | layers = [1, 2, 3, 4, 3, 3, 1] 176 | downsample = [0, 1, 1, 1, 0, 1, 0] 177 | channels_per_layers = [16, 24, 32, 64, 96, 160, 320] 178 | 179 | from functools import reduce 180 | channels = reduce( 181 | lambda x, y: x + [[y[0]] * y[1]] if y[2] != 0 else x[:-1] + [x[-1] + [y[0]] * y[1]], 182 | zip(channels_per_layers, layers, downsample), 183 | [[]]) 184 | 185 | if width_scale != 1.0: 186 | channels = [[int(cij * width_scale) for cij in ci] for ci in channels] 187 | init_block_channels = int(init_block_channels * width_scale) 188 | if width_scale > 1.0: 189 | final_block_channels = int(final_block_channels * width_scale) 190 | 191 | net = MobileNetV2( 192 | channels=channels, 193 | init_block_channels=init_block_channels, 194 | final_block_channels=final_block_channels, 195 | remove_exp_conv=remove_exp_conv, 196 | **kwargs) 197 | 198 | if pretrained: 199 | if (model_name is None) or (not model_name): 200 | raise ValueError("Parameter `model_name` should be properly initialized for loading pretrained model.") 201 | from .model_store import download_model 202 | download_model( 203 | net=net, 204 | model_name=model_name, 205 | local_model_store_dir_path=root) 206 | 207 | return net 208 | 209 | 210 | def mobilenetv2_w1(**kwargs): 211 | """ 212 | 1.0 MobileNetV2-224 model from 'MobileNetV2: Inverted Residuals and Linear Bottlenecks,' 213 | https://arxiv.org/abs/1801.04381. 214 | Parameters: 215 | ---------- 216 | pretrained : bool, default False 217 | Whether to load the pretrained weights for model. 218 | root : str, default '~/.torch/models' 219 | Location for keeping the model parameters. 220 | """ 221 | return get_mobilenetv2(width_scale=1.0, model_name="mobilenetv2_w1", **kwargs) 222 | 223 | 224 | def mobilenetv2_w3d4(**kwargs): 225 | """ 226 | 0.75 MobileNetV2-224 model from 'MobileNetV2: Inverted Residuals and Linear Bottlenecks,' 227 | https://arxiv.org/abs/1801.04381. 228 | Parameters: 229 | ---------- 230 | pretrained : bool, default False 231 | Whether to load the pretrained weights for model. 232 | root : str, default '~/.torch/models' 233 | Location for keeping the model parameters. 234 | """ 235 | return get_mobilenetv2(width_scale=0.75, model_name="mobilenetv2_w3d4", **kwargs) 236 | 237 | 238 | def mobilenetv2_wd2(**kwargs): 239 | """ 240 | 0.5 MobileNetV2-224 model from 'MobileNetV2: Inverted Residuals and Linear Bottlenecks,' 241 | https://arxiv.org/abs/1801.04381. 242 | Parameters: 243 | ---------- 244 | pretrained : bool, default False 245 | Whether to load the pretrained weights for model. 246 | root : str, default '~/.torch/models' 247 | Location for keeping the model parameters. 248 | """ 249 | return get_mobilenetv2(width_scale=0.5, model_name="mobilenetv2_wd2", **kwargs) 250 | 251 | 252 | def mobilenetv2_wd4(**kwargs): 253 | """ 254 | 0.25 MobileNetV2-224 model from 'MobileNetV2: Inverted Residuals and Linear Bottlenecks,' 255 | https://arxiv.org/abs/1801.04381. 256 | Parameters: 257 | ---------- 258 | pretrained : bool, default False 259 | Whether to load the pretrained weights for model. 260 | root : str, default '~/.torch/models' 261 | Location for keeping the model parameters. 262 | """ 263 | return get_mobilenetv2(width_scale=0.25, model_name="mobilenetv2_wd4", **kwargs) 264 | 265 | 266 | def mobilenetv2b_w1(**kwargs): 267 | """ 268 | 1.0 MobileNetV2b-224 model from 'MobileNetV2: Inverted Residuals and Linear Bottlenecks,' 269 | https://arxiv.org/abs/1801.04381. 270 | Parameters: 271 | ---------- 272 | pretrained : bool, default False 273 | Whether to load the pretrained weights for model. 274 | root : str, default '~/.torch/models' 275 | Location for keeping the model parameters. 276 | """ 277 | return get_mobilenetv2(width_scale=1.0, remove_exp_conv=True, model_name="mobilenetv2b_w1", **kwargs) 278 | 279 | 280 | def mobilenetv2b_w3d4(**kwargs): 281 | """ 282 | 0.75 MobileNetV2b-224 model from 'MobileNetV2: Inverted Residuals and Linear Bottlenecks,' 283 | https://arxiv.org/abs/1801.04381. 284 | Parameters: 285 | ---------- 286 | pretrained : bool, default False 287 | Whether to load the pretrained weights for model. 288 | root : str, default '~/.torch/models' 289 | Location for keeping the model parameters. 290 | """ 291 | return get_mobilenetv2(width_scale=0.75, remove_exp_conv=True, model_name="mobilenetv2b_w3d4", **kwargs) 292 | 293 | 294 | def mobilenetv2b_wd2(**kwargs): 295 | """ 296 | 0.5 MobileNetV2b-224 model from 'MobileNetV2: Inverted Residuals and Linear Bottlenecks,' 297 | https://arxiv.org/abs/1801.04381. 298 | Parameters: 299 | ---------- 300 | pretrained : bool, default False 301 | Whether to load the pretrained weights for model. 302 | root : str, default '~/.torch/models' 303 | Location for keeping the model parameters. 304 | """ 305 | return get_mobilenetv2(width_scale=0.5, remove_exp_conv=True, model_name="mobilenetv2b_wd2", **kwargs) 306 | 307 | 308 | def mobilenetv2b_wd4(**kwargs): 309 | """ 310 | 0.25 MobileNetV2b-224 model from 'MobileNetV2: Inverted Residuals and Linear Bottlenecks,' 311 | https://arxiv.org/abs/1801.04381. 312 | Parameters: 313 | ---------- 314 | pretrained : bool, default False 315 | Whether to load the pretrained weights for model. 316 | root : str, default '~/.torch/models' 317 | Location for keeping the model parameters. 318 | """ 319 | return get_mobilenetv2(width_scale=0.25, remove_exp_conv=True, model_name="mobilenetv2b_wd4", **kwargs) 320 | 321 | 322 | def _calc_width(net): 323 | import numpy as np 324 | net_params = filter(lambda p: p.requires_grad, net.parameters()) 325 | weight_count = 0 326 | for param in net_params: 327 | weight_count += np.prod(param.size()) 328 | return weight_count 329 | 330 | 331 | def _test(): 332 | import torch 333 | 334 | pretrained = False 335 | 336 | models = [ 337 | mobilenetv2_w1, 338 | mobilenetv2_w3d4, 339 | mobilenetv2_wd2, 340 | mobilenetv2_wd4, 341 | mobilenetv2b_w1, 342 | mobilenetv2b_w3d4, 343 | mobilenetv2b_wd2, 344 | mobilenetv2b_wd4, 345 | ] 346 | 347 | for model in models: 348 | 349 | net = model(pretrained=pretrained) 350 | 351 | # net.train() 352 | net.eval() 353 | weight_count = _calc_width(net) 354 | print("m={}, {}".format(model.__name__, weight_count)) 355 | assert (model != mobilenetv2_w1 or weight_count == 3504960) 356 | assert (model != mobilenetv2_w3d4 or weight_count == 2627592) 357 | assert (model != mobilenetv2_wd2 or weight_count == 1964736) 358 | assert (model != mobilenetv2_wd4 or weight_count == 1516392) 359 | assert (model != mobilenetv2b_w1 or weight_count == 3503872) 360 | assert (model != mobilenetv2b_w3d4 or weight_count == 2626968) 361 | assert (model != mobilenetv2b_wd2 or weight_count == 1964448) 362 | assert (model != mobilenetv2b_wd4 or weight_count == 1516312) 363 | 364 | x = torch.randn(1, 3, 224, 224) 365 | y = net(x) 366 | y.sum().backward() 367 | assert (tuple(y.size()) == (1, 1000)) 368 | 369 | 370 | if __name__ == "__main__": 371 | _test() -------------------------------------------------------------------------------- /models/glouncv/alexnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.init as init 5 | from inspect import isfunction 6 | 7 | class ConvBlock(nn.Module): 8 | """ 9 | Standard convolution block with Batch normalization and activation. 10 | Parameters: 11 | ---------- 12 | in_channels : int 13 | Number of input channels. 14 | out_channels : int 15 | Number of output channels. 16 | kernel_size : int or tuple/list of 2 int 17 | Convolution window size. 18 | stride : int or tuple/list of 2 int 19 | Strides of the convolution. 20 | padding : int, or tuple/list of 2 int, or tuple/list of 4 int 21 | Padding value for convolution layer. 22 | dilation : int or tuple/list of 2 int, default 1 23 | Dilation value for convolution layer. 24 | groups : int, default 1 25 | Number of groups. 26 | bias : bool, default False 27 | Whether the layer uses a bias vector. 28 | use_bn : bool, default True 29 | Whether to use BatchNorm layer. 30 | bn_eps : float, default 1e-5 31 | Small float added to variance in Batch norm. 32 | activation : function or str or None, default nn.ReLU(inplace=True) 33 | Activation function or name of activation function. 34 | """ 35 | def __init__(self, 36 | in_channels, 37 | out_channels, 38 | kernel_size, 39 | stride, 40 | padding, 41 | dilation=1, 42 | groups=1, 43 | bias=False, 44 | use_bn=True, 45 | bn_eps=1e-5, 46 | activation=(lambda: nn.ReLU(inplace=True))): 47 | super(ConvBlock, self).__init__() 48 | self.activate = (activation is not None) 49 | self.use_bn = use_bn 50 | self.use_pad = (isinstance(padding, (list, tuple)) and (len(padding) == 4)) 51 | 52 | if self.use_pad: 53 | self.pad = nn.ZeroPad2d(padding=padding) 54 | padding = 0 55 | self.conv = nn.Conv2d( 56 | in_channels=in_channels, 57 | out_channels=out_channels, 58 | kernel_size=kernel_size, 59 | stride=stride, 60 | padding=padding, 61 | dilation=dilation, 62 | groups=groups, 63 | bias=bias) 64 | if self.use_bn: 65 | self.bn = nn.BatchNorm2d( 66 | num_features=out_channels, 67 | eps=bn_eps) 68 | if self.activate: 69 | self.activ = get_activation_layer(activation) 70 | 71 | def forward(self, x): 72 | if self.use_pad: 73 | x = self.pad(x) 74 | x = self.conv(x) 75 | if self.use_bn: 76 | x = self.bn(x) 77 | if self.activate: 78 | x = self.activ(x) 79 | return x 80 | 81 | 82 | 83 | class AlexConv(ConvBlock): 84 | """ 85 | AlexNet specific convolution block. 86 | Parameters: 87 | ---------- 88 | in_channels : int 89 | Number of input channels. 90 | out_channels : int 91 | Number of output channels. 92 | kernel_size : int or tuple/list of 2 int 93 | Convolution window size. 94 | stride : int or tuple/list of 2 int 95 | Strides of the convolution. 96 | padding : int or tuple/list of 2 int 97 | Padding value for convolution layer. 98 | use_lrn : bool 99 | Whether to use LRN layer. 100 | """ 101 | def __init__(self, 102 | in_channels, 103 | out_channels, 104 | kernel_size, 105 | stride, 106 | padding, 107 | use_lrn): 108 | super(AlexConv, self).__init__( 109 | in_channels=in_channels, 110 | out_channels=out_channels, 111 | kernel_size=kernel_size, 112 | stride=stride, 113 | padding=padding, 114 | bias=True, 115 | use_bn=False) 116 | self.use_lrn = use_lrn 117 | 118 | def forward(self, x): 119 | x = super(AlexConv, self).forward(x) 120 | if self.use_lrn: 121 | x = F.local_response_norm(x, size=5, k=2.0) 122 | return x 123 | 124 | 125 | class AlexDense(nn.Module): 126 | """ 127 | AlexNet specific dense block. 128 | Parameters: 129 | ---------- 130 | in_channels : int 131 | Number of input channels. 132 | out_channels : int 133 | Number of output channels. 134 | """ 135 | def __init__(self, 136 | in_channels, 137 | out_channels): 138 | super(AlexDense, self).__init__() 139 | self.fc = nn.Linear( 140 | in_features=in_channels, 141 | out_features=out_channels) 142 | self.activ = nn.ReLU(inplace=True) 143 | self.dropout = nn.Dropout(p=0.5) 144 | 145 | def forward(self, x): 146 | x = self.fc(x) 147 | x = self.activ(x) 148 | x = self.dropout(x) 149 | return x 150 | 151 | 152 | class AlexOutputBlock(nn.Module): 153 | """ 154 | AlexNet specific output block. 155 | Parameters: 156 | ---------- 157 | in_channels : int 158 | Number of input channels. 159 | classes : int 160 | Number of classification classes. 161 | """ 162 | def __init__(self, 163 | in_channels, 164 | classes): 165 | super(AlexOutputBlock, self).__init__() 166 | mid_channels = 4096 167 | 168 | self.fc1 = AlexDense( 169 | in_channels=in_channels, 170 | out_channels=mid_channels) 171 | self.fc2 = AlexDense( 172 | in_channels=mid_channels, 173 | out_channels=mid_channels) 174 | self.fc3 = nn.Linear( 175 | in_features=mid_channels, 176 | out_features=classes) 177 | 178 | def forward(self, x): 179 | x = self.fc1(x) 180 | x = self.fc2(x) 181 | x = self.fc3(x) 182 | return x 183 | 184 | 185 | class AlexNet(nn.Module): 186 | """ 187 | AlexNet model from 'One weird trick for parallelizing convolutional neural networks,' 188 | https://arxiv.org/abs/1404.5997. 189 | Parameters: 190 | ---------- 191 | channels : list of list of int 192 | Number of output channels for each unit. 193 | kernel_sizes : list of list of int 194 | Convolution window sizes for each unit. 195 | strides : list of list of int or tuple/list of 2 int 196 | Strides of the convolution for each unit. 197 | paddings : list of list of int or tuple/list of 2 int 198 | Padding value for convolution layer for each unit. 199 | use_lrn : bool 200 | Whether to use LRN layer. 201 | in_channels : int, default 3 202 | Number of input channels. 203 | in_size : tuple of two ints, default (224, 224) 204 | Spatial size of the expected input image. 205 | num_classes : int, default 1000 206 | Number of classification classes. 207 | """ 208 | def __init__(self, 209 | channels, 210 | kernel_sizes, 211 | strides, 212 | paddings, 213 | use_lrn, 214 | in_channels=3, 215 | in_size=(224, 224), 216 | num_classes=1000): 217 | super(AlexNet, self).__init__() 218 | self.in_size = in_size 219 | self.num_classes = num_classes 220 | 221 | self.features = nn.Sequential() 222 | for i, channels_per_stage in enumerate(channels): 223 | use_lrn_i = use_lrn and (i in [0, 1]) 224 | stage = nn.Sequential() 225 | for j, out_channels in enumerate(channels_per_stage): 226 | stage.add_module("unit{}".format(j + 1), AlexConv( 227 | in_channels=in_channels, 228 | out_channels=out_channels, 229 | kernel_size=kernel_sizes[i][j], 230 | stride=strides[i][j], 231 | padding=paddings[i][j], 232 | use_lrn=use_lrn_i)) 233 | in_channels = out_channels 234 | stage.add_module("pool{}".format(i + 1), nn.MaxPool2d( 235 | kernel_size=3, 236 | stride=2, 237 | padding=0, 238 | ceil_mode=True)) 239 | self.features.add_module("stage{}".format(i + 1), stage) 240 | 241 | self.output = AlexOutputBlock( 242 | in_channels=(in_channels * 6 * 6), 243 | classes=num_classes) 244 | 245 | self._init_params() 246 | 247 | def _init_params(self): 248 | for name, module in self.named_modules(): 249 | if isinstance(module, nn.Conv2d): 250 | init.kaiming_uniform_(module.weight) 251 | if module.bias is not None: 252 | init.constant_(module.bias, 0) 253 | 254 | def forward(self, x): 255 | x = self.features(x) 256 | x = x.view(x.size(0), -1) 257 | x = self.output(x) 258 | return x 259 | 260 | 261 | 262 | def get_activation_layer(activation): 263 | """ 264 | Create activation layer from string/function. 265 | Parameters: 266 | ---------- 267 | activation : function, or str, or nn.Module 268 | Activation function or name of activation function. 269 | Returns 270 | ------- 271 | nn.Module 272 | Activation layer. 273 | """ 274 | assert (activation is not None) 275 | if isfunction(activation): 276 | return activation() 277 | elif isinstance(activation, str): 278 | if activation == "relu": 279 | return nn.ReLU(inplace=True) 280 | elif activation == "relu6": 281 | return nn.ReLU6(inplace=True) 282 | elif activation == "swish": 283 | return Swish() 284 | elif activation == "hswish": 285 | return HSwish(inplace=True) 286 | elif activation == "sigmoid": 287 | return nn.Sigmoid() 288 | elif activation == "hsigmoid": 289 | return HSigmoid() 290 | elif activation == "identity": 291 | return Identity() 292 | else: 293 | raise NotImplementedError() 294 | else: 295 | assert (isinstance(activation, nn.Module)) 296 | return activation 297 | 298 | 299 | def get_alexnet(version="a", 300 | model_name=None, 301 | pretrained=False, 302 | root=os.path.join("~", ".torch", "models"), 303 | **kwargs): 304 | """ 305 | Create AlexNet model with specific parameters. 306 | Parameters: 307 | ---------- 308 | version : str, default 'a' 309 | Version of AlexNet ('a' or 'b'). 310 | model_name : str or None, default None 311 | Model name for loading pretrained model. 312 | pretrained : bool, default False 313 | Whether to load the pretrained weights for model. 314 | root : str, default '~/.torch/models' 315 | Location for keeping the model parameters. 316 | """ 317 | if version == "a": 318 | channels = [[96], [256], [384, 384, 256]] 319 | kernel_sizes = [[11], [5], [3, 3, 3]] 320 | strides = [[4], [1], [1, 1, 1]] 321 | paddings = [[0], [2], [1, 1, 1]] 322 | use_lrn = True 323 | elif version == "b": 324 | channels = [[64], [192], [384, 256, 256]] 325 | kernel_sizes = [[11], [5], [3, 3, 3]] 326 | strides = [[4], [1], [1, 1, 1]] 327 | paddings = [[2], [2], [1, 1, 1]] 328 | use_lrn = False 329 | else: 330 | raise ValueError("Unsupported AlexNet version {}".format(version)) 331 | 332 | net = AlexNet( 333 | channels=channels, 334 | kernel_sizes=kernel_sizes, 335 | strides=strides, 336 | paddings=paddings, 337 | use_lrn=use_lrn, 338 | **kwargs) 339 | 340 | if pretrained: 341 | if (model_name is None) or (not model_name): 342 | raise ValueError("Parameter `model_name` should be properly initialized for loading pretrained model.") 343 | # from .model_store import download_model 344 | # download_model( 345 | # net=net, 346 | # model_name=model_name, 347 | # local_model_store_dir_path=root) 348 | 349 | return net 350 | 351 | 352 | def alexnet(**kwargs): 353 | """ 354 | AlexNet model from 'One weird trick for parallelizing convolutional neural networks,' 355 | https://arxiv.org/abs/1404.5997. 356 | Parameters: 357 | ---------- 358 | pretrained : bool, default False 359 | Whether to load the pretrained weights for model. 360 | root : str, default '~/.torch/models' 361 | Location for keeping the model parameters. 362 | """ 363 | return get_alexnet(model_name="alexnet", **kwargs) 364 | 365 | 366 | def alexnetb(**kwargs): 367 | """ 368 | AlexNet-b model from 'One weird trick for parallelizing convolutional neural networks,' 369 | https://arxiv.org/abs/1404.5997. Non-standard version. 370 | Parameters: 371 | ---------- 372 | pretrained : bool, default False 373 | Whether to load the pretrained weights for model. 374 | root : str, default '~/.torch/models' 375 | Location for keeping the model parameters. 376 | """ 377 | return get_alexnet(version="b", model_name="alexnetb", **kwargs) 378 | 379 | 380 | def _calc_width(net): 381 | import numpy as np 382 | net_params = filter(lambda p: p.requires_grad, net.parameters()) 383 | weight_count = 0 384 | for param in net_params: 385 | weight_count += np.prod(param.size()) 386 | return weight_count 387 | 388 | 389 | def _test(): 390 | import torch 391 | 392 | pretrained = False 393 | 394 | models = [ 395 | alexnet, 396 | alexnetb, 397 | ] 398 | 399 | for model in models: 400 | 401 | net = model(pretrained=pretrained) 402 | print (net) 403 | # net.train() 404 | net.eval() 405 | weight_count = _calc_width(net) 406 | print("m={}, {}".format(model.__name__, weight_count)) 407 | assert (model != alexnet or weight_count == 62378344) 408 | assert (model != alexnetb or weight_count == 61100840) 409 | 410 | x = torch.randn(1, 3, 224, 224) 411 | y = net(x) 412 | # y.sum().backward() 413 | assert (tuple(y.size()) == (1, 1000)) 414 | 415 | 416 | if __name__ == "__main__": 417 | _test() -------------------------------------------------------------------------------- /models/glouncv/alexnet_bn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.init as init 5 | from inspect import isfunction 6 | 7 | class ConvBlock(nn.Module): 8 | """ 9 | Standard convolution block with Batch normalization and activation. 10 | Parameters: 11 | ---------- 12 | in_channels : int 13 | Number of input channels. 14 | out_channels : int 15 | Number of output channels. 16 | kernel_size : int or tuple/list of 2 int 17 | Convolution window size. 18 | stride : int or tuple/list of 2 int 19 | Strides of the convolution. 20 | padding : int, or tuple/list of 2 int, or tuple/list of 4 int 21 | Padding value for convolution layer. 22 | dilation : int or tuple/list of 2 int, default 1 23 | Dilation value for convolution layer. 24 | groups : int, default 1 25 | Number of groups. 26 | bias : bool, default False 27 | Whether the layer uses a bias vector. 28 | use_bn : bool, default True 29 | Whether to use BatchNorm layer. 30 | bn_eps : float, default 1e-5 31 | Small float added to variance in Batch norm. 32 | activation : function or str or None, default nn.ReLU(inplace=True) 33 | Activation function or name of activation function. 34 | """ 35 | def __init__(self, 36 | in_channels, 37 | out_channels, 38 | kernel_size, 39 | stride, 40 | padding, 41 | dilation=1, 42 | groups=1, 43 | bias=False, 44 | use_bn=True, 45 | bn_eps=1e-5, 46 | activation=(lambda: nn.ReLU(inplace=True))): 47 | super(ConvBlock, self).__init__() 48 | self.activate = (activation is not None) 49 | self.use_bn = use_bn 50 | self.use_pad = (isinstance(padding, (list, tuple)) and (len(padding) == 4)) 51 | 52 | if self.use_pad: 53 | self.pad = nn.ZeroPad2d(padding=padding) 54 | padding = 0 55 | self.conv = nn.Conv2d( 56 | in_channels=in_channels, 57 | out_channels=out_channels, 58 | kernel_size=kernel_size, 59 | stride=stride, 60 | padding=padding, 61 | dilation=dilation, 62 | groups=groups, 63 | bias=bias) 64 | if self.use_bn: 65 | self.bn = nn.BatchNorm2d( 66 | num_features=out_channels, 67 | eps=bn_eps) 68 | if self.activate: 69 | self.activ = get_activation_layer(activation) 70 | 71 | def forward(self, x): 72 | if self.use_pad: 73 | x = self.pad(x) 74 | x = self.conv(x) 75 | if self.use_bn: 76 | x = self.bn(x) 77 | if self.activate: 78 | x = self.activ(x) 79 | return x 80 | 81 | 82 | 83 | class AlexConv(ConvBlock): 84 | """ 85 | AlexNet specific convolution block. 86 | Parameters: 87 | ---------- 88 | in_channels : int 89 | Number of input channels. 90 | out_channels : int 91 | Number of output channels. 92 | kernel_size : int or tuple/list of 2 int 93 | Convolution window size. 94 | stride : int or tuple/list of 2 int 95 | Strides of the convolution. 96 | padding : int or tuple/list of 2 int 97 | Padding value for convolution layer. 98 | use_lrn : bool 99 | Whether to use LRN layer. 100 | """ 101 | def __init__(self, 102 | in_channels, 103 | out_channels, 104 | kernel_size, 105 | stride, 106 | padding, 107 | use_lrn): 108 | super(AlexConv, self).__init__( 109 | in_channels=in_channels, 110 | out_channels=out_channels, 111 | kernel_size=kernel_size, 112 | stride=stride, 113 | padding=padding, 114 | bias=True, 115 | use_bn=use_lrn) 116 | self.use_lrn = False #hardcoding. 117 | 118 | def forward(self, x): 119 | x = super(AlexConv, self).forward(x) 120 | if self.use_lrn: 121 | x = F.local_response_norm(x, size=5, k=2.0) 122 | return x 123 | 124 | 125 | class AlexDense(nn.Module): 126 | """ 127 | AlexNet specific dense block. 128 | Parameters: 129 | ---------- 130 | in_channels : int 131 | Number of input channels. 132 | out_channels : int 133 | Number of output channels. 134 | """ 135 | def __init__(self, 136 | in_channels, 137 | out_channels): 138 | super(AlexDense, self).__init__() 139 | self.fc = nn.Linear( 140 | in_features=in_channels, 141 | out_features=out_channels) 142 | self.activ = nn.ReLU(inplace=True) 143 | self.dropout = nn.Dropout(p=0.5) 144 | 145 | def forward(self, x): 146 | x = self.fc(x) 147 | x = self.activ(x) 148 | x = self.dropout(x) 149 | return x 150 | 151 | 152 | class AlexOutputBlock(nn.Module): 153 | """ 154 | AlexNet specific output block. 155 | Parameters: 156 | ---------- 157 | in_channels : int 158 | Number of input channels. 159 | classes : int 160 | Number of classification classes. 161 | """ 162 | def __init__(self, 163 | in_channels, 164 | classes): 165 | super(AlexOutputBlock, self).__init__() 166 | mid_channels = 4096 167 | 168 | self.fc1 = AlexDense( 169 | in_channels=in_channels, 170 | out_channels=mid_channels) 171 | self.fc2 = AlexDense( 172 | in_channels=mid_channels, 173 | out_channels=mid_channels) 174 | self.fc3 = nn.Linear( 175 | in_features=mid_channels, 176 | out_features=classes) 177 | 178 | def forward(self, x): 179 | x = self.fc1(x) 180 | x = self.fc2(x) 181 | x = self.fc3(x) 182 | return x 183 | 184 | 185 | class AlexNet(nn.Module): 186 | """ 187 | AlexNet model from 'One weird trick for parallelizing convolutional neural networks,' 188 | https://arxiv.org/abs/1404.5997. 189 | Parameters: 190 | ---------- 191 | channels : list of list of int 192 | Number of output channels for each unit. 193 | kernel_sizes : list of list of int 194 | Convolution window sizes for each unit. 195 | strides : list of list of int or tuple/list of 2 int 196 | Strides of the convolution for each unit. 197 | paddings : list of list of int or tuple/list of 2 int 198 | Padding value for convolution layer for each unit. 199 | use_lrn : bool 200 | Whether to use LRN layer. 201 | in_channels : int, default 3 202 | Number of input channels. 203 | in_size : tuple of two ints, default (224, 224) 204 | Spatial size of the expected input image. 205 | num_classes : int, default 1000 206 | Number of classification classes. 207 | """ 208 | def __init__(self, 209 | channels, 210 | kernel_sizes, 211 | strides, 212 | paddings, 213 | use_lrn, 214 | in_channels=3, 215 | in_size=(224, 224), 216 | num_classes=1000): 217 | super(AlexNet, self).__init__() 218 | self.in_size = in_size 219 | self.num_classes = num_classes 220 | 221 | self.features = nn.Sequential() 222 | for i, channels_per_stage in enumerate(channels): 223 | use_lrn_i = use_lrn and (i in [0, 1]) 224 | stage = nn.Sequential() 225 | for j, out_channels in enumerate(channels_per_stage): 226 | stage.add_module("unit{}".format(j + 1), AlexConv( 227 | in_channels=in_channels, 228 | out_channels=out_channels, 229 | kernel_size=kernel_sizes[i][j], 230 | stride=strides[i][j], 231 | padding=paddings[i][j], 232 | use_lrn=use_lrn_i)) 233 | in_channels = out_channels 234 | stage.add_module("pool{}".format(i + 1), nn.MaxPool2d( 235 | kernel_size=3, 236 | stride=2, 237 | padding=0, 238 | ceil_mode=True)) 239 | self.features.add_module("stage{}".format(i + 1), stage) 240 | 241 | self.output = AlexOutputBlock( 242 | in_channels=(in_channels * 6 * 6), 243 | classes=num_classes) 244 | 245 | self._init_params() 246 | 247 | def _init_params(self): 248 | for name, module in self.named_modules(): 249 | if isinstance(module, nn.Conv2d): 250 | init.kaiming_uniform_(module.weight) 251 | if module.bias is not None: 252 | init.constant_(module.bias, 0) 253 | 254 | def forward(self, x): 255 | x = self.features(x) 256 | x = x.view(x.size(0), -1) 257 | x = self.output(x) 258 | return x 259 | 260 | 261 | 262 | def get_activation_layer(activation): 263 | """ 264 | Create activation layer from string/function. 265 | Parameters: 266 | ---------- 267 | activation : function, or str, or nn.Module 268 | Activation function or name of activation function. 269 | Returns 270 | ------- 271 | nn.Module 272 | Activation layer. 273 | """ 274 | assert (activation is not None) 275 | if isfunction(activation): 276 | return activation() 277 | elif isinstance(activation, str): 278 | if activation == "relu": 279 | return nn.ReLU(inplace=True) 280 | elif activation == "relu6": 281 | return nn.ReLU6(inplace=True) 282 | elif activation == "swish": 283 | return Swish() 284 | elif activation == "hswish": 285 | return HSwish(inplace=True) 286 | elif activation == "sigmoid": 287 | return nn.Sigmoid() 288 | elif activation == "hsigmoid": 289 | return HSigmoid() 290 | elif activation == "identity": 291 | return Identity() 292 | else: 293 | raise NotImplementedError() 294 | else: 295 | assert (isinstance(activation, nn.Module)) 296 | return activation 297 | 298 | 299 | def get_alexnet(version="a", 300 | model_name=None, 301 | pretrained=False, 302 | root=os.path.join("~", ".torch", "models"), 303 | **kwargs): 304 | """ 305 | Create AlexNet model with specific parameters. 306 | Parameters: 307 | ---------- 308 | version : str, default 'a' 309 | Version of AlexNet ('a' or 'b'). 310 | model_name : str or None, default None 311 | Model name for loading pretrained model. 312 | pretrained : bool, default False 313 | Whether to load the pretrained weights for model. 314 | root : str, default '~/.torch/models' 315 | Location for keeping the model parameters. 316 | """ 317 | if version == "a": 318 | channels = [[96], [256], [384, 384, 256]] 319 | kernel_sizes = [[11], [5], [3, 3, 3]] 320 | strides = [[4], [1], [1, 1, 1]] 321 | paddings = [[0], [2], [1, 1, 1]] 322 | use_lrn = True 323 | elif version == "b": 324 | channels = [[64], [192], [384, 256, 256]] 325 | kernel_sizes = [[11], [5], [3, 3, 3]] 326 | strides = [[4], [1], [1, 1, 1]] 327 | paddings = [[2], [2], [1, 1, 1]] 328 | use_lrn = False 329 | else: 330 | raise ValueError("Unsupported AlexNet version {}".format(version)) 331 | 332 | net = AlexNet( 333 | channels=channels, 334 | kernel_sizes=kernel_sizes, 335 | strides=strides, 336 | paddings=paddings, 337 | use_lrn=use_lrn, 338 | **kwargs) 339 | 340 | if pretrained: 341 | if (model_name is None) or (not model_name): 342 | raise ValueError("Parameter `model_name` should be properly initialized for loading pretrained model.") 343 | # from .model_store import download_model 344 | # download_model( 345 | # net=net, 346 | # model_name=model_name, 347 | # local_model_store_dir_path=root) 348 | 349 | return net 350 | 351 | 352 | def alexnet(**kwargs): 353 | """ 354 | AlexNet model from 'One weird trick for parallelizing convolutional neural networks,' 355 | https://arxiv.org/abs/1404.5997. 356 | Parameters: 357 | ---------- 358 | pretrained : bool, default False 359 | Whether to load the pretrained weights for model. 360 | root : str, default '~/.torch/models' 361 | Location for keeping the model parameters. 362 | """ 363 | return get_alexnet(model_name="alexnet", **kwargs) 364 | 365 | 366 | def alexnetb(**kwargs): 367 | """ 368 | AlexNet-b model from 'One weird trick for parallelizing convolutional neural networks,' 369 | https://arxiv.org/abs/1404.5997. Non-standard version. 370 | Parameters: 371 | ---------- 372 | pretrained : bool, default False 373 | Whether to load the pretrained weights for model. 374 | root : str, default '~/.torch/models' 375 | Location for keeping the model parameters. 376 | """ 377 | return get_alexnet(version="b", model_name="alexnetb", **kwargs) 378 | 379 | 380 | def _calc_width(net): 381 | import numpy as np 382 | net_params = filter(lambda p: p.requires_grad, net.parameters()) 383 | weight_count = 0 384 | for param in net_params: 385 | weight_count += np.prod(param.size()) 386 | return weight_count 387 | 388 | 389 | def _test(): 390 | import torch 391 | 392 | pretrained = False 393 | 394 | models = [ 395 | alexnet, 396 | alexnetb, 397 | ] 398 | 399 | for model in models: 400 | 401 | net = model(pretrained=pretrained) 402 | print (net) 403 | # net.train() 404 | net.eval() 405 | weight_count = _calc_width(net) 406 | print("m={}, {}".format(model.__name__, weight_count)) 407 | assert (model != alexnet or weight_count == 62378344) 408 | assert (model != alexnetb or weight_count == 61100840) 409 | 410 | x = torch.randn(1, 3, 224, 224) 411 | y = net(x) 412 | # y.sum().backward() 413 | assert (tuple(y.size()) == (1, 1000)) 414 | 415 | 416 | if __name__ == "__main__": 417 | _test() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import os 7 | import sys 8 | import time 9 | import math 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | import torchvision 15 | import torchvision.transforms as transforms 16 | 17 | def get_mean_and_std(dataset): 18 | '''Compute the mean and std value of dataset.''' 19 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 20 | mean = torch.zeros(3) 21 | std = torch.zeros(3) 22 | print('==> Computing mean and std..') 23 | for inputs, targets in dataloader: 24 | for i in range(3): 25 | mean[i] += inputs[:,i,:,:].mean() 26 | std[i] += inputs[:,i,:,:].std() 27 | mean.div_(len(dataset)) 28 | std.div_(len(dataset)) 29 | return mean, std 30 | 31 | def init_params(net): 32 | '''Init layer parameters.''' 33 | for m in net.modules(): 34 | if isinstance(m, nn.Conv2d): 35 | init.kaiming_normal(m.weight, mode='fan_out') 36 | if m.bias: 37 | init.constant(m.bias, 0) 38 | elif isinstance(m, nn.BatchNorm2d): 39 | init.constant(m.weight, 1) 40 | init.constant(m.bias, 0) 41 | elif isinstance(m, nn.Linear): 42 | init.normal(m.weight, std=1e-3) 43 | if m.bias: 44 | init.constant(m.bias, 0) 45 | 46 | 47 | _, term_width = os.popen('stty size', 'r').read().split() 48 | term_width = int(term_width) 49 | 50 | TOTAL_BAR_LENGTH = 65. 51 | last_time = time.time() 52 | begin_time = last_time 53 | def progress_bar(current, total, msg=None): 54 | global last_time, begin_time 55 | if current == 0: 56 | begin_time = time.time() # Reset for new bar. 57 | 58 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 59 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 60 | 61 | sys.stdout.write(' [') 62 | for i in range(cur_len): 63 | sys.stdout.write('=') 64 | sys.stdout.write('>') 65 | for i in range(rest_len): 66 | sys.stdout.write('.') 67 | sys.stdout.write(']') 68 | 69 | cur_time = time.time() 70 | step_time = cur_time - last_time 71 | last_time = cur_time 72 | tot_time = cur_time - begin_time 73 | 74 | L = [] 75 | L.append(' Step: %s' % format_time(step_time)) 76 | L.append(' | Tot: %s' % format_time(tot_time)) 77 | if msg: 78 | L.append(' | ' + msg) 79 | 80 | msg = ''.join(L) 81 | sys.stdout.write(msg) 82 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 83 | sys.stdout.write(' ') 84 | 85 | # Go back to the center of the bar. 86 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): 87 | sys.stdout.write('\b') 88 | sys.stdout.write(' %d/%d ' % (current+1, total)) 89 | 90 | if current < total-1: 91 | sys.stdout.write('\r') 92 | else: 93 | sys.stdout.write('\n') 94 | sys.stdout.flush() 95 | 96 | def format_time(seconds): 97 | days = int(seconds / 3600/24) 98 | seconds = seconds - days*3600*24 99 | hours = int(seconds / 3600) 100 | seconds = seconds - hours*3600 101 | minutes = int(seconds / 60) 102 | seconds = seconds - minutes*60 103 | secondsf = int(seconds) 104 | seconds = seconds - secondsf 105 | millis = int(seconds*1000) 106 | 107 | f = '' 108 | i = 1 109 | if days > 0: 110 | f += str(days) + 'D' 111 | i += 1 112 | if hours > 0 and i <= 2: 113 | f += str(hours) + 'h' 114 | i += 1 115 | if minutes > 0 and i <= 2: 116 | f += str(minutes) + 'm' 117 | i += 1 118 | if secondsf > 0 and i <= 2: 119 | f += str(secondsf) + 's' 120 | i += 1 121 | if millis > 0 and i <= 2: 122 | f += str(millis) + 'ms' 123 | i += 1 124 | if f == '': 125 | f = '0ms' 126 | return f 127 | 128 | 129 | def replace_all(model, replacement_dict={}): 130 | """ 131 | Replace all layers in the original model with new layers corresponding to `replacement_dict`. 132 | E.g input example: 133 | replacement_dict={ nn.Conv2d : partial(NIPS2019_QConv2d, bit=args.bit) } 134 | """ 135 | 136 | def __replace_module(model): 137 | for module_name in model._modules: 138 | m = model._modules[module_name] 139 | 140 | if type(m) in replacement_dict.keys(): 141 | if isinstance(m, nn.Conv2d): 142 | new_module = replacement_dict[type(m)] 143 | model._modules[module_name] = new_module(in_channels=m.in_channels, 144 | out_channels=m.out_channels, kernel_size=m.kernel_size, 145 | stride=m.stride, padding=m.padding, dilation=m.dilation, 146 | groups=m.groups, bias=(m.bias!=None)) 147 | 148 | elif isinstance(m, nn.Linear): 149 | new_module = replacement_dict[type(m)] 150 | model._modules[module_name] = new_module(in_features=m.in_features, 151 | out_features=m.out_features, 152 | bias=(m.bias!=None)) 153 | 154 | elif len(model._modules[module_name]._modules) > 0: 155 | __replace_module(model._modules[module_name]) 156 | 157 | __replace_module(model) 158 | return model 159 | 160 | 161 | def replace_single_module(new_cls, current_module): 162 | m = current_module 163 | if isinstance(m, nn.Conv2d): 164 | return new_cls(in_channels=m.in_channels, 165 | out_channels=m.out_channels, kernel_size=m.kernel_size, 166 | stride=m.stride, padding=m.padding, dilation=m.dilation, 167 | groups=m.groups, bias=(m.bias!=None)) 168 | 169 | elif isinstance(m, nn.Linear): 170 | return new_cls(in_features=m.in_features, out_features=m.out_features, bias=(m.bias != None)) 171 | 172 | return None 173 | 174 | 175 | 176 | def replace_module(model, replacement_dict={}, exception_dict={}, arch="presnet18"): 177 | """ 178 | Replace all layers in the original model with new layers corresponding to `replacement_dict`. 179 | E.g input example: 180 | replacement_dict={ nn.Conv2d : partial(NIPS2019_QConv2d, bit=args.bit) } 181 | exception_dict={ 182 | 'conv1': partial(NIPS2019_QConv2d, bit=8) 183 | 'fc': partial(NIPS2019_QLinear, bit=8) 184 | } 185 | """ 186 | assert arch in ["presnet32", "presnet18", "glouncv-alexnet", "glouncv-alexnet-bn", "postech-alexnet", "glouncv-presnet34", "glouncv-presnet50", "glouncv-mobilenetv2_w1"],\ 187 | ("Not support this type of architecture !") 188 | 189 | model = replace_all(model, replacement_dict=replacement_dict) 190 | 191 | if arch == "presnet32": 192 | model.conv1 = replace_single_module(new_cls=exception_dict['__first__'], current_module=model.conv1) 193 | model.fc = replace_single_module(new_cls=exception_dict['__last__'], current_module=model.fc) 194 | 195 | if "__downsampling__" in exception_dict.keys(): 196 | new_conv_cls = exception_dict['__downsampling__'] 197 | model.layer2[0].downsample[0] = replace_single_module(new_cls=new_conv_cls, current_module=model.layer2[0].downsample[0] ) 198 | model.layer3[0].downsample[0] = replace_single_module(new_cls=new_conv_cls, current_module=model.layer3[0].downsample[0] ) 199 | 200 | if arch == "presnet18": 201 | model.conv1 = replace_single_module(new_cls=exception_dict['__first__'], current_module=model.conv1) 202 | model.fc = replace_single_module(new_cls=exception_dict['__last__'], current_module=model.fc) 203 | 204 | if "__downsampling__" in exception_dict.keys(): 205 | new_conv_cls = exception_dict['__downsampling__'] 206 | model.layer2[0].shortcut[0] = replace_single_module(new_cls=new_conv_cls, current_module=model.layer2[0].shortcut[0] ) 207 | model.layer3[0].shortcut[0] = replace_single_module(new_cls=new_conv_cls, current_module=model.layer3[0].shortcut[0] ) 208 | model.layer4[0].shortcut[0] = replace_single_module(new_cls=new_conv_cls, current_module=model.layer4[0].shortcut[0] ) 209 | 210 | if arch == "glouncv-presnet34": 211 | model.features.init_block.conv = replace_single_module(new_cls=exception_dict['__first__'], current_module=model.features.init_block.conv) 212 | model.output = replace_single_module(new_cls=exception_dict['__last__'], current_module=model.output) 213 | 214 | if "__downsampling__" in exception_dict.keys(): 215 | new_conv_cls = exception_dict['__downsampling__'] 216 | model.features.stage2.unit1.identity_conv = replace_single_module(new_cls=new_conv_cls, current_module=model.features.stage2.unit1.identity_conv ) 217 | model.features.stage3.unit1.identity_conv = replace_single_module(new_cls=new_conv_cls, current_module=model.features.stage3.unit1.identity_conv ) 218 | model.features.stage4.unit1.identity_conv = replace_single_module(new_cls=new_conv_cls, current_module=model.features.stage4.unit1.identity_conv ) 219 | 220 | if arch == "glouncv-presnet50": 221 | model.features.init_block.conv = replace_single_module(new_cls=exception_dict['__first__'], current_module=model.features.init_block.conv) 222 | model.output = replace_single_module(new_cls=exception_dict['__last__'], current_module=model.output) 223 | if "__downsampling__" in exception_dict.keys(): 224 | new_conv_cls = exception_dict['__downsampling__'] 225 | model.features.stage1.unit1.identity_conv = replace_single_module(new_cls=new_conv_cls, current_module=model.features.stage1.unit1.identity_conv ) 226 | model.features.stage2.unit1.identity_conv = replace_single_module(new_cls=new_conv_cls, current_module=model.features.stage2.unit1.identity_conv ) 227 | model.features.stage3.unit1.identity_conv = replace_single_module(new_cls=new_conv_cls, current_module=model.features.stage3.unit1.identity_conv ) 228 | model.features.stage4.unit1.identity_conv = replace_single_module(new_cls=new_conv_cls, current_module=model.features.stage4.unit1.identity_conv ) 229 | 230 | if arch in ["glouncv-alexnet", "glouncv-alexnet-bn"]: 231 | model.features.stage1.unit1.conv = replace_single_module(new_cls=exception_dict['__first__'], current_module=model.features.stage1.unit1.conv) 232 | model.output.fc3 = replace_single_module(new_cls=exception_dict['__last__'], current_module=model.output.fc3) 233 | 234 | if arch == "glouncv-mobilenetv2_w1": 235 | model.features.init_block.conv = replace_single_module(new_cls=exception_dict['__first__'], current_module=model.features.init_block.conv) 236 | model.output = replace_single_module(new_cls=exception_dict['__last__'], current_module=model.output) 237 | return model 238 | 239 | 240 | 241 | 242 | def get_dataloaders(dataset="cifar100", batch_size=128, data_root="~/data"): 243 | if dataset in ("imagenet", "imagenette"): 244 | traindir = os.path.join(data_root, 'train') 245 | valdir = os.path.join(data_root, 'val') 246 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 247 | std=[0.229, 0.224, 0.225]) 248 | 249 | train_dataset = torchvision.datasets.ImageFolder( 250 | traindir, 251 | transforms.Compose([ 252 | transforms.RandomResizedCrop(224), 253 | transforms.RandomHorizontalFlip(), 254 | transforms.ToTensor(), 255 | normalize, 256 | ])) 257 | 258 | 259 | trainloader = torch.utils.data.DataLoader( 260 | train_dataset, batch_size=batch_size, shuffle=True, 261 | num_workers=4, pin_memory=True, sampler=None) 262 | 263 | testloader = torch.utils.data.DataLoader( 264 | torchvision.datasets.ImageFolder(valdir, transforms.Compose([ 265 | transforms.Resize(256), 266 | transforms.CenterCrop(224), 267 | transforms.ToTensor(), 268 | normalize, 269 | ])), 270 | batch_size=batch_size, shuffle=False, 271 | num_workers=4, pin_memory=True) 272 | 273 | 274 | elif dataset == "cifar100": 275 | 276 | transform_train = transforms.Compose([ 277 | transforms.RandomCrop(32, padding=4), 278 | transforms.RandomHorizontalFlip(), 279 | # transforms.RandomRotation(15), #ResNet20, #ResNet32 does not have enough capacity for this transformation. 280 | transforms.ToTensor(), 281 | transforms.Normalize(mean=(0.5070751592371323, 0.48654887331495095, 0.4409178433670343), 282 | std=(0.2673342858792401, 0.2564384629170883, 0.27615047132568404)), 283 | ]) 284 | 285 | transform_test = transforms.Compose([ 286 | transforms.ToTensor(), 287 | transforms.Normalize(mean=(0.5070751592371323, 0.48654887331495095, 0.4409178433670343), 288 | std=(0.2673342858792401, 0.2564384629170883, 0.27615047132568404)), 289 | ]) 290 | 291 | 292 | trainloader = torch.utils.data.DataLoader(torchvision.datasets.CIFAR100( 293 | root=data_root, train=True, download=True, 294 | transform=transform_train), 295 | batch_size=batch_size, shuffle=True, 296 | num_workers=4) 297 | 298 | testloader = torch.utils.data.DataLoader(torchvision.datasets.CIFAR100( 299 | root=data_root, train=False, download=True, 300 | transform=transform_test), 301 | batch_size=batch_size, shuffle=False, 302 | num_workers=4) 303 | 304 | else: 305 | raise NotImplementedError('Not support this type of dataset: ' + dataset) 306 | 307 | return trainloader, testloader 308 | 309 | 310 | def save_checkpoint(net, lr_scheduler, optimizer, acc, epoch, filename='ckpt_best.pth'): 311 | state = { 312 | 'net': net.state_dict(), 313 | 'acc': acc, 314 | 'epoch': epoch, 315 | 'lr_scheduler': lr_scheduler.state_dict() if lr_scheduler is not None \ 316 | else None, 317 | 'optimizer': optimizer.state_dict() if optimizer is not None \ 318 | else None, 319 | } 320 | torch.save(state, filename) 321 | 322 | def add_weight_decay(model, weight_decay, skip_keys): 323 | decay, no_decay = [], [] 324 | for name, param in model.named_parameters(): 325 | if not param.requires_grad: 326 | continue # frozen weights 327 | added = False 328 | for skip_key in skip_keys: 329 | if skip_key in name: 330 | print ("Skip weight decay for: ", name) 331 | no_decay.append(param) 332 | added = True 333 | break 334 | if not added: 335 | decay.append(param) 336 | return [{'params': no_decay, 'weight_decay': 0.}, {'params': decay, 'weight_decay': weight_decay}] 337 | -------------------------------------------------------------------------------- /models/glouncv/preresnet_cifar.py: -------------------------------------------------------------------------------- 1 | """ 2 | PreResNet for CIFAR/SVHN, implemented in PyTorch. 3 | Original papers: 'Identity Mappings in Deep Residual Networks,' https://arxiv.org/abs/1603.05027. 4 | """ 5 | 6 | __all__ = ['CIFARPreResNet', 'preresnet20_cifar10', 'preresnet20_cifar100', 'preresnet20_svhn', 7 | 'preresnet56_cifar10', 'preresnet56_cifar100', 'preresnet56_svhn', 8 | 'preresnet110_cifar10', 'preresnet110_cifar100', 'preresnet110_svhn', 9 | 'preresnet164bn_cifar10', 'preresnet164bn_cifar100', 'preresnet164bn_svhn', 10 | 'preresnet272bn_cifar10', 'preresnet272bn_cifar100', 'preresnet272bn_svhn', 11 | 'preresnet542bn_cifar10', 'preresnet542bn_cifar100', 'preresnet542bn_svhn', 12 | 'preresnet1001_cifar10', 'preresnet1001_cifar100', 'preresnet1001_svhn', 13 | 'preresnet1202_cifar10', 'preresnet1202_cifar100', 'preresnet1202_svhn'] 14 | 15 | import os 16 | import torch.nn as nn 17 | import torch.nn.init as init 18 | from .common import conv3x3 19 | from .preresnet import PreResUnit, PreResActivation 20 | 21 | 22 | class CIFARPreResNet(nn.Module): 23 | """ 24 | PreResNet model for CIFAR from 'Identity Mappings in Deep Residual Networks,' https://arxiv.org/abs/1603.05027. 25 | Parameters: 26 | ---------- 27 | channels : list of list of int 28 | Number of output channels for each unit. 29 | init_block_channels : int 30 | Number of output channels for the initial unit. 31 | bottleneck : bool 32 | Whether to use a bottleneck or simple block in units. 33 | in_channels : int, default 3 34 | Number of input channels. 35 | in_size : tuple of two ints, default (32, 32) 36 | Spatial size of the expected input image. 37 | num_classes : int, default 10 38 | Number of classification classes. 39 | """ 40 | def __init__(self, 41 | channels, 42 | init_block_channels, 43 | bottleneck, 44 | in_channels=3, 45 | in_size=(32, 32), 46 | num_classes=10): 47 | super(CIFARPreResNet, self).__init__() 48 | self.in_size = in_size 49 | self.num_classes = num_classes 50 | 51 | self.features = nn.Sequential() 52 | self.features.add_module("init_block", conv3x3( 53 | in_channels=in_channels, 54 | out_channels=init_block_channels)) 55 | in_channels = init_block_channels 56 | for i, channels_per_stage in enumerate(channels): 57 | stage = nn.Sequential() 58 | for j, out_channels in enumerate(channels_per_stage): 59 | stride = 2 if (j == 0) and (i != 0) else 1 60 | stage.add_module("unit{}".format(j + 1), PreResUnit( 61 | in_channels=in_channels, 62 | out_channels=out_channels, 63 | stride=stride, 64 | bottleneck=bottleneck, 65 | conv1_stride=False)) 66 | in_channels = out_channels 67 | self.features.add_module("stage{}".format(i + 1), stage) 68 | self.features.add_module("post_activ", PreResActivation(in_channels=in_channels)) 69 | self.features.add_module("final_pool", nn.AvgPool2d( 70 | kernel_size=8, 71 | stride=1)) 72 | 73 | self.output = nn.Linear( 74 | in_features=in_channels, 75 | out_features=num_classes) 76 | 77 | self._init_params() 78 | 79 | def _init_params(self): 80 | for name, module in self.named_modules(): 81 | if isinstance(module, nn.Conv2d): 82 | init.kaiming_uniform_(module.weight) 83 | if module.bias is not None: 84 | init.constant_(module.bias, 0) 85 | 86 | def forward(self, x): 87 | x = self.features(x) 88 | x = x.view(x.size(0), -1) 89 | x = self.output(x) 90 | return x 91 | 92 | 93 | def get_preresnet_cifar(num_classes, 94 | blocks, 95 | bottleneck, 96 | model_name=None, 97 | pretrained=False, 98 | root=os.path.join("~", ".torch", "models"), 99 | **kwargs): 100 | """ 101 | Create PreResNet model for CIFAR with specific parameters. 102 | Parameters: 103 | ---------- 104 | num_classes : int 105 | Number of classification classes. 106 | blocks : int 107 | Number of blocks. 108 | bottleneck : bool 109 | Whether to use a bottleneck or simple block in units. 110 | model_name : str or None, default None 111 | Model name for loading pretrained model. 112 | pretrained : bool, default False 113 | Whether to load the pretrained weights for model. 114 | root : str, default '~/.torch/models' 115 | Location for keeping the model parameters. 116 | """ 117 | assert (num_classes in [10, 100]) 118 | 119 | if bottleneck: 120 | assert ((blocks - 2) % 9 == 0) 121 | layers = [(blocks - 2) // 9] * 3 122 | else: 123 | assert ((blocks - 2) % 6 == 0) 124 | layers = [(blocks - 2) // 6] * 3 125 | 126 | channels_per_layers = [16, 32, 64] 127 | init_block_channels = 16 128 | 129 | channels = [[ci] * li for (ci, li) in zip(channels_per_layers, layers)] 130 | 131 | if bottleneck: 132 | channels = [[cij * 4 for cij in ci] for ci in channels] 133 | 134 | net = CIFARPreResNet( 135 | channels=channels, 136 | init_block_channels=init_block_channels, 137 | bottleneck=bottleneck, 138 | num_classes=num_classes, 139 | **kwargs) 140 | 141 | if pretrained: 142 | if (model_name is None) or (not model_name): 143 | raise ValueError("Parameter `model_name` should be properly initialized for loading pretrained model.") 144 | from .model_store import download_model 145 | download_model( 146 | net=net, 147 | model_name=model_name, 148 | local_model_store_dir_path=root) 149 | 150 | return net 151 | 152 | 153 | def preresnet20_cifar10(num_classes=10, **kwargs): 154 | """ 155 | PreResNet-20 model for CIFAR-10 from 'Identity Mappings in Deep Residual Networks,' 156 | https://arxiv.org/abs/1603.05027. 157 | Parameters: 158 | ---------- 159 | num_classes : int, default 10 160 | Number of classification classes. 161 | pretrained : bool, default False 162 | Whether to load the pretrained weights for model. 163 | root : str, default '~/.torch/models' 164 | Location for keeping the model parameters. 165 | """ 166 | return get_preresnet_cifar(num_classes=num_classes, blocks=20, bottleneck=False, model_name="preresnet20_cifar10", 167 | **kwargs) 168 | 169 | 170 | def preresnet20_cifar100(num_classes=100, **kwargs): 171 | """ 172 | PreResNet-20 model for CIFAR-100 from 'Identity Mappings in Deep Residual Networks,' 173 | https://arxiv.org/abs/1603.05027. 174 | Parameters: 175 | ---------- 176 | num_classes : int, default 100 177 | Number of classification classes. 178 | pretrained : bool, default False 179 | Whether to load the pretrained weights for model. 180 | root : str, default '~/.torch/models' 181 | Location for keeping the model parameters. 182 | """ 183 | return get_preresnet_cifar(num_classes=num_classes, blocks=20, bottleneck=False, model_name="preresnet20_cifar100", 184 | **kwargs) 185 | 186 | 187 | def preresnet20_svhn(num_classes=10, **kwargs): 188 | """ 189 | PreResNet-20 model for SVHN from 'Identity Mappings in Deep Residual Networks,' 190 | https://arxiv.org/abs/1603.05027. 191 | Parameters: 192 | ---------- 193 | num_classes : int, default 10 194 | Number of classification classes. 195 | pretrained : bool, default False 196 | Whether to load the pretrained weights for model. 197 | root : str, default '~/.torch/models' 198 | Location for keeping the model parameters. 199 | """ 200 | return get_preresnet_cifar(num_classes=num_classes, blocks=20, bottleneck=False, model_name="preresnet20_svhn", 201 | **kwargs) 202 | 203 | 204 | def preresnet56_cifar10(num_classes=10, **kwargs): 205 | """ 206 | PreResNet-56 model for CIFAR-10 from 'Identity Mappings in Deep Residual Networks,' 207 | https://arxiv.org/abs/1603.05027. 208 | Parameters: 209 | ---------- 210 | num_classes : int, default 10 211 | Number of classification classes. 212 | pretrained : bool, default False 213 | Whether to load the pretrained weights for model. 214 | root : str, default '~/.torch/models' 215 | Location for keeping the model parameters. 216 | """ 217 | return get_preresnet_cifar(num_classes=num_classes, blocks=56, bottleneck=False, model_name="preresnet56_cifar10", 218 | **kwargs) 219 | 220 | 221 | def preresnet56_cifar100(num_classes=100, **kwargs): 222 | """ 223 | PreResNet-56 model for CIFAR-100 from 'Identity Mappings in Deep Residual Networks,' 224 | https://arxiv.org/abs/1603.05027. 225 | Parameters: 226 | ---------- 227 | num_classes : int, default 100 228 | Number of classification classes. 229 | pretrained : bool, default False 230 | Whether to load the pretrained weights for model. 231 | root : str, default '~/.torch/models' 232 | Location for keeping the model parameters. 233 | """ 234 | return get_preresnet_cifar(num_classes=num_classes, blocks=56, bottleneck=False, model_name="preresnet56_cifar100", 235 | **kwargs) 236 | 237 | 238 | def preresnet56_svhn(num_classes=10, **kwargs): 239 | """ 240 | PreResNet-56 model for SVHN from 'Identity Mappings in Deep Residual Networks,' 241 | https://arxiv.org/abs/1603.05027. 242 | Parameters: 243 | ---------- 244 | num_classes : int, default 10 245 | Number of classification classes. 246 | pretrained : bool, default False 247 | Whether to load the pretrained weights for model. 248 | root : str, default '~/.torch/models' 249 | Location for keeping the model parameters. 250 | """ 251 | return get_preresnet_cifar(num_classes=num_classes, blocks=56, bottleneck=False, model_name="preresnet56_svhn", 252 | **kwargs) 253 | 254 | 255 | def preresnet110_cifar10(num_classes=10, **kwargs): 256 | """ 257 | PreResNet-110 model for CIFAR-10 from 'Identity Mappings in Deep Residual Networks,' 258 | https://arxiv.org/abs/1603.05027. 259 | Parameters: 260 | ---------- 261 | num_classes : int, default 10 262 | Number of classification classes. 263 | pretrained : bool, default False 264 | Whether to load the pretrained weights for model. 265 | root : str, default '~/.torch/models' 266 | Location for keeping the model parameters. 267 | """ 268 | return get_preresnet_cifar(num_classes=num_classes, blocks=110, bottleneck=False, model_name="preresnet110_cifar10", 269 | **kwargs) 270 | 271 | 272 | def preresnet110_cifar100(num_classes=100, **kwargs): 273 | """ 274 | PreResNet-110 model for CIFAR-100 from 'Identity Mappings in Deep Residual Networks,' 275 | https://arxiv.org/abs/1603.05027. 276 | Parameters: 277 | ---------- 278 | num_classes : int, default 100 279 | Number of classification classes. 280 | pretrained : bool, default False 281 | Whether to load the pretrained weights for model. 282 | root : str, default '~/.torch/models' 283 | Location for keeping the model parameters. 284 | """ 285 | return get_preresnet_cifar(num_classes=num_classes, blocks=110, bottleneck=False, 286 | model_name="preresnet110_cifar100", **kwargs) 287 | 288 | 289 | def preresnet110_svhn(num_classes=10, **kwargs): 290 | """ 291 | PreResNet-110 model for SVHN from 'Identity Mappings in Deep Residual Networks,' 292 | https://arxiv.org/abs/1603.05027. 293 | Parameters: 294 | ---------- 295 | num_classes : int, default 10 296 | Number of classification classes. 297 | pretrained : bool, default False 298 | Whether to load the pretrained weights for model. 299 | root : str, default '~/.torch/models' 300 | Location for keeping the model parameters. 301 | """ 302 | return get_preresnet_cifar(num_classes=num_classes, blocks=110, bottleneck=False, model_name="preresnet110_svhn", 303 | **kwargs) 304 | 305 | 306 | def preresnet164bn_cifar10(num_classes=10, **kwargs): 307 | """ 308 | PreResNet-164(BN) model for CIFAR-10 from 'Identity Mappings in Deep Residual Networks,' 309 | https://arxiv.org/abs/1603.05027. 310 | Parameters: 311 | ---------- 312 | num_classes : int, default 10 313 | Number of classification classes. 314 | pretrained : bool, default False 315 | Whether to load the pretrained weights for model. 316 | root : str, default '~/.torch/models' 317 | Location for keeping the model parameters. 318 | """ 319 | return get_preresnet_cifar(num_classes=num_classes, blocks=164, bottleneck=True, 320 | model_name="preresnet164bn_cifar10", **kwargs) 321 | 322 | 323 | def preresnet164bn_cifar100(num_classes=100, **kwargs): 324 | """ 325 | PreResNet-164(BN) model for CIFAR-100 from 'Identity Mappings in Deep Residual Networks,' 326 | https://arxiv.org/abs/1603.05027. 327 | Parameters: 328 | ---------- 329 | num_classes : int, default 100 330 | Number of classification classes. 331 | pretrained : bool, default False 332 | Whether to load the pretrained weights for model. 333 | root : str, default '~/.torch/models' 334 | Location for keeping the model parameters. 335 | """ 336 | return get_preresnet_cifar(num_classes=num_classes, blocks=164, bottleneck=True, 337 | model_name="preresnet164bn_cifar100", **kwargs) 338 | 339 | 340 | def preresnet164bn_svhn(num_classes=10, **kwargs): 341 | """ 342 | PreResNet-164(BN) model for SVHN from 'Identity Mappings in Deep Residual Networks,' 343 | https://arxiv.org/abs/1603.05027. 344 | Parameters: 345 | ---------- 346 | num_classes : int, default 10 347 | Number of classification classes. 348 | pretrained : bool, default False 349 | Whether to load the pretrained weights for model. 350 | root : str, default '~/.torch/models' 351 | Location for keeping the model parameters. 352 | """ 353 | return get_preresnet_cifar(num_classes=num_classes, blocks=164, bottleneck=True, 354 | model_name="preresnet164bn_svhn", **kwargs) 355 | 356 | 357 | def preresnet272bn_cifar10(num_classes=10, **kwargs): 358 | """ 359 | PreResNet-272(BN) model for CIFAR-10 from 'Identity Mappings in Deep Residual Networks,' 360 | https://arxiv.org/abs/1603.05027. 361 | Parameters: 362 | ---------- 363 | num_classes : int, default 10 364 | Number of classification classes. 365 | pretrained : bool, default False 366 | Whether to load the pretrained weights for model. 367 | root : str, default '~/.torch/models' 368 | Location for keeping the model parameters. 369 | """ 370 | return get_preresnet_cifar(num_classes=num_classes, blocks=272, bottleneck=True, 371 | model_name="preresnet272bn_cifar10", **kwargs) 372 | 373 | 374 | def preresnet272bn_cifar100(num_classes=100, **kwargs): 375 | """ 376 | PreResNet-272(BN) model for CIFAR-100 from 'Identity Mappings in Deep Residual Networks,' 377 | https://arxiv.org/abs/1603.05027. 378 | Parameters: 379 | ---------- 380 | num_classes : int, default 100 381 | Number of classification classes. 382 | pretrained : bool, default False 383 | Whether to load the pretrained weights for model. 384 | root : str, default '~/.torch/models' 385 | Location for keeping the model parameters. 386 | """ 387 | return get_preresnet_cifar(num_classes=num_classes, blocks=272, bottleneck=True, 388 | model_name="preresnet272bn_cifar100", **kwargs) 389 | 390 | 391 | def preresnet272bn_svhn(num_classes=10, **kwargs): 392 | """ 393 | PreResNet-272(BN) model for SVHN from 'Identity Mappings in Deep Residual Networks,' 394 | https://arxiv.org/abs/1603.05027. 395 | Parameters: 396 | ---------- 397 | num_classes : int, default 10 398 | Number of classification classes. 399 | pretrained : bool, default False 400 | Whether to load the pretrained weights for model. 401 | root : str, default '~/.torch/models' 402 | Location for keeping the model parameters. 403 | """ 404 | return get_preresnet_cifar(num_classes=num_classes, blocks=272, bottleneck=True, 405 | model_name="preresnet272bn_svhn", **kwargs) 406 | 407 | 408 | def preresnet542bn_cifar10(num_classes=10, **kwargs): 409 | """ 410 | PreResNet-542(BN) model for CIFAR-10 from 'Identity Mappings in Deep Residual Networks,' 411 | https://arxiv.org/abs/1603.05027. 412 | Parameters: 413 | ---------- 414 | num_classes : int, default 10 415 | Number of classification classes. 416 | pretrained : bool, default False 417 | Whether to load the pretrained weights for model. 418 | root : str, default '~/.torch/models' 419 | Location for keeping the model parameters. 420 | """ 421 | return get_preresnet_cifar(num_classes=num_classes, blocks=542, bottleneck=True, 422 | model_name="preresnet542bn_cifar10", **kwargs) 423 | 424 | 425 | def preresnet542bn_cifar100(num_classes=100, **kwargs): 426 | """ 427 | PreResNet-542(BN) model for CIFAR-100 from 'Identity Mappings in Deep Residual Networks,' 428 | https://arxiv.org/abs/1603.05027. 429 | Parameters: 430 | ---------- 431 | num_classes : int, default 100 432 | Number of classification classes. 433 | pretrained : bool, default False 434 | Whether to load the pretrained weights for model. 435 | root : str, default '~/.torch/models' 436 | Location for keeping the model parameters. 437 | """ 438 | return get_preresnet_cifar(num_classes=num_classes, blocks=542, bottleneck=True, 439 | model_name="preresnet542bn_cifar100", **kwargs) 440 | 441 | 442 | def preresnet542bn_svhn(num_classes=10, **kwargs): 443 | """ 444 | PreResNet-542(BN) model for SVHN from 'Identity Mappings in Deep Residual Networks,' 445 | https://arxiv.org/abs/1603.05027. 446 | Parameters: 447 | ---------- 448 | num_classes : int, default 10 449 | Number of classification classes. 450 | pretrained : bool, default False 451 | Whether to load the pretrained weights for model. 452 | root : str, default '~/.torch/models' 453 | Location for keeping the model parameters. 454 | """ 455 | return get_preresnet_cifar(num_classes=num_classes, blocks=542, bottleneck=True, 456 | model_name="preresnet542bn_svhn", **kwargs) 457 | 458 | 459 | def preresnet1001_cifar10(num_classes=10, **kwargs): 460 | """ 461 | PreResNet-1001 model for CIFAR-10 from 'Identity Mappings in Deep Residual Networks,' 462 | https://arxiv.org/abs/1603.05027. 463 | Parameters: 464 | ---------- 465 | num_classes : int, default 10 466 | Number of classification classes. 467 | pretrained : bool, default False 468 | Whether to load the pretrained weights for model. 469 | root : str, default '~/.torch/models' 470 | Location for keeping the model parameters. 471 | """ 472 | return get_preresnet_cifar(num_classes=num_classes, blocks=1001, bottleneck=True, 473 | model_name="preresnet1001_cifar10", **kwargs) 474 | 475 | 476 | def preresnet1001_cifar100(num_classes=100, **kwargs): 477 | """ 478 | PreResNet-1001 model for CIFAR-100 from 'Identity Mappings in Deep Residual Networks,' 479 | https://arxiv.org/abs/1603.05027. 480 | Parameters: 481 | ---------- 482 | num_classes : int, default 100 483 | Number of classification classes. 484 | pretrained : bool, default False 485 | Whether to load the pretrained weights for model. 486 | root : str, default '~/.torch/models' 487 | Location for keeping the model parameters. 488 | """ 489 | return get_preresnet_cifar(num_classes=num_classes, blocks=1001, bottleneck=True, 490 | model_name="preresnet1001_cifar100", **kwargs) 491 | 492 | 493 | def preresnet1001_svhn(num_classes=10, **kwargs): 494 | """ 495 | PreResNet-1001 model for SVHN from 'Identity Mappings in Deep Residual Networks,' 496 | https://arxiv.org/abs/1603.05027. 497 | Parameters: 498 | ---------- 499 | num_classes : int, default 10 500 | Number of classification classes. 501 | pretrained : bool, default False 502 | Whether to load the pretrained weights for model. 503 | root : str, default '~/.torch/models' 504 | Location for keeping the model parameters. 505 | """ 506 | return get_preresnet_cifar(num_classes=num_classes, blocks=1001, bottleneck=True, 507 | model_name="preresnet1001_svhn", **kwargs) 508 | 509 | 510 | def preresnet1202_cifar10(num_classes=10, **kwargs): 511 | """ 512 | PreResNet-1202 model for CIFAR-10 from 'Identity Mappings in Deep Residual Networks,' 513 | https://arxiv.org/abs/1603.05027. 514 | Parameters: 515 | ---------- 516 | num_classes : int, default 10 517 | Number of classification classes. 518 | pretrained : bool, default False 519 | Whether to load the pretrained weights for model. 520 | root : str, default '~/.torch/models' 521 | Location for keeping the model parameters. 522 | """ 523 | return get_preresnet_cifar(num_classes=num_classes, blocks=1202, bottleneck=False, 524 | model_name="preresnet1202_cifar10", **kwargs) 525 | 526 | 527 | def preresnet1202_cifar100(num_classes=100, **kwargs): 528 | """ 529 | PreResNet-1202 model for CIFAR-100 from 'Identity Mappings in Deep Residual Networks,' 530 | https://arxiv.org/abs/1603.05027. 531 | Parameters: 532 | ---------- 533 | num_classes : int, default 100 534 | Number of classification classes. 535 | pretrained : bool, default False 536 | Whether to load the pretrained weights for model. 537 | root : str, default '~/.torch/models' 538 | Location for keeping the model parameters. 539 | """ 540 | return get_preresnet_cifar(num_classes=num_classes, blocks=1202, bottleneck=False, 541 | model_name="preresnet1202_cifar100", **kwargs) 542 | 543 | 544 | def preresnet1202_svhn(num_classes=10, **kwargs): 545 | """ 546 | PreResNet-1202 model for SVHN from 'Identity Mappings in Deep Residual Networks,' 547 | https://arxiv.org/abs/1603.05027. 548 | Parameters: 549 | ---------- 550 | num_classes : int, default 10 551 | Number of classification classes. 552 | pretrained : bool, default False 553 | Whether to load the pretrained weights for model. 554 | root : str, default '~/.torch/models' 555 | Location for keeping the model parameters. 556 | """ 557 | return get_preresnet_cifar(num_classes=num_classes, blocks=1202, bottleneck=False, 558 | model_name="preresnet1202_svhn", **kwargs) 559 | 560 | 561 | def _calc_width(net): 562 | import numpy as np 563 | net_params = filter(lambda p: p.requires_grad, net.parameters()) 564 | weight_count = 0 565 | for param in net_params: 566 | weight_count += np.prod(param.size()) 567 | return weight_count 568 | 569 | 570 | def _test(): 571 | import torch 572 | 573 | pretrained = False 574 | 575 | models = [ 576 | (preresnet20_cifar10, 10), 577 | (preresnet20_cifar100, 100), 578 | (preresnet20_svhn, 10), 579 | (preresnet56_cifar10, 10), 580 | (preresnet56_cifar100, 100), 581 | (preresnet56_svhn, 10), 582 | (preresnet110_cifar10, 10), 583 | (preresnet110_cifar100, 100), 584 | (preresnet110_svhn, 10), 585 | (preresnet164bn_cifar10, 10), 586 | (preresnet164bn_cifar100, 100), 587 | (preresnet164bn_svhn, 10), 588 | (preresnet272bn_cifar10, 10), 589 | (preresnet272bn_cifar100, 100), 590 | (preresnet272bn_svhn, 10), 591 | (preresnet542bn_cifar10, 10), 592 | (preresnet542bn_cifar100, 100), 593 | (preresnet542bn_svhn, 10), 594 | (preresnet1001_cifar10, 10), 595 | (preresnet1001_cifar100, 100), 596 | (preresnet1001_svhn, 10), 597 | (preresnet1202_cifar10, 10), 598 | (preresnet1202_cifar100, 100), 599 | (preresnet1202_svhn, 10), 600 | ] 601 | 602 | for model, num_classes in models: 603 | 604 | net = model(pretrained=pretrained) 605 | 606 | # net.train() 607 | net.eval() 608 | weight_count = _calc_width(net) 609 | print("m={}, {}".format(model.__name__, weight_count)) 610 | assert (model != preresnet20_cifar10 or weight_count == 272282) 611 | assert (model != preresnet20_cifar100 or weight_count == 278132) 612 | assert (model != preresnet20_svhn or weight_count == 272282) 613 | assert (model != preresnet56_cifar10 or weight_count == 855578) 614 | assert (model != preresnet56_cifar100 or weight_count == 861428) 615 | assert (model != preresnet56_svhn or weight_count == 855578) 616 | assert (model != preresnet110_cifar10 or weight_count == 1730522) 617 | assert (model != preresnet110_cifar100 or weight_count == 1736372) 618 | assert (model != preresnet110_svhn or weight_count == 1730522) 619 | assert (model != preresnet164bn_cifar10 or weight_count == 1703258) 620 | assert (model != preresnet164bn_cifar100 or weight_count == 1726388) 621 | assert (model != preresnet164bn_svhn or weight_count == 1703258) 622 | assert (model != preresnet272bn_cifar10 or weight_count == 2816090) 623 | assert (model != preresnet272bn_cifar100 or weight_count == 2839220) 624 | assert (model != preresnet272bn_svhn or weight_count == 2816090) 625 | assert (model != preresnet542bn_cifar10 or weight_count == 5598170) 626 | assert (model != preresnet542bn_cifar100 or weight_count == 5621300) 627 | assert (model != preresnet542bn_svhn or weight_count == 5598170) 628 | assert (model != preresnet1001_cifar10 or weight_count == 10327706) 629 | assert (model != preresnet1001_cifar100 or weight_count == 10350836) 630 | assert (model != preresnet1001_svhn or weight_count == 10327706) 631 | assert (model != preresnet1202_cifar10 or weight_count == 19423834) 632 | assert (model != preresnet1202_cifar100 or weight_count == 19429684) 633 | assert (model != preresnet1202_svhn or weight_count == 19423834) 634 | 635 | x = torch.randn(1, 3, 32, 32) 636 | y = net(x) 637 | y.sum().backward() 638 | assert (tuple(y.size()) == (1, num_classes)) 639 | 640 | 641 | if __name__ == "__main__": 642 | _test() -------------------------------------------------------------------------------- /models/glouncv/preresnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | PreResNet for ImageNet-1K, implemented in PyTorch. 3 | Original paper: 'Identity Mappings in Deep Residual Networks,' https://arxiv.org/abs/1603.05027. 4 | """ 5 | 6 | __all__ = ['PreResNet', 'preresnet10', 'preresnet12', 'preresnet14', 'preresnetbc14b', 'preresnet16', 'preresnet18_wd4', 7 | 'preresnet18_wd2', 'preresnet18_w3d4', 'preresnet18', 'preresnet26', 'preresnetbc26b', 'preresnet34', 8 | 'preresnetbc38b', 'preresnet50', 'preresnet50b', 'preresnet101', 'preresnet101b', 'preresnet152', 9 | 'preresnet152b', 'preresnet200', 'preresnet200b', 'preresnet269b', 'PreResBlock', 'PreResBottleneck', 10 | 'PreResUnit', 'PreResInitBlock', 'PreResActivation'] 11 | 12 | import os 13 | import torch.nn as nn 14 | import torch.nn.init as init 15 | from .common import pre_conv1x1_block, pre_conv3x3_block, conv1x1 16 | 17 | 18 | class PreResBlock(nn.Module): 19 | """ 20 | Simple PreResNet block for residual path in PreResNet unit. 21 | Parameters: 22 | ---------- 23 | in_channels : int 24 | Number of input channels. 25 | out_channels : int 26 | Number of output channels. 27 | stride : int or tuple/list of 2 int 28 | Strides of the convolution. 29 | bias : bool, default False 30 | Whether the layer uses a bias vector. 31 | use_bn : bool, default True 32 | Whether to use BatchNorm layer. 33 | """ 34 | def __init__(self, 35 | in_channels, 36 | out_channels, 37 | stride, 38 | bias=False, 39 | use_bn=True): 40 | super(PreResBlock, self).__init__() 41 | self.conv1 = pre_conv3x3_block( 42 | in_channels=in_channels, 43 | out_channels=out_channels, 44 | stride=stride, 45 | bias=bias, 46 | use_bn=use_bn, 47 | return_preact=True) 48 | self.conv2 = pre_conv3x3_block( 49 | in_channels=out_channels, 50 | out_channels=out_channels, 51 | bias=bias, 52 | use_bn=use_bn) 53 | 54 | def forward(self, x): 55 | x, x_pre_activ = self.conv1(x) 56 | x = self.conv2(x) 57 | return x, x_pre_activ 58 | 59 | 60 | class PreResBottleneck(nn.Module): 61 | """ 62 | PreResNet bottleneck block for residual path in PreResNet unit. 63 | Parameters: 64 | ---------- 65 | in_channels : int 66 | Number of input channels. 67 | out_channels : int 68 | Number of output channels. 69 | stride : int or tuple/list of 2 int 70 | Strides of the convolution. 71 | conv1_stride : bool 72 | Whether to use stride in the first or the second convolution layer of the block. 73 | """ 74 | def __init__(self, 75 | in_channels, 76 | out_channels, 77 | stride, 78 | conv1_stride): 79 | super(PreResBottleneck, self).__init__() 80 | mid_channels = out_channels // 4 81 | 82 | self.conv1 = pre_conv1x1_block( 83 | in_channels=in_channels, 84 | out_channels=mid_channels, 85 | stride=(stride if conv1_stride else 1), 86 | return_preact=True) 87 | self.conv2 = pre_conv3x3_block( 88 | in_channels=mid_channels, 89 | out_channels=mid_channels, 90 | stride=(1 if conv1_stride else stride)) 91 | self.conv3 = pre_conv1x1_block( 92 | in_channels=mid_channels, 93 | out_channels=out_channels) 94 | 95 | def forward(self, x): 96 | x, x_pre_activ = self.conv1(x) 97 | x = self.conv2(x) 98 | x = self.conv3(x) 99 | return x, x_pre_activ 100 | 101 | 102 | class PreResUnit(nn.Module): 103 | """ 104 | PreResNet unit with residual connection. 105 | Parameters: 106 | ---------- 107 | in_channels : int 108 | Number of input channels. 109 | out_channels : int 110 | Number of output channels. 111 | stride : int or tuple/list of 2 int 112 | Strides of the convolution. 113 | bias : bool, default False 114 | Whether the layer uses a bias vector. 115 | use_bn : bool, default True 116 | Whether to use BatchNorm layer. 117 | bottleneck : bool, default True 118 | Whether to use a bottleneck or simple block in units. 119 | conv1_stride : bool, default False 120 | Whether to use stride in the first or the second convolution layer of the block. 121 | """ 122 | def __init__(self, 123 | in_channels, 124 | out_channels, 125 | stride, 126 | bias=False, 127 | use_bn=True, 128 | bottleneck=True, 129 | conv1_stride=False): 130 | super(PreResUnit, self).__init__() 131 | self.resize_identity = (in_channels != out_channels) or (stride != 1) 132 | 133 | if bottleneck: 134 | self.body = PreResBottleneck( 135 | in_channels=in_channels, 136 | out_channels=out_channels, 137 | stride=stride, 138 | conv1_stride=conv1_stride) 139 | else: 140 | self.body = PreResBlock( 141 | in_channels=in_channels, 142 | out_channels=out_channels, 143 | stride=stride, 144 | bias=bias, 145 | use_bn=use_bn) 146 | if self.resize_identity: 147 | self.identity_conv = conv1x1( 148 | in_channels=in_channels, 149 | out_channels=out_channels, 150 | stride=stride, 151 | bias=bias) 152 | 153 | def forward(self, x): 154 | identity = x 155 | x, x_pre_activ = self.body(x) 156 | if self.resize_identity: 157 | identity = self.identity_conv(x_pre_activ) 158 | x = x + identity 159 | return x 160 | 161 | 162 | class PreResInitBlock(nn.Module): 163 | """ 164 | PreResNet specific initial block. 165 | Parameters: 166 | ---------- 167 | in_channels : int 168 | Number of input channels. 169 | out_channels : int 170 | Number of output channels. 171 | """ 172 | def __init__(self, 173 | in_channels, 174 | out_channels): 175 | super(PreResInitBlock, self).__init__() 176 | self.conv = nn.Conv2d( 177 | in_channels=in_channels, 178 | out_channels=out_channels, 179 | kernel_size=7, 180 | stride=2, 181 | padding=3, 182 | bias=False) 183 | self.bn = nn.BatchNorm2d(num_features=out_channels) 184 | self.activ = nn.ReLU(inplace=True) 185 | self.pool = nn.MaxPool2d( 186 | kernel_size=3, 187 | stride=2, 188 | padding=1) 189 | 190 | def forward(self, x): 191 | x = self.conv(x) 192 | x = self.bn(x) 193 | x = self.activ(x) 194 | x = self.pool(x) 195 | return x 196 | 197 | 198 | class PreResActivation(nn.Module): 199 | """ 200 | PreResNet pure pre-activation block without convolution layer. It's used by itself as the final block. 201 | Parameters: 202 | ---------- 203 | in_channels : int 204 | Number of input channels. 205 | """ 206 | def __init__(self, 207 | in_channels): 208 | super(PreResActivation, self).__init__() 209 | self.bn = nn.BatchNorm2d(num_features=in_channels) 210 | self.activ = nn.ReLU(inplace=True) 211 | 212 | def forward(self, x): 213 | x = self.bn(x) 214 | x = self.activ(x) 215 | return x 216 | 217 | 218 | class PreResNet(nn.Module): 219 | """ 220 | PreResNet model from 'Identity Mappings in Deep Residual Networks,' https://arxiv.org/abs/1603.05027. 221 | Parameters: 222 | ---------- 223 | channels : list of list of int 224 | Number of output channels for each unit. 225 | init_block_channels : int 226 | Number of output channels for the initial unit. 227 | bottleneck : bool 228 | Whether to use a bottleneck or simple block in units. 229 | conv1_stride : bool 230 | Whether to use stride in the first or the second convolution layer in units. 231 | in_channels : int, default 3 232 | Number of input channels. 233 | in_size : tuple of two ints, default (224, 224) 234 | Spatial size of the expected input image. 235 | num_classes : int, default 1000 236 | Number of classification classes. 237 | """ 238 | def __init__(self, 239 | channels, 240 | init_block_channels, 241 | bottleneck, 242 | conv1_stride, 243 | in_channels=3, 244 | in_size=(224, 224), 245 | num_classes=1000): 246 | super(PreResNet, self).__init__() 247 | self.in_size = in_size 248 | self.num_classes = num_classes 249 | 250 | self.features = nn.Sequential() 251 | self.features.add_module("init_block", PreResInitBlock( 252 | in_channels=in_channels, 253 | out_channels=init_block_channels)) 254 | in_channels = init_block_channels 255 | for i, channels_per_stage in enumerate(channels): 256 | stage = nn.Sequential() 257 | for j, out_channels in enumerate(channels_per_stage): 258 | stride = 1 if (i == 0) or (j != 0) else 2 259 | stage.add_module("unit{}".format(j + 1), PreResUnit( 260 | in_channels=in_channels, 261 | out_channels=out_channels, 262 | stride=stride, 263 | bottleneck=bottleneck, 264 | conv1_stride=conv1_stride)) 265 | in_channels = out_channels 266 | self.features.add_module("stage{}".format(i + 1), stage) 267 | self.features.add_module("post_activ", PreResActivation(in_channels=in_channels)) 268 | self.features.add_module("final_pool", nn.AvgPool2d( 269 | kernel_size=7, 270 | stride=1)) 271 | 272 | self.output = nn.Linear( 273 | in_features=in_channels, 274 | out_features=num_classes) 275 | 276 | self._init_params() 277 | 278 | def _init_params(self): 279 | for name, module in self.named_modules(): 280 | if isinstance(module, nn.Conv2d): 281 | init.kaiming_uniform_(module.weight) 282 | if module.bias is not None: 283 | init.constant_(module.bias, 0) 284 | 285 | def forward(self, x): 286 | x = self.features(x) 287 | x = x.view(x.size(0), -1) 288 | x = self.output(x) 289 | return x 290 | 291 | 292 | def get_preresnet(blocks, 293 | bottleneck=None, 294 | conv1_stride=True, 295 | width_scale=1.0, 296 | model_name=None, 297 | pretrained=False, 298 | root=os.path.join("~", ".torch", "models"), 299 | **kwargs): 300 | """ 301 | Create PreResNet model with specific parameters. 302 | Parameters: 303 | ---------- 304 | blocks : int 305 | Number of blocks. 306 | bottleneck : bool, default None 307 | Whether to use a bottleneck or simple block in units. 308 | conv1_stride : bool, default True 309 | Whether to use stride in the first or the second convolution layer in units. 310 | width_scale : float, default 1.0 311 | Scale factor for width of layers. 312 | model_name : str or None, default None 313 | Model name for loading pretrained model. 314 | pretrained : bool, default False 315 | Whether to load the pretrained weights for model. 316 | root : str, default '~/.torch/models' 317 | Location for keeping the model parameters. 318 | """ 319 | if bottleneck is None: 320 | bottleneck = (blocks >= 50) 321 | 322 | if blocks == 10: 323 | layers = [1, 1, 1, 1] 324 | elif blocks == 12: 325 | layers = [2, 1, 1, 1] 326 | elif blocks == 14 and not bottleneck: 327 | layers = [2, 2, 1, 1] 328 | elif (blocks == 14) and bottleneck: 329 | layers = [1, 1, 1, 1] 330 | elif blocks == 16: 331 | layers = [2, 2, 2, 1] 332 | elif blocks == 18: 333 | layers = [2, 2, 2, 2] 334 | elif (blocks == 26) and not bottleneck: 335 | layers = [3, 3, 3, 3] 336 | elif (blocks == 26) and bottleneck: 337 | layers = [2, 2, 2, 2] 338 | elif blocks == 34: 339 | layers = [3, 4, 6, 3] 340 | elif (blocks == 38) and bottleneck: 341 | layers = [3, 3, 3, 3] 342 | elif blocks == 50: 343 | layers = [3, 4, 6, 3] 344 | elif blocks == 101: 345 | layers = [3, 4, 23, 3] 346 | elif blocks == 152: 347 | layers = [3, 8, 36, 3] 348 | elif blocks == 200: 349 | layers = [3, 24, 36, 3] 350 | elif blocks == 269: 351 | layers = [3, 30, 48, 8] 352 | else: 353 | raise ValueError("Unsupported PreResNet with number of blocks: {}".format(blocks)) 354 | 355 | if bottleneck: 356 | assert (sum(layers) * 3 + 2 == blocks) 357 | else: 358 | assert (sum(layers) * 2 + 2 == blocks) 359 | 360 | init_block_channels = 64 361 | channels_per_layers = [64, 128, 256, 512] 362 | 363 | if bottleneck: 364 | bottleneck_factor = 4 365 | channels_per_layers = [ci * bottleneck_factor for ci in channels_per_layers] 366 | 367 | channels = [[ci] * li for (ci, li) in zip(channels_per_layers, layers)] 368 | 369 | if width_scale != 1.0: 370 | channels = [[int(cij * width_scale) if (i != len(channels) - 1) or (j != len(ci) - 1) else cij 371 | for j, cij in enumerate(ci)] for i, ci in enumerate(channels)] 372 | init_block_channels = int(init_block_channels * width_scale) 373 | 374 | net = PreResNet( 375 | channels=channels, 376 | init_block_channels=init_block_channels, 377 | bottleneck=bottleneck, 378 | conv1_stride=conv1_stride, 379 | **kwargs) 380 | 381 | if pretrained: 382 | if (model_name is None) or (not model_name): 383 | raise ValueError("Parameter `model_name` should be properly initialized for loading pretrained model.") 384 | from .model_store import download_model 385 | download_model( 386 | net=net, 387 | model_name=model_name, 388 | local_model_store_dir_path=root) 389 | 390 | return net 391 | 392 | 393 | def preresnet10(**kwargs): 394 | """ 395 | PreResNet-10 model from 'Identity Mappings in Deep Residual Networks,' https://arxiv.org/abs/1603.05027. 396 | It's an experimental model. 397 | Parameters: 398 | ---------- 399 | pretrained : bool, default False 400 | Whether to load the pretrained weights for model. 401 | root : str, default '~/.torch/models' 402 | Location for keeping the model parameters. 403 | """ 404 | return get_preresnet(blocks=10, model_name="preresnet10", **kwargs) 405 | 406 | 407 | def preresnet12(**kwargs): 408 | """ 409 | PreResNet-12 model from 'Identity Mappings in Deep Residual Networks,' https://arxiv.org/abs/1603.05027. 410 | It's an experimental model. 411 | Parameters: 412 | ---------- 413 | pretrained : bool, default False 414 | Whether to load the pretrained weights for model. 415 | root : str, default '~/.torch/models' 416 | Location for keeping the model parameters. 417 | """ 418 | return get_preresnet(blocks=12, model_name="preresnet12", **kwargs) 419 | 420 | 421 | def preresnet14(**kwargs): 422 | """ 423 | PreResNet-14 model from 'Identity Mappings in Deep Residual Networks,' https://arxiv.org/abs/1603.05027. 424 | It's an experimental model. 425 | Parameters: 426 | ---------- 427 | pretrained : bool, default False 428 | Whether to load the pretrained weights for model. 429 | root : str, default '~/.torch/models' 430 | Location for keeping the model parameters. 431 | """ 432 | return get_preresnet(blocks=14, model_name="preresnet14", **kwargs) 433 | 434 | 435 | def preresnetbc14b(**kwargs): 436 | """ 437 | PreResNet-BC-14b model from 'Identity Mappings in Deep Residual Networks,' https://arxiv.org/abs/1603.05027. 438 | It's an experimental model (bottleneck compressed). 439 | Parameters: 440 | ---------- 441 | pretrained : bool, default False 442 | Whether to load the pretrained weights for model. 443 | root : str, default '~/.torch/models' 444 | Location for keeping the model parameters. 445 | """ 446 | return get_preresnet(blocks=14, bottleneck=True, conv1_stride=False, model_name="preresnetbc14b", **kwargs) 447 | 448 | 449 | def preresnet16(**kwargs): 450 | """ 451 | PreResNet-16 model from 'Identity Mappings in Deep Residual Networks,' https://arxiv.org/abs/1603.05027. 452 | It's an experimental model. 453 | Parameters: 454 | ---------- 455 | pretrained : bool, default False 456 | Whether to load the pretrained weights for model. 457 | root : str, default '~/.torch/models' 458 | Location for keeping the model parameters. 459 | """ 460 | return get_preresnet(blocks=16, model_name="preresnet16", **kwargs) 461 | 462 | 463 | def preresnet18_wd4(**kwargs): 464 | """ 465 | PreResNet-18 model with 0.25 width scale from 'Identity Mappings in Deep Residual Networks,' 466 | https://arxiv.org/abs/1603.05027. It's an experimental model. 467 | Parameters: 468 | ---------- 469 | pretrained : bool, default False 470 | Whether to load the pretrained weights for model. 471 | root : str, default '~/.torch/models' 472 | Location for keeping the model parameters. 473 | """ 474 | return get_preresnet(blocks=18, width_scale=0.25, model_name="preresnet18_wd4", **kwargs) 475 | 476 | 477 | def preresnet18_wd2(**kwargs): 478 | """ 479 | PreResNet-18 model with 0.5 width scale from 'Identity Mappings in Deep Residual Networks,' 480 | https://arxiv.org/abs/1603.05027. It's an experimental model. 481 | Parameters: 482 | ---------- 483 | pretrained : bool, default False 484 | Whether to load the pretrained weights for model. 485 | root : str, default '~/.torch/models' 486 | Location for keeping the model parameters. 487 | """ 488 | return get_preresnet(blocks=18, width_scale=0.5, model_name="preresnet18_wd2", **kwargs) 489 | 490 | 491 | def preresnet18_w3d4(**kwargs): 492 | """ 493 | PreResNet-18 model with 0.75 width scale from 'Identity Mappings in Deep Residual Networks,' 494 | https://arxiv.org/abs/1603.05027. It's an experimental model. 495 | Parameters: 496 | ---------- 497 | pretrained : bool, default False 498 | Whether to load the pretrained weights for model. 499 | root : str, default '~/.torch/models' 500 | Location for keeping the model parameters. 501 | """ 502 | return get_preresnet(blocks=18, width_scale=0.75, model_name="preresnet18_w3d4", **kwargs) 503 | 504 | 505 | def preresnet18(**kwargs): 506 | """ 507 | PreResNet-18 model from 'Identity Mappings in Deep Residual Networks,' https://arxiv.org/abs/1603.05027. 508 | Parameters: 509 | ---------- 510 | pretrained : bool, default False 511 | Whether to load the pretrained weights for model. 512 | root : str, default '~/.torch/models' 513 | Location for keeping the model parameters. 514 | """ 515 | return get_preresnet(blocks=18, model_name="preresnet18", **kwargs) 516 | 517 | 518 | def preresnet26(**kwargs): 519 | """ 520 | PreResNet-26 model from 'Identity Mappings in Deep Residual Networks,' https://arxiv.org/abs/1603.05027. 521 | It's an experimental model. 522 | Parameters: 523 | ---------- 524 | pretrained : bool, default False 525 | Whether to load the pretrained weights for model. 526 | root : str, default '~/.torch/models' 527 | Location for keeping the model parameters. 528 | """ 529 | return get_preresnet(blocks=26, bottleneck=False, model_name="preresnet26", **kwargs) 530 | 531 | 532 | def preresnetbc26b(**kwargs): 533 | """ 534 | PreResNet-BC-26b model from 'Identity Mappings in Deep Residual Networks,' https://arxiv.org/abs/1603.05027. 535 | It's an experimental model (bottleneck compressed). 536 | Parameters: 537 | ---------- 538 | pretrained : bool, default False 539 | Whether to load the pretrained weights for model. 540 | root : str, default '~/.torch/models' 541 | Location for keeping the model parameters. 542 | """ 543 | return get_preresnet(blocks=26, bottleneck=True, conv1_stride=False, model_name="preresnetbc26b", **kwargs) 544 | 545 | 546 | def preresnet34(**kwargs): 547 | """ 548 | PreResNet-34 model from 'Identity Mappings in Deep Residual Networks,' https://arxiv.org/abs/1603.05027. 549 | Parameters: 550 | ---------- 551 | pretrained : bool, default False 552 | Whether to load the pretrained weights for model. 553 | root : str, default '~/.torch/models' 554 | Location for keeping the model parameters. 555 | """ 556 | return get_preresnet(blocks=34, model_name="preresnet34", **kwargs) 557 | 558 | 559 | def preresnetbc38b(**kwargs): 560 | """ 561 | PreResNet-BC-38b model from 'Identity Mappings in Deep Residual Networks,' https://arxiv.org/abs/1603.05027. 562 | It's an experimental model (bottleneck compressed). 563 | Parameters: 564 | ---------- 565 | pretrained : bool, default False 566 | Whether to load the pretrained weights for model. 567 | root : str, default '~/.torch/models' 568 | Location for keeping the model parameters. 569 | """ 570 | return get_preresnet(blocks=38, bottleneck=True, conv1_stride=False, model_name="preresnetbc38b", **kwargs) 571 | 572 | 573 | def preresnet50(**kwargs): 574 | """ 575 | PreResNet-50 model from 'Identity Mappings in Deep Residual Networks,' https://arxiv.org/abs/1603.05027. 576 | Parameters: 577 | ---------- 578 | pretrained : bool, default False 579 | Whether to load the pretrained weights for model. 580 | root : str, default '~/.torch/models' 581 | Location for keeping the model parameters. 582 | """ 583 | return get_preresnet(blocks=50, model_name="preresnet50", **kwargs) 584 | 585 | 586 | def preresnet50b(**kwargs): 587 | """ 588 | PreResNet-50 model with stride at the second convolution in bottleneck block from 'Identity Mappings in Deep 589 | Residual Networks,' https://arxiv.org/abs/1603.05027. 590 | Parameters: 591 | ---------- 592 | pretrained : bool, default False 593 | Whether to load the pretrained weights for model. 594 | root : str, default '~/.torch/models' 595 | Location for keeping the model parameters. 596 | """ 597 | return get_preresnet(blocks=50, conv1_stride=False, model_name="preresnet50b", **kwargs) 598 | 599 | 600 | def preresnet101(**kwargs): 601 | """ 602 | PreResNet-101 model from 'Identity Mappings in Deep Residual Networks,' https://arxiv.org/abs/1603.05027. 603 | Parameters: 604 | ---------- 605 | pretrained : bool, default False 606 | Whether to load the pretrained weights for model. 607 | root : str, default '~/.torch/models' 608 | Location for keeping the model parameters. 609 | """ 610 | return get_preresnet(blocks=101, model_name="preresnet101", **kwargs) 611 | 612 | 613 | def preresnet101b(**kwargs): 614 | """ 615 | PreResNet-101 model with stride at the second convolution in bottleneck block from 'Identity Mappings in Deep 616 | Residual Networks,' https://arxiv.org/abs/1603.05027. 617 | Parameters: 618 | ---------- 619 | pretrained : bool, default False 620 | Whether to load the pretrained weights for model. 621 | root : str, default '~/.torch/models' 622 | Location for keeping the model parameters. 623 | """ 624 | return get_preresnet(blocks=101, conv1_stride=False, model_name="preresnet101b", **kwargs) 625 | 626 | 627 | def preresnet152(**kwargs): 628 | """ 629 | PreResNet-152 model from 'Identity Mappings in Deep Residual Networks,' https://arxiv.org/abs/1603.05027. 630 | Parameters: 631 | ---------- 632 | pretrained : bool, default False 633 | Whether to load the pretrained weights for model. 634 | root : str, default '~/.torch/models' 635 | Location for keeping the model parameters. 636 | """ 637 | return get_preresnet(blocks=152, model_name="preresnet152", **kwargs) 638 | 639 | 640 | def preresnet152b(**kwargs): 641 | """ 642 | PreResNet-152 model with stride at the second convolution in bottleneck block from 'Identity Mappings in Deep 643 | Residual Networks,' https://arxiv.org/abs/1603.05027. 644 | Parameters: 645 | ---------- 646 | pretrained : bool, default False 647 | Whether to load the pretrained weights for model. 648 | root : str, default '~/.torch/models' 649 | Location for keeping the model parameters. 650 | """ 651 | return get_preresnet(blocks=152, conv1_stride=False, model_name="preresnet152b", **kwargs) 652 | 653 | 654 | def preresnet200(**kwargs): 655 | """ 656 | PreResNet-200 model from 'Identity Mappings in Deep Residual Networks,' https://arxiv.org/abs/1603.05027. 657 | Parameters: 658 | ---------- 659 | pretrained : bool, default False 660 | Whether to load the pretrained weights for model. 661 | root : str, default '~/.torch/models' 662 | Location for keeping the model parameters. 663 | """ 664 | return get_preresnet(blocks=200, model_name="preresnet200", **kwargs) 665 | 666 | 667 | def preresnet200b(**kwargs): 668 | """ 669 | PreResNet-200 model with stride at the second convolution in bottleneck block from 'Identity Mappings in Deep 670 | Residual Networks,' https://arxiv.org/abs/1603.05027. 671 | Parameters: 672 | ---------- 673 | pretrained : bool, default False 674 | Whether to load the pretrained weights for model. 675 | root : str, default '~/.torch/models' 676 | Location for keeping the model parameters. 677 | """ 678 | return get_preresnet(blocks=200, conv1_stride=False, model_name="preresnet200b", **kwargs) 679 | 680 | 681 | def preresnet269b(**kwargs): 682 | """ 683 | PreResNet-269 model with stride at the second convolution in bottleneck block from 'Identity Mappings in Deep 684 | Residual Networks,' https://arxiv.org/abs/1603.05027. 685 | Parameters: 686 | ---------- 687 | pretrained : bool, default False 688 | Whether to load the pretrained weights for model. 689 | root : str, default '~/.torch/models' 690 | Location for keeping the model parameters. 691 | """ 692 | return get_preresnet(blocks=269, conv1_stride=False, model_name="preresnet269b", **kwargs) 693 | 694 | 695 | def _calc_width(net): 696 | import numpy as np 697 | net_params = filter(lambda p: p.requires_grad, net.parameters()) 698 | weight_count = 0 699 | for param in net_params: 700 | weight_count += np.prod(param.size()) 701 | return weight_count 702 | 703 | 704 | def _test(): 705 | import torch 706 | 707 | pretrained = False 708 | 709 | models = [ 710 | preresnet10, 711 | preresnet12, 712 | preresnet14, 713 | preresnetbc14b, 714 | preresnet16, 715 | preresnet18_wd4, 716 | preresnet18_wd2, 717 | preresnet18_w3d4, 718 | preresnet18, 719 | preresnet26, 720 | preresnetbc26b, 721 | preresnet34, 722 | preresnetbc38b, 723 | preresnet50, 724 | preresnet50b, 725 | preresnet101, 726 | preresnet101b, 727 | preresnet152, 728 | preresnet152b, 729 | preresnet200, 730 | preresnet200b, 731 | preresnet269b, 732 | ] 733 | 734 | for model in models: 735 | 736 | net = model(pretrained=pretrained) 737 | 738 | # net.train() 739 | net.eval() 740 | weight_count = _calc_width(net) 741 | print("m={}, {}".format(model.__name__, weight_count)) 742 | assert (model != preresnet10 or weight_count == 5417128) 743 | assert (model != preresnet12 or weight_count == 5491112) 744 | assert (model != preresnet14 or weight_count == 5786536) 745 | assert (model != preresnetbc14b or weight_count == 10057384) 746 | assert (model != preresnet16 or weight_count == 6967208) 747 | assert (model != preresnet18_wd4 or weight_count == 3935960) 748 | assert (model != preresnet18_wd2 or weight_count == 5802440) 749 | assert (model != preresnet18_w3d4 or weight_count == 8473784) 750 | assert (model != preresnet18 or weight_count == 11687848) 751 | assert (model != preresnet26 or weight_count == 17958568) 752 | assert (model != preresnetbc26b or weight_count == 15987624) 753 | assert (model != preresnet34 or weight_count == 21796008) 754 | assert (model != preresnetbc38b or weight_count == 21917864) 755 | assert (model != preresnet50 or weight_count == 25549480) 756 | assert (model != preresnet50b or weight_count == 25549480) 757 | assert (model != preresnet101 or weight_count == 44541608) 758 | assert (model != preresnet101b or weight_count == 44541608) 759 | assert (model != preresnet152 or weight_count == 60185256) 760 | assert (model != preresnet152b or weight_count == 60185256) 761 | assert (model != preresnet200 or weight_count == 64666280) 762 | assert (model != preresnet200b or weight_count == 64666280) 763 | assert (model != preresnet269b or weight_count == 102065832) 764 | 765 | x = torch.randn(1, 3, 224, 224) 766 | y = net(x) 767 | y.sum().backward() 768 | assert (tuple(y.size()) == (1, 1000)) 769 | 770 | 771 | if __name__ == "__main__": 772 | _test() 773 | --------------------------------------------------------------------------------