├── models
├── __init__.py
├── glouncv
│ ├── __init__.py
│ ├── mobilenetv2.py
│ ├── alexnet.py
│ ├── alexnet_bn.py
│ ├── preresnet_cifar.py
│ └── preresnet.py
├── imagenet_presnet.py
└── cifar100_presnet.py
├── monitors
├── __init__.py
└── metrics.py
├── quantizer
├── __init__.py
└── uniq.py
├── requirements.txt
├── lr_scheduler.py
├── README.md
├── main.py
└── utils.py
/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/monitors/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/quantizer/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/glouncv/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==0.10.0
2 | cachetools==4.1.1
3 | certifi==2020.6.20
4 | chardet==3.0.4
5 | future==0.18.2
6 | google-auth==1.21.2
7 | google-auth-oauthlib==0.4.1
8 | grpcio==1.32.0
9 | idna==2.10
10 | importlib-metadata==1.7.0
11 | Markdown==3.2.2
12 | numpy==1.18.5
13 | oauthlib==3.1.0
14 | Pillow==7.2.0
15 | pkg-resources==0.0.0
16 | protobuf==3.13.0
17 | pyasn1==0.4.8
18 | pyasn1-modules==0.2.8
19 | requests==2.24.0
20 | requests-oauthlib==1.3.0
21 | rsa==4.6
22 | six==1.15.0
23 | tensorboard==2.3.0
24 | tensorboard-plugin-wit==1.7.0
25 | torch==1.4.0
26 | torchvision==0.5.0
27 | urllib3==1.25.10
28 | Werkzeug==1.0.1
29 | zipp==1.2.0
30 |
--------------------------------------------------------------------------------
/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | class ConstantWarmupScheduler(object):
2 |
3 | def __init__(self, optimizer, min_lr=0.001, total_epoch=5, after_lr=0.01, after_scheduler=None):
4 | self.optimizer = optimizer
5 | self.total_epoch = total_epoch
6 | self.min_lr = min_lr
7 | self.after_lr = after_lr
8 | self.after_scheduler = after_scheduler
9 | self._current_epoch = 0
10 | super(ConstantWarmupScheduler, self).__init__()
11 |
12 | def step(self):
13 | if self._current_epoch < self.total_epoch:
14 | for param_group in self.optimizer.param_groups:
15 | param_group['lr'] = self.min_lr
16 | else:
17 | if self._current_epoch == self.total_epoch:
18 | for param_group in self.optimizer.param_groups:
19 | param_group['lr'] = self.after_lr
20 |
21 | self.after_scheduler.step()
22 | self._current_epoch += 1
23 |
24 |
25 | def state_dict(self):
26 | self.after_scheduler.state_dict() \
27 | if self._current_epoch >= self.total_epoch else None
--------------------------------------------------------------------------------
/monitors/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def write_metrics(writer, epoch, net, wt_optimizer, train_loss, train_acc1, test_loss, test_acc1, prefix="Train"):
6 |
7 | writer.add_scalar('%s_Train/Loss' % (prefix), train_loss, epoch)
8 | writer.add_scalar('%s_Train/Acc1'% (prefix), train_acc1, epoch)
9 | writer.add_scalar('%s_Test/Loss' % (prefix), test_loss, epoch)
10 | writer.add_scalar('%s_Test/Acc1' % (prefix), test_acc1, epoch)
11 | writer.add_scalar('%s_Train/LR' % (prefix), wt_optimizer.param_groups[0]['lr'], epoch)
12 |
13 | for n, param in net.named_parameters():
14 | if ".delta" in n:
15 | if param.ndim == 0:
16 | writer.add_scalar('{}_Train/delta_{}'.format(prefix, n), param, epoch)
17 | else:
18 | writer.add_histogram('{}_Train/delta_{}'.format(prefix, n), param, epoch)
19 |
20 | # Weight Histogram
21 | for n, m in net.named_modules():
22 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear) or isinstance(m, nn.BatchNorm2d):
23 | writer.add_histogram('{}_Train/{}.weight'.format(prefix, n), m.weight, epoch)
24 | writer.add_histogram('{}_Train/{}.weight.grad'.format(prefix, n), m.weight.grad, epoch)
25 |
26 | if m.bias != None:
27 | writer.add_histogram('{}_Train/{}.bias'.format(prefix, n), m.bias, epoch)
28 | writer.add_histogram('{}_Train/{}.bias.grad'.format(prefix, n), m.bias.grad, epoch)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # UniQ
2 |
3 | This repo contains the code and data of the following paper:
4 |
5 | **Training Multi-bit Quantized and Binarized Networks with A Learnable Symmetric Quantizer**
6 |
7 |
8 | ## Prerequisites
9 | Use the package manager [pip](https://pip.pypa.io/en/stable/) to install the library dependencies
10 |
11 |
12 | ```bash
13 | pip install -r requirements.txt
14 | ```
15 |
16 | ## Training
17 |
18 | ```bash
19 | export CUDA_VISIBLE_DEVICES=[GPU_IDs] && \
20 | python main.py --train_id [training_id] \
21 | --lr [learning_rate_value] --wd [weight_decay_value] --batch-size [batch_size] \
22 | --dataset [dataset_name] --arch [architecture_name] \
23 | --bit [bit-width] --epoch [training_epochs] \
24 | --data_root [path_to_dataset] \
25 | --init_from [path_to_pretrained_model] \
26 | --train_scheme uniq --quant_mode [quantization_mode] \
27 | --num_calibration_batches [number_of_batches_for_initialization]
28 | ```
29 |
30 |
31 |
32 | ## Testing
33 |
34 | ```bash
35 | export CUDA_VISIBLE_DEVICES=[GPU_IDs] && \
36 | python main.py --train_id [training_id] \
37 | --batch-size [batch_size] \
38 | --dataset [dataset_name] --arch [architecture_name] \
39 | --bit [bit-width]
40 | --data_root [path_to_dataset] \
41 | --init_from [path_to_trained_model] \
42 | --train_scheme uniq --quant_mode [quantization_mode] \
43 | -e
44 | ```
45 |
46 |
47 | | Arguments | Description |
48 | | ------------- | ------------- |
49 | | `--train_id` | ID for experiment management (arbitrary). |
50 | | `--lr` | Learning rate |
51 | | `--wd` | Weight decay |
52 | | `--batch_size` | Batch size |
53 | | `--dataset` | Dataset name
Possible values: `cifar100`, `imagenet` |
54 | | `--data_root` | Path to the dataset directory |
55 | | `--arch` | Architecture name
Possible values: `presnet18`, `presnet32`, `glouncv-presnet34`, `glouncv-mobilenetv2_w1` |
56 | | `--bit` | Bit-width (W/A) |
57 | | `--epoch` | Number of training epochs |
58 | | `--init_from` | Path to the pretrained model. |
59 | | `--train_scheme` | Training scheme
Possible values: `fp32` (normal training), `uniq` (low-bit quantization training) |
60 | | `--quant_mode` | Quantization mode
Possible values: `layer_wise` (layer-wise quantization), `kernel-wise` (kernel-wise quantization) |
61 | | `--num_calibration_batches` | Number of batches used for initialization |
62 |
63 |
64 | For each experiment details and hyperparameter setting, we refer the readers to the paper and `main.py` file.
65 |
66 | ## Citation
67 | If you find RBNN useful in your research, please consider citing:
68 | ```
69 | @ARTICLE{9383003,
70 | author={P. {Pham} and J. A. {Abraham} and J. {Chung}},
71 | journal={IEEE Access},
72 | title={Training Multi-Bit Quantized and Binarized Networks with a Learnable Symmetric Quantizer},
73 | year={2021},
74 | volume={9},
75 | number={},
76 | pages={47194-47203},
77 | doi={10.1109/ACCESS.2021.3067889}}
78 | ```
79 |
80 | ## Contributing
81 | Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change.
82 |
83 | Please make sure to update tests as appropriate.
84 |
--------------------------------------------------------------------------------
/models/imagenet_presnet.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 |
5 | # based on
6 | https://github.com/kuangliu/pytorch-cifar/blob/master/models/preact_resnet.py
7 | https://raw.githubusercontent.com/NVlabs/Taylor_pruning/b21ed61ac41cb59a9879a95350bd752ab26ffd91/models/preact_resnet.py
8 | """
9 |
10 | '''Pre-activation ResNet in PyTorch.
11 | Reference:
12 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
13 | Identity Mappings in Deep Residual Networks. arXiv:1603.05027
14 | '''
15 |
16 | import torch
17 | import torch.nn as nn
18 | import torch.nn.functional as F
19 |
20 |
21 | class PreActBlock(nn.Module):
22 | '''Pre-activation version of the BasicBlock.'''
23 | expansion = 1
24 |
25 | def __init__(self, in_planes, planes, stride=1):
26 | super(PreActBlock, self).__init__()
27 | self.bn1 = nn.BatchNorm2d(in_planes)
28 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
29 | self.relu1 = nn.ReLU(inplace=True)
30 |
31 | self.bn2 = nn.BatchNorm2d(planes)
32 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
33 | self.relu2 = nn.ReLU(inplace=True)
34 |
35 | if stride != 1 or in_planes != self.expansion*planes:
36 | self.shortcut = nn.Sequential(
37 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
38 | )
39 |
40 | def forward(self, x):
41 | out = self.relu1(self.bn1(x))
42 |
43 | if hasattr(self, 'shortcut'):
44 | shortcut = self.shortcut(out)
45 | else:
46 | shortcut = x
47 |
48 | out = self.conv1(out)
49 | out = self.bn2(out)
50 | out = self.relu2(out)
51 | out = self.conv2(out)
52 | out = out + shortcut
53 | return out
54 |
55 |
56 | class PreActBottleneck(nn.Module):
57 | '''Pre-activation version of the original Bottleneck module.'''
58 | expansion = 4
59 |
60 | def __init__(self, in_planes, planes, stride=1):
61 | super(PreActBottleneck, self).__init__()
62 |
63 | self.bn1 = nn.BatchNorm2d(in_planes)
64 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
65 |
66 | self.bn2 = nn.BatchNorm2d(planes)
67 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
68 |
69 | self.bn3 = nn.BatchNorm2d(planes)
70 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
71 |
72 | if stride != 1 or in_planes != self.expansion*planes:
73 | self.shortcut = nn.Sequential(
74 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False))
75 |
76 | def forward(self, x):
77 | out = F.relu(self.bn1(x))
78 | input_out = out
79 |
80 | out = self.conv1(out)
81 | out = self.bn2(out)
82 | out = F.relu(out)
83 |
84 | out = self.conv2(out)
85 | out = self.bn3(out)
86 |
87 | out = F.relu(out)
88 | out = self.conv3(out)
89 |
90 | if hasattr(self, 'shortcut'):
91 | shortcut = self.shortcut(input_out)
92 | else:
93 | shortcut = x
94 |
95 | out = out + shortcut
96 | return out
97 |
98 |
99 | class PreActResNet(nn.Module):
100 | def __init__(self, block, num_blocks, num_classes=1000):
101 | super(PreActResNet, self).__init__()
102 |
103 | self.in_planes = 64
104 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
105 | bias=False)
106 | self.bn1 = nn.BatchNorm2d(64)
107 | self.relu1 = nn.ReLU(inplace=True)
108 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
109 |
110 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
111 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
112 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
113 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
114 |
115 | # Pre-activation
116 | self.bn2 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
117 | self.relu2 = nn.ReLU(inplace=True)
118 |
119 | self.avgpool = nn.AvgPool2d(7, stride=1)
120 | self.fc = nn.Linear(512 * block.expansion, num_classes)
121 |
122 | for m in self.modules():
123 | if isinstance(m, nn.Conv2d):
124 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
125 | if isinstance(m, nn.BatchNorm2d):
126 | m.bias.data.zero_()
127 |
128 | def _make_layer(self, block, planes, num_blocks, stride):
129 | strides = [stride] + [1]*(num_blocks-1)
130 | layers = []
131 | for stride in strides:
132 | layers.append(block(self.in_planes, planes, stride))
133 | self.in_planes = planes * block.expansion
134 | return nn.Sequential(*layers)
135 |
136 | def forward(self, x):
137 | out = self.conv1(x)
138 | out = self.bn1(out)
139 | out = self.relu1(out)
140 | out = self.maxpool(out)
141 |
142 | out = self.layer1(out)
143 | out = self.layer2(out)
144 | out = self.layer3(out)
145 | out = self.layer4(out)
146 |
147 | out = self.bn2(out)
148 | out = self.relu2(out)
149 |
150 | out = self.avgpool(out)
151 |
152 | out = out.view(out.size(0), -1)
153 | out = self.fc(out)
154 | return out
155 |
156 |
157 | def PreActResNet18():
158 | return PreActResNet(PreActBlock, [2,2,2,2])
159 |
160 | def PreActResNet34():
161 | return PreActResNet(PreActBlock, [3,4,6,3])
162 |
163 | def PreActResNet50():
164 | return PreActResNet(PreActBottleneck, [3,4,6,3])
165 |
166 | def PreActResNet101():
167 | return PreActResNet(PreActBottleneck, [3,4,23,3])
168 |
169 | def PreActResNet152():
170 | return PreActResNet(PreActBottleneck, [3,8,36,3])
171 |
172 |
173 | def test():
174 | net = PreActResNet18()
175 | y = net((torch.randn(1,3,32,32)))
176 | print(y.size())
177 |
178 | # test()
179 |
--------------------------------------------------------------------------------
/quantizer/uniq.py:
--------------------------------------------------------------------------------
1 | import torch as t
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 | import torch
5 | import math
6 |
7 |
8 | def grad_scale(x, scale):
9 | y = x
10 | y_grad = x * scale
11 | return (y - y_grad).detach() + y_grad
12 |
13 |
14 | def round_pass(x):
15 | y = x.round()
16 | y_grad = x
17 | return (y - y_grad).detach() + y_grad
18 |
19 |
20 | class STATUS(object):
21 | INIT_READY = 0
22 | INIT_DONE = 1
23 | NOT_READY = -1
24 |
25 |
26 |
27 | class UniQQuantizer(t.nn.Module):
28 | def __init__(self, bit, is_activation=False, **kwargs):
29 | super(UniQQuantizer,self).__init__()
30 |
31 | self.bit = bit
32 | self.is_activation = is_activation
33 | self.delta_normal = {1: 1.595769121605729, 2: 0.9956866859435065, 3: 0.5860194414434872, 4: 0.33520061219993685, 5: 0.18813879027991698, 6: 0.10406300944201481, 7: 0.05686767238235839, 8: 0.03076238758025524, 9: 0.016498958773102656}
34 | self.delta_positive_normal = {1: 1.22399153, 2: 0.65076985, 3: 0.35340955, 4: 0.19324868, 5: 0.10548752, 6: 0.0572659, 7: 0.03087133, 8: 0.01652923, 9: 0.00879047}
35 | self.quant_mode = kwargs.get('quant_mode', 'layer_wise')
36 | self.layer_type = kwargs.get('layer_type', 'conv')
37 |
38 | if self.quant_mode == 'layer_wise':
39 | self.delta = nn.Parameter(torch.tensor(0.0), requires_grad=True)
40 |
41 | elif self.quant_mode == 'kernel_wise':
42 | assert kwargs['num_channels'] > 1
43 | if self.layer_type == 'conv':
44 | shape = [1, kwargs['num_channels'], 1, 1] if self.is_activation else [kwargs['num_channels'], 1, 1, 1]
45 | self.delta = nn.Parameter(torch.Tensor(*shape), requires_grad=True)
46 | else:
47 | shape = [1, kwargs['num_channels']] if self.is_activation else [kwargs['num_channels'], 1]
48 | self.delta = nn.Parameter(torch.Tensor(*shape), requires_grad=True)
49 |
50 | self.kwargs = kwargs
51 | self.register_buffer('init_state', torch.tensor(STATUS.NOT_READY))
52 | self.register_buffer('min_val', torch.tensor(0.0, dtype=torch.float))
53 | self.register_buffer('max_val', torch.tensor(2**(self.bit) - 1, dtype=torch.float))
54 |
55 |
56 | def set_init_state(self, value):
57 | self.init_state.fill_(value)
58 |
59 | def initialization(self, x):
60 | if self.is_activation:
61 | if self.quant_mode == 'kernel_wise':
62 | if self.layer_type == 'conv':
63 | _meanx = (x.detach()**2).view(x.shape[0], -1, x.shape[2] * x.shape[3]).mean(2, True).mean(0, True).view(1, -1, 1, 1)
64 |
65 | elif self.layer_type == 'linear':
66 | _meanx = (x.detach()**2).mean(1, True).mean(0, True).view(1, 1)
67 |
68 | _meanx[_meanx==0] = _meanx[_meanx!=0].min()
69 | pre_relu_std = ((2*_meanx))**0.5
70 | else:
71 | pre_relu_std = (2*((x.detach()**2).mean()))**0.5
72 | self.delta.data.copy_(torch.max(self.delta.data, pre_relu_std * self.delta_positive_normal[self.bit]))
73 |
74 | else:
75 |
76 | if self.quant_mode == 'kernel_wise':
77 | if self.layer_type == 'conv':
78 | std = x.detach().view(-1, x.shape[1] * x.shape[2] * x.shape[3]).std(1, True).view(-1, 1, 1, 1)
79 | if self.layer_type == 'linear':
80 | std = x.detach().view(-1, x.shape[1]).std(1, True).view(-1, 1)
81 | else:
82 | std = x.detach().std()
83 | self.delta.data.copy_( std * self.delta_normal[self.bit])
84 |
85 | def forward(self, x):
86 | if self.training and self.init_state == STATUS.INIT_READY:
87 | self.initialization(x)
88 |
89 | # Quantization
90 | if self.is_activation:
91 | if self.quant_mode == 'kernel_wise':
92 | g = 1.0 / math.sqrt((x.numel() / x.shape[1]) * (2**self.bit -1))
93 | else:
94 | g = 1.0 / math.sqrt(x.numel() * (2**self.bit -1))
95 |
96 | step_size = grad_scale(self.delta, g)
97 | x = x / step_size
98 | x = round_pass(torch.min(torch.max(x, self.min_val), self.max_val)) * step_size
99 | else:
100 |
101 | if self.quant_mode== 'kernel_wise':
102 | g = 1.0 / math.sqrt((x.numel() / x.shape[0]) * max((2**(self.bit-1) -1),1))
103 | else:
104 | g = 1.0 / math.sqrt(x.numel() * max((2**(self.bit-1) -1),1))
105 |
106 | step_size = grad_scale(self.delta, g)
107 | alpha = step_size * self.max_val * 0.5
108 | x = (x + alpha) / step_size
109 | x = round_pass(torch.min(torch.max(x, self.min_val), self.max_val)) * step_size - alpha
110 |
111 | return x
112 |
113 | def extra_repr(self):
114 | return "bit=%s, is_activation=%s, quant_mode=%s" % \
115 | (self.bit, self.is_activation, self.kwargs.get('quant_mode', 'layer_wise'))
116 |
117 |
118 |
119 | class UniQConv2d(nn.Conv2d):
120 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
121 | padding=0, dilation=1, groups=1, bias=True, bit=4, quant_mode='layer_wise'):
122 |
123 |
124 | super(UniQConv2d, self).__init__(
125 | in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
126 | stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
127 |
128 | # use per-channel quantization (optinal) for weights only.
129 | self.quan_w = UniQQuantizer(bit=bit, is_activation=False, quant_mode=quant_mode, num_channels=out_channels)
130 | self.quan_a = UniQQuantizer(bit=bit, is_activation=True, quant_mode='layer_wise', num_channels=in_channels)
131 | self.bit = bit
132 |
133 | def forward(self, x):
134 | if self.bit == 32:
135 | return F.conv2d(x, self.weight, self.bias, self.stride,
136 | self.padding, self.dilation, self.groups)
137 | else:
138 | return F.conv2d(self.quan_a(x), self.quan_w(self.weight), self.bias, self.stride,
139 | self.padding, self.dilation, self.groups)
140 |
141 | class UniQInputConv2d(nn.Conv2d):
142 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
143 | padding=0, dilation=1, groups=1, bias=True, bit=4, quant_mode='layer_wise'):
144 |
145 |
146 | super(UniQInputConv2d, self).__init__(
147 | in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
148 | stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
149 |
150 | #always use `layer_wise` for the first layer
151 | self.quan_w = UniQQuantizer(bit=bit, is_activation=False, quant_mode=quant_mode, num_channels=out_channels)
152 | self.quan_a = UniQQuantizer(bit=bit, is_activation=False, quant_mode='layer_wise', num_channels=in_channels)
153 | self.bit = bit
154 |
155 | def forward(self, x):
156 | if self.bit == 32:
157 | return F.conv2d(x, self.weight, self.bias, self.stride,
158 | self.padding, self.dilation, self.groups)
159 | else:
160 | return F.conv2d(self.quan_a(x), self.quan_w(self.weight), self.bias, self.stride,
161 | self.padding, self.dilation, self.groups)
162 |
163 |
164 | class UniQLinear(nn.Linear):
165 | def __init__(self, in_features, out_features, bias=True, bit=4, quant_mode='layer_wise'):
166 |
167 | super(UniQLinear, self).__init__(in_features=in_features, out_features=out_features, bias=bias)
168 |
169 | #always use `layer_wise` for the last layer
170 | self.quan_w = UniQQuantizer(bit=bit, is_activation=False, quant_mode=quant_mode, num_channels=out_features, layer_type='linear')
171 | self.quan_a = UniQQuantizer(bit=bit, is_activation=True, quant_mode='layer_wise', num_channels=in_features, layer_type='linear')
172 | self.bit = bit
173 |
174 | def forward(self, x):
175 | if self.bit == 32:
176 | return F.linear(x, self.weight, self.bias)
177 | else:
178 | return F.linear(self.quan_a(x), self.quan_w(self.weight), self.bias)
179 |
--------------------------------------------------------------------------------
/models/cifar100_presnet.py:
--------------------------------------------------------------------------------
1 | '''
2 | resnet for cifar in pytorch
3 |
4 | Reference:
5 | [1] K. He, X. Zhang, S. Ren, and J. Sun. Deep residual learning for image recognition. In CVPR, 2016.
6 | [2] K. He, X. Zhang, S. Ren, and J. Sun. Identity mappings in deep residual networks. In ECCV, 2016.
7 | '''
8 |
9 | import torch
10 | import torch.nn as nn
11 | import math
12 |
13 |
14 | def conv3x3(in_planes, out_planes, stride=1):
15 | " 3x3 convolution with padding "
16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
17 |
18 |
19 | class BasicBlock(nn.Module):
20 | expansion=1
21 |
22 | def __init__(self, inplanes, planes, stride=1, downsample=None):
23 | super(BasicBlock, self).__init__()
24 | self.conv1 = conv3x3(inplanes, planes, stride)
25 | self.bn1 = nn.BatchNorm2d(planes)
26 | self.relu = nn.ReLU(inplace=True)
27 | self.conv2 = conv3x3(planes, planes)
28 | self.bn2 = nn.BatchNorm2d(planes)
29 | self.downsample = downsample
30 | self.stride = stride
31 |
32 | def forward(self, x):
33 | residual = x
34 |
35 | out = self.conv1(x)
36 | out = self.bn1(out)
37 | out = self.relu(out)
38 |
39 | out = self.conv2(out)
40 | out = self.bn2(out)
41 |
42 | if self.downsample is not None:
43 | residual = self.downsample(x)
44 |
45 | out += residual
46 | out = self.relu(out)
47 |
48 | return out
49 |
50 |
51 | class Bottleneck(nn.Module):
52 | expansion=4
53 |
54 | def __init__(self, inplanes, planes, stride=1, downsample=None):
55 | super(Bottleneck, self).__init__()
56 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
57 | self.bn1 = nn.BatchNorm2d(planes)
58 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
59 | self.bn2 = nn.BatchNorm2d(planes)
60 | self.conv3 = nn.Conv2d(planes, planes*4, kernel_size=1, bias=False)
61 | self.bn3 = nn.BatchNorm2d(planes*4)
62 | self.relu = nn.ReLU(inplace=True)
63 | self.downsample = downsample
64 | self.stride = stride
65 |
66 | def forward(self, x):
67 | residual = x
68 |
69 | out = self.conv1(x)
70 | out = self.bn1(out)
71 | out = self.relu(out)
72 |
73 | out = self.conv2(out)
74 | out = self.bn2(out)
75 | out = self.relu(out)
76 |
77 | out = self.conv3(out)
78 | out = self.bn3(out)
79 |
80 | if self.downsample is not None:
81 | residual = self.downsample(x)
82 |
83 | out += residual
84 | out = self.relu(out)
85 |
86 | return out
87 |
88 |
89 | class PreActBasicBlock(nn.Module):
90 | expansion = 1
91 |
92 | def __init__(self, inplanes, planes, stride=1, downsample=None):
93 | super(PreActBasicBlock, self).__init__()
94 | self.bn1 = nn.BatchNorm2d(inplanes)
95 | self.relu = nn.ReLU(inplace=True)
96 | self.conv1 = conv3x3(inplanes, planes, stride)
97 | self.bn2 = nn.BatchNorm2d(planes)
98 | self.conv2 = conv3x3(planes, planes)
99 | self.downsample = downsample
100 | self.stride = stride
101 |
102 | def forward(self, x):
103 | residual = x
104 |
105 | out = self.bn1(x)
106 | out = self.relu(out)
107 |
108 | if self.downsample is not None:
109 | residual = self.downsample(out)
110 |
111 | out = self.conv1(out)
112 |
113 | out = self.bn2(out)
114 | out = self.relu(out)
115 | out = self.conv2(out)
116 |
117 | out += residual
118 |
119 | return out
120 |
121 |
122 | class PreActBottleneck(nn.Module):
123 | expansion = 4
124 |
125 | def __init__(self, inplanes, planes, stride=1, downsample=None):
126 | super(PreActBottleneck, self).__init__()
127 | self.bn1 = nn.BatchNorm2d(inplanes)
128 | self.relu = nn.ReLU(inplace=True)
129 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
130 | self.bn2 = nn.BatchNorm2d(planes)
131 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
132 | self.bn3 = nn.BatchNorm2d(planes)
133 | self.conv3 = nn.Conv2d(planes, planes*4, kernel_size=1, bias=False)
134 | self.downsample = downsample
135 | self.stride = stride
136 |
137 | def forward(self, x):
138 | residual = x
139 |
140 | out = self.bn1(x)
141 | out = self.relu(out)
142 |
143 | if self.downsample is not None:
144 | residual = self.downsample(out)
145 |
146 | out = self.conv1(out)
147 |
148 | out = self.bn2(out)
149 | out = self.relu(out)
150 | out = self.conv2(out)
151 |
152 | out = self.bn3(out)
153 | out = self.relu(out)
154 | out = self.conv3(out)
155 |
156 | out += residual
157 |
158 | return out
159 |
160 |
161 | class ResNet_Cifar(nn.Module):
162 |
163 | def __init__(self, block, layers, num_classes=10):
164 | super(ResNet_Cifar, self).__init__()
165 | self.inplanes = 16
166 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
167 | self.bn1 = nn.BatchNorm2d(16)
168 | self.relu = nn.ReLU(inplace=True)
169 | self.layer1 = self._make_layer(block, 16, layers[0])
170 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
171 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2)
172 | self.avgpool = nn.AvgPool2d(8, stride=1)
173 | self.fc = nn.Linear(64 * block.expansion, num_classes)
174 |
175 | for m in self.modules():
176 | if isinstance(m, nn.Conv2d):
177 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
178 | m.weight.data.normal_(0, math.sqrt(2. / n))
179 | elif isinstance(m, nn.BatchNorm2d):
180 | m.weight.data.fill_(1)
181 | m.bias.data.zero_()
182 |
183 | def _make_layer(self, block, planes, blocks, stride=1):
184 | downsample = None
185 | if stride != 1 or self.inplanes != planes * block.expansion:
186 | downsample = nn.Sequential(
187 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
188 | nn.BatchNorm2d(planes * block.expansion)
189 | )
190 |
191 | layers = []
192 | layers.append(block(self.inplanes, planes, stride, downsample))
193 | self.inplanes = planes * block.expansion
194 | for _ in range(1, blocks):
195 | layers.append(block(self.inplanes, planes))
196 |
197 | return nn.Sequential(*layers)
198 |
199 | def forward(self, x):
200 | x = self.conv1(x)
201 | x = self.bn1(x)
202 | x = self.relu(x)
203 |
204 | x = self.layer1(x)
205 | x = self.layer2(x)
206 | x = self.layer3(x)
207 |
208 | x = self.avgpool(x)
209 | x = x.view(x.size(0), -1)
210 | x = self.fc(x)
211 |
212 | return x
213 |
214 |
215 | class PreAct_ResNet_Cifar(nn.Module):
216 |
217 | def __init__(self, block, layers, num_classes=10):
218 | super(PreAct_ResNet_Cifar, self).__init__()
219 | self.inplanes = 16
220 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
221 | self.layer1 = self._make_layer(block, 16, layers[0])
222 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
223 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2)
224 | self.bn = nn.BatchNorm2d(64*block.expansion)
225 | self.relu = nn.ReLU(inplace=True)
226 | self.avgpool = nn.AvgPool2d(8, stride=1)
227 | self.fc = nn.Linear(64*block.expansion, num_classes)
228 |
229 | for m in self.modules():
230 | if isinstance(m, nn.Conv2d):
231 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
232 | m.weight.data.normal_(0, math.sqrt(2. / n))
233 | elif isinstance(m, nn.BatchNorm2d):
234 | m.weight.data.fill_(1)
235 | m.bias.data.zero_()
236 |
237 | def _make_layer(self, block, planes, blocks, stride=1):
238 | downsample = None
239 | if stride != 1 or self.inplanes != planes*block.expansion:
240 | downsample = nn.Sequential(
241 | nn.Conv2d(self.inplanes, planes*block.expansion, kernel_size=1, stride=stride, bias=False)
242 | )
243 |
244 | layers = []
245 | layers.append(block(self.inplanes, planes, stride, downsample))
246 | self.inplanes = planes*block.expansion
247 | for _ in range(1, blocks):
248 | layers.append(block(self.inplanes, planes))
249 | return nn.Sequential(*layers)
250 |
251 | def forward(self, x):
252 | x = self.conv1(x)
253 |
254 | x = self.layer1(x)
255 | x = self.layer2(x)
256 | x = self.layer3(x)
257 |
258 | x = self.bn(x)
259 | x = self.relu(x)
260 | x = self.avgpool(x)
261 | x = x.view(x.size(0), -1)
262 | x = self.fc(x)
263 |
264 | return x
265 |
266 |
267 |
268 | def resnet20_cifar(**kwargs):
269 | model = ResNet_Cifar(BasicBlock, [3, 3, 3], **kwargs)
270 | return model
271 |
272 |
273 | def resnet32_cifar(**kwargs):
274 | model = ResNet_Cifar(BasicBlock, [5, 5, 5], **kwargs)
275 | return model
276 |
277 |
278 | def resnet44_cifar(**kwargs):
279 | model = ResNet_Cifar(BasicBlock, [7, 7, 7], **kwargs)
280 | return model
281 |
282 |
283 | def resnet56_cifar(**kwargs):
284 | model = ResNet_Cifar(BasicBlock, [9, 9, 9], **kwargs)
285 | return model
286 |
287 |
288 | def resnet110_cifar(**kwargs):
289 | model = ResNet_Cifar(BasicBlock, [18, 18, 18], **kwargs)
290 | return model
291 |
292 |
293 | def resnet1202_cifar(**kwargs):
294 | model = ResNet_Cifar(BasicBlock, [200, 200, 200], **kwargs)
295 | return model
296 |
297 |
298 | def resnet164_cifar(**kwargs):
299 | model = ResNet_Cifar(Bottleneck, [18, 18, 18], **kwargs)
300 | return model
301 |
302 |
303 | def resnet1001_cifar(**kwargs):
304 | model = ResNet_Cifar(Bottleneck, [111, 111, 111], **kwargs)
305 | return model
306 |
307 | def preact_resnet20_cifar(**kwargs):
308 | model = PreAct_ResNet_Cifar(PreActBasicBlock, [3, 3, 3], **kwargs)
309 | return model
310 |
311 | def preact_resnet32_cifar(**kwargs):
312 | model = PreAct_ResNet_Cifar(PreActBasicBlock, [5, 5, 5], **kwargs)
313 | return model
314 |
315 | def preact_resnet110_cifar(**kwargs):
316 | model = PreAct_ResNet_Cifar(PreActBasicBlock, [18, 18, 18], **kwargs)
317 | return model
318 |
319 |
320 | def preact_resnet164_cifar(**kwargs):
321 | model = PreAct_ResNet_Cifar(PreActBottleneck, [18, 18, 18], **kwargs)
322 | return model
323 |
324 |
325 | def preact_resnet1001_cifar(**kwargs):
326 | model = PreAct_ResNet_Cifar(PreActBottleneck, [111, 111, 111], **kwargs)
327 | return model
328 |
329 |
330 | if __name__ == '__main__':
331 | net = resnet20_cifar()
332 | y = net(torch.randn(1, 3, 64, 64))
333 | print(net)
334 | print(y.size())
335 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | import torch.nn.functional as F
5 | import torch.backends.cudnn as cudnn
6 |
7 | import torchvision
8 | import torchvision.transforms as transforms
9 |
10 | import os
11 | import time
12 | import math
13 | import argparse
14 | import warnings
15 | import numpy as np
16 |
17 | from functools import partial
18 | from torch.utils.tensorboard import SummaryWriter
19 | from monitors.metrics import write_metrics
20 |
21 | import lr_scheduler
22 | import utils
23 |
24 | from models.imagenet_presnet import PreActResNet18
25 | from models.glouncv.alexnet import alexnet
26 | from models.glouncv.preresnet import preresnet34
27 | from models.glouncv.mobilenetv2 import mobilenetv2_w1
28 | from models.cifar100_presnet import preact_resnet32_cifar
29 |
30 |
31 | parser = argparse.ArgumentParser(description='PyTorch ImageNet/CIFAR Training')
32 | parser.add_argument('--lr', default=0.1, type=float, help='Main learning rate')
33 | parser.add_argument('--warmup_lr', default=0.001, type=float, help='Warmup learning rate')
34 |
35 | parser.add_argument('--wd', default=1e-4, type=float, help='weight decay')
36 | parser.add_argument('--bit', default=4, type=int, help='bit-width for UniQ quantizer')
37 |
38 | parser.add_argument('--dataset', default='imagenette', type=str,
39 | help='dataset name for training')
40 | parser.add_argument('--data_root', default = '/soc_local/data/pytorch/imagenet/', type=str,
41 | help='path to dataset')
42 | parser.add_argument('-b', '--batch-size', default=256, type=int,
43 | metavar='N',
44 | help='mini-batch size (default: 256), this is the total '
45 | 'batch size of all GPUs on the current node when '
46 | 'using Data Parallel or Distributed Data Parallel')
47 | parser.add_argument('--arch', default='resnet18', type=str,
48 | choices=['presnet18', 'presnet32', 'glouncv-presnet34', 'glouncv-mobilenetv2_w1'],
49 | help='network architecture')
50 |
51 | parser.add_argument('--init_from', type=str,
52 | help='init weights from from checkpoint')
53 |
54 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
55 | help='evaluate model on validation set')
56 | parser.add_argument('--epochs', default=120, type=int, help='number of training epochs')
57 |
58 | parser.add_argument('--train_id', type=str, default= 'train-01',
59 | help='training id, is used for collect experiment results')
60 |
61 | parser.add_argument('--train_scheme', type=str, default= 'fp32', choices=['fp32', 'uniq'],
62 | help='Training scheme')
63 |
64 | parser.add_argument('--optimizer', type=str, default= 'sgd', choices=['sgd', 'adam'],
65 | help='Optimizer selection.')
66 |
67 | parser.add_argument('--output_dir', type=str, default= 'outputs',
68 | help='output directory')
69 |
70 | parser.add_argument('--print_freq', default=10, type=int, help='log print frequency.')
71 |
72 |
73 | parser.add_argument('--quant_mode', type=str, default= 'layer_wise', choices=['layer_wise', 'kernel_wise'],
74 | help='Quantization mode')
75 |
76 | parser.add_argument('--num_calibration_batches', default=100, type=int, help='number of calibration training batches')
77 |
78 | parser.add_argument('--enable_warmup', dest='enable_warmup', action='store_true',
79 | help='Enable warm-up learning rate.')
80 |
81 | parser.add_argument('--warmup_epochs', default=5, type=int, help='number of epochs for warm-up')
82 |
83 | parser.add_argument('--dropout_ratio', default=0.1, type=float, help='dropout ratio for AlexNet.')
84 |
85 |
86 | args = parser.parse_args()
87 | print ("Script arguments:\n", args)
88 |
89 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
90 | best_acc = 0
91 | start_epoch = 0
92 | working_dir = os.path.join(args.output_dir, args.train_id)
93 | os.makedirs(working_dir, exist_ok=True)
94 | writer = SummaryWriter(working_dir)
95 |
96 |
97 | # Setup data.
98 | print('==> Preparing data..')
99 | trainloader, testloader = utils.get_dataloaders(dataset=args.dataset, batch_size=args.batch_size, data_root=args.data_root)
100 |
101 | # Setup model
102 | # ----------------------------------------
103 | print('==> Building model..')
104 | if args.dataset == "imagenet":
105 | models = {
106 | 'presnet18': PreActResNet18,
107 | 'glouncv-alexnet': alexnet,
108 | 'glouncv-presnet34': preresnet34,
109 | 'glouncv-mobilenetv2_w1': mobilenetv2_w1
110 | }
111 | net = models.get(args.arch, None)()
112 |
113 | elif args.dataset == "cifar100":
114 | assert args.arch == "presnet32"
115 | net = preact_resnet32_cifar(num_classes=100)
116 |
117 | assert net != None
118 |
119 |
120 |
121 | # Module replacement
122 | # ---------------------------------
123 | if args.train_scheme.startswith("uniq"):
124 | from quantizer.uniq import UniQConv2d, UniQInputConv2d, UniQLinear
125 | if args.bit > 1:
126 | replacement_dict = {
127 | nn.Conv2d : partial(UniQConv2d, bit=args.bit, quant_mode=args.quant_mode),
128 | nn.Linear: partial(UniQLinear, bit=args.bit, quant_mode=args.quant_mode)}
129 | exception_dict = {
130 | '__first__': partial(UniQInputConv2d, bit=8),
131 | '__last__': partial(UniQLinear, bit=8),
132 | }
133 |
134 | if args.arch == "glouncv-mobilenetv2_w1":
135 | exception_dict['__last__'] = partial(UniQConv2d, bit=8)
136 | net = utils.replace_module(net, replacement_dict=replacement_dict, exception_dict=exception_dict, arch=args.arch)
137 |
138 | else:
139 | # All settings for binary neural networks.
140 | assert args.wd == 0
141 | replacement_dict = {nn.Conv2d : partial(UniQConv2d, bit=1, quant_mode=args.quant_mode),
142 | nn.Linear: partial(UniQLinear, bit=1, quant_mode=args.quant_mode) }
143 | exception_dict = {
144 | '__first__': partial(UniQInputConv2d, bit=32),
145 | '__last__': partial(UniQLinear, bit=32),
146 | '__downsampling__': partial(UniQConv2d, bit=32, quant_mode=args.quant_mode)
147 | }
148 |
149 | if args.arch == "glouncv-mobilenetv2_w1":
150 | exception_dict['__last__'] = partial(UniQConv2d, bit=32)
151 | net = utils.replace_module(net, replacement_dict=replacement_dict, exception_dict=exception_dict, arch=args.arch)
152 |
153 | # The following part is used for dropout ratio modification.
154 | if args.arch.startswith("glouncv-alexnet"):
155 | net.output.fc1.dropout = nn.Dropout(p=args.dropout_ratio, inplace=False)
156 | net.output.fc2.dropout = nn.Dropout(p=args.dropout_ratio, inplace=False)
157 |
158 |
159 |
160 | net = net.to(device)
161 | if device == 'cuda':
162 | net = torch.nn.DataParallel(net)
163 | cudnn.benchmark = True
164 |
165 | print (net)
166 | print ("Number of learnable parameters: ", sum(p.numel() for p in net.parameters() if p.requires_grad) / 1e6, "M")
167 | time.sleep(5)
168 |
169 |
170 |
171 | # Loading checkpoint
172 | # -----------------------------
173 | if args.init_from and os.path.isfile(args.init_from):
174 | print('==> Initializing from checkpoint: ', args.init_from)
175 | checkpoint = torch.load(args.init_from)
176 | loaded_params = {}
177 | for k,v in checkpoint['net'].items():
178 | if not k.startswith("module."):
179 | loaded_params["module." + k] = v
180 | else:
181 | loaded_params[k] = v
182 |
183 | net_state_dict = net.state_dict()
184 | net_state_dict.update(loaded_params)
185 | net.load_state_dict(net_state_dict)
186 | else:
187 | warnings.warn("No checkpoint file is provided !!!")
188 |
189 |
190 |
191 | params = utils.add_weight_decay(net, weight_decay=args.wd, skip_keys=['delta', 'alpha'])
192 | criterion = nn.CrossEntropyLoss()
193 |
194 | # Setup optimizer
195 | # ----------------------------
196 | if args.optimizer == 'sgd':
197 | print ("==> Use SGD optimizer")
198 | optimizer = optim.SGD(params, lr=args.lr,
199 | momentum=0.9, weight_decay=args.wd)
200 | elif args.optimizer == 'adam':
201 | print ("==> Use Adam optimizer")
202 | optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.wd)
203 |
204 |
205 | # Setup LR scheduler
206 | # ----------------------------
207 | if args.enable_warmup:
208 | lr_scheduler = lr_scheduler.ConstantWarmupScheduler(optimizer=optimizer, min_lr=args.warmup_lr, total_epoch=args.warmup_epochs, after_lr=args.lr,
209 | after_scheduler=optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs - args.warmup_epochs))
210 | else:
211 | lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs )
212 |
213 |
214 |
215 | def train(epoch, ):
216 | global args
217 |
218 | print('\nEpoch: %d' % epoch)
219 | net.train()
220 | train_loss = 0
221 | correct = 0
222 | total = 0
223 |
224 | for batch_idx, (inputs, targets) in enumerate(trainloader):
225 | inputs, targets = inputs.to(device), targets.to(device)
226 | optimizer.zero_grad()
227 | outputs = net(inputs)
228 | loss = criterion(outputs, targets)
229 | loss.backward()
230 | optimizer.step()
231 | train_loss += loss.item()
232 | _, predicted = outputs.max(1)
233 | total += targets.size(0)
234 | correct += predicted.eq(targets).sum().item()
235 |
236 | if batch_idx % args.print_freq == 0:
237 | print ("[Train] Epoch=", epoch, " BatchID=", batch_idx, 'Loss: %.3f | Acc: %.3f%% (%d/%d)' \
238 | % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
239 |
240 | return (train_loss/batch_idx, correct/total)
241 |
242 | def test(epoch):
243 | global best_acc, args
244 |
245 | net.eval()
246 | test_loss = 0
247 | correct = 0
248 | total = 0
249 | with torch.no_grad():
250 | for batch_idx, (inputs, targets) in enumerate(testloader):
251 | inputs, targets = inputs.to(device), targets.to(device)
252 | outputs = net(inputs)
253 | loss = criterion(outputs, targets)
254 | test_loss += loss.item()
255 | _, predicted = outputs.max(1)
256 | total += targets.size(0)
257 | correct += predicted.eq(targets).sum().item()
258 |
259 | if batch_idx % args.print_freq == 0:
260 | print ("[Test] Epoch=", epoch, " BatchID=", batch_idx, 'Loss: %.3f | Acc: %.3f%% (%d/%d)' \
261 | % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
262 |
263 | # Save checkpoint.
264 | acc = 100.*correct/total
265 | if acc > best_acc:
266 | best_acc = acc
267 | utils.save_checkpoint(net, lr_scheduler, optimizer, acc, epoch,
268 | filename=os.path.join(working_dir, 'ckpt_best.pth'))
269 | print('Saving..')
270 | print ('Best accuracy: ', best_acc)
271 |
272 | return (test_loss/batch_idx, correct/total)
273 |
274 |
275 | def simple_initialization(num_batches=100):
276 | net.train()
277 | from quantizer.uniq import STATUS, UniQConv2d, UniQInputConv2d, UniQLinear
278 | for n, m in net.named_modules():
279 | if isinstance(m, UniQConv2d) or isinstance(m, UniQInputConv2d) or isinstance(m, UniQLinear):
280 | assert getattr(m, 'quan_a', None) != None
281 | assert getattr(m, 'quan_w', None) != None
282 | m.quan_a.set_init_state(STATUS.INIT_READY)
283 | m.quan_w.set_init_state(STATUS.INIT_READY)
284 |
285 |
286 | for batch_idx, (inputs, _) in enumerate(trainloader):
287 | inputs = inputs.to(device)
288 | net(inputs)
289 | if batch_idx + 1 == num_batches: break
290 |
291 | for n, m in net.named_modules():
292 | if isinstance(m, UniQConv2d) or isinstance(m, UniQInputConv2d) or isinstance(m, UniQLinear):
293 | assert getattr(m, 'quan_a', None) != None
294 | assert getattr(m, 'quan_w', None) != None
295 | m.quan_a.set_init_state(STATUS.INIT_DONE)
296 | m.quan_w.set_init_state(STATUS.INIT_DONE)
297 |
298 |
299 |
300 | if args.evaluate:
301 | print ("==> Start evaluating ...")
302 | test(-1)
303 | exit()
304 |
305 |
306 |
307 | # Main training
308 | # -----------------------------------------------
309 | # Reset to 'warmup_lr' if we are using warmup strategy.
310 | if args.enable_warmup:
311 | assert args.bit == 1
312 | for param_group in optimizer.param_groups:
313 | param_group['lr'] = args.warmup_lr
314 |
315 | # Initialization
316 | # ------------------------------------------------
317 | if args.bit != 32 and args.train_scheme in ["uniq", ]:
318 | simple_initialization(num_batches=args.num_calibration_batches)
319 |
320 | # Training
321 | # -----------------------------------------------
322 | for epoch in range(start_epoch, args.epochs):
323 | train_loss, train_acc1 = train(epoch)
324 | test_loss, test_acc1 = test(epoch)
325 |
326 | if lr_scheduler is not None:
327 | lr_scheduler.step()
328 |
329 | write_metrics(writer, epoch, net, \
330 | optimizer, train_loss, train_acc1, test_loss, test_acc1, prefix="Standard_Training")
331 |
--------------------------------------------------------------------------------
/models/glouncv/mobilenetv2.py:
--------------------------------------------------------------------------------
1 | """
2 | MobileNetV2 for ImageNet-1K, implemented in PyTorch.
3 | Original paper: 'MobileNetV2: Inverted Residuals and Linear Bottlenecks,' https://arxiv.org/abs/1801.04381.
4 | """
5 |
6 | __all__ = ['MobileNetV2', 'mobilenetv2_w1', 'mobilenetv2_w3d4', 'mobilenetv2_wd2', 'mobilenetv2_wd4', 'mobilenetv2b_w1',
7 | 'mobilenetv2b_w3d4', 'mobilenetv2b_wd2', 'mobilenetv2b_wd4']
8 |
9 | import os
10 | import torch.nn as nn
11 | import torch.nn.init as init
12 | from .common import conv1x1, conv1x1_block, conv3x3_block, dwconv3x3_block
13 |
14 |
15 | class LinearBottleneck(nn.Module):
16 | """
17 | So-called 'Linear Bottleneck' layer. It is used as a MobileNetV2 unit.
18 | Parameters:
19 | ----------
20 | in_channels : int
21 | Number of input channels.
22 | out_channels : int
23 | Number of output channels.
24 | stride : int or tuple/list of 2 int
25 | Strides of the second convolution layer.
26 | expansion : bool
27 | Whether do expansion of channels.
28 | remove_exp_conv : bool
29 | Whether to remove expansion convolution.
30 | """
31 | def __init__(self,
32 | in_channels,
33 | out_channels,
34 | stride,
35 | expansion,
36 | remove_exp_conv):
37 | super(LinearBottleneck, self).__init__()
38 | self.residual = (in_channels == out_channels) and (stride == 1)
39 | mid_channels = in_channels * 6 if expansion else in_channels
40 | self.use_exp_conv = (expansion or (not remove_exp_conv))
41 |
42 | if self.use_exp_conv:
43 | self.conv1 = conv1x1_block(
44 | in_channels=in_channels,
45 | out_channels=mid_channels,
46 | activation="relu6")
47 | self.conv2 = dwconv3x3_block(
48 | in_channels=mid_channels,
49 | out_channels=mid_channels,
50 | stride=stride,
51 | activation="relu6")
52 | self.conv3 = conv1x1_block(
53 | in_channels=mid_channels,
54 | out_channels=out_channels,
55 | activation=None)
56 |
57 | def forward(self, x):
58 | if self.residual:
59 | identity = x
60 | if self.use_exp_conv:
61 | x = self.conv1(x)
62 | x = self.conv2(x)
63 | x = self.conv3(x)
64 | if self.residual:
65 | x = x + identity
66 | return x
67 |
68 |
69 | class MobileNetV2(nn.Module):
70 | """
71 | MobileNetV2 model from 'MobileNetV2: Inverted Residuals and Linear Bottlenecks,' https://arxiv.org/abs/1801.04381.
72 | Parameters:
73 | ----------
74 | channels : list of list of int
75 | Number of output channels for each unit.
76 | init_block_channels : int
77 | Number of output channels for the initial unit.
78 | final_block_channels : int
79 | Number of output channels for the final block of the feature extractor.
80 | remove_exp_conv : bool
81 | Whether to remove expansion convolution.
82 | in_channels : int, default 3
83 | Number of input channels.
84 | in_size : tuple of two ints, default (224, 224)
85 | Spatial size of the expected input image.
86 | num_classes : int, default 1000
87 | Number of classification classes.
88 | """
89 | def __init__(self,
90 | channels,
91 | init_block_channels,
92 | final_block_channels,
93 | remove_exp_conv,
94 | in_channels=3,
95 | in_size=(224, 224),
96 | num_classes=1000):
97 | super(MobileNetV2, self).__init__()
98 | self.in_size = in_size
99 | self.num_classes = num_classes
100 |
101 | self.features = nn.Sequential()
102 | self.features.add_module("init_block", conv3x3_block(
103 | in_channels=in_channels,
104 | out_channels=init_block_channels,
105 | stride=2,
106 | activation="relu6"))
107 | in_channels = init_block_channels
108 | for i, channels_per_stage in enumerate(channels):
109 | stage = nn.Sequential()
110 | for j, out_channels in enumerate(channels_per_stage):
111 | stride = 2 if (j == 0) and (i != 0) else 1
112 | expansion = (i != 0) or (j != 0)
113 | stage.add_module("unit{}".format(j + 1), LinearBottleneck(
114 | in_channels=in_channels,
115 | out_channels=out_channels,
116 | stride=stride,
117 | expansion=expansion,
118 | remove_exp_conv=remove_exp_conv))
119 | in_channels = out_channels
120 | self.features.add_module("stage{}".format(i + 1), stage)
121 | self.features.add_module("final_block", conv1x1_block(
122 | in_channels=in_channels,
123 | out_channels=final_block_channels,
124 | activation="relu6"))
125 | in_channels = final_block_channels
126 | self.features.add_module("final_pool", nn.AvgPool2d(
127 | kernel_size=7,
128 | stride=1))
129 |
130 | self.output = conv1x1(
131 | in_channels=in_channels,
132 | out_channels=num_classes,
133 | bias=False)
134 |
135 | self._init_params()
136 |
137 | def _init_params(self):
138 | for name, module in self.named_modules():
139 | if isinstance(module, nn.Conv2d):
140 | init.kaiming_uniform_(module.weight)
141 | if module.bias is not None:
142 | init.constant_(module.bias, 0)
143 |
144 | def forward(self, x):
145 | x = self.features(x)
146 | x = self.output(x)
147 | x = x.view(x.size(0), -1)
148 | return x
149 |
150 |
151 | def get_mobilenetv2(width_scale,
152 | remove_exp_conv=False,
153 | model_name=None,
154 | pretrained=False,
155 | root=os.path.join("~", ".torch", "models"),
156 | **kwargs):
157 | """
158 | Create MobileNetV2 model with specific parameters.
159 | Parameters:
160 | ----------
161 | width_scale : float
162 | Scale factor for width of layers.
163 | remove_exp_conv : bool, default False
164 | Whether to remove expansion convolution.
165 | model_name : str or None, default None
166 | Model name for loading pretrained model.
167 | pretrained : bool, default False
168 | Whether to load the pretrained weights for model.
169 | root : str, default '~/.torch/models'
170 | Location for keeping the model parameters.
171 | """
172 |
173 | init_block_channels = 32
174 | final_block_channels = 1280
175 | layers = [1, 2, 3, 4, 3, 3, 1]
176 | downsample = [0, 1, 1, 1, 0, 1, 0]
177 | channels_per_layers = [16, 24, 32, 64, 96, 160, 320]
178 |
179 | from functools import reduce
180 | channels = reduce(
181 | lambda x, y: x + [[y[0]] * y[1]] if y[2] != 0 else x[:-1] + [x[-1] + [y[0]] * y[1]],
182 | zip(channels_per_layers, layers, downsample),
183 | [[]])
184 |
185 | if width_scale != 1.0:
186 | channels = [[int(cij * width_scale) for cij in ci] for ci in channels]
187 | init_block_channels = int(init_block_channels * width_scale)
188 | if width_scale > 1.0:
189 | final_block_channels = int(final_block_channels * width_scale)
190 |
191 | net = MobileNetV2(
192 | channels=channels,
193 | init_block_channels=init_block_channels,
194 | final_block_channels=final_block_channels,
195 | remove_exp_conv=remove_exp_conv,
196 | **kwargs)
197 |
198 | if pretrained:
199 | if (model_name is None) or (not model_name):
200 | raise ValueError("Parameter `model_name` should be properly initialized for loading pretrained model.")
201 | from .model_store import download_model
202 | download_model(
203 | net=net,
204 | model_name=model_name,
205 | local_model_store_dir_path=root)
206 |
207 | return net
208 |
209 |
210 | def mobilenetv2_w1(**kwargs):
211 | """
212 | 1.0 MobileNetV2-224 model from 'MobileNetV2: Inverted Residuals and Linear Bottlenecks,'
213 | https://arxiv.org/abs/1801.04381.
214 | Parameters:
215 | ----------
216 | pretrained : bool, default False
217 | Whether to load the pretrained weights for model.
218 | root : str, default '~/.torch/models'
219 | Location for keeping the model parameters.
220 | """
221 | return get_mobilenetv2(width_scale=1.0, model_name="mobilenetv2_w1", **kwargs)
222 |
223 |
224 | def mobilenetv2_w3d4(**kwargs):
225 | """
226 | 0.75 MobileNetV2-224 model from 'MobileNetV2: Inverted Residuals and Linear Bottlenecks,'
227 | https://arxiv.org/abs/1801.04381.
228 | Parameters:
229 | ----------
230 | pretrained : bool, default False
231 | Whether to load the pretrained weights for model.
232 | root : str, default '~/.torch/models'
233 | Location for keeping the model parameters.
234 | """
235 | return get_mobilenetv2(width_scale=0.75, model_name="mobilenetv2_w3d4", **kwargs)
236 |
237 |
238 | def mobilenetv2_wd2(**kwargs):
239 | """
240 | 0.5 MobileNetV2-224 model from 'MobileNetV2: Inverted Residuals and Linear Bottlenecks,'
241 | https://arxiv.org/abs/1801.04381.
242 | Parameters:
243 | ----------
244 | pretrained : bool, default False
245 | Whether to load the pretrained weights for model.
246 | root : str, default '~/.torch/models'
247 | Location for keeping the model parameters.
248 | """
249 | return get_mobilenetv2(width_scale=0.5, model_name="mobilenetv2_wd2", **kwargs)
250 |
251 |
252 | def mobilenetv2_wd4(**kwargs):
253 | """
254 | 0.25 MobileNetV2-224 model from 'MobileNetV2: Inverted Residuals and Linear Bottlenecks,'
255 | https://arxiv.org/abs/1801.04381.
256 | Parameters:
257 | ----------
258 | pretrained : bool, default False
259 | Whether to load the pretrained weights for model.
260 | root : str, default '~/.torch/models'
261 | Location for keeping the model parameters.
262 | """
263 | return get_mobilenetv2(width_scale=0.25, model_name="mobilenetv2_wd4", **kwargs)
264 |
265 |
266 | def mobilenetv2b_w1(**kwargs):
267 | """
268 | 1.0 MobileNetV2b-224 model from 'MobileNetV2: Inverted Residuals and Linear Bottlenecks,'
269 | https://arxiv.org/abs/1801.04381.
270 | Parameters:
271 | ----------
272 | pretrained : bool, default False
273 | Whether to load the pretrained weights for model.
274 | root : str, default '~/.torch/models'
275 | Location for keeping the model parameters.
276 | """
277 | return get_mobilenetv2(width_scale=1.0, remove_exp_conv=True, model_name="mobilenetv2b_w1", **kwargs)
278 |
279 |
280 | def mobilenetv2b_w3d4(**kwargs):
281 | """
282 | 0.75 MobileNetV2b-224 model from 'MobileNetV2: Inverted Residuals and Linear Bottlenecks,'
283 | https://arxiv.org/abs/1801.04381.
284 | Parameters:
285 | ----------
286 | pretrained : bool, default False
287 | Whether to load the pretrained weights for model.
288 | root : str, default '~/.torch/models'
289 | Location for keeping the model parameters.
290 | """
291 | return get_mobilenetv2(width_scale=0.75, remove_exp_conv=True, model_name="mobilenetv2b_w3d4", **kwargs)
292 |
293 |
294 | def mobilenetv2b_wd2(**kwargs):
295 | """
296 | 0.5 MobileNetV2b-224 model from 'MobileNetV2: Inverted Residuals and Linear Bottlenecks,'
297 | https://arxiv.org/abs/1801.04381.
298 | Parameters:
299 | ----------
300 | pretrained : bool, default False
301 | Whether to load the pretrained weights for model.
302 | root : str, default '~/.torch/models'
303 | Location for keeping the model parameters.
304 | """
305 | return get_mobilenetv2(width_scale=0.5, remove_exp_conv=True, model_name="mobilenetv2b_wd2", **kwargs)
306 |
307 |
308 | def mobilenetv2b_wd4(**kwargs):
309 | """
310 | 0.25 MobileNetV2b-224 model from 'MobileNetV2: Inverted Residuals and Linear Bottlenecks,'
311 | https://arxiv.org/abs/1801.04381.
312 | Parameters:
313 | ----------
314 | pretrained : bool, default False
315 | Whether to load the pretrained weights for model.
316 | root : str, default '~/.torch/models'
317 | Location for keeping the model parameters.
318 | """
319 | return get_mobilenetv2(width_scale=0.25, remove_exp_conv=True, model_name="mobilenetv2b_wd4", **kwargs)
320 |
321 |
322 | def _calc_width(net):
323 | import numpy as np
324 | net_params = filter(lambda p: p.requires_grad, net.parameters())
325 | weight_count = 0
326 | for param in net_params:
327 | weight_count += np.prod(param.size())
328 | return weight_count
329 |
330 |
331 | def _test():
332 | import torch
333 |
334 | pretrained = False
335 |
336 | models = [
337 | mobilenetv2_w1,
338 | mobilenetv2_w3d4,
339 | mobilenetv2_wd2,
340 | mobilenetv2_wd4,
341 | mobilenetv2b_w1,
342 | mobilenetv2b_w3d4,
343 | mobilenetv2b_wd2,
344 | mobilenetv2b_wd4,
345 | ]
346 |
347 | for model in models:
348 |
349 | net = model(pretrained=pretrained)
350 |
351 | # net.train()
352 | net.eval()
353 | weight_count = _calc_width(net)
354 | print("m={}, {}".format(model.__name__, weight_count))
355 | assert (model != mobilenetv2_w1 or weight_count == 3504960)
356 | assert (model != mobilenetv2_w3d4 or weight_count == 2627592)
357 | assert (model != mobilenetv2_wd2 or weight_count == 1964736)
358 | assert (model != mobilenetv2_wd4 or weight_count == 1516392)
359 | assert (model != mobilenetv2b_w1 or weight_count == 3503872)
360 | assert (model != mobilenetv2b_w3d4 or weight_count == 2626968)
361 | assert (model != mobilenetv2b_wd2 or weight_count == 1964448)
362 | assert (model != mobilenetv2b_wd4 or weight_count == 1516312)
363 |
364 | x = torch.randn(1, 3, 224, 224)
365 | y = net(x)
366 | y.sum().backward()
367 | assert (tuple(y.size()) == (1, 1000))
368 |
369 |
370 | if __name__ == "__main__":
371 | _test()
--------------------------------------------------------------------------------
/models/glouncv/alexnet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.nn.init as init
5 | from inspect import isfunction
6 |
7 | class ConvBlock(nn.Module):
8 | """
9 | Standard convolution block with Batch normalization and activation.
10 | Parameters:
11 | ----------
12 | in_channels : int
13 | Number of input channels.
14 | out_channels : int
15 | Number of output channels.
16 | kernel_size : int or tuple/list of 2 int
17 | Convolution window size.
18 | stride : int or tuple/list of 2 int
19 | Strides of the convolution.
20 | padding : int, or tuple/list of 2 int, or tuple/list of 4 int
21 | Padding value for convolution layer.
22 | dilation : int or tuple/list of 2 int, default 1
23 | Dilation value for convolution layer.
24 | groups : int, default 1
25 | Number of groups.
26 | bias : bool, default False
27 | Whether the layer uses a bias vector.
28 | use_bn : bool, default True
29 | Whether to use BatchNorm layer.
30 | bn_eps : float, default 1e-5
31 | Small float added to variance in Batch norm.
32 | activation : function or str or None, default nn.ReLU(inplace=True)
33 | Activation function or name of activation function.
34 | """
35 | def __init__(self,
36 | in_channels,
37 | out_channels,
38 | kernel_size,
39 | stride,
40 | padding,
41 | dilation=1,
42 | groups=1,
43 | bias=False,
44 | use_bn=True,
45 | bn_eps=1e-5,
46 | activation=(lambda: nn.ReLU(inplace=True))):
47 | super(ConvBlock, self).__init__()
48 | self.activate = (activation is not None)
49 | self.use_bn = use_bn
50 | self.use_pad = (isinstance(padding, (list, tuple)) and (len(padding) == 4))
51 |
52 | if self.use_pad:
53 | self.pad = nn.ZeroPad2d(padding=padding)
54 | padding = 0
55 | self.conv = nn.Conv2d(
56 | in_channels=in_channels,
57 | out_channels=out_channels,
58 | kernel_size=kernel_size,
59 | stride=stride,
60 | padding=padding,
61 | dilation=dilation,
62 | groups=groups,
63 | bias=bias)
64 | if self.use_bn:
65 | self.bn = nn.BatchNorm2d(
66 | num_features=out_channels,
67 | eps=bn_eps)
68 | if self.activate:
69 | self.activ = get_activation_layer(activation)
70 |
71 | def forward(self, x):
72 | if self.use_pad:
73 | x = self.pad(x)
74 | x = self.conv(x)
75 | if self.use_bn:
76 | x = self.bn(x)
77 | if self.activate:
78 | x = self.activ(x)
79 | return x
80 |
81 |
82 |
83 | class AlexConv(ConvBlock):
84 | """
85 | AlexNet specific convolution block.
86 | Parameters:
87 | ----------
88 | in_channels : int
89 | Number of input channels.
90 | out_channels : int
91 | Number of output channels.
92 | kernel_size : int or tuple/list of 2 int
93 | Convolution window size.
94 | stride : int or tuple/list of 2 int
95 | Strides of the convolution.
96 | padding : int or tuple/list of 2 int
97 | Padding value for convolution layer.
98 | use_lrn : bool
99 | Whether to use LRN layer.
100 | """
101 | def __init__(self,
102 | in_channels,
103 | out_channels,
104 | kernel_size,
105 | stride,
106 | padding,
107 | use_lrn):
108 | super(AlexConv, self).__init__(
109 | in_channels=in_channels,
110 | out_channels=out_channels,
111 | kernel_size=kernel_size,
112 | stride=stride,
113 | padding=padding,
114 | bias=True,
115 | use_bn=False)
116 | self.use_lrn = use_lrn
117 |
118 | def forward(self, x):
119 | x = super(AlexConv, self).forward(x)
120 | if self.use_lrn:
121 | x = F.local_response_norm(x, size=5, k=2.0)
122 | return x
123 |
124 |
125 | class AlexDense(nn.Module):
126 | """
127 | AlexNet specific dense block.
128 | Parameters:
129 | ----------
130 | in_channels : int
131 | Number of input channels.
132 | out_channels : int
133 | Number of output channels.
134 | """
135 | def __init__(self,
136 | in_channels,
137 | out_channels):
138 | super(AlexDense, self).__init__()
139 | self.fc = nn.Linear(
140 | in_features=in_channels,
141 | out_features=out_channels)
142 | self.activ = nn.ReLU(inplace=True)
143 | self.dropout = nn.Dropout(p=0.5)
144 |
145 | def forward(self, x):
146 | x = self.fc(x)
147 | x = self.activ(x)
148 | x = self.dropout(x)
149 | return x
150 |
151 |
152 | class AlexOutputBlock(nn.Module):
153 | """
154 | AlexNet specific output block.
155 | Parameters:
156 | ----------
157 | in_channels : int
158 | Number of input channels.
159 | classes : int
160 | Number of classification classes.
161 | """
162 | def __init__(self,
163 | in_channels,
164 | classes):
165 | super(AlexOutputBlock, self).__init__()
166 | mid_channels = 4096
167 |
168 | self.fc1 = AlexDense(
169 | in_channels=in_channels,
170 | out_channels=mid_channels)
171 | self.fc2 = AlexDense(
172 | in_channels=mid_channels,
173 | out_channels=mid_channels)
174 | self.fc3 = nn.Linear(
175 | in_features=mid_channels,
176 | out_features=classes)
177 |
178 | def forward(self, x):
179 | x = self.fc1(x)
180 | x = self.fc2(x)
181 | x = self.fc3(x)
182 | return x
183 |
184 |
185 | class AlexNet(nn.Module):
186 | """
187 | AlexNet model from 'One weird trick for parallelizing convolutional neural networks,'
188 | https://arxiv.org/abs/1404.5997.
189 | Parameters:
190 | ----------
191 | channels : list of list of int
192 | Number of output channels for each unit.
193 | kernel_sizes : list of list of int
194 | Convolution window sizes for each unit.
195 | strides : list of list of int or tuple/list of 2 int
196 | Strides of the convolution for each unit.
197 | paddings : list of list of int or tuple/list of 2 int
198 | Padding value for convolution layer for each unit.
199 | use_lrn : bool
200 | Whether to use LRN layer.
201 | in_channels : int, default 3
202 | Number of input channels.
203 | in_size : tuple of two ints, default (224, 224)
204 | Spatial size of the expected input image.
205 | num_classes : int, default 1000
206 | Number of classification classes.
207 | """
208 | def __init__(self,
209 | channels,
210 | kernel_sizes,
211 | strides,
212 | paddings,
213 | use_lrn,
214 | in_channels=3,
215 | in_size=(224, 224),
216 | num_classes=1000):
217 | super(AlexNet, self).__init__()
218 | self.in_size = in_size
219 | self.num_classes = num_classes
220 |
221 | self.features = nn.Sequential()
222 | for i, channels_per_stage in enumerate(channels):
223 | use_lrn_i = use_lrn and (i in [0, 1])
224 | stage = nn.Sequential()
225 | for j, out_channels in enumerate(channels_per_stage):
226 | stage.add_module("unit{}".format(j + 1), AlexConv(
227 | in_channels=in_channels,
228 | out_channels=out_channels,
229 | kernel_size=kernel_sizes[i][j],
230 | stride=strides[i][j],
231 | padding=paddings[i][j],
232 | use_lrn=use_lrn_i))
233 | in_channels = out_channels
234 | stage.add_module("pool{}".format(i + 1), nn.MaxPool2d(
235 | kernel_size=3,
236 | stride=2,
237 | padding=0,
238 | ceil_mode=True))
239 | self.features.add_module("stage{}".format(i + 1), stage)
240 |
241 | self.output = AlexOutputBlock(
242 | in_channels=(in_channels * 6 * 6),
243 | classes=num_classes)
244 |
245 | self._init_params()
246 |
247 | def _init_params(self):
248 | for name, module in self.named_modules():
249 | if isinstance(module, nn.Conv2d):
250 | init.kaiming_uniform_(module.weight)
251 | if module.bias is not None:
252 | init.constant_(module.bias, 0)
253 |
254 | def forward(self, x):
255 | x = self.features(x)
256 | x = x.view(x.size(0), -1)
257 | x = self.output(x)
258 | return x
259 |
260 |
261 |
262 | def get_activation_layer(activation):
263 | """
264 | Create activation layer from string/function.
265 | Parameters:
266 | ----------
267 | activation : function, or str, or nn.Module
268 | Activation function or name of activation function.
269 | Returns
270 | -------
271 | nn.Module
272 | Activation layer.
273 | """
274 | assert (activation is not None)
275 | if isfunction(activation):
276 | return activation()
277 | elif isinstance(activation, str):
278 | if activation == "relu":
279 | return nn.ReLU(inplace=True)
280 | elif activation == "relu6":
281 | return nn.ReLU6(inplace=True)
282 | elif activation == "swish":
283 | return Swish()
284 | elif activation == "hswish":
285 | return HSwish(inplace=True)
286 | elif activation == "sigmoid":
287 | return nn.Sigmoid()
288 | elif activation == "hsigmoid":
289 | return HSigmoid()
290 | elif activation == "identity":
291 | return Identity()
292 | else:
293 | raise NotImplementedError()
294 | else:
295 | assert (isinstance(activation, nn.Module))
296 | return activation
297 |
298 |
299 | def get_alexnet(version="a",
300 | model_name=None,
301 | pretrained=False,
302 | root=os.path.join("~", ".torch", "models"),
303 | **kwargs):
304 | """
305 | Create AlexNet model with specific parameters.
306 | Parameters:
307 | ----------
308 | version : str, default 'a'
309 | Version of AlexNet ('a' or 'b').
310 | model_name : str or None, default None
311 | Model name for loading pretrained model.
312 | pretrained : bool, default False
313 | Whether to load the pretrained weights for model.
314 | root : str, default '~/.torch/models'
315 | Location for keeping the model parameters.
316 | """
317 | if version == "a":
318 | channels = [[96], [256], [384, 384, 256]]
319 | kernel_sizes = [[11], [5], [3, 3, 3]]
320 | strides = [[4], [1], [1, 1, 1]]
321 | paddings = [[0], [2], [1, 1, 1]]
322 | use_lrn = True
323 | elif version == "b":
324 | channels = [[64], [192], [384, 256, 256]]
325 | kernel_sizes = [[11], [5], [3, 3, 3]]
326 | strides = [[4], [1], [1, 1, 1]]
327 | paddings = [[2], [2], [1, 1, 1]]
328 | use_lrn = False
329 | else:
330 | raise ValueError("Unsupported AlexNet version {}".format(version))
331 |
332 | net = AlexNet(
333 | channels=channels,
334 | kernel_sizes=kernel_sizes,
335 | strides=strides,
336 | paddings=paddings,
337 | use_lrn=use_lrn,
338 | **kwargs)
339 |
340 | if pretrained:
341 | if (model_name is None) or (not model_name):
342 | raise ValueError("Parameter `model_name` should be properly initialized for loading pretrained model.")
343 | # from .model_store import download_model
344 | # download_model(
345 | # net=net,
346 | # model_name=model_name,
347 | # local_model_store_dir_path=root)
348 |
349 | return net
350 |
351 |
352 | def alexnet(**kwargs):
353 | """
354 | AlexNet model from 'One weird trick for parallelizing convolutional neural networks,'
355 | https://arxiv.org/abs/1404.5997.
356 | Parameters:
357 | ----------
358 | pretrained : bool, default False
359 | Whether to load the pretrained weights for model.
360 | root : str, default '~/.torch/models'
361 | Location for keeping the model parameters.
362 | """
363 | return get_alexnet(model_name="alexnet", **kwargs)
364 |
365 |
366 | def alexnetb(**kwargs):
367 | """
368 | AlexNet-b model from 'One weird trick for parallelizing convolutional neural networks,'
369 | https://arxiv.org/abs/1404.5997. Non-standard version.
370 | Parameters:
371 | ----------
372 | pretrained : bool, default False
373 | Whether to load the pretrained weights for model.
374 | root : str, default '~/.torch/models'
375 | Location for keeping the model parameters.
376 | """
377 | return get_alexnet(version="b", model_name="alexnetb", **kwargs)
378 |
379 |
380 | def _calc_width(net):
381 | import numpy as np
382 | net_params = filter(lambda p: p.requires_grad, net.parameters())
383 | weight_count = 0
384 | for param in net_params:
385 | weight_count += np.prod(param.size())
386 | return weight_count
387 |
388 |
389 | def _test():
390 | import torch
391 |
392 | pretrained = False
393 |
394 | models = [
395 | alexnet,
396 | alexnetb,
397 | ]
398 |
399 | for model in models:
400 |
401 | net = model(pretrained=pretrained)
402 | print (net)
403 | # net.train()
404 | net.eval()
405 | weight_count = _calc_width(net)
406 | print("m={}, {}".format(model.__name__, weight_count))
407 | assert (model != alexnet or weight_count == 62378344)
408 | assert (model != alexnetb or weight_count == 61100840)
409 |
410 | x = torch.randn(1, 3, 224, 224)
411 | y = net(x)
412 | # y.sum().backward()
413 | assert (tuple(y.size()) == (1, 1000))
414 |
415 |
416 | if __name__ == "__main__":
417 | _test()
--------------------------------------------------------------------------------
/models/glouncv/alexnet_bn.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.nn.init as init
5 | from inspect import isfunction
6 |
7 | class ConvBlock(nn.Module):
8 | """
9 | Standard convolution block with Batch normalization and activation.
10 | Parameters:
11 | ----------
12 | in_channels : int
13 | Number of input channels.
14 | out_channels : int
15 | Number of output channels.
16 | kernel_size : int or tuple/list of 2 int
17 | Convolution window size.
18 | stride : int or tuple/list of 2 int
19 | Strides of the convolution.
20 | padding : int, or tuple/list of 2 int, or tuple/list of 4 int
21 | Padding value for convolution layer.
22 | dilation : int or tuple/list of 2 int, default 1
23 | Dilation value for convolution layer.
24 | groups : int, default 1
25 | Number of groups.
26 | bias : bool, default False
27 | Whether the layer uses a bias vector.
28 | use_bn : bool, default True
29 | Whether to use BatchNorm layer.
30 | bn_eps : float, default 1e-5
31 | Small float added to variance in Batch norm.
32 | activation : function or str or None, default nn.ReLU(inplace=True)
33 | Activation function or name of activation function.
34 | """
35 | def __init__(self,
36 | in_channels,
37 | out_channels,
38 | kernel_size,
39 | stride,
40 | padding,
41 | dilation=1,
42 | groups=1,
43 | bias=False,
44 | use_bn=True,
45 | bn_eps=1e-5,
46 | activation=(lambda: nn.ReLU(inplace=True))):
47 | super(ConvBlock, self).__init__()
48 | self.activate = (activation is not None)
49 | self.use_bn = use_bn
50 | self.use_pad = (isinstance(padding, (list, tuple)) and (len(padding) == 4))
51 |
52 | if self.use_pad:
53 | self.pad = nn.ZeroPad2d(padding=padding)
54 | padding = 0
55 | self.conv = nn.Conv2d(
56 | in_channels=in_channels,
57 | out_channels=out_channels,
58 | kernel_size=kernel_size,
59 | stride=stride,
60 | padding=padding,
61 | dilation=dilation,
62 | groups=groups,
63 | bias=bias)
64 | if self.use_bn:
65 | self.bn = nn.BatchNorm2d(
66 | num_features=out_channels,
67 | eps=bn_eps)
68 | if self.activate:
69 | self.activ = get_activation_layer(activation)
70 |
71 | def forward(self, x):
72 | if self.use_pad:
73 | x = self.pad(x)
74 | x = self.conv(x)
75 | if self.use_bn:
76 | x = self.bn(x)
77 | if self.activate:
78 | x = self.activ(x)
79 | return x
80 |
81 |
82 |
83 | class AlexConv(ConvBlock):
84 | """
85 | AlexNet specific convolution block.
86 | Parameters:
87 | ----------
88 | in_channels : int
89 | Number of input channels.
90 | out_channels : int
91 | Number of output channels.
92 | kernel_size : int or tuple/list of 2 int
93 | Convolution window size.
94 | stride : int or tuple/list of 2 int
95 | Strides of the convolution.
96 | padding : int or tuple/list of 2 int
97 | Padding value for convolution layer.
98 | use_lrn : bool
99 | Whether to use LRN layer.
100 | """
101 | def __init__(self,
102 | in_channels,
103 | out_channels,
104 | kernel_size,
105 | stride,
106 | padding,
107 | use_lrn):
108 | super(AlexConv, self).__init__(
109 | in_channels=in_channels,
110 | out_channels=out_channels,
111 | kernel_size=kernel_size,
112 | stride=stride,
113 | padding=padding,
114 | bias=True,
115 | use_bn=use_lrn)
116 | self.use_lrn = False #hardcoding.
117 |
118 | def forward(self, x):
119 | x = super(AlexConv, self).forward(x)
120 | if self.use_lrn:
121 | x = F.local_response_norm(x, size=5, k=2.0)
122 | return x
123 |
124 |
125 | class AlexDense(nn.Module):
126 | """
127 | AlexNet specific dense block.
128 | Parameters:
129 | ----------
130 | in_channels : int
131 | Number of input channels.
132 | out_channels : int
133 | Number of output channels.
134 | """
135 | def __init__(self,
136 | in_channels,
137 | out_channels):
138 | super(AlexDense, self).__init__()
139 | self.fc = nn.Linear(
140 | in_features=in_channels,
141 | out_features=out_channels)
142 | self.activ = nn.ReLU(inplace=True)
143 | self.dropout = nn.Dropout(p=0.5)
144 |
145 | def forward(self, x):
146 | x = self.fc(x)
147 | x = self.activ(x)
148 | x = self.dropout(x)
149 | return x
150 |
151 |
152 | class AlexOutputBlock(nn.Module):
153 | """
154 | AlexNet specific output block.
155 | Parameters:
156 | ----------
157 | in_channels : int
158 | Number of input channels.
159 | classes : int
160 | Number of classification classes.
161 | """
162 | def __init__(self,
163 | in_channels,
164 | classes):
165 | super(AlexOutputBlock, self).__init__()
166 | mid_channels = 4096
167 |
168 | self.fc1 = AlexDense(
169 | in_channels=in_channels,
170 | out_channels=mid_channels)
171 | self.fc2 = AlexDense(
172 | in_channels=mid_channels,
173 | out_channels=mid_channels)
174 | self.fc3 = nn.Linear(
175 | in_features=mid_channels,
176 | out_features=classes)
177 |
178 | def forward(self, x):
179 | x = self.fc1(x)
180 | x = self.fc2(x)
181 | x = self.fc3(x)
182 | return x
183 |
184 |
185 | class AlexNet(nn.Module):
186 | """
187 | AlexNet model from 'One weird trick for parallelizing convolutional neural networks,'
188 | https://arxiv.org/abs/1404.5997.
189 | Parameters:
190 | ----------
191 | channels : list of list of int
192 | Number of output channels for each unit.
193 | kernel_sizes : list of list of int
194 | Convolution window sizes for each unit.
195 | strides : list of list of int or tuple/list of 2 int
196 | Strides of the convolution for each unit.
197 | paddings : list of list of int or tuple/list of 2 int
198 | Padding value for convolution layer for each unit.
199 | use_lrn : bool
200 | Whether to use LRN layer.
201 | in_channels : int, default 3
202 | Number of input channels.
203 | in_size : tuple of two ints, default (224, 224)
204 | Spatial size of the expected input image.
205 | num_classes : int, default 1000
206 | Number of classification classes.
207 | """
208 | def __init__(self,
209 | channels,
210 | kernel_sizes,
211 | strides,
212 | paddings,
213 | use_lrn,
214 | in_channels=3,
215 | in_size=(224, 224),
216 | num_classes=1000):
217 | super(AlexNet, self).__init__()
218 | self.in_size = in_size
219 | self.num_classes = num_classes
220 |
221 | self.features = nn.Sequential()
222 | for i, channels_per_stage in enumerate(channels):
223 | use_lrn_i = use_lrn and (i in [0, 1])
224 | stage = nn.Sequential()
225 | for j, out_channels in enumerate(channels_per_stage):
226 | stage.add_module("unit{}".format(j + 1), AlexConv(
227 | in_channels=in_channels,
228 | out_channels=out_channels,
229 | kernel_size=kernel_sizes[i][j],
230 | stride=strides[i][j],
231 | padding=paddings[i][j],
232 | use_lrn=use_lrn_i))
233 | in_channels = out_channels
234 | stage.add_module("pool{}".format(i + 1), nn.MaxPool2d(
235 | kernel_size=3,
236 | stride=2,
237 | padding=0,
238 | ceil_mode=True))
239 | self.features.add_module("stage{}".format(i + 1), stage)
240 |
241 | self.output = AlexOutputBlock(
242 | in_channels=(in_channels * 6 * 6),
243 | classes=num_classes)
244 |
245 | self._init_params()
246 |
247 | def _init_params(self):
248 | for name, module in self.named_modules():
249 | if isinstance(module, nn.Conv2d):
250 | init.kaiming_uniform_(module.weight)
251 | if module.bias is not None:
252 | init.constant_(module.bias, 0)
253 |
254 | def forward(self, x):
255 | x = self.features(x)
256 | x = x.view(x.size(0), -1)
257 | x = self.output(x)
258 | return x
259 |
260 |
261 |
262 | def get_activation_layer(activation):
263 | """
264 | Create activation layer from string/function.
265 | Parameters:
266 | ----------
267 | activation : function, or str, or nn.Module
268 | Activation function or name of activation function.
269 | Returns
270 | -------
271 | nn.Module
272 | Activation layer.
273 | """
274 | assert (activation is not None)
275 | if isfunction(activation):
276 | return activation()
277 | elif isinstance(activation, str):
278 | if activation == "relu":
279 | return nn.ReLU(inplace=True)
280 | elif activation == "relu6":
281 | return nn.ReLU6(inplace=True)
282 | elif activation == "swish":
283 | return Swish()
284 | elif activation == "hswish":
285 | return HSwish(inplace=True)
286 | elif activation == "sigmoid":
287 | return nn.Sigmoid()
288 | elif activation == "hsigmoid":
289 | return HSigmoid()
290 | elif activation == "identity":
291 | return Identity()
292 | else:
293 | raise NotImplementedError()
294 | else:
295 | assert (isinstance(activation, nn.Module))
296 | return activation
297 |
298 |
299 | def get_alexnet(version="a",
300 | model_name=None,
301 | pretrained=False,
302 | root=os.path.join("~", ".torch", "models"),
303 | **kwargs):
304 | """
305 | Create AlexNet model with specific parameters.
306 | Parameters:
307 | ----------
308 | version : str, default 'a'
309 | Version of AlexNet ('a' or 'b').
310 | model_name : str or None, default None
311 | Model name for loading pretrained model.
312 | pretrained : bool, default False
313 | Whether to load the pretrained weights for model.
314 | root : str, default '~/.torch/models'
315 | Location for keeping the model parameters.
316 | """
317 | if version == "a":
318 | channels = [[96], [256], [384, 384, 256]]
319 | kernel_sizes = [[11], [5], [3, 3, 3]]
320 | strides = [[4], [1], [1, 1, 1]]
321 | paddings = [[0], [2], [1, 1, 1]]
322 | use_lrn = True
323 | elif version == "b":
324 | channels = [[64], [192], [384, 256, 256]]
325 | kernel_sizes = [[11], [5], [3, 3, 3]]
326 | strides = [[4], [1], [1, 1, 1]]
327 | paddings = [[2], [2], [1, 1, 1]]
328 | use_lrn = False
329 | else:
330 | raise ValueError("Unsupported AlexNet version {}".format(version))
331 |
332 | net = AlexNet(
333 | channels=channels,
334 | kernel_sizes=kernel_sizes,
335 | strides=strides,
336 | paddings=paddings,
337 | use_lrn=use_lrn,
338 | **kwargs)
339 |
340 | if pretrained:
341 | if (model_name is None) or (not model_name):
342 | raise ValueError("Parameter `model_name` should be properly initialized for loading pretrained model.")
343 | # from .model_store import download_model
344 | # download_model(
345 | # net=net,
346 | # model_name=model_name,
347 | # local_model_store_dir_path=root)
348 |
349 | return net
350 |
351 |
352 | def alexnet(**kwargs):
353 | """
354 | AlexNet model from 'One weird trick for parallelizing convolutional neural networks,'
355 | https://arxiv.org/abs/1404.5997.
356 | Parameters:
357 | ----------
358 | pretrained : bool, default False
359 | Whether to load the pretrained weights for model.
360 | root : str, default '~/.torch/models'
361 | Location for keeping the model parameters.
362 | """
363 | return get_alexnet(model_name="alexnet", **kwargs)
364 |
365 |
366 | def alexnetb(**kwargs):
367 | """
368 | AlexNet-b model from 'One weird trick for parallelizing convolutional neural networks,'
369 | https://arxiv.org/abs/1404.5997. Non-standard version.
370 | Parameters:
371 | ----------
372 | pretrained : bool, default False
373 | Whether to load the pretrained weights for model.
374 | root : str, default '~/.torch/models'
375 | Location for keeping the model parameters.
376 | """
377 | return get_alexnet(version="b", model_name="alexnetb", **kwargs)
378 |
379 |
380 | def _calc_width(net):
381 | import numpy as np
382 | net_params = filter(lambda p: p.requires_grad, net.parameters())
383 | weight_count = 0
384 | for param in net_params:
385 | weight_count += np.prod(param.size())
386 | return weight_count
387 |
388 |
389 | def _test():
390 | import torch
391 |
392 | pretrained = False
393 |
394 | models = [
395 | alexnet,
396 | alexnetb,
397 | ]
398 |
399 | for model in models:
400 |
401 | net = model(pretrained=pretrained)
402 | print (net)
403 | # net.train()
404 | net.eval()
405 | weight_count = _calc_width(net)
406 | print("m={}, {}".format(model.__name__, weight_count))
407 | assert (model != alexnet or weight_count == 62378344)
408 | assert (model != alexnetb or weight_count == 61100840)
409 |
410 | x = torch.randn(1, 3, 224, 224)
411 | y = net(x)
412 | # y.sum().backward()
413 | assert (tuple(y.size()) == (1, 1000))
414 |
415 |
416 | if __name__ == "__main__":
417 | _test()
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | '''Some helper functions for PyTorch, including:
2 | - get_mean_and_std: calculate the mean and std value of dataset.
3 | - msr_init: net parameter initialization.
4 | - progress_bar: progress bar mimic xlua.progress.
5 | '''
6 | import os
7 | import sys
8 | import time
9 | import math
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.init as init
14 | import torchvision
15 | import torchvision.transforms as transforms
16 |
17 | def get_mean_and_std(dataset):
18 | '''Compute the mean and std value of dataset.'''
19 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
20 | mean = torch.zeros(3)
21 | std = torch.zeros(3)
22 | print('==> Computing mean and std..')
23 | for inputs, targets in dataloader:
24 | for i in range(3):
25 | mean[i] += inputs[:,i,:,:].mean()
26 | std[i] += inputs[:,i,:,:].std()
27 | mean.div_(len(dataset))
28 | std.div_(len(dataset))
29 | return mean, std
30 |
31 | def init_params(net):
32 | '''Init layer parameters.'''
33 | for m in net.modules():
34 | if isinstance(m, nn.Conv2d):
35 | init.kaiming_normal(m.weight, mode='fan_out')
36 | if m.bias:
37 | init.constant(m.bias, 0)
38 | elif isinstance(m, nn.BatchNorm2d):
39 | init.constant(m.weight, 1)
40 | init.constant(m.bias, 0)
41 | elif isinstance(m, nn.Linear):
42 | init.normal(m.weight, std=1e-3)
43 | if m.bias:
44 | init.constant(m.bias, 0)
45 |
46 |
47 | _, term_width = os.popen('stty size', 'r').read().split()
48 | term_width = int(term_width)
49 |
50 | TOTAL_BAR_LENGTH = 65.
51 | last_time = time.time()
52 | begin_time = last_time
53 | def progress_bar(current, total, msg=None):
54 | global last_time, begin_time
55 | if current == 0:
56 | begin_time = time.time() # Reset for new bar.
57 |
58 | cur_len = int(TOTAL_BAR_LENGTH*current/total)
59 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
60 |
61 | sys.stdout.write(' [')
62 | for i in range(cur_len):
63 | sys.stdout.write('=')
64 | sys.stdout.write('>')
65 | for i in range(rest_len):
66 | sys.stdout.write('.')
67 | sys.stdout.write(']')
68 |
69 | cur_time = time.time()
70 | step_time = cur_time - last_time
71 | last_time = cur_time
72 | tot_time = cur_time - begin_time
73 |
74 | L = []
75 | L.append(' Step: %s' % format_time(step_time))
76 | L.append(' | Tot: %s' % format_time(tot_time))
77 | if msg:
78 | L.append(' | ' + msg)
79 |
80 | msg = ''.join(L)
81 | sys.stdout.write(msg)
82 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
83 | sys.stdout.write(' ')
84 |
85 | # Go back to the center of the bar.
86 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
87 | sys.stdout.write('\b')
88 | sys.stdout.write(' %d/%d ' % (current+1, total))
89 |
90 | if current < total-1:
91 | sys.stdout.write('\r')
92 | else:
93 | sys.stdout.write('\n')
94 | sys.stdout.flush()
95 |
96 | def format_time(seconds):
97 | days = int(seconds / 3600/24)
98 | seconds = seconds - days*3600*24
99 | hours = int(seconds / 3600)
100 | seconds = seconds - hours*3600
101 | minutes = int(seconds / 60)
102 | seconds = seconds - minutes*60
103 | secondsf = int(seconds)
104 | seconds = seconds - secondsf
105 | millis = int(seconds*1000)
106 |
107 | f = ''
108 | i = 1
109 | if days > 0:
110 | f += str(days) + 'D'
111 | i += 1
112 | if hours > 0 and i <= 2:
113 | f += str(hours) + 'h'
114 | i += 1
115 | if minutes > 0 and i <= 2:
116 | f += str(minutes) + 'm'
117 | i += 1
118 | if secondsf > 0 and i <= 2:
119 | f += str(secondsf) + 's'
120 | i += 1
121 | if millis > 0 and i <= 2:
122 | f += str(millis) + 'ms'
123 | i += 1
124 | if f == '':
125 | f = '0ms'
126 | return f
127 |
128 |
129 | def replace_all(model, replacement_dict={}):
130 | """
131 | Replace all layers in the original model with new layers corresponding to `replacement_dict`.
132 | E.g input example:
133 | replacement_dict={ nn.Conv2d : partial(NIPS2019_QConv2d, bit=args.bit) }
134 | """
135 |
136 | def __replace_module(model):
137 | for module_name in model._modules:
138 | m = model._modules[module_name]
139 |
140 | if type(m) in replacement_dict.keys():
141 | if isinstance(m, nn.Conv2d):
142 | new_module = replacement_dict[type(m)]
143 | model._modules[module_name] = new_module(in_channels=m.in_channels,
144 | out_channels=m.out_channels, kernel_size=m.kernel_size,
145 | stride=m.stride, padding=m.padding, dilation=m.dilation,
146 | groups=m.groups, bias=(m.bias!=None))
147 |
148 | elif isinstance(m, nn.Linear):
149 | new_module = replacement_dict[type(m)]
150 | model._modules[module_name] = new_module(in_features=m.in_features,
151 | out_features=m.out_features,
152 | bias=(m.bias!=None))
153 |
154 | elif len(model._modules[module_name]._modules) > 0:
155 | __replace_module(model._modules[module_name])
156 |
157 | __replace_module(model)
158 | return model
159 |
160 |
161 | def replace_single_module(new_cls, current_module):
162 | m = current_module
163 | if isinstance(m, nn.Conv2d):
164 | return new_cls(in_channels=m.in_channels,
165 | out_channels=m.out_channels, kernel_size=m.kernel_size,
166 | stride=m.stride, padding=m.padding, dilation=m.dilation,
167 | groups=m.groups, bias=(m.bias!=None))
168 |
169 | elif isinstance(m, nn.Linear):
170 | return new_cls(in_features=m.in_features, out_features=m.out_features, bias=(m.bias != None))
171 |
172 | return None
173 |
174 |
175 |
176 | def replace_module(model, replacement_dict={}, exception_dict={}, arch="presnet18"):
177 | """
178 | Replace all layers in the original model with new layers corresponding to `replacement_dict`.
179 | E.g input example:
180 | replacement_dict={ nn.Conv2d : partial(NIPS2019_QConv2d, bit=args.bit) }
181 | exception_dict={
182 | 'conv1': partial(NIPS2019_QConv2d, bit=8)
183 | 'fc': partial(NIPS2019_QLinear, bit=8)
184 | }
185 | """
186 | assert arch in ["presnet32", "presnet18", "glouncv-alexnet", "glouncv-alexnet-bn", "postech-alexnet", "glouncv-presnet34", "glouncv-presnet50", "glouncv-mobilenetv2_w1"],\
187 | ("Not support this type of architecture !")
188 |
189 | model = replace_all(model, replacement_dict=replacement_dict)
190 |
191 | if arch == "presnet32":
192 | model.conv1 = replace_single_module(new_cls=exception_dict['__first__'], current_module=model.conv1)
193 | model.fc = replace_single_module(new_cls=exception_dict['__last__'], current_module=model.fc)
194 |
195 | if "__downsampling__" in exception_dict.keys():
196 | new_conv_cls = exception_dict['__downsampling__']
197 | model.layer2[0].downsample[0] = replace_single_module(new_cls=new_conv_cls, current_module=model.layer2[0].downsample[0] )
198 | model.layer3[0].downsample[0] = replace_single_module(new_cls=new_conv_cls, current_module=model.layer3[0].downsample[0] )
199 |
200 | if arch == "presnet18":
201 | model.conv1 = replace_single_module(new_cls=exception_dict['__first__'], current_module=model.conv1)
202 | model.fc = replace_single_module(new_cls=exception_dict['__last__'], current_module=model.fc)
203 |
204 | if "__downsampling__" in exception_dict.keys():
205 | new_conv_cls = exception_dict['__downsampling__']
206 | model.layer2[0].shortcut[0] = replace_single_module(new_cls=new_conv_cls, current_module=model.layer2[0].shortcut[0] )
207 | model.layer3[0].shortcut[0] = replace_single_module(new_cls=new_conv_cls, current_module=model.layer3[0].shortcut[0] )
208 | model.layer4[0].shortcut[0] = replace_single_module(new_cls=new_conv_cls, current_module=model.layer4[0].shortcut[0] )
209 |
210 | if arch == "glouncv-presnet34":
211 | model.features.init_block.conv = replace_single_module(new_cls=exception_dict['__first__'], current_module=model.features.init_block.conv)
212 | model.output = replace_single_module(new_cls=exception_dict['__last__'], current_module=model.output)
213 |
214 | if "__downsampling__" in exception_dict.keys():
215 | new_conv_cls = exception_dict['__downsampling__']
216 | model.features.stage2.unit1.identity_conv = replace_single_module(new_cls=new_conv_cls, current_module=model.features.stage2.unit1.identity_conv )
217 | model.features.stage3.unit1.identity_conv = replace_single_module(new_cls=new_conv_cls, current_module=model.features.stage3.unit1.identity_conv )
218 | model.features.stage4.unit1.identity_conv = replace_single_module(new_cls=new_conv_cls, current_module=model.features.stage4.unit1.identity_conv )
219 |
220 | if arch == "glouncv-presnet50":
221 | model.features.init_block.conv = replace_single_module(new_cls=exception_dict['__first__'], current_module=model.features.init_block.conv)
222 | model.output = replace_single_module(new_cls=exception_dict['__last__'], current_module=model.output)
223 | if "__downsampling__" in exception_dict.keys():
224 | new_conv_cls = exception_dict['__downsampling__']
225 | model.features.stage1.unit1.identity_conv = replace_single_module(new_cls=new_conv_cls, current_module=model.features.stage1.unit1.identity_conv )
226 | model.features.stage2.unit1.identity_conv = replace_single_module(new_cls=new_conv_cls, current_module=model.features.stage2.unit1.identity_conv )
227 | model.features.stage3.unit1.identity_conv = replace_single_module(new_cls=new_conv_cls, current_module=model.features.stage3.unit1.identity_conv )
228 | model.features.stage4.unit1.identity_conv = replace_single_module(new_cls=new_conv_cls, current_module=model.features.stage4.unit1.identity_conv )
229 |
230 | if arch in ["glouncv-alexnet", "glouncv-alexnet-bn"]:
231 | model.features.stage1.unit1.conv = replace_single_module(new_cls=exception_dict['__first__'], current_module=model.features.stage1.unit1.conv)
232 | model.output.fc3 = replace_single_module(new_cls=exception_dict['__last__'], current_module=model.output.fc3)
233 |
234 | if arch == "glouncv-mobilenetv2_w1":
235 | model.features.init_block.conv = replace_single_module(new_cls=exception_dict['__first__'], current_module=model.features.init_block.conv)
236 | model.output = replace_single_module(new_cls=exception_dict['__last__'], current_module=model.output)
237 | return model
238 |
239 |
240 |
241 |
242 | def get_dataloaders(dataset="cifar100", batch_size=128, data_root="~/data"):
243 | if dataset in ("imagenet", "imagenette"):
244 | traindir = os.path.join(data_root, 'train')
245 | valdir = os.path.join(data_root, 'val')
246 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
247 | std=[0.229, 0.224, 0.225])
248 |
249 | train_dataset = torchvision.datasets.ImageFolder(
250 | traindir,
251 | transforms.Compose([
252 | transforms.RandomResizedCrop(224),
253 | transforms.RandomHorizontalFlip(),
254 | transforms.ToTensor(),
255 | normalize,
256 | ]))
257 |
258 |
259 | trainloader = torch.utils.data.DataLoader(
260 | train_dataset, batch_size=batch_size, shuffle=True,
261 | num_workers=4, pin_memory=True, sampler=None)
262 |
263 | testloader = torch.utils.data.DataLoader(
264 | torchvision.datasets.ImageFolder(valdir, transforms.Compose([
265 | transforms.Resize(256),
266 | transforms.CenterCrop(224),
267 | transforms.ToTensor(),
268 | normalize,
269 | ])),
270 | batch_size=batch_size, shuffle=False,
271 | num_workers=4, pin_memory=True)
272 |
273 |
274 | elif dataset == "cifar100":
275 |
276 | transform_train = transforms.Compose([
277 | transforms.RandomCrop(32, padding=4),
278 | transforms.RandomHorizontalFlip(),
279 | # transforms.RandomRotation(15), #ResNet20, #ResNet32 does not have enough capacity for this transformation.
280 | transforms.ToTensor(),
281 | transforms.Normalize(mean=(0.5070751592371323, 0.48654887331495095, 0.4409178433670343),
282 | std=(0.2673342858792401, 0.2564384629170883, 0.27615047132568404)),
283 | ])
284 |
285 | transform_test = transforms.Compose([
286 | transforms.ToTensor(),
287 | transforms.Normalize(mean=(0.5070751592371323, 0.48654887331495095, 0.4409178433670343),
288 | std=(0.2673342858792401, 0.2564384629170883, 0.27615047132568404)),
289 | ])
290 |
291 |
292 | trainloader = torch.utils.data.DataLoader(torchvision.datasets.CIFAR100(
293 | root=data_root, train=True, download=True,
294 | transform=transform_train),
295 | batch_size=batch_size, shuffle=True,
296 | num_workers=4)
297 |
298 | testloader = torch.utils.data.DataLoader(torchvision.datasets.CIFAR100(
299 | root=data_root, train=False, download=True,
300 | transform=transform_test),
301 | batch_size=batch_size, shuffle=False,
302 | num_workers=4)
303 |
304 | else:
305 | raise NotImplementedError('Not support this type of dataset: ' + dataset)
306 |
307 | return trainloader, testloader
308 |
309 |
310 | def save_checkpoint(net, lr_scheduler, optimizer, acc, epoch, filename='ckpt_best.pth'):
311 | state = {
312 | 'net': net.state_dict(),
313 | 'acc': acc,
314 | 'epoch': epoch,
315 | 'lr_scheduler': lr_scheduler.state_dict() if lr_scheduler is not None \
316 | else None,
317 | 'optimizer': optimizer.state_dict() if optimizer is not None \
318 | else None,
319 | }
320 | torch.save(state, filename)
321 |
322 | def add_weight_decay(model, weight_decay, skip_keys):
323 | decay, no_decay = [], []
324 | for name, param in model.named_parameters():
325 | if not param.requires_grad:
326 | continue # frozen weights
327 | added = False
328 | for skip_key in skip_keys:
329 | if skip_key in name:
330 | print ("Skip weight decay for: ", name)
331 | no_decay.append(param)
332 | added = True
333 | break
334 | if not added:
335 | decay.append(param)
336 | return [{'params': no_decay, 'weight_decay': 0.}, {'params': decay, 'weight_decay': weight_decay}]
337 |
--------------------------------------------------------------------------------
/models/glouncv/preresnet_cifar.py:
--------------------------------------------------------------------------------
1 | """
2 | PreResNet for CIFAR/SVHN, implemented in PyTorch.
3 | Original papers: 'Identity Mappings in Deep Residual Networks,' https://arxiv.org/abs/1603.05027.
4 | """
5 |
6 | __all__ = ['CIFARPreResNet', 'preresnet20_cifar10', 'preresnet20_cifar100', 'preresnet20_svhn',
7 | 'preresnet56_cifar10', 'preresnet56_cifar100', 'preresnet56_svhn',
8 | 'preresnet110_cifar10', 'preresnet110_cifar100', 'preresnet110_svhn',
9 | 'preresnet164bn_cifar10', 'preresnet164bn_cifar100', 'preresnet164bn_svhn',
10 | 'preresnet272bn_cifar10', 'preresnet272bn_cifar100', 'preresnet272bn_svhn',
11 | 'preresnet542bn_cifar10', 'preresnet542bn_cifar100', 'preresnet542bn_svhn',
12 | 'preresnet1001_cifar10', 'preresnet1001_cifar100', 'preresnet1001_svhn',
13 | 'preresnet1202_cifar10', 'preresnet1202_cifar100', 'preresnet1202_svhn']
14 |
15 | import os
16 | import torch.nn as nn
17 | import torch.nn.init as init
18 | from .common import conv3x3
19 | from .preresnet import PreResUnit, PreResActivation
20 |
21 |
22 | class CIFARPreResNet(nn.Module):
23 | """
24 | PreResNet model for CIFAR from 'Identity Mappings in Deep Residual Networks,' https://arxiv.org/abs/1603.05027.
25 | Parameters:
26 | ----------
27 | channels : list of list of int
28 | Number of output channels for each unit.
29 | init_block_channels : int
30 | Number of output channels for the initial unit.
31 | bottleneck : bool
32 | Whether to use a bottleneck or simple block in units.
33 | in_channels : int, default 3
34 | Number of input channels.
35 | in_size : tuple of two ints, default (32, 32)
36 | Spatial size of the expected input image.
37 | num_classes : int, default 10
38 | Number of classification classes.
39 | """
40 | def __init__(self,
41 | channels,
42 | init_block_channels,
43 | bottleneck,
44 | in_channels=3,
45 | in_size=(32, 32),
46 | num_classes=10):
47 | super(CIFARPreResNet, self).__init__()
48 | self.in_size = in_size
49 | self.num_classes = num_classes
50 |
51 | self.features = nn.Sequential()
52 | self.features.add_module("init_block", conv3x3(
53 | in_channels=in_channels,
54 | out_channels=init_block_channels))
55 | in_channels = init_block_channels
56 | for i, channels_per_stage in enumerate(channels):
57 | stage = nn.Sequential()
58 | for j, out_channels in enumerate(channels_per_stage):
59 | stride = 2 if (j == 0) and (i != 0) else 1
60 | stage.add_module("unit{}".format(j + 1), PreResUnit(
61 | in_channels=in_channels,
62 | out_channels=out_channels,
63 | stride=stride,
64 | bottleneck=bottleneck,
65 | conv1_stride=False))
66 | in_channels = out_channels
67 | self.features.add_module("stage{}".format(i + 1), stage)
68 | self.features.add_module("post_activ", PreResActivation(in_channels=in_channels))
69 | self.features.add_module("final_pool", nn.AvgPool2d(
70 | kernel_size=8,
71 | stride=1))
72 |
73 | self.output = nn.Linear(
74 | in_features=in_channels,
75 | out_features=num_classes)
76 |
77 | self._init_params()
78 |
79 | def _init_params(self):
80 | for name, module in self.named_modules():
81 | if isinstance(module, nn.Conv2d):
82 | init.kaiming_uniform_(module.weight)
83 | if module.bias is not None:
84 | init.constant_(module.bias, 0)
85 |
86 | def forward(self, x):
87 | x = self.features(x)
88 | x = x.view(x.size(0), -1)
89 | x = self.output(x)
90 | return x
91 |
92 |
93 | def get_preresnet_cifar(num_classes,
94 | blocks,
95 | bottleneck,
96 | model_name=None,
97 | pretrained=False,
98 | root=os.path.join("~", ".torch", "models"),
99 | **kwargs):
100 | """
101 | Create PreResNet model for CIFAR with specific parameters.
102 | Parameters:
103 | ----------
104 | num_classes : int
105 | Number of classification classes.
106 | blocks : int
107 | Number of blocks.
108 | bottleneck : bool
109 | Whether to use a bottleneck or simple block in units.
110 | model_name : str or None, default None
111 | Model name for loading pretrained model.
112 | pretrained : bool, default False
113 | Whether to load the pretrained weights for model.
114 | root : str, default '~/.torch/models'
115 | Location for keeping the model parameters.
116 | """
117 | assert (num_classes in [10, 100])
118 |
119 | if bottleneck:
120 | assert ((blocks - 2) % 9 == 0)
121 | layers = [(blocks - 2) // 9] * 3
122 | else:
123 | assert ((blocks - 2) % 6 == 0)
124 | layers = [(blocks - 2) // 6] * 3
125 |
126 | channels_per_layers = [16, 32, 64]
127 | init_block_channels = 16
128 |
129 | channels = [[ci] * li for (ci, li) in zip(channels_per_layers, layers)]
130 |
131 | if bottleneck:
132 | channels = [[cij * 4 for cij in ci] for ci in channels]
133 |
134 | net = CIFARPreResNet(
135 | channels=channels,
136 | init_block_channels=init_block_channels,
137 | bottleneck=bottleneck,
138 | num_classes=num_classes,
139 | **kwargs)
140 |
141 | if pretrained:
142 | if (model_name is None) or (not model_name):
143 | raise ValueError("Parameter `model_name` should be properly initialized for loading pretrained model.")
144 | from .model_store import download_model
145 | download_model(
146 | net=net,
147 | model_name=model_name,
148 | local_model_store_dir_path=root)
149 |
150 | return net
151 |
152 |
153 | def preresnet20_cifar10(num_classes=10, **kwargs):
154 | """
155 | PreResNet-20 model for CIFAR-10 from 'Identity Mappings in Deep Residual Networks,'
156 | https://arxiv.org/abs/1603.05027.
157 | Parameters:
158 | ----------
159 | num_classes : int, default 10
160 | Number of classification classes.
161 | pretrained : bool, default False
162 | Whether to load the pretrained weights for model.
163 | root : str, default '~/.torch/models'
164 | Location for keeping the model parameters.
165 | """
166 | return get_preresnet_cifar(num_classes=num_classes, blocks=20, bottleneck=False, model_name="preresnet20_cifar10",
167 | **kwargs)
168 |
169 |
170 | def preresnet20_cifar100(num_classes=100, **kwargs):
171 | """
172 | PreResNet-20 model for CIFAR-100 from 'Identity Mappings in Deep Residual Networks,'
173 | https://arxiv.org/abs/1603.05027.
174 | Parameters:
175 | ----------
176 | num_classes : int, default 100
177 | Number of classification classes.
178 | pretrained : bool, default False
179 | Whether to load the pretrained weights for model.
180 | root : str, default '~/.torch/models'
181 | Location for keeping the model parameters.
182 | """
183 | return get_preresnet_cifar(num_classes=num_classes, blocks=20, bottleneck=False, model_name="preresnet20_cifar100",
184 | **kwargs)
185 |
186 |
187 | def preresnet20_svhn(num_classes=10, **kwargs):
188 | """
189 | PreResNet-20 model for SVHN from 'Identity Mappings in Deep Residual Networks,'
190 | https://arxiv.org/abs/1603.05027.
191 | Parameters:
192 | ----------
193 | num_classes : int, default 10
194 | Number of classification classes.
195 | pretrained : bool, default False
196 | Whether to load the pretrained weights for model.
197 | root : str, default '~/.torch/models'
198 | Location for keeping the model parameters.
199 | """
200 | return get_preresnet_cifar(num_classes=num_classes, blocks=20, bottleneck=False, model_name="preresnet20_svhn",
201 | **kwargs)
202 |
203 |
204 | def preresnet56_cifar10(num_classes=10, **kwargs):
205 | """
206 | PreResNet-56 model for CIFAR-10 from 'Identity Mappings in Deep Residual Networks,'
207 | https://arxiv.org/abs/1603.05027.
208 | Parameters:
209 | ----------
210 | num_classes : int, default 10
211 | Number of classification classes.
212 | pretrained : bool, default False
213 | Whether to load the pretrained weights for model.
214 | root : str, default '~/.torch/models'
215 | Location for keeping the model parameters.
216 | """
217 | return get_preresnet_cifar(num_classes=num_classes, blocks=56, bottleneck=False, model_name="preresnet56_cifar10",
218 | **kwargs)
219 |
220 |
221 | def preresnet56_cifar100(num_classes=100, **kwargs):
222 | """
223 | PreResNet-56 model for CIFAR-100 from 'Identity Mappings in Deep Residual Networks,'
224 | https://arxiv.org/abs/1603.05027.
225 | Parameters:
226 | ----------
227 | num_classes : int, default 100
228 | Number of classification classes.
229 | pretrained : bool, default False
230 | Whether to load the pretrained weights for model.
231 | root : str, default '~/.torch/models'
232 | Location for keeping the model parameters.
233 | """
234 | return get_preresnet_cifar(num_classes=num_classes, blocks=56, bottleneck=False, model_name="preresnet56_cifar100",
235 | **kwargs)
236 |
237 |
238 | def preresnet56_svhn(num_classes=10, **kwargs):
239 | """
240 | PreResNet-56 model for SVHN from 'Identity Mappings in Deep Residual Networks,'
241 | https://arxiv.org/abs/1603.05027.
242 | Parameters:
243 | ----------
244 | num_classes : int, default 10
245 | Number of classification classes.
246 | pretrained : bool, default False
247 | Whether to load the pretrained weights for model.
248 | root : str, default '~/.torch/models'
249 | Location for keeping the model parameters.
250 | """
251 | return get_preresnet_cifar(num_classes=num_classes, blocks=56, bottleneck=False, model_name="preresnet56_svhn",
252 | **kwargs)
253 |
254 |
255 | def preresnet110_cifar10(num_classes=10, **kwargs):
256 | """
257 | PreResNet-110 model for CIFAR-10 from 'Identity Mappings in Deep Residual Networks,'
258 | https://arxiv.org/abs/1603.05027.
259 | Parameters:
260 | ----------
261 | num_classes : int, default 10
262 | Number of classification classes.
263 | pretrained : bool, default False
264 | Whether to load the pretrained weights for model.
265 | root : str, default '~/.torch/models'
266 | Location for keeping the model parameters.
267 | """
268 | return get_preresnet_cifar(num_classes=num_classes, blocks=110, bottleneck=False, model_name="preresnet110_cifar10",
269 | **kwargs)
270 |
271 |
272 | def preresnet110_cifar100(num_classes=100, **kwargs):
273 | """
274 | PreResNet-110 model for CIFAR-100 from 'Identity Mappings in Deep Residual Networks,'
275 | https://arxiv.org/abs/1603.05027.
276 | Parameters:
277 | ----------
278 | num_classes : int, default 100
279 | Number of classification classes.
280 | pretrained : bool, default False
281 | Whether to load the pretrained weights for model.
282 | root : str, default '~/.torch/models'
283 | Location for keeping the model parameters.
284 | """
285 | return get_preresnet_cifar(num_classes=num_classes, blocks=110, bottleneck=False,
286 | model_name="preresnet110_cifar100", **kwargs)
287 |
288 |
289 | def preresnet110_svhn(num_classes=10, **kwargs):
290 | """
291 | PreResNet-110 model for SVHN from 'Identity Mappings in Deep Residual Networks,'
292 | https://arxiv.org/abs/1603.05027.
293 | Parameters:
294 | ----------
295 | num_classes : int, default 10
296 | Number of classification classes.
297 | pretrained : bool, default False
298 | Whether to load the pretrained weights for model.
299 | root : str, default '~/.torch/models'
300 | Location for keeping the model parameters.
301 | """
302 | return get_preresnet_cifar(num_classes=num_classes, blocks=110, bottleneck=False, model_name="preresnet110_svhn",
303 | **kwargs)
304 |
305 |
306 | def preresnet164bn_cifar10(num_classes=10, **kwargs):
307 | """
308 | PreResNet-164(BN) model for CIFAR-10 from 'Identity Mappings in Deep Residual Networks,'
309 | https://arxiv.org/abs/1603.05027.
310 | Parameters:
311 | ----------
312 | num_classes : int, default 10
313 | Number of classification classes.
314 | pretrained : bool, default False
315 | Whether to load the pretrained weights for model.
316 | root : str, default '~/.torch/models'
317 | Location for keeping the model parameters.
318 | """
319 | return get_preresnet_cifar(num_classes=num_classes, blocks=164, bottleneck=True,
320 | model_name="preresnet164bn_cifar10", **kwargs)
321 |
322 |
323 | def preresnet164bn_cifar100(num_classes=100, **kwargs):
324 | """
325 | PreResNet-164(BN) model for CIFAR-100 from 'Identity Mappings in Deep Residual Networks,'
326 | https://arxiv.org/abs/1603.05027.
327 | Parameters:
328 | ----------
329 | num_classes : int, default 100
330 | Number of classification classes.
331 | pretrained : bool, default False
332 | Whether to load the pretrained weights for model.
333 | root : str, default '~/.torch/models'
334 | Location for keeping the model parameters.
335 | """
336 | return get_preresnet_cifar(num_classes=num_classes, blocks=164, bottleneck=True,
337 | model_name="preresnet164bn_cifar100", **kwargs)
338 |
339 |
340 | def preresnet164bn_svhn(num_classes=10, **kwargs):
341 | """
342 | PreResNet-164(BN) model for SVHN from 'Identity Mappings in Deep Residual Networks,'
343 | https://arxiv.org/abs/1603.05027.
344 | Parameters:
345 | ----------
346 | num_classes : int, default 10
347 | Number of classification classes.
348 | pretrained : bool, default False
349 | Whether to load the pretrained weights for model.
350 | root : str, default '~/.torch/models'
351 | Location for keeping the model parameters.
352 | """
353 | return get_preresnet_cifar(num_classes=num_classes, blocks=164, bottleneck=True,
354 | model_name="preresnet164bn_svhn", **kwargs)
355 |
356 |
357 | def preresnet272bn_cifar10(num_classes=10, **kwargs):
358 | """
359 | PreResNet-272(BN) model for CIFAR-10 from 'Identity Mappings in Deep Residual Networks,'
360 | https://arxiv.org/abs/1603.05027.
361 | Parameters:
362 | ----------
363 | num_classes : int, default 10
364 | Number of classification classes.
365 | pretrained : bool, default False
366 | Whether to load the pretrained weights for model.
367 | root : str, default '~/.torch/models'
368 | Location for keeping the model parameters.
369 | """
370 | return get_preresnet_cifar(num_classes=num_classes, blocks=272, bottleneck=True,
371 | model_name="preresnet272bn_cifar10", **kwargs)
372 |
373 |
374 | def preresnet272bn_cifar100(num_classes=100, **kwargs):
375 | """
376 | PreResNet-272(BN) model for CIFAR-100 from 'Identity Mappings in Deep Residual Networks,'
377 | https://arxiv.org/abs/1603.05027.
378 | Parameters:
379 | ----------
380 | num_classes : int, default 100
381 | Number of classification classes.
382 | pretrained : bool, default False
383 | Whether to load the pretrained weights for model.
384 | root : str, default '~/.torch/models'
385 | Location for keeping the model parameters.
386 | """
387 | return get_preresnet_cifar(num_classes=num_classes, blocks=272, bottleneck=True,
388 | model_name="preresnet272bn_cifar100", **kwargs)
389 |
390 |
391 | def preresnet272bn_svhn(num_classes=10, **kwargs):
392 | """
393 | PreResNet-272(BN) model for SVHN from 'Identity Mappings in Deep Residual Networks,'
394 | https://arxiv.org/abs/1603.05027.
395 | Parameters:
396 | ----------
397 | num_classes : int, default 10
398 | Number of classification classes.
399 | pretrained : bool, default False
400 | Whether to load the pretrained weights for model.
401 | root : str, default '~/.torch/models'
402 | Location for keeping the model parameters.
403 | """
404 | return get_preresnet_cifar(num_classes=num_classes, blocks=272, bottleneck=True,
405 | model_name="preresnet272bn_svhn", **kwargs)
406 |
407 |
408 | def preresnet542bn_cifar10(num_classes=10, **kwargs):
409 | """
410 | PreResNet-542(BN) model for CIFAR-10 from 'Identity Mappings in Deep Residual Networks,'
411 | https://arxiv.org/abs/1603.05027.
412 | Parameters:
413 | ----------
414 | num_classes : int, default 10
415 | Number of classification classes.
416 | pretrained : bool, default False
417 | Whether to load the pretrained weights for model.
418 | root : str, default '~/.torch/models'
419 | Location for keeping the model parameters.
420 | """
421 | return get_preresnet_cifar(num_classes=num_classes, blocks=542, bottleneck=True,
422 | model_name="preresnet542bn_cifar10", **kwargs)
423 |
424 |
425 | def preresnet542bn_cifar100(num_classes=100, **kwargs):
426 | """
427 | PreResNet-542(BN) model for CIFAR-100 from 'Identity Mappings in Deep Residual Networks,'
428 | https://arxiv.org/abs/1603.05027.
429 | Parameters:
430 | ----------
431 | num_classes : int, default 100
432 | Number of classification classes.
433 | pretrained : bool, default False
434 | Whether to load the pretrained weights for model.
435 | root : str, default '~/.torch/models'
436 | Location for keeping the model parameters.
437 | """
438 | return get_preresnet_cifar(num_classes=num_classes, blocks=542, bottleneck=True,
439 | model_name="preresnet542bn_cifar100", **kwargs)
440 |
441 |
442 | def preresnet542bn_svhn(num_classes=10, **kwargs):
443 | """
444 | PreResNet-542(BN) model for SVHN from 'Identity Mappings in Deep Residual Networks,'
445 | https://arxiv.org/abs/1603.05027.
446 | Parameters:
447 | ----------
448 | num_classes : int, default 10
449 | Number of classification classes.
450 | pretrained : bool, default False
451 | Whether to load the pretrained weights for model.
452 | root : str, default '~/.torch/models'
453 | Location for keeping the model parameters.
454 | """
455 | return get_preresnet_cifar(num_classes=num_classes, blocks=542, bottleneck=True,
456 | model_name="preresnet542bn_svhn", **kwargs)
457 |
458 |
459 | def preresnet1001_cifar10(num_classes=10, **kwargs):
460 | """
461 | PreResNet-1001 model for CIFAR-10 from 'Identity Mappings in Deep Residual Networks,'
462 | https://arxiv.org/abs/1603.05027.
463 | Parameters:
464 | ----------
465 | num_classes : int, default 10
466 | Number of classification classes.
467 | pretrained : bool, default False
468 | Whether to load the pretrained weights for model.
469 | root : str, default '~/.torch/models'
470 | Location for keeping the model parameters.
471 | """
472 | return get_preresnet_cifar(num_classes=num_classes, blocks=1001, bottleneck=True,
473 | model_name="preresnet1001_cifar10", **kwargs)
474 |
475 |
476 | def preresnet1001_cifar100(num_classes=100, **kwargs):
477 | """
478 | PreResNet-1001 model for CIFAR-100 from 'Identity Mappings in Deep Residual Networks,'
479 | https://arxiv.org/abs/1603.05027.
480 | Parameters:
481 | ----------
482 | num_classes : int, default 100
483 | Number of classification classes.
484 | pretrained : bool, default False
485 | Whether to load the pretrained weights for model.
486 | root : str, default '~/.torch/models'
487 | Location for keeping the model parameters.
488 | """
489 | return get_preresnet_cifar(num_classes=num_classes, blocks=1001, bottleneck=True,
490 | model_name="preresnet1001_cifar100", **kwargs)
491 |
492 |
493 | def preresnet1001_svhn(num_classes=10, **kwargs):
494 | """
495 | PreResNet-1001 model for SVHN from 'Identity Mappings in Deep Residual Networks,'
496 | https://arxiv.org/abs/1603.05027.
497 | Parameters:
498 | ----------
499 | num_classes : int, default 10
500 | Number of classification classes.
501 | pretrained : bool, default False
502 | Whether to load the pretrained weights for model.
503 | root : str, default '~/.torch/models'
504 | Location for keeping the model parameters.
505 | """
506 | return get_preresnet_cifar(num_classes=num_classes, blocks=1001, bottleneck=True,
507 | model_name="preresnet1001_svhn", **kwargs)
508 |
509 |
510 | def preresnet1202_cifar10(num_classes=10, **kwargs):
511 | """
512 | PreResNet-1202 model for CIFAR-10 from 'Identity Mappings in Deep Residual Networks,'
513 | https://arxiv.org/abs/1603.05027.
514 | Parameters:
515 | ----------
516 | num_classes : int, default 10
517 | Number of classification classes.
518 | pretrained : bool, default False
519 | Whether to load the pretrained weights for model.
520 | root : str, default '~/.torch/models'
521 | Location for keeping the model parameters.
522 | """
523 | return get_preresnet_cifar(num_classes=num_classes, blocks=1202, bottleneck=False,
524 | model_name="preresnet1202_cifar10", **kwargs)
525 |
526 |
527 | def preresnet1202_cifar100(num_classes=100, **kwargs):
528 | """
529 | PreResNet-1202 model for CIFAR-100 from 'Identity Mappings in Deep Residual Networks,'
530 | https://arxiv.org/abs/1603.05027.
531 | Parameters:
532 | ----------
533 | num_classes : int, default 100
534 | Number of classification classes.
535 | pretrained : bool, default False
536 | Whether to load the pretrained weights for model.
537 | root : str, default '~/.torch/models'
538 | Location for keeping the model parameters.
539 | """
540 | return get_preresnet_cifar(num_classes=num_classes, blocks=1202, bottleneck=False,
541 | model_name="preresnet1202_cifar100", **kwargs)
542 |
543 |
544 | def preresnet1202_svhn(num_classes=10, **kwargs):
545 | """
546 | PreResNet-1202 model for SVHN from 'Identity Mappings in Deep Residual Networks,'
547 | https://arxiv.org/abs/1603.05027.
548 | Parameters:
549 | ----------
550 | num_classes : int, default 10
551 | Number of classification classes.
552 | pretrained : bool, default False
553 | Whether to load the pretrained weights for model.
554 | root : str, default '~/.torch/models'
555 | Location for keeping the model parameters.
556 | """
557 | return get_preresnet_cifar(num_classes=num_classes, blocks=1202, bottleneck=False,
558 | model_name="preresnet1202_svhn", **kwargs)
559 |
560 |
561 | def _calc_width(net):
562 | import numpy as np
563 | net_params = filter(lambda p: p.requires_grad, net.parameters())
564 | weight_count = 0
565 | for param in net_params:
566 | weight_count += np.prod(param.size())
567 | return weight_count
568 |
569 |
570 | def _test():
571 | import torch
572 |
573 | pretrained = False
574 |
575 | models = [
576 | (preresnet20_cifar10, 10),
577 | (preresnet20_cifar100, 100),
578 | (preresnet20_svhn, 10),
579 | (preresnet56_cifar10, 10),
580 | (preresnet56_cifar100, 100),
581 | (preresnet56_svhn, 10),
582 | (preresnet110_cifar10, 10),
583 | (preresnet110_cifar100, 100),
584 | (preresnet110_svhn, 10),
585 | (preresnet164bn_cifar10, 10),
586 | (preresnet164bn_cifar100, 100),
587 | (preresnet164bn_svhn, 10),
588 | (preresnet272bn_cifar10, 10),
589 | (preresnet272bn_cifar100, 100),
590 | (preresnet272bn_svhn, 10),
591 | (preresnet542bn_cifar10, 10),
592 | (preresnet542bn_cifar100, 100),
593 | (preresnet542bn_svhn, 10),
594 | (preresnet1001_cifar10, 10),
595 | (preresnet1001_cifar100, 100),
596 | (preresnet1001_svhn, 10),
597 | (preresnet1202_cifar10, 10),
598 | (preresnet1202_cifar100, 100),
599 | (preresnet1202_svhn, 10),
600 | ]
601 |
602 | for model, num_classes in models:
603 |
604 | net = model(pretrained=pretrained)
605 |
606 | # net.train()
607 | net.eval()
608 | weight_count = _calc_width(net)
609 | print("m={}, {}".format(model.__name__, weight_count))
610 | assert (model != preresnet20_cifar10 or weight_count == 272282)
611 | assert (model != preresnet20_cifar100 or weight_count == 278132)
612 | assert (model != preresnet20_svhn or weight_count == 272282)
613 | assert (model != preresnet56_cifar10 or weight_count == 855578)
614 | assert (model != preresnet56_cifar100 or weight_count == 861428)
615 | assert (model != preresnet56_svhn or weight_count == 855578)
616 | assert (model != preresnet110_cifar10 or weight_count == 1730522)
617 | assert (model != preresnet110_cifar100 or weight_count == 1736372)
618 | assert (model != preresnet110_svhn or weight_count == 1730522)
619 | assert (model != preresnet164bn_cifar10 or weight_count == 1703258)
620 | assert (model != preresnet164bn_cifar100 or weight_count == 1726388)
621 | assert (model != preresnet164bn_svhn or weight_count == 1703258)
622 | assert (model != preresnet272bn_cifar10 or weight_count == 2816090)
623 | assert (model != preresnet272bn_cifar100 or weight_count == 2839220)
624 | assert (model != preresnet272bn_svhn or weight_count == 2816090)
625 | assert (model != preresnet542bn_cifar10 or weight_count == 5598170)
626 | assert (model != preresnet542bn_cifar100 or weight_count == 5621300)
627 | assert (model != preresnet542bn_svhn or weight_count == 5598170)
628 | assert (model != preresnet1001_cifar10 or weight_count == 10327706)
629 | assert (model != preresnet1001_cifar100 or weight_count == 10350836)
630 | assert (model != preresnet1001_svhn or weight_count == 10327706)
631 | assert (model != preresnet1202_cifar10 or weight_count == 19423834)
632 | assert (model != preresnet1202_cifar100 or weight_count == 19429684)
633 | assert (model != preresnet1202_svhn or weight_count == 19423834)
634 |
635 | x = torch.randn(1, 3, 32, 32)
636 | y = net(x)
637 | y.sum().backward()
638 | assert (tuple(y.size()) == (1, num_classes))
639 |
640 |
641 | if __name__ == "__main__":
642 | _test()
--------------------------------------------------------------------------------
/models/glouncv/preresnet.py:
--------------------------------------------------------------------------------
1 | """
2 | PreResNet for ImageNet-1K, implemented in PyTorch.
3 | Original paper: 'Identity Mappings in Deep Residual Networks,' https://arxiv.org/abs/1603.05027.
4 | """
5 |
6 | __all__ = ['PreResNet', 'preresnet10', 'preresnet12', 'preresnet14', 'preresnetbc14b', 'preresnet16', 'preresnet18_wd4',
7 | 'preresnet18_wd2', 'preresnet18_w3d4', 'preresnet18', 'preresnet26', 'preresnetbc26b', 'preresnet34',
8 | 'preresnetbc38b', 'preresnet50', 'preresnet50b', 'preresnet101', 'preresnet101b', 'preresnet152',
9 | 'preresnet152b', 'preresnet200', 'preresnet200b', 'preresnet269b', 'PreResBlock', 'PreResBottleneck',
10 | 'PreResUnit', 'PreResInitBlock', 'PreResActivation']
11 |
12 | import os
13 | import torch.nn as nn
14 | import torch.nn.init as init
15 | from .common import pre_conv1x1_block, pre_conv3x3_block, conv1x1
16 |
17 |
18 | class PreResBlock(nn.Module):
19 | """
20 | Simple PreResNet block for residual path in PreResNet unit.
21 | Parameters:
22 | ----------
23 | in_channels : int
24 | Number of input channels.
25 | out_channels : int
26 | Number of output channels.
27 | stride : int or tuple/list of 2 int
28 | Strides of the convolution.
29 | bias : bool, default False
30 | Whether the layer uses a bias vector.
31 | use_bn : bool, default True
32 | Whether to use BatchNorm layer.
33 | """
34 | def __init__(self,
35 | in_channels,
36 | out_channels,
37 | stride,
38 | bias=False,
39 | use_bn=True):
40 | super(PreResBlock, self).__init__()
41 | self.conv1 = pre_conv3x3_block(
42 | in_channels=in_channels,
43 | out_channels=out_channels,
44 | stride=stride,
45 | bias=bias,
46 | use_bn=use_bn,
47 | return_preact=True)
48 | self.conv2 = pre_conv3x3_block(
49 | in_channels=out_channels,
50 | out_channels=out_channels,
51 | bias=bias,
52 | use_bn=use_bn)
53 |
54 | def forward(self, x):
55 | x, x_pre_activ = self.conv1(x)
56 | x = self.conv2(x)
57 | return x, x_pre_activ
58 |
59 |
60 | class PreResBottleneck(nn.Module):
61 | """
62 | PreResNet bottleneck block for residual path in PreResNet unit.
63 | Parameters:
64 | ----------
65 | in_channels : int
66 | Number of input channels.
67 | out_channels : int
68 | Number of output channels.
69 | stride : int or tuple/list of 2 int
70 | Strides of the convolution.
71 | conv1_stride : bool
72 | Whether to use stride in the first or the second convolution layer of the block.
73 | """
74 | def __init__(self,
75 | in_channels,
76 | out_channels,
77 | stride,
78 | conv1_stride):
79 | super(PreResBottleneck, self).__init__()
80 | mid_channels = out_channels // 4
81 |
82 | self.conv1 = pre_conv1x1_block(
83 | in_channels=in_channels,
84 | out_channels=mid_channels,
85 | stride=(stride if conv1_stride else 1),
86 | return_preact=True)
87 | self.conv2 = pre_conv3x3_block(
88 | in_channels=mid_channels,
89 | out_channels=mid_channels,
90 | stride=(1 if conv1_stride else stride))
91 | self.conv3 = pre_conv1x1_block(
92 | in_channels=mid_channels,
93 | out_channels=out_channels)
94 |
95 | def forward(self, x):
96 | x, x_pre_activ = self.conv1(x)
97 | x = self.conv2(x)
98 | x = self.conv3(x)
99 | return x, x_pre_activ
100 |
101 |
102 | class PreResUnit(nn.Module):
103 | """
104 | PreResNet unit with residual connection.
105 | Parameters:
106 | ----------
107 | in_channels : int
108 | Number of input channels.
109 | out_channels : int
110 | Number of output channels.
111 | stride : int or tuple/list of 2 int
112 | Strides of the convolution.
113 | bias : bool, default False
114 | Whether the layer uses a bias vector.
115 | use_bn : bool, default True
116 | Whether to use BatchNorm layer.
117 | bottleneck : bool, default True
118 | Whether to use a bottleneck or simple block in units.
119 | conv1_stride : bool, default False
120 | Whether to use stride in the first or the second convolution layer of the block.
121 | """
122 | def __init__(self,
123 | in_channels,
124 | out_channels,
125 | stride,
126 | bias=False,
127 | use_bn=True,
128 | bottleneck=True,
129 | conv1_stride=False):
130 | super(PreResUnit, self).__init__()
131 | self.resize_identity = (in_channels != out_channels) or (stride != 1)
132 |
133 | if bottleneck:
134 | self.body = PreResBottleneck(
135 | in_channels=in_channels,
136 | out_channels=out_channels,
137 | stride=stride,
138 | conv1_stride=conv1_stride)
139 | else:
140 | self.body = PreResBlock(
141 | in_channels=in_channels,
142 | out_channels=out_channels,
143 | stride=stride,
144 | bias=bias,
145 | use_bn=use_bn)
146 | if self.resize_identity:
147 | self.identity_conv = conv1x1(
148 | in_channels=in_channels,
149 | out_channels=out_channels,
150 | stride=stride,
151 | bias=bias)
152 |
153 | def forward(self, x):
154 | identity = x
155 | x, x_pre_activ = self.body(x)
156 | if self.resize_identity:
157 | identity = self.identity_conv(x_pre_activ)
158 | x = x + identity
159 | return x
160 |
161 |
162 | class PreResInitBlock(nn.Module):
163 | """
164 | PreResNet specific initial block.
165 | Parameters:
166 | ----------
167 | in_channels : int
168 | Number of input channels.
169 | out_channels : int
170 | Number of output channels.
171 | """
172 | def __init__(self,
173 | in_channels,
174 | out_channels):
175 | super(PreResInitBlock, self).__init__()
176 | self.conv = nn.Conv2d(
177 | in_channels=in_channels,
178 | out_channels=out_channels,
179 | kernel_size=7,
180 | stride=2,
181 | padding=3,
182 | bias=False)
183 | self.bn = nn.BatchNorm2d(num_features=out_channels)
184 | self.activ = nn.ReLU(inplace=True)
185 | self.pool = nn.MaxPool2d(
186 | kernel_size=3,
187 | stride=2,
188 | padding=1)
189 |
190 | def forward(self, x):
191 | x = self.conv(x)
192 | x = self.bn(x)
193 | x = self.activ(x)
194 | x = self.pool(x)
195 | return x
196 |
197 |
198 | class PreResActivation(nn.Module):
199 | """
200 | PreResNet pure pre-activation block without convolution layer. It's used by itself as the final block.
201 | Parameters:
202 | ----------
203 | in_channels : int
204 | Number of input channels.
205 | """
206 | def __init__(self,
207 | in_channels):
208 | super(PreResActivation, self).__init__()
209 | self.bn = nn.BatchNorm2d(num_features=in_channels)
210 | self.activ = nn.ReLU(inplace=True)
211 |
212 | def forward(self, x):
213 | x = self.bn(x)
214 | x = self.activ(x)
215 | return x
216 |
217 |
218 | class PreResNet(nn.Module):
219 | """
220 | PreResNet model from 'Identity Mappings in Deep Residual Networks,' https://arxiv.org/abs/1603.05027.
221 | Parameters:
222 | ----------
223 | channels : list of list of int
224 | Number of output channels for each unit.
225 | init_block_channels : int
226 | Number of output channels for the initial unit.
227 | bottleneck : bool
228 | Whether to use a bottleneck or simple block in units.
229 | conv1_stride : bool
230 | Whether to use stride in the first or the second convolution layer in units.
231 | in_channels : int, default 3
232 | Number of input channels.
233 | in_size : tuple of two ints, default (224, 224)
234 | Spatial size of the expected input image.
235 | num_classes : int, default 1000
236 | Number of classification classes.
237 | """
238 | def __init__(self,
239 | channels,
240 | init_block_channels,
241 | bottleneck,
242 | conv1_stride,
243 | in_channels=3,
244 | in_size=(224, 224),
245 | num_classes=1000):
246 | super(PreResNet, self).__init__()
247 | self.in_size = in_size
248 | self.num_classes = num_classes
249 |
250 | self.features = nn.Sequential()
251 | self.features.add_module("init_block", PreResInitBlock(
252 | in_channels=in_channels,
253 | out_channels=init_block_channels))
254 | in_channels = init_block_channels
255 | for i, channels_per_stage in enumerate(channels):
256 | stage = nn.Sequential()
257 | for j, out_channels in enumerate(channels_per_stage):
258 | stride = 1 if (i == 0) or (j != 0) else 2
259 | stage.add_module("unit{}".format(j + 1), PreResUnit(
260 | in_channels=in_channels,
261 | out_channels=out_channels,
262 | stride=stride,
263 | bottleneck=bottleneck,
264 | conv1_stride=conv1_stride))
265 | in_channels = out_channels
266 | self.features.add_module("stage{}".format(i + 1), stage)
267 | self.features.add_module("post_activ", PreResActivation(in_channels=in_channels))
268 | self.features.add_module("final_pool", nn.AvgPool2d(
269 | kernel_size=7,
270 | stride=1))
271 |
272 | self.output = nn.Linear(
273 | in_features=in_channels,
274 | out_features=num_classes)
275 |
276 | self._init_params()
277 |
278 | def _init_params(self):
279 | for name, module in self.named_modules():
280 | if isinstance(module, nn.Conv2d):
281 | init.kaiming_uniform_(module.weight)
282 | if module.bias is not None:
283 | init.constant_(module.bias, 0)
284 |
285 | def forward(self, x):
286 | x = self.features(x)
287 | x = x.view(x.size(0), -1)
288 | x = self.output(x)
289 | return x
290 |
291 |
292 | def get_preresnet(blocks,
293 | bottleneck=None,
294 | conv1_stride=True,
295 | width_scale=1.0,
296 | model_name=None,
297 | pretrained=False,
298 | root=os.path.join("~", ".torch", "models"),
299 | **kwargs):
300 | """
301 | Create PreResNet model with specific parameters.
302 | Parameters:
303 | ----------
304 | blocks : int
305 | Number of blocks.
306 | bottleneck : bool, default None
307 | Whether to use a bottleneck or simple block in units.
308 | conv1_stride : bool, default True
309 | Whether to use stride in the first or the second convolution layer in units.
310 | width_scale : float, default 1.0
311 | Scale factor for width of layers.
312 | model_name : str or None, default None
313 | Model name for loading pretrained model.
314 | pretrained : bool, default False
315 | Whether to load the pretrained weights for model.
316 | root : str, default '~/.torch/models'
317 | Location for keeping the model parameters.
318 | """
319 | if bottleneck is None:
320 | bottleneck = (blocks >= 50)
321 |
322 | if blocks == 10:
323 | layers = [1, 1, 1, 1]
324 | elif blocks == 12:
325 | layers = [2, 1, 1, 1]
326 | elif blocks == 14 and not bottleneck:
327 | layers = [2, 2, 1, 1]
328 | elif (blocks == 14) and bottleneck:
329 | layers = [1, 1, 1, 1]
330 | elif blocks == 16:
331 | layers = [2, 2, 2, 1]
332 | elif blocks == 18:
333 | layers = [2, 2, 2, 2]
334 | elif (blocks == 26) and not bottleneck:
335 | layers = [3, 3, 3, 3]
336 | elif (blocks == 26) and bottleneck:
337 | layers = [2, 2, 2, 2]
338 | elif blocks == 34:
339 | layers = [3, 4, 6, 3]
340 | elif (blocks == 38) and bottleneck:
341 | layers = [3, 3, 3, 3]
342 | elif blocks == 50:
343 | layers = [3, 4, 6, 3]
344 | elif blocks == 101:
345 | layers = [3, 4, 23, 3]
346 | elif blocks == 152:
347 | layers = [3, 8, 36, 3]
348 | elif blocks == 200:
349 | layers = [3, 24, 36, 3]
350 | elif blocks == 269:
351 | layers = [3, 30, 48, 8]
352 | else:
353 | raise ValueError("Unsupported PreResNet with number of blocks: {}".format(blocks))
354 |
355 | if bottleneck:
356 | assert (sum(layers) * 3 + 2 == blocks)
357 | else:
358 | assert (sum(layers) * 2 + 2 == blocks)
359 |
360 | init_block_channels = 64
361 | channels_per_layers = [64, 128, 256, 512]
362 |
363 | if bottleneck:
364 | bottleneck_factor = 4
365 | channels_per_layers = [ci * bottleneck_factor for ci in channels_per_layers]
366 |
367 | channels = [[ci] * li for (ci, li) in zip(channels_per_layers, layers)]
368 |
369 | if width_scale != 1.0:
370 | channels = [[int(cij * width_scale) if (i != len(channels) - 1) or (j != len(ci) - 1) else cij
371 | for j, cij in enumerate(ci)] for i, ci in enumerate(channels)]
372 | init_block_channels = int(init_block_channels * width_scale)
373 |
374 | net = PreResNet(
375 | channels=channels,
376 | init_block_channels=init_block_channels,
377 | bottleneck=bottleneck,
378 | conv1_stride=conv1_stride,
379 | **kwargs)
380 |
381 | if pretrained:
382 | if (model_name is None) or (not model_name):
383 | raise ValueError("Parameter `model_name` should be properly initialized for loading pretrained model.")
384 | from .model_store import download_model
385 | download_model(
386 | net=net,
387 | model_name=model_name,
388 | local_model_store_dir_path=root)
389 |
390 | return net
391 |
392 |
393 | def preresnet10(**kwargs):
394 | """
395 | PreResNet-10 model from 'Identity Mappings in Deep Residual Networks,' https://arxiv.org/abs/1603.05027.
396 | It's an experimental model.
397 | Parameters:
398 | ----------
399 | pretrained : bool, default False
400 | Whether to load the pretrained weights for model.
401 | root : str, default '~/.torch/models'
402 | Location for keeping the model parameters.
403 | """
404 | return get_preresnet(blocks=10, model_name="preresnet10", **kwargs)
405 |
406 |
407 | def preresnet12(**kwargs):
408 | """
409 | PreResNet-12 model from 'Identity Mappings in Deep Residual Networks,' https://arxiv.org/abs/1603.05027.
410 | It's an experimental model.
411 | Parameters:
412 | ----------
413 | pretrained : bool, default False
414 | Whether to load the pretrained weights for model.
415 | root : str, default '~/.torch/models'
416 | Location for keeping the model parameters.
417 | """
418 | return get_preresnet(blocks=12, model_name="preresnet12", **kwargs)
419 |
420 |
421 | def preresnet14(**kwargs):
422 | """
423 | PreResNet-14 model from 'Identity Mappings in Deep Residual Networks,' https://arxiv.org/abs/1603.05027.
424 | It's an experimental model.
425 | Parameters:
426 | ----------
427 | pretrained : bool, default False
428 | Whether to load the pretrained weights for model.
429 | root : str, default '~/.torch/models'
430 | Location for keeping the model parameters.
431 | """
432 | return get_preresnet(blocks=14, model_name="preresnet14", **kwargs)
433 |
434 |
435 | def preresnetbc14b(**kwargs):
436 | """
437 | PreResNet-BC-14b model from 'Identity Mappings in Deep Residual Networks,' https://arxiv.org/abs/1603.05027.
438 | It's an experimental model (bottleneck compressed).
439 | Parameters:
440 | ----------
441 | pretrained : bool, default False
442 | Whether to load the pretrained weights for model.
443 | root : str, default '~/.torch/models'
444 | Location for keeping the model parameters.
445 | """
446 | return get_preresnet(blocks=14, bottleneck=True, conv1_stride=False, model_name="preresnetbc14b", **kwargs)
447 |
448 |
449 | def preresnet16(**kwargs):
450 | """
451 | PreResNet-16 model from 'Identity Mappings in Deep Residual Networks,' https://arxiv.org/abs/1603.05027.
452 | It's an experimental model.
453 | Parameters:
454 | ----------
455 | pretrained : bool, default False
456 | Whether to load the pretrained weights for model.
457 | root : str, default '~/.torch/models'
458 | Location for keeping the model parameters.
459 | """
460 | return get_preresnet(blocks=16, model_name="preresnet16", **kwargs)
461 |
462 |
463 | def preresnet18_wd4(**kwargs):
464 | """
465 | PreResNet-18 model with 0.25 width scale from 'Identity Mappings in Deep Residual Networks,'
466 | https://arxiv.org/abs/1603.05027. It's an experimental model.
467 | Parameters:
468 | ----------
469 | pretrained : bool, default False
470 | Whether to load the pretrained weights for model.
471 | root : str, default '~/.torch/models'
472 | Location for keeping the model parameters.
473 | """
474 | return get_preresnet(blocks=18, width_scale=0.25, model_name="preresnet18_wd4", **kwargs)
475 |
476 |
477 | def preresnet18_wd2(**kwargs):
478 | """
479 | PreResNet-18 model with 0.5 width scale from 'Identity Mappings in Deep Residual Networks,'
480 | https://arxiv.org/abs/1603.05027. It's an experimental model.
481 | Parameters:
482 | ----------
483 | pretrained : bool, default False
484 | Whether to load the pretrained weights for model.
485 | root : str, default '~/.torch/models'
486 | Location for keeping the model parameters.
487 | """
488 | return get_preresnet(blocks=18, width_scale=0.5, model_name="preresnet18_wd2", **kwargs)
489 |
490 |
491 | def preresnet18_w3d4(**kwargs):
492 | """
493 | PreResNet-18 model with 0.75 width scale from 'Identity Mappings in Deep Residual Networks,'
494 | https://arxiv.org/abs/1603.05027. It's an experimental model.
495 | Parameters:
496 | ----------
497 | pretrained : bool, default False
498 | Whether to load the pretrained weights for model.
499 | root : str, default '~/.torch/models'
500 | Location for keeping the model parameters.
501 | """
502 | return get_preresnet(blocks=18, width_scale=0.75, model_name="preresnet18_w3d4", **kwargs)
503 |
504 |
505 | def preresnet18(**kwargs):
506 | """
507 | PreResNet-18 model from 'Identity Mappings in Deep Residual Networks,' https://arxiv.org/abs/1603.05027.
508 | Parameters:
509 | ----------
510 | pretrained : bool, default False
511 | Whether to load the pretrained weights for model.
512 | root : str, default '~/.torch/models'
513 | Location for keeping the model parameters.
514 | """
515 | return get_preresnet(blocks=18, model_name="preresnet18", **kwargs)
516 |
517 |
518 | def preresnet26(**kwargs):
519 | """
520 | PreResNet-26 model from 'Identity Mappings in Deep Residual Networks,' https://arxiv.org/abs/1603.05027.
521 | It's an experimental model.
522 | Parameters:
523 | ----------
524 | pretrained : bool, default False
525 | Whether to load the pretrained weights for model.
526 | root : str, default '~/.torch/models'
527 | Location for keeping the model parameters.
528 | """
529 | return get_preresnet(blocks=26, bottleneck=False, model_name="preresnet26", **kwargs)
530 |
531 |
532 | def preresnetbc26b(**kwargs):
533 | """
534 | PreResNet-BC-26b model from 'Identity Mappings in Deep Residual Networks,' https://arxiv.org/abs/1603.05027.
535 | It's an experimental model (bottleneck compressed).
536 | Parameters:
537 | ----------
538 | pretrained : bool, default False
539 | Whether to load the pretrained weights for model.
540 | root : str, default '~/.torch/models'
541 | Location for keeping the model parameters.
542 | """
543 | return get_preresnet(blocks=26, bottleneck=True, conv1_stride=False, model_name="preresnetbc26b", **kwargs)
544 |
545 |
546 | def preresnet34(**kwargs):
547 | """
548 | PreResNet-34 model from 'Identity Mappings in Deep Residual Networks,' https://arxiv.org/abs/1603.05027.
549 | Parameters:
550 | ----------
551 | pretrained : bool, default False
552 | Whether to load the pretrained weights for model.
553 | root : str, default '~/.torch/models'
554 | Location for keeping the model parameters.
555 | """
556 | return get_preresnet(blocks=34, model_name="preresnet34", **kwargs)
557 |
558 |
559 | def preresnetbc38b(**kwargs):
560 | """
561 | PreResNet-BC-38b model from 'Identity Mappings in Deep Residual Networks,' https://arxiv.org/abs/1603.05027.
562 | It's an experimental model (bottleneck compressed).
563 | Parameters:
564 | ----------
565 | pretrained : bool, default False
566 | Whether to load the pretrained weights for model.
567 | root : str, default '~/.torch/models'
568 | Location for keeping the model parameters.
569 | """
570 | return get_preresnet(blocks=38, bottleneck=True, conv1_stride=False, model_name="preresnetbc38b", **kwargs)
571 |
572 |
573 | def preresnet50(**kwargs):
574 | """
575 | PreResNet-50 model from 'Identity Mappings in Deep Residual Networks,' https://arxiv.org/abs/1603.05027.
576 | Parameters:
577 | ----------
578 | pretrained : bool, default False
579 | Whether to load the pretrained weights for model.
580 | root : str, default '~/.torch/models'
581 | Location for keeping the model parameters.
582 | """
583 | return get_preresnet(blocks=50, model_name="preresnet50", **kwargs)
584 |
585 |
586 | def preresnet50b(**kwargs):
587 | """
588 | PreResNet-50 model with stride at the second convolution in bottleneck block from 'Identity Mappings in Deep
589 | Residual Networks,' https://arxiv.org/abs/1603.05027.
590 | Parameters:
591 | ----------
592 | pretrained : bool, default False
593 | Whether to load the pretrained weights for model.
594 | root : str, default '~/.torch/models'
595 | Location for keeping the model parameters.
596 | """
597 | return get_preresnet(blocks=50, conv1_stride=False, model_name="preresnet50b", **kwargs)
598 |
599 |
600 | def preresnet101(**kwargs):
601 | """
602 | PreResNet-101 model from 'Identity Mappings in Deep Residual Networks,' https://arxiv.org/abs/1603.05027.
603 | Parameters:
604 | ----------
605 | pretrained : bool, default False
606 | Whether to load the pretrained weights for model.
607 | root : str, default '~/.torch/models'
608 | Location for keeping the model parameters.
609 | """
610 | return get_preresnet(blocks=101, model_name="preresnet101", **kwargs)
611 |
612 |
613 | def preresnet101b(**kwargs):
614 | """
615 | PreResNet-101 model with stride at the second convolution in bottleneck block from 'Identity Mappings in Deep
616 | Residual Networks,' https://arxiv.org/abs/1603.05027.
617 | Parameters:
618 | ----------
619 | pretrained : bool, default False
620 | Whether to load the pretrained weights for model.
621 | root : str, default '~/.torch/models'
622 | Location for keeping the model parameters.
623 | """
624 | return get_preresnet(blocks=101, conv1_stride=False, model_name="preresnet101b", **kwargs)
625 |
626 |
627 | def preresnet152(**kwargs):
628 | """
629 | PreResNet-152 model from 'Identity Mappings in Deep Residual Networks,' https://arxiv.org/abs/1603.05027.
630 | Parameters:
631 | ----------
632 | pretrained : bool, default False
633 | Whether to load the pretrained weights for model.
634 | root : str, default '~/.torch/models'
635 | Location for keeping the model parameters.
636 | """
637 | return get_preresnet(blocks=152, model_name="preresnet152", **kwargs)
638 |
639 |
640 | def preresnet152b(**kwargs):
641 | """
642 | PreResNet-152 model with stride at the second convolution in bottleneck block from 'Identity Mappings in Deep
643 | Residual Networks,' https://arxiv.org/abs/1603.05027.
644 | Parameters:
645 | ----------
646 | pretrained : bool, default False
647 | Whether to load the pretrained weights for model.
648 | root : str, default '~/.torch/models'
649 | Location for keeping the model parameters.
650 | """
651 | return get_preresnet(blocks=152, conv1_stride=False, model_name="preresnet152b", **kwargs)
652 |
653 |
654 | def preresnet200(**kwargs):
655 | """
656 | PreResNet-200 model from 'Identity Mappings in Deep Residual Networks,' https://arxiv.org/abs/1603.05027.
657 | Parameters:
658 | ----------
659 | pretrained : bool, default False
660 | Whether to load the pretrained weights for model.
661 | root : str, default '~/.torch/models'
662 | Location for keeping the model parameters.
663 | """
664 | return get_preresnet(blocks=200, model_name="preresnet200", **kwargs)
665 |
666 |
667 | def preresnet200b(**kwargs):
668 | """
669 | PreResNet-200 model with stride at the second convolution in bottleneck block from 'Identity Mappings in Deep
670 | Residual Networks,' https://arxiv.org/abs/1603.05027.
671 | Parameters:
672 | ----------
673 | pretrained : bool, default False
674 | Whether to load the pretrained weights for model.
675 | root : str, default '~/.torch/models'
676 | Location for keeping the model parameters.
677 | """
678 | return get_preresnet(blocks=200, conv1_stride=False, model_name="preresnet200b", **kwargs)
679 |
680 |
681 | def preresnet269b(**kwargs):
682 | """
683 | PreResNet-269 model with stride at the second convolution in bottleneck block from 'Identity Mappings in Deep
684 | Residual Networks,' https://arxiv.org/abs/1603.05027.
685 | Parameters:
686 | ----------
687 | pretrained : bool, default False
688 | Whether to load the pretrained weights for model.
689 | root : str, default '~/.torch/models'
690 | Location for keeping the model parameters.
691 | """
692 | return get_preresnet(blocks=269, conv1_stride=False, model_name="preresnet269b", **kwargs)
693 |
694 |
695 | def _calc_width(net):
696 | import numpy as np
697 | net_params = filter(lambda p: p.requires_grad, net.parameters())
698 | weight_count = 0
699 | for param in net_params:
700 | weight_count += np.prod(param.size())
701 | return weight_count
702 |
703 |
704 | def _test():
705 | import torch
706 |
707 | pretrained = False
708 |
709 | models = [
710 | preresnet10,
711 | preresnet12,
712 | preresnet14,
713 | preresnetbc14b,
714 | preresnet16,
715 | preresnet18_wd4,
716 | preresnet18_wd2,
717 | preresnet18_w3d4,
718 | preresnet18,
719 | preresnet26,
720 | preresnetbc26b,
721 | preresnet34,
722 | preresnetbc38b,
723 | preresnet50,
724 | preresnet50b,
725 | preresnet101,
726 | preresnet101b,
727 | preresnet152,
728 | preresnet152b,
729 | preresnet200,
730 | preresnet200b,
731 | preresnet269b,
732 | ]
733 |
734 | for model in models:
735 |
736 | net = model(pretrained=pretrained)
737 |
738 | # net.train()
739 | net.eval()
740 | weight_count = _calc_width(net)
741 | print("m={}, {}".format(model.__name__, weight_count))
742 | assert (model != preresnet10 or weight_count == 5417128)
743 | assert (model != preresnet12 or weight_count == 5491112)
744 | assert (model != preresnet14 or weight_count == 5786536)
745 | assert (model != preresnetbc14b or weight_count == 10057384)
746 | assert (model != preresnet16 or weight_count == 6967208)
747 | assert (model != preresnet18_wd4 or weight_count == 3935960)
748 | assert (model != preresnet18_wd2 or weight_count == 5802440)
749 | assert (model != preresnet18_w3d4 or weight_count == 8473784)
750 | assert (model != preresnet18 or weight_count == 11687848)
751 | assert (model != preresnet26 or weight_count == 17958568)
752 | assert (model != preresnetbc26b or weight_count == 15987624)
753 | assert (model != preresnet34 or weight_count == 21796008)
754 | assert (model != preresnetbc38b or weight_count == 21917864)
755 | assert (model != preresnet50 or weight_count == 25549480)
756 | assert (model != preresnet50b or weight_count == 25549480)
757 | assert (model != preresnet101 or weight_count == 44541608)
758 | assert (model != preresnet101b or weight_count == 44541608)
759 | assert (model != preresnet152 or weight_count == 60185256)
760 | assert (model != preresnet152b or weight_count == 60185256)
761 | assert (model != preresnet200 or weight_count == 64666280)
762 | assert (model != preresnet200b or weight_count == 64666280)
763 | assert (model != preresnet269b or weight_count == 102065832)
764 |
765 | x = torch.randn(1, 3, 224, 224)
766 | y = net(x)
767 | y.sum().backward()
768 | assert (tuple(y.size()) == (1, 1000))
769 |
770 |
771 | if __name__ == "__main__":
772 | _test()
773 |
--------------------------------------------------------------------------------