├── .gitignore ├── LICENSE ├── README.md ├── cifar10_models ├── densenet.py ├── googlenet.py ├── inception.py ├── mobilenetv2.py ├── resnet.py ├── resnet_orig.py └── vgg.py ├── data.py ├── module.py ├── schduler.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | cifar10/ 2 | wandb/ 3 | *.ckpt 4 | checkpoints/ 5 | *.pt 6 | *.zip 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | pip-wheel-metadata/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 102 | __pypackages__/ 103 | 104 | # Celery stuff 105 | celerybeat-schedule 106 | celerybeat.pid 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .env 113 | .venv 114 | env/ 115 | venv/ 116 | ENV/ 117 | env.bak/ 118 | venv.bak/ 119 | 120 | # Spyder project settings 121 | .spyderproject 122 | .spyproject 123 | 124 | # Rope project settings 125 | .ropeproject 126 | 127 | # mkdocs documentation 128 | /site 129 | 130 | # mypy 131 | .mypy_cache/ 132 | .dmypy.json 133 | dmypy.json 134 | 135 | # Pyre type checker 136 | .pyre/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Huy Phan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch models trained on CIFAR-10 dataset 2 | - I modified [TorchVision](https://pytorch.org/docs/stable/torchvision/models.html) official implementation of popular CNN models, and trained those on CIFAR-10 dataset. 3 | - I changed *number of class, filter size, stride, and padding* in the the original code so that it works with CIFAR-10. 4 | - I also share the **weights** of these models, so you can just load the weights and use them. 5 | - The code is highly re-producible and readable by using PyTorch-Lightning. 6 | 7 | ## Statistics of supported models 8 | | No. | Model | Val. Acc. | No. Params | Size | 9 | |:---:|:-------------|----------:|-----------:|-------:| 10 | | 1 | vgg11_bn | 92.39% | 28.150 M | 108 MB | 11 | | 2 | vgg13_bn | 94.22% | 28.334 M | 109 MB | 12 | | 3 | vgg16_bn | 94.00% | 33.647 M | 129 MB | 13 | | 4 | vgg19_bn | 93.95% | 38.959 M | 149 MB | 14 | | 5 | resnet18 | 93.07% | 11.174 M | 43 MB | 15 | | 6 | resnet34 | 93.34% | 21.282 M | 82 MB | 16 | | 7 | resnet50 | 93.65% | 23.521 M | 91 MB | 17 | | 8 | densenet121 | 94.06% | 6.956 M | 28 MB | 18 | | 9 | densenet161 | 94.07% | 26.483 M | 103 MB | 19 | | 10 | densenet169 | 94.05% | 12.493 M | 49 MB | 20 | | 11 | mobilenet_v2 | 93.91% | 2.237 M | 9 MB | 21 | | 12 | googlenet | 92.85% | 5.491 M | 22 MB | 22 | | 13 | inception_v3 | 93.74% | 21.640 M | 83 MB | 23 | 24 | ## Details Report & Run Logs 25 | Weight and Biases' details report for this project [WandB Report](https://wandb.ai/huyvnphan/cifar10/reports/CIFAR10-Classification-using-PyTorch---VmlldzozOTg0ODQ?accessToken=9m2q1ajhppuziprsq9tlryynvmqbkrbvjdoktrz7o6gtqilmtqbv2r9jjrtb2tqq) 26 | 27 | Weight and Biases' run logs for this project [WandB Run Log](https://wandb.ai/huyvnphan/cifar10). You can see each run hyper-parameters, training accuracy, validation accuracy, loss, time taken. 28 | 29 | ## How To Cite 30 | [![DOI](https://zenodo.org/badge/195914773.svg)](https://zenodo.org/badge/latestdoi/195914773) 31 | 32 | ## How to use pretrained models 33 | 34 | **Automatically download and extract the weights from Box (933 MB)** 35 | ```python 36 | python train.py --download_weights 1 37 | ``` 38 | Or use [Google Drive](https://drive.google.com/file/d/17fmN8eQdLpq2jIMQ_X0IXDPXfI9oVWgq/view?usp=sharing) backup link (you have to download and extract manually) 39 | 40 | **Load model and run** 41 | ```python 42 | from cifar10_models.vgg import vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn 43 | 44 | # Untrained model 45 | my_model = vgg11_bn() 46 | 47 | # Pretrained model 48 | my_model = vgg11_bn(pretrained=True) 49 | my_model.eval() # for evaluation 50 | ``` 51 | 52 | If you use your own images, all models expect data to be in range [0, 1] then normalized by 53 | ```python 54 | mean = [0.4914, 0.4822, 0.4465] 55 | std = [0.2471, 0.2435, 0.2616] 56 | ``` 57 | 58 | ## How to train models from scratch 59 | Check the `train.py` to see all available hyper-parameter choices. 60 | To reproduce the same accuracy use the default hyper-parameters 61 | 62 | `python train.py --classifier resnet18` 63 | 64 | ## How to test pretrained models 65 | `python train.py --test_phase 1 --pretrained 1 --classifier resnet18` 66 | 67 | Output 68 | 69 | `{'acc/test': tensor(93.0689, device='cuda:0')}` 70 | 71 | 72 | ## Requirements 73 | **Just to use pretrained models** 74 | - pytorch = 1.7.0 75 | 76 | **To train & test** 77 | - pytorch = 1.7.0 78 | - torchvision = 0.7.0 79 | - tensorboard = 2.2.1 80 | - pytorch-lightning = 1.1.0 81 | -------------------------------------------------------------------------------- /cifar10_models/densenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | __all__ = ["DenseNet", "densenet121", "densenet169", "densenet161"] 9 | 10 | 11 | class _DenseLayer(nn.Sequential): 12 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): 13 | super(_DenseLayer, self).__init__() 14 | self.add_module("norm1", nn.BatchNorm2d(num_input_features)), 15 | self.add_module("relu1", nn.ReLU(inplace=True)), 16 | self.add_module( 17 | "conv1", 18 | nn.Conv2d( 19 | num_input_features, 20 | bn_size * growth_rate, 21 | kernel_size=1, 22 | stride=1, 23 | bias=False, 24 | ), 25 | ), 26 | self.add_module("norm2", nn.BatchNorm2d(bn_size * growth_rate)), 27 | self.add_module("relu2", nn.ReLU(inplace=True)), 28 | self.add_module( 29 | "conv2", 30 | nn.Conv2d( 31 | bn_size * growth_rate, 32 | growth_rate, 33 | kernel_size=3, 34 | stride=1, 35 | padding=1, 36 | bias=False, 37 | ), 38 | ), 39 | self.drop_rate = drop_rate 40 | 41 | def forward(self, x): 42 | new_features = super(_DenseLayer, self).forward(x) 43 | if self.drop_rate > 0: 44 | new_features = F.dropout( 45 | new_features, p=self.drop_rate, training=self.training 46 | ) 47 | return torch.cat([x, new_features], 1) 48 | 49 | 50 | class _DenseBlock(nn.Sequential): 51 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): 52 | super(_DenseBlock, self).__init__() 53 | for i in range(num_layers): 54 | layer = _DenseLayer( 55 | num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate 56 | ) 57 | self.add_module("denselayer%d" % (i + 1), layer) 58 | 59 | 60 | class _Transition(nn.Sequential): 61 | def __init__(self, num_input_features, num_output_features): 62 | super(_Transition, self).__init__() 63 | self.add_module("norm", nn.BatchNorm2d(num_input_features)) 64 | self.add_module("relu", nn.ReLU(inplace=True)) 65 | self.add_module( 66 | "conv", 67 | nn.Conv2d( 68 | num_input_features, 69 | num_output_features, 70 | kernel_size=1, 71 | stride=1, 72 | bias=False, 73 | ), 74 | ) 75 | self.add_module("pool", nn.AvgPool2d(kernel_size=2, stride=2)) 76 | 77 | 78 | class DenseNet(nn.Module): 79 | r"""Densenet-BC model class, based on 80 | `"Densely Connected Convolutional Networks" `_ 81 | 82 | Args: 83 | growth_rate (int) - how many filters to add each layer (`k` in paper) 84 | block_config (list of 4 ints) - how many layers in each pooling block 85 | num_init_features (int) - the number of filters to learn in the first convolution layer 86 | bn_size (int) - multiplicative factor for number of bottle neck layers 87 | (i.e. bn_size * k features in the bottleneck layer) 88 | drop_rate (float) - dropout rate after each dense layer 89 | num_classes (int) - number of classification classes 90 | """ 91 | 92 | def __init__( 93 | self, 94 | growth_rate=32, 95 | block_config=(6, 12, 24, 16), 96 | num_init_features=64, 97 | bn_size=4, 98 | drop_rate=0, 99 | num_classes=10, 100 | ): 101 | 102 | super(DenseNet, self).__init__() 103 | 104 | # First convolution 105 | 106 | # CIFAR-10: kernel_size 7 ->3, stride 2->1, padding 3->1 107 | self.features = nn.Sequential( 108 | OrderedDict( 109 | [ 110 | ( 111 | "conv0", 112 | nn.Conv2d( 113 | 3, 114 | num_init_features, 115 | kernel_size=3, 116 | stride=1, 117 | padding=1, 118 | bias=False, 119 | ), 120 | ), 121 | ("norm0", nn.BatchNorm2d(num_init_features)), 122 | ("relu0", nn.ReLU(inplace=True)), 123 | ("pool0", nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 124 | ] 125 | ) 126 | ) 127 | # END 128 | 129 | # Each denseblock 130 | num_features = num_init_features 131 | for i, num_layers in enumerate(block_config): 132 | block = _DenseBlock( 133 | num_layers=num_layers, 134 | num_input_features=num_features, 135 | bn_size=bn_size, 136 | growth_rate=growth_rate, 137 | drop_rate=drop_rate, 138 | ) 139 | self.features.add_module("denseblock%d" % (i + 1), block) 140 | num_features = num_features + num_layers * growth_rate 141 | if i != len(block_config) - 1: 142 | trans = _Transition( 143 | num_input_features=num_features, 144 | num_output_features=num_features // 2, 145 | ) 146 | self.features.add_module("transition%d" % (i + 1), trans) 147 | num_features = num_features // 2 148 | 149 | # Final batch norm 150 | self.features.add_module("norm5", nn.BatchNorm2d(num_features)) 151 | 152 | # Linear layer 153 | self.classifier = nn.Linear(num_features, num_classes) 154 | 155 | # Official init from torch repo. 156 | for m in self.modules(): 157 | if isinstance(m, nn.Conv2d): 158 | nn.init.kaiming_normal_(m.weight) 159 | elif isinstance(m, nn.BatchNorm2d): 160 | nn.init.constant_(m.weight, 1) 161 | nn.init.constant_(m.bias, 0) 162 | elif isinstance(m, nn.Linear): 163 | nn.init.constant_(m.bias, 0) 164 | 165 | def forward(self, x): 166 | features = self.features(x) 167 | out = F.relu(features, inplace=True) 168 | out = F.adaptive_avg_pool2d(out, (1, 1)).view(features.size(0), -1) 169 | out = self.classifier(out) 170 | return out 171 | 172 | 173 | def _densenet( 174 | arch, 175 | growth_rate, 176 | block_config, 177 | num_init_features, 178 | pretrained, 179 | progress, 180 | device, 181 | **kwargs 182 | ): 183 | model = DenseNet(growth_rate, block_config, num_init_features, **kwargs) 184 | if pretrained: 185 | script_dir = os.path.dirname(__file__) 186 | state_dict = torch.load( 187 | script_dir + "/state_dicts/" + arch + ".pt", map_location=device 188 | ) 189 | model.load_state_dict(state_dict) 190 | return model 191 | 192 | 193 | def densenet121(pretrained=False, progress=True, device="cpu", **kwargs): 194 | r"""Densenet-121 model from 195 | `"Densely Connected Convolutional Networks" `_ 196 | 197 | Args: 198 | pretrained (bool): If True, returns a model pre-trained on ImageNet 199 | progress (bool): If True, displays a progress bar of the download to stderr 200 | """ 201 | return _densenet( 202 | "densenet121", 32, (6, 12, 24, 16), 64, pretrained, progress, device, **kwargs 203 | ) 204 | 205 | 206 | def densenet161(pretrained=False, progress=True, device="cpu", **kwargs): 207 | r"""Densenet-161 model from 208 | `"Densely Connected Convolutional Networks" `_ 209 | 210 | Args: 211 | pretrained (bool): If True, returns a model pre-trained on ImageNet 212 | progress (bool): If True, displays a progress bar of the download to stderr 213 | """ 214 | return _densenet( 215 | "densenet161", 48, (6, 12, 36, 24), 96, pretrained, progress, device, **kwargs 216 | ) 217 | 218 | 219 | def densenet169(pretrained=False, progress=True, device="cpu", **kwargs): 220 | r"""Densenet-169 model from 221 | `"Densely Connected Convolutional Networks" `_ 222 | 223 | Args: 224 | pretrained (bool): If True, returns a model pre-trained on ImageNet 225 | progress (bool): If True, displays a progress bar of the download to stderr 226 | """ 227 | return _densenet( 228 | "densenet169", 32, (6, 12, 32, 32), 64, pretrained, progress, device, **kwargs 229 | ) 230 | -------------------------------------------------------------------------------- /cifar10_models/googlenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import namedtuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | __all__ = ["GoogLeNet", "googlenet"] 9 | 10 | 11 | _GoogLeNetOuputs = namedtuple( 12 | "GoogLeNetOuputs", ["logits", "aux_logits2", "aux_logits1"] 13 | ) 14 | 15 | 16 | def googlenet(pretrained=False, progress=True, device="cpu", **kwargs): 17 | r"""GoogLeNet (Inception v1) model architecture from 18 | `"Going Deeper with Convolutions" `_. 19 | 20 | Args: 21 | pretrained (bool): If True, returns a model pre-trained on ImageNet 22 | progress (bool): If True, displays a progress bar of the download to stderr 23 | aux_logits (bool): If True, adds two auxiliary branches that can improve training. 24 | Default: *False* when pretrained is True otherwise *True* 25 | transform_input (bool): If True, preprocesses the input according to the method with which it 26 | was trained on ImageNet. Default: *False* 27 | """ 28 | model = GoogLeNet() 29 | if pretrained: 30 | script_dir = os.path.dirname(__file__) 31 | state_dict = torch.load( 32 | script_dir + "/state_dicts/googlenet.pt", map_location=device 33 | ) 34 | model.load_state_dict(state_dict) 35 | return model 36 | 37 | 38 | class GoogLeNet(nn.Module): 39 | 40 | # CIFAR10: aux_logits True->False 41 | def __init__(self, num_classes=10, aux_logits=False, transform_input=False): 42 | super(GoogLeNet, self).__init__() 43 | self.aux_logits = aux_logits 44 | self.transform_input = transform_input 45 | 46 | # CIFAR10: out_channels 64->192, kernel_size 7->3, stride 2->1, padding 3->1 47 | self.conv1 = BasicConv2d(3, 192, kernel_size=3, stride=1, padding=1) 48 | # self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True) 49 | # self.conv2 = BasicConv2d(64, 64, kernel_size=1) 50 | # self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1) 51 | # self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True) 52 | # END 53 | 54 | self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32) 55 | self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64) 56 | 57 | # CIFAR10: padding 0->1, ciel_model True->False 58 | self.maxpool3 = nn.MaxPool2d(3, stride=2, padding=1, ceil_mode=False) 59 | # END 60 | 61 | self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64) 62 | self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64) 63 | self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64) 64 | self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64) 65 | self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128) 66 | 67 | # CIFAR10: kernel_size 2->3, padding 0->1, ciel_model True->False 68 | self.maxpool4 = nn.MaxPool2d(3, stride=2, padding=1, ceil_mode=False) 69 | # END 70 | 71 | self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128) 72 | self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128) 73 | 74 | if aux_logits: 75 | self.aux1 = InceptionAux(512, num_classes) 76 | self.aux2 = InceptionAux(528, num_classes) 77 | 78 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 79 | self.dropout = nn.Dropout(0.2) 80 | self.fc = nn.Linear(1024, num_classes) 81 | 82 | # if init_weights: 83 | # self._initialize_weights() 84 | 85 | # def _initialize_weights(self): 86 | # for m in self.modules(): 87 | # if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 88 | # import scipy.stats as stats 89 | # X = stats.truncnorm(-2, 2, scale=0.01) 90 | # values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype) 91 | # values = values.view(m.weight.size()) 92 | # with torch.no_grad(): 93 | # m.weight.copy_(values) 94 | # elif isinstance(m, nn.BatchNorm2d): 95 | # nn.init.constant_(m.weight, 1) 96 | # nn.init.constant_(m.bias, 0) 97 | 98 | def forward(self, x): 99 | if self.transform_input: 100 | x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 101 | x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 102 | x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 103 | x = torch.cat((x_ch0, x_ch1, x_ch2), 1) 104 | 105 | # N x 3 x 224 x 224 106 | x = self.conv1(x) 107 | 108 | # CIFAR10 109 | # N x 64 x 112 x 112 110 | # x = self.maxpool1(x) 111 | # N x 64 x 56 x 56 112 | # x = self.conv2(x) 113 | # N x 64 x 56 x 56 114 | # x = self.conv3(x) 115 | # N x 192 x 56 x 56 116 | # x = self.maxpool2(x) 117 | # END 118 | 119 | # N x 192 x 28 x 28 120 | x = self.inception3a(x) 121 | # N x 256 x 28 x 28 122 | x = self.inception3b(x) 123 | # N x 480 x 28 x 28 124 | x = self.maxpool3(x) 125 | # N x 480 x 14 x 14 126 | x = self.inception4a(x) 127 | # N x 512 x 14 x 14 128 | if self.training and self.aux_logits: 129 | aux1 = self.aux1(x) 130 | 131 | x = self.inception4b(x) 132 | # N x 512 x 14 x 14 133 | x = self.inception4c(x) 134 | # N x 512 x 14 x 14 135 | x = self.inception4d(x) 136 | # N x 528 x 14 x 14 137 | if self.training and self.aux_logits: 138 | aux2 = self.aux2(x) 139 | 140 | x = self.inception4e(x) 141 | # N x 832 x 14 x 14 142 | x = self.maxpool4(x) 143 | # N x 832 x 7 x 7 144 | x = self.inception5a(x) 145 | # N x 832 x 7 x 7 146 | x = self.inception5b(x) 147 | # N x 1024 x 7 x 7 148 | 149 | x = self.avgpool(x) 150 | # N x 1024 x 1 x 1 151 | x = x.view(x.size(0), -1) 152 | # N x 1024 153 | x = self.dropout(x) 154 | x = self.fc(x) 155 | # N x 1000 (num_classes) 156 | if self.training and self.aux_logits: 157 | return _GoogLeNetOuputs(x, aux2, aux1) 158 | return x 159 | 160 | 161 | class Inception(nn.Module): 162 | def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj): 163 | super(Inception, self).__init__() 164 | 165 | self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1) 166 | 167 | self.branch2 = nn.Sequential( 168 | BasicConv2d(in_channels, ch3x3red, kernel_size=1), 169 | BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1), 170 | ) 171 | 172 | self.branch3 = nn.Sequential( 173 | BasicConv2d(in_channels, ch5x5red, kernel_size=1), 174 | BasicConv2d(ch5x5red, ch5x5, kernel_size=3, padding=1), 175 | ) 176 | 177 | self.branch4 = nn.Sequential( 178 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True), 179 | BasicConv2d(in_channels, pool_proj, kernel_size=1), 180 | ) 181 | 182 | def forward(self, x): 183 | branch1 = self.branch1(x) 184 | branch2 = self.branch2(x) 185 | branch3 = self.branch3(x) 186 | branch4 = self.branch4(x) 187 | 188 | outputs = [branch1, branch2, branch3, branch4] 189 | return torch.cat(outputs, 1) 190 | 191 | 192 | class InceptionAux(nn.Module): 193 | def __init__(self, in_channels, num_classes): 194 | super(InceptionAux, self).__init__() 195 | self.conv = BasicConv2d(in_channels, 128, kernel_size=1) 196 | 197 | self.fc1 = nn.Linear(2048, 1024) 198 | self.fc2 = nn.Linear(1024, num_classes) 199 | 200 | def forward(self, x): 201 | # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14 202 | x = F.adaptive_avg_pool2d(x, (4, 4)) 203 | # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4 204 | x = self.conv(x) 205 | # N x 128 x 4 x 4 206 | x = x.view(x.size(0), -1) 207 | # N x 2048 208 | x = F.relu(self.fc1(x), inplace=True) 209 | # N x 2048 210 | x = F.dropout(x, 0.7, training=self.training) 211 | # N x 2048 212 | x = self.fc2(x) 213 | # N x 1024 214 | 215 | return x 216 | 217 | 218 | class BasicConv2d(nn.Module): 219 | def __init__(self, in_channels, out_channels, **kwargs): 220 | super(BasicConv2d, self).__init__() 221 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 222 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 223 | 224 | def forward(self, x): 225 | x = self.conv(x) 226 | x = self.bn(x) 227 | return F.relu(x, inplace=True) 228 | -------------------------------------------------------------------------------- /cifar10_models/inception.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import namedtuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | __all__ = ["Inception3", "inception_v3"] 9 | 10 | 11 | _InceptionOuputs = namedtuple("InceptionOuputs", ["logits", "aux_logits"]) 12 | 13 | 14 | def inception_v3(pretrained=False, progress=True, device="cpu", **kwargs): 15 | r"""Inception v3 model architecture from 16 | `"Rethinking the Inception Architecture for Computer Vision" `_. 17 | 18 | .. note:: 19 | **Important**: In contrast to the other models the inception_v3 expects tensors with a size of 20 | N x 3 x 299 x 299, so ensure your images are sized accordingly. 21 | 22 | Args: 23 | pretrained (bool): If True, returns a model pre-trained on ImageNet 24 | progress (bool): If True, displays a progress bar of the download to stderr 25 | aux_logits (bool): If True, add an auxiliary branch that can improve training. 26 | Default: *True* 27 | transform_input (bool): If True, preprocesses the input according to the method with which it 28 | was trained on ImageNet. Default: *False* 29 | """ 30 | model = Inception3() 31 | if pretrained: 32 | script_dir = os.path.dirname(__file__) 33 | state_dict = torch.load( 34 | script_dir + "/state_dicts/inception_v3.pt", map_location=device 35 | ) 36 | model.load_state_dict(state_dict) 37 | return model 38 | 39 | 40 | class Inception3(nn.Module): 41 | # CIFAR10: aux_logits True->False 42 | def __init__(self, num_classes=10, aux_logits=False, transform_input=False): 43 | super(Inception3, self).__init__() 44 | self.aux_logits = aux_logits 45 | self.transform_input = transform_input 46 | 47 | # CIFAR10: stride 2->1, padding 0 -> 1 48 | self.Conv2d_1a_3x3 = BasicConv2d(3, 192, kernel_size=3, stride=1, padding=1) 49 | # self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3) 50 | # self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1) 51 | # self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1) 52 | # self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3) 53 | self.Mixed_5b = InceptionA(192, pool_features=32) 54 | self.Mixed_5c = InceptionA(256, pool_features=64) 55 | self.Mixed_5d = InceptionA(288, pool_features=64) 56 | self.Mixed_6a = InceptionB(288) 57 | self.Mixed_6b = InceptionC(768, channels_7x7=128) 58 | self.Mixed_6c = InceptionC(768, channels_7x7=160) 59 | self.Mixed_6d = InceptionC(768, channels_7x7=160) 60 | self.Mixed_6e = InceptionC(768, channels_7x7=192) 61 | if aux_logits: 62 | self.AuxLogits = InceptionAux(768, num_classes) 63 | self.Mixed_7a = InceptionD(768) 64 | self.Mixed_7b = InceptionE(1280) 65 | self.Mixed_7c = InceptionE(2048) 66 | self.fc = nn.Linear(2048, num_classes) 67 | 68 | # for m in self.modules(): 69 | # if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 70 | # import scipy.stats as stats 71 | # stddev = m.stddev if hasattr(m, 'stddev') else 0.1 72 | # X = stats.truncnorm(-2, 2, scale=stddev) 73 | # values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype) 74 | # values = values.view(m.weight.size()) 75 | # with torch.no_grad(): 76 | # m.weight.copy_(values) 77 | # elif isinstance(m, nn.BatchNorm2d): 78 | # nn.init.constant_(m.weight, 1) 79 | # nn.init.constant_(m.bias, 0) 80 | 81 | def forward(self, x): 82 | if self.transform_input: 83 | x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 84 | x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 85 | x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 86 | x = torch.cat((x_ch0, x_ch1, x_ch2), 1) 87 | # N x 3 x 299 x 299 88 | x = self.Conv2d_1a_3x3(x) 89 | 90 | # CIFAR10 91 | # N x 32 x 149 x 149 92 | # x = self.Conv2d_2a_3x3(x) 93 | # N x 32 x 147 x 147 94 | # x = self.Conv2d_2b_3x3(x) 95 | # N x 64 x 147 x 147 96 | # x = F.max_pool2d(x, kernel_size=3, stride=2) 97 | # N x 64 x 73 x 73 98 | # x = self.Conv2d_3b_1x1(x) 99 | # N x 80 x 73 x 73 100 | # x = self.Conv2d_4a_3x3(x) 101 | # N x 192 x 71 x 71 102 | # x = F.max_pool2d(x, kernel_size=3, stride=2) 103 | # N x 192 x 35 x 35 104 | x = self.Mixed_5b(x) 105 | # N x 256 x 35 x 35 106 | x = self.Mixed_5c(x) 107 | # N x 288 x 35 x 35 108 | x = self.Mixed_5d(x) 109 | # N x 288 x 35 x 35 110 | x = self.Mixed_6a(x) 111 | # N x 768 x 17 x 17 112 | x = self.Mixed_6b(x) 113 | # N x 768 x 17 x 17 114 | x = self.Mixed_6c(x) 115 | # N x 768 x 17 x 17 116 | x = self.Mixed_6d(x) 117 | # N x 768 x 17 x 17 118 | x = self.Mixed_6e(x) 119 | # N x 768 x 17 x 17 120 | if self.training and self.aux_logits: 121 | aux = self.AuxLogits(x) 122 | # N x 768 x 17 x 17 123 | x = self.Mixed_7a(x) 124 | # N x 1280 x 8 x 8 125 | x = self.Mixed_7b(x) 126 | # N x 2048 x 8 x 8 127 | x = self.Mixed_7c(x) 128 | # N x 2048 x 8 x 8 129 | # Adaptive average pooling 130 | x = F.adaptive_avg_pool2d(x, (1, 1)) 131 | # N x 2048 x 1 x 1 132 | x = F.dropout(x, training=self.training) 133 | # N x 2048 x 1 x 1 134 | x = x.view(x.size(0), -1) 135 | # N x 2048 136 | x = self.fc(x) 137 | # N x 1000 (num_classes) 138 | if self.training and self.aux_logits: 139 | return _InceptionOuputs(x, aux) 140 | return x 141 | 142 | 143 | class InceptionA(nn.Module): 144 | def __init__(self, in_channels, pool_features): 145 | super(InceptionA, self).__init__() 146 | self.branch1x1 = BasicConv2d(in_channels, 64, kernel_size=1) 147 | 148 | self.branch5x5_1 = BasicConv2d(in_channels, 48, kernel_size=1) 149 | self.branch5x5_2 = BasicConv2d(48, 64, kernel_size=5, padding=2) 150 | 151 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1) 152 | self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1) 153 | self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, padding=1) 154 | 155 | self.branch_pool = BasicConv2d(in_channels, pool_features, kernel_size=1) 156 | 157 | def forward(self, x): 158 | branch1x1 = self.branch1x1(x) 159 | 160 | branch5x5 = self.branch5x5_1(x) 161 | branch5x5 = self.branch5x5_2(branch5x5) 162 | 163 | branch3x3dbl = self.branch3x3dbl_1(x) 164 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 165 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 166 | 167 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 168 | branch_pool = self.branch_pool(branch_pool) 169 | 170 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 171 | return torch.cat(outputs, 1) 172 | 173 | 174 | class InceptionB(nn.Module): 175 | def __init__(self, in_channels): 176 | super(InceptionB, self).__init__() 177 | self.branch3x3 = BasicConv2d(in_channels, 384, kernel_size=3, stride=2) 178 | 179 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1) 180 | self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1) 181 | self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, stride=2) 182 | 183 | def forward(self, x): 184 | branch3x3 = self.branch3x3(x) 185 | 186 | branch3x3dbl = self.branch3x3dbl_1(x) 187 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 188 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 189 | 190 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) 191 | 192 | outputs = [branch3x3, branch3x3dbl, branch_pool] 193 | return torch.cat(outputs, 1) 194 | 195 | 196 | class InceptionC(nn.Module): 197 | def __init__(self, in_channels, channels_7x7): 198 | super(InceptionC, self).__init__() 199 | self.branch1x1 = BasicConv2d(in_channels, 192, kernel_size=1) 200 | 201 | c7 = channels_7x7 202 | self.branch7x7_1 = BasicConv2d(in_channels, c7, kernel_size=1) 203 | self.branch7x7_2 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)) 204 | self.branch7x7_3 = BasicConv2d(c7, 192, kernel_size=(7, 1), padding=(3, 0)) 205 | 206 | self.branch7x7dbl_1 = BasicConv2d(in_channels, c7, kernel_size=1) 207 | self.branch7x7dbl_2 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)) 208 | self.branch7x7dbl_3 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)) 209 | self.branch7x7dbl_4 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)) 210 | self.branch7x7dbl_5 = BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3)) 211 | 212 | self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1) 213 | 214 | def forward(self, x): 215 | branch1x1 = self.branch1x1(x) 216 | 217 | branch7x7 = self.branch7x7_1(x) 218 | branch7x7 = self.branch7x7_2(branch7x7) 219 | branch7x7 = self.branch7x7_3(branch7x7) 220 | 221 | branch7x7dbl = self.branch7x7dbl_1(x) 222 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 223 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 224 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 225 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 226 | 227 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 228 | branch_pool = self.branch_pool(branch_pool) 229 | 230 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 231 | return torch.cat(outputs, 1) 232 | 233 | 234 | class InceptionD(nn.Module): 235 | def __init__(self, in_channels): 236 | super(InceptionD, self).__init__() 237 | self.branch3x3_1 = BasicConv2d(in_channels, 192, kernel_size=1) 238 | self.branch3x3_2 = BasicConv2d(192, 320, kernel_size=3, stride=2) 239 | 240 | self.branch7x7x3_1 = BasicConv2d(in_channels, 192, kernel_size=1) 241 | self.branch7x7x3_2 = BasicConv2d(192, 192, kernel_size=(1, 7), padding=(0, 3)) 242 | self.branch7x7x3_3 = BasicConv2d(192, 192, kernel_size=(7, 1), padding=(3, 0)) 243 | self.branch7x7x3_4 = BasicConv2d(192, 192, kernel_size=3, stride=2) 244 | 245 | def forward(self, x): 246 | branch3x3 = self.branch3x3_1(x) 247 | branch3x3 = self.branch3x3_2(branch3x3) 248 | 249 | branch7x7x3 = self.branch7x7x3_1(x) 250 | branch7x7x3 = self.branch7x7x3_2(branch7x7x3) 251 | branch7x7x3 = self.branch7x7x3_3(branch7x7x3) 252 | branch7x7x3 = self.branch7x7x3_4(branch7x7x3) 253 | 254 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) 255 | outputs = [branch3x3, branch7x7x3, branch_pool] 256 | return torch.cat(outputs, 1) 257 | 258 | 259 | class InceptionE(nn.Module): 260 | def __init__(self, in_channels): 261 | super(InceptionE, self).__init__() 262 | self.branch1x1 = BasicConv2d(in_channels, 320, kernel_size=1) 263 | 264 | self.branch3x3_1 = BasicConv2d(in_channels, 384, kernel_size=1) 265 | self.branch3x3_2a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) 266 | self.branch3x3_2b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) 267 | 268 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 448, kernel_size=1) 269 | self.branch3x3dbl_2 = BasicConv2d(448, 384, kernel_size=3, padding=1) 270 | self.branch3x3dbl_3a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) 271 | self.branch3x3dbl_3b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) 272 | 273 | self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1) 274 | 275 | def forward(self, x): 276 | branch1x1 = self.branch1x1(x) 277 | 278 | branch3x3 = self.branch3x3_1(x) 279 | branch3x3 = [ 280 | self.branch3x3_2a(branch3x3), 281 | self.branch3x3_2b(branch3x3), 282 | ] 283 | branch3x3 = torch.cat(branch3x3, 1) 284 | 285 | branch3x3dbl = self.branch3x3dbl_1(x) 286 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 287 | branch3x3dbl = [ 288 | self.branch3x3dbl_3a(branch3x3dbl), 289 | self.branch3x3dbl_3b(branch3x3dbl), 290 | ] 291 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 292 | 293 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 294 | branch_pool = self.branch_pool(branch_pool) 295 | 296 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 297 | return torch.cat(outputs, 1) 298 | 299 | 300 | class InceptionAux(nn.Module): 301 | def __init__(self, in_channels, num_classes): 302 | super(InceptionAux, self).__init__() 303 | self.conv0 = BasicConv2d(in_channels, 128, kernel_size=1) 304 | self.conv1 = BasicConv2d(128, 768, kernel_size=5) 305 | self.conv1.stddev = 0.01 306 | self.fc = nn.Linear(768, num_classes) 307 | self.fc.stddev = 0.001 308 | 309 | def forward(self, x): 310 | # N x 768 x 17 x 17 311 | x = F.avg_pool2d(x, kernel_size=5, stride=3) 312 | # N x 768 x 5 x 5 313 | x = self.conv0(x) 314 | # N x 128 x 5 x 5 315 | x = self.conv1(x) 316 | # N x 768 x 1 x 1 317 | # Adaptive average pooling 318 | x = F.adaptive_avg_pool2d(x, (1, 1)) 319 | # N x 768 x 1 x 1 320 | x = x.view(x.size(0), -1) 321 | # N x 768 322 | x = self.fc(x) 323 | # N x 1000 324 | return x 325 | 326 | 327 | class BasicConv2d(nn.Module): 328 | def __init__(self, in_channels, out_channels, **kwargs): 329 | super(BasicConv2d, self).__init__() 330 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 331 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 332 | 333 | def forward(self, x): 334 | x = self.conv(x) 335 | x = self.bn(x) 336 | return F.relu(x, inplace=True) 337 | -------------------------------------------------------------------------------- /cifar10_models/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | __all__ = ["MobileNetV2", "mobilenet_v2"] 7 | 8 | 9 | class ConvBNReLU(nn.Sequential): 10 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): 11 | padding = (kernel_size - 1) // 2 12 | super(ConvBNReLU, self).__init__( 13 | nn.Conv2d( 14 | in_planes, 15 | out_planes, 16 | kernel_size, 17 | stride, 18 | padding, 19 | groups=groups, 20 | bias=False, 21 | ), 22 | nn.BatchNorm2d(out_planes), 23 | nn.ReLU6(inplace=True), 24 | ) 25 | 26 | 27 | class InvertedResidual(nn.Module): 28 | def __init__(self, inp, oup, stride, expand_ratio): 29 | super(InvertedResidual, self).__init__() 30 | self.stride = stride 31 | assert stride in [1, 2] 32 | 33 | hidden_dim = int(round(inp * expand_ratio)) 34 | self.use_res_connect = self.stride == 1 and inp == oup 35 | 36 | layers = [] 37 | if expand_ratio != 1: 38 | # pw 39 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) 40 | layers.extend( 41 | [ 42 | # dw 43 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), 44 | # pw-linear 45 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 46 | nn.BatchNorm2d(oup), 47 | ] 48 | ) 49 | self.conv = nn.Sequential(*layers) 50 | 51 | def forward(self, x): 52 | if self.use_res_connect: 53 | return x + self.conv(x) 54 | else: 55 | return self.conv(x) 56 | 57 | 58 | class MobileNetV2(nn.Module): 59 | def __init__(self, num_classes=10, width_mult=1.0): 60 | super(MobileNetV2, self).__init__() 61 | block = InvertedResidual 62 | input_channel = 32 63 | last_channel = 1280 64 | 65 | # CIFAR10 66 | inverted_residual_setting = [ 67 | # t, c, n, s 68 | [1, 16, 1, 1], 69 | [6, 24, 2, 1], # Stride 2 -> 1 for CIFAR-10 70 | [6, 32, 3, 2], 71 | [6, 64, 4, 2], 72 | [6, 96, 3, 1], 73 | [6, 160, 3, 2], 74 | [6, 320, 1, 1], 75 | ] 76 | # END 77 | 78 | # building first layer 79 | input_channel = int(input_channel * width_mult) 80 | self.last_channel = int(last_channel * max(1.0, width_mult)) 81 | 82 | # CIFAR10: stride 2 -> 1 83 | features = [ConvBNReLU(3, input_channel, stride=1)] 84 | # END 85 | 86 | # building inverted residual blocks 87 | for t, c, n, s in inverted_residual_setting: 88 | output_channel = int(c * width_mult) 89 | for i in range(n): 90 | stride = s if i == 0 else 1 91 | features.append( 92 | block(input_channel, output_channel, stride, expand_ratio=t) 93 | ) 94 | input_channel = output_channel 95 | # building last several layers 96 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) 97 | # make it nn.Sequential 98 | self.features = nn.Sequential(*features) 99 | 100 | # building classifier 101 | self.classifier = nn.Sequential( 102 | nn.Dropout(0.2), 103 | nn.Linear(self.last_channel, num_classes), 104 | ) 105 | 106 | # weight initialization 107 | for m in self.modules(): 108 | if isinstance(m, nn.Conv2d): 109 | nn.init.kaiming_normal_(m.weight, mode="fan_out") 110 | if m.bias is not None: 111 | nn.init.zeros_(m.bias) 112 | elif isinstance(m, nn.BatchNorm2d): 113 | nn.init.ones_(m.weight) 114 | nn.init.zeros_(m.bias) 115 | elif isinstance(m, nn.Linear): 116 | nn.init.normal_(m.weight, 0, 0.01) 117 | nn.init.zeros_(m.bias) 118 | 119 | def forward(self, x): 120 | x = self.features(x) 121 | x = x.mean([2, 3]) 122 | x = self.classifier(x) 123 | return x 124 | 125 | 126 | def mobilenet_v2(pretrained=False, progress=True, device="cpu", **kwargs): 127 | """ 128 | Constructs a MobileNetV2 architecture from 129 | `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_. 130 | 131 | Args: 132 | pretrained (bool): If True, returns a model pre-trained on ImageNet 133 | progress (bool): If True, displays a progress bar of the download to stderr 134 | """ 135 | model = MobileNetV2(**kwargs) 136 | if pretrained: 137 | script_dir = os.path.dirname(__file__) 138 | state_dict = torch.load( 139 | script_dir + "/state_dicts/mobilenet_v2.pt", map_location=device 140 | ) 141 | model.load_state_dict(state_dict) 142 | return model 143 | -------------------------------------------------------------------------------- /cifar10_models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | 5 | __all__ = [ 6 | "ResNet", 7 | "resnet18", 8 | "resnet34", 9 | "resnet50", 10 | ] 11 | 12 | 13 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 14 | """3x3 convolution with padding""" 15 | return nn.Conv2d( 16 | in_planes, 17 | out_planes, 18 | kernel_size=3, 19 | stride=stride, 20 | padding=dilation, 21 | groups=groups, 22 | bias=False, 23 | dilation=dilation, 24 | ) 25 | 26 | 27 | def conv1x1(in_planes, out_planes, stride=1): 28 | """1x1 convolution""" 29 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 30 | 31 | 32 | class BasicBlock(nn.Module): 33 | expansion = 1 34 | 35 | def __init__( 36 | self, 37 | inplanes, 38 | planes, 39 | stride=1, 40 | downsample=None, 41 | groups=1, 42 | base_width=64, 43 | dilation=1, 44 | norm_layer=None, 45 | ): 46 | super(BasicBlock, self).__init__() 47 | if norm_layer is None: 48 | norm_layer = nn.BatchNorm2d 49 | if groups != 1 or base_width != 64: 50 | raise ValueError("BasicBlock only supports groups=1 and base_width=64") 51 | if dilation > 1: 52 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 53 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 54 | self.conv1 = conv3x3(inplanes, planes, stride) 55 | self.bn1 = norm_layer(planes) 56 | self.relu = nn.ReLU(inplace=True) 57 | self.conv2 = conv3x3(planes, planes) 58 | self.bn2 = norm_layer(planes) 59 | self.downsample = downsample 60 | self.stride = stride 61 | 62 | def forward(self, x): 63 | identity = x 64 | 65 | out = self.conv1(x) 66 | out = self.bn1(out) 67 | out = self.relu(out) 68 | 69 | out = self.conv2(out) 70 | out = self.bn2(out) 71 | 72 | if self.downsample is not None: 73 | identity = self.downsample(x) 74 | 75 | out += identity 76 | out = self.relu(out) 77 | 78 | return out 79 | 80 | 81 | class Bottleneck(nn.Module): 82 | expansion = 4 83 | 84 | def __init__( 85 | self, 86 | inplanes, 87 | planes, 88 | stride=1, 89 | downsample=None, 90 | groups=1, 91 | base_width=64, 92 | dilation=1, 93 | norm_layer=None, 94 | ): 95 | super(Bottleneck, self).__init__() 96 | if norm_layer is None: 97 | norm_layer = nn.BatchNorm2d 98 | width = int(planes * (base_width / 64.0)) * groups 99 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 100 | self.conv1 = conv1x1(inplanes, width) 101 | self.bn1 = norm_layer(width) 102 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 103 | self.bn2 = norm_layer(width) 104 | self.conv3 = conv1x1(width, planes * self.expansion) 105 | self.bn3 = norm_layer(planes * self.expansion) 106 | self.relu = nn.ReLU(inplace=True) 107 | self.downsample = downsample 108 | self.stride = stride 109 | 110 | def forward(self, x): 111 | identity = x 112 | 113 | out = self.conv1(x) 114 | out = self.bn1(out) 115 | out = self.relu(out) 116 | 117 | out = self.conv2(out) 118 | out = self.bn2(out) 119 | out = self.relu(out) 120 | 121 | out = self.conv3(out) 122 | out = self.bn3(out) 123 | 124 | if self.downsample is not None: 125 | identity = self.downsample(x) 126 | 127 | out += identity 128 | out = self.relu(out) 129 | 130 | return out 131 | 132 | 133 | class ResNet(nn.Module): 134 | def __init__( 135 | self, 136 | block, 137 | layers, 138 | num_classes=10, 139 | zero_init_residual=False, 140 | groups=1, 141 | width_per_group=64, 142 | replace_stride_with_dilation=None, 143 | norm_layer=None, 144 | ): 145 | super(ResNet, self).__init__() 146 | if norm_layer is None: 147 | norm_layer = nn.BatchNorm2d 148 | self._norm_layer = norm_layer 149 | 150 | self.inplanes = 64 151 | self.dilation = 1 152 | if replace_stride_with_dilation is None: 153 | # each element in the tuple indicates if we should replace 154 | # the 2x2 stride with a dilated convolution instead 155 | replace_stride_with_dilation = [False, False, False] 156 | if len(replace_stride_with_dilation) != 3: 157 | raise ValueError( 158 | "replace_stride_with_dilation should be None " 159 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation) 160 | ) 161 | self.groups = groups 162 | self.base_width = width_per_group 163 | 164 | # CIFAR10: kernel_size 7 -> 3, stride 2 -> 1, padding 3->1 165 | self.conv1 = nn.Conv2d( 166 | 3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False 167 | ) 168 | # END 169 | 170 | self.bn1 = norm_layer(self.inplanes) 171 | self.relu = nn.ReLU(inplace=True) 172 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 173 | self.layer1 = self._make_layer(block, 64, layers[0]) 174 | self.layer2 = self._make_layer( 175 | block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0] 176 | ) 177 | self.layer3 = self._make_layer( 178 | block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1] 179 | ) 180 | self.layer4 = self._make_layer( 181 | block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2] 182 | ) 183 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 184 | self.fc = nn.Linear(512 * block.expansion, num_classes) 185 | 186 | for m in self.modules(): 187 | if isinstance(m, nn.Conv2d): 188 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 189 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 190 | nn.init.constant_(m.weight, 1) 191 | nn.init.constant_(m.bias, 0) 192 | 193 | # Zero-initialize the last BN in each residual branch, 194 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 195 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 196 | if zero_init_residual: 197 | for m in self.modules(): 198 | if isinstance(m, Bottleneck): 199 | nn.init.constant_(m.bn3.weight, 0) 200 | elif isinstance(m, BasicBlock): 201 | nn.init.constant_(m.bn2.weight, 0) 202 | 203 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 204 | norm_layer = self._norm_layer 205 | downsample = None 206 | previous_dilation = self.dilation 207 | if dilate: 208 | self.dilation *= stride 209 | stride = 1 210 | if stride != 1 or self.inplanes != planes * block.expansion: 211 | downsample = nn.Sequential( 212 | conv1x1(self.inplanes, planes * block.expansion, stride), 213 | norm_layer(planes * block.expansion), 214 | ) 215 | 216 | layers = [] 217 | layers.append( 218 | block( 219 | self.inplanes, 220 | planes, 221 | stride, 222 | downsample, 223 | self.groups, 224 | self.base_width, 225 | previous_dilation, 226 | norm_layer, 227 | ) 228 | ) 229 | self.inplanes = planes * block.expansion 230 | for _ in range(1, blocks): 231 | layers.append( 232 | block( 233 | self.inplanes, 234 | planes, 235 | groups=self.groups, 236 | base_width=self.base_width, 237 | dilation=self.dilation, 238 | norm_layer=norm_layer, 239 | ) 240 | ) 241 | 242 | return nn.Sequential(*layers) 243 | 244 | def forward(self, x): 245 | x = self.conv1(x) 246 | x = self.bn1(x) 247 | x = self.relu(x) 248 | x = self.maxpool(x) 249 | 250 | x = self.layer1(x) 251 | x = self.layer2(x) 252 | x = self.layer3(x) 253 | x = self.layer4(x) 254 | 255 | x = self.avgpool(x) 256 | x = x.reshape(x.size(0), -1) 257 | x = self.fc(x) 258 | 259 | return x 260 | 261 | 262 | def _resnet(arch, block, layers, pretrained, progress, device, **kwargs): 263 | model = ResNet(block, layers, **kwargs) 264 | if pretrained: 265 | script_dir = os.path.dirname(__file__) 266 | state_dict = torch.load( 267 | script_dir + "/state_dicts/" + arch + ".pt", map_location=device 268 | ) 269 | model.load_state_dict(state_dict) 270 | return model 271 | 272 | 273 | def resnet18(pretrained=False, progress=True, device="cpu", **kwargs): 274 | """Constructs a ResNet-18 model. 275 | Args: 276 | pretrained (bool): If True, returns a model pre-trained on ImageNet 277 | progress (bool): If True, displays a progress bar of the download to stderr 278 | """ 279 | return _resnet( 280 | "resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, device, **kwargs 281 | ) 282 | 283 | 284 | def resnet34(pretrained=False, progress=True, device="cpu", **kwargs): 285 | """Constructs a ResNet-34 model. 286 | Args: 287 | pretrained (bool): If True, returns a model pre-trained on ImageNet 288 | progress (bool): If True, displays a progress bar of the download to stderr 289 | """ 290 | return _resnet( 291 | "resnet34", BasicBlock, [3, 4, 6, 3], pretrained, progress, device, **kwargs 292 | ) 293 | 294 | 295 | def resnet50(pretrained=False, progress=True, device="cpu", **kwargs): 296 | """Constructs a ResNet-50 model. 297 | Args: 298 | pretrained (bool): If True, returns a model pre-trained on ImageNet 299 | progress (bool): If True, displays a progress bar of the download to stderr 300 | """ 301 | return _resnet( 302 | "resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, device, **kwargs 303 | ) 304 | -------------------------------------------------------------------------------- /cifar10_models/resnet_orig.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import os 5 | 6 | # Credit to https://github.com/akamaster/pytorch_resnet_cifar10 7 | 8 | __all__ = ["resnet_orig"] 9 | 10 | 11 | class LambdaLayer(nn.Module): 12 | def __init__(self, lambd): 13 | super(LambdaLayer, self).__init__() 14 | self.lambd = lambd 15 | 16 | def forward(self, x): 17 | return self.lambd(x) 18 | 19 | 20 | class BasicBlock(nn.Module): 21 | expansion = 1 22 | 23 | def __init__(self, in_planes, planes, stride=1, option="A"): 24 | super(BasicBlock, self).__init__() 25 | self.conv1 = nn.Conv2d( 26 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 27 | ) 28 | self.bn1 = nn.BatchNorm2d(planes) 29 | self.conv2 = nn.Conv2d( 30 | planes, planes, kernel_size=3, stride=1, padding=1, bias=False 31 | ) 32 | self.bn2 = nn.BatchNorm2d(planes) 33 | 34 | self.shortcut = nn.Sequential() 35 | if stride != 1 or in_planes != planes: 36 | if option == "A": 37 | """ 38 | For CIFAR10 ResNet paper uses option A. 39 | """ 40 | self.shortcut = LambdaLayer( 41 | lambda x: F.pad( 42 | x[:, :, ::2, ::2], 43 | (0, 0, 0, 0, planes // 4, planes // 4), 44 | "constant", 45 | 0, 46 | ) 47 | ) 48 | elif option == "B": 49 | self.shortcut = nn.Sequential( 50 | nn.Conv2d( 51 | in_planes, 52 | self.expansion * planes, 53 | kernel_size=1, 54 | stride=stride, 55 | bias=False, 56 | ), 57 | nn.BatchNorm2d(self.expansion * planes), 58 | ) 59 | 60 | def forward(self, x): 61 | out = F.relu(self.bn1(self.conv1(x))) 62 | out = self.bn2(self.conv2(out)) 63 | out += self.shortcut(x) 64 | out = F.relu(out) 65 | return out 66 | 67 | 68 | class ResNet(nn.Module): 69 | def __init__(self, block, num_blocks, num_classes=10): 70 | super(ResNet, self).__init__() 71 | self.in_planes = 16 72 | 73 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 74 | self.bn1 = nn.BatchNorm2d(16) 75 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 76 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 77 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 78 | self.linear = nn.Linear(64, num_classes) 79 | 80 | def _make_layer(self, block, planes, num_blocks, stride): 81 | strides = [stride] + [1] * (num_blocks - 1) 82 | layers = [] 83 | for stride in strides: 84 | layers.append(block(self.in_planes, planes, stride)) 85 | self.in_planes = planes * block.expansion 86 | 87 | return nn.Sequential(*layers) 88 | 89 | def forward(self, x): 90 | out = F.relu(self.bn1(self.conv1(x))) 91 | out = self.layer1(out) 92 | out = self.layer2(out) 93 | out = self.layer3(out) 94 | out = F.avg_pool2d(out, out.size()[3]) 95 | out = out.view(out.size(0), -1) 96 | out = self.linear(out) 97 | return out 98 | 99 | 100 | def resnet_orig(pretrained=True, device="cpu"): 101 | net = ResNet(BasicBlock, [3, 3, 3]) 102 | if pretrained: 103 | script_dir = os.path.dirname(__file__) 104 | state_dict = torch.load( 105 | script_dir + "/state_dicts/resnet_orig.pt", map_location=device 106 | ) 107 | net.load_state_dict(state_dict) 108 | return net 109 | -------------------------------------------------------------------------------- /cifar10_models/vgg.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | __all__ = [ 7 | "VGG", 8 | "vgg11_bn", 9 | "vgg13_bn", 10 | "vgg16_bn", 11 | "vgg19_bn", 12 | ] 13 | 14 | 15 | class VGG(nn.Module): 16 | def __init__(self, features, num_classes=10, init_weights=True): 17 | super(VGG, self).__init__() 18 | self.features = features 19 | # CIFAR 10 (7, 7) to (1, 1) 20 | # self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 21 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 22 | 23 | self.classifier = nn.Sequential( 24 | nn.Linear(512 * 1 * 1, 4096), 25 | # nn.Linear(512 * 7 * 7, 4096), 26 | nn.ReLU(True), 27 | nn.Dropout(), 28 | nn.Linear(4096, 4096), 29 | nn.ReLU(True), 30 | nn.Dropout(), 31 | nn.Linear(4096, num_classes), 32 | ) 33 | if init_weights: 34 | self._initialize_weights() 35 | 36 | def forward(self, x): 37 | x = self.features(x) 38 | x = self.avgpool(x) 39 | x = x.view(x.size(0), -1) 40 | x = self.classifier(x) 41 | return x 42 | 43 | def _initialize_weights(self): 44 | for m in self.modules(): 45 | if isinstance(m, nn.Conv2d): 46 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 47 | if m.bias is not None: 48 | nn.init.constant_(m.bias, 0) 49 | elif isinstance(m, nn.BatchNorm2d): 50 | nn.init.constant_(m.weight, 1) 51 | nn.init.constant_(m.bias, 0) 52 | elif isinstance(m, nn.Linear): 53 | nn.init.normal_(m.weight, 0, 0.01) 54 | nn.init.constant_(m.bias, 0) 55 | 56 | 57 | def make_layers(cfg, batch_norm=False): 58 | layers = [] 59 | in_channels = 3 60 | for v in cfg: 61 | if v == "M": 62 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 63 | else: 64 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 65 | if batch_norm: 66 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 67 | else: 68 | layers += [conv2d, nn.ReLU(inplace=True)] 69 | in_channels = v 70 | return nn.Sequential(*layers) 71 | 72 | 73 | cfgs = { 74 | "A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], 75 | "B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], 76 | "D": [ 77 | 64, 78 | 64, 79 | "M", 80 | 128, 81 | 128, 82 | "M", 83 | 256, 84 | 256, 85 | 256, 86 | "M", 87 | 512, 88 | 512, 89 | 512, 90 | "M", 91 | 512, 92 | 512, 93 | 512, 94 | "M", 95 | ], 96 | "E": [ 97 | 64, 98 | 64, 99 | "M", 100 | 128, 101 | 128, 102 | "M", 103 | 256, 104 | 256, 105 | 256, 106 | 256, 107 | "M", 108 | 512, 109 | 512, 110 | 512, 111 | 512, 112 | "M", 113 | 512, 114 | 512, 115 | 512, 116 | 512, 117 | "M", 118 | ], 119 | } 120 | 121 | 122 | def _vgg(arch, cfg, batch_norm, pretrained, progress, device, **kwargs): 123 | if pretrained: 124 | kwargs["init_weights"] = False 125 | model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) 126 | if pretrained: 127 | script_dir = os.path.dirname(__file__) 128 | state_dict = torch.load( 129 | script_dir + "/state_dicts/" + arch + ".pt", map_location=device 130 | ) 131 | model.load_state_dict(state_dict) 132 | return model 133 | 134 | 135 | def vgg11_bn(pretrained=False, progress=True, device="cpu", **kwargs): 136 | """VGG 11-layer model (configuration "A") with batch normalization 137 | 138 | Args: 139 | pretrained (bool): If True, returns a model pre-trained on ImageNet 140 | progress (bool): If True, displays a progress bar of the download to stderr 141 | """ 142 | return _vgg("vgg11_bn", "A", True, pretrained, progress, device, **kwargs) 143 | 144 | 145 | def vgg13_bn(pretrained=False, progress=True, device="cpu", **kwargs): 146 | """VGG 13-layer model (configuration "B") with batch normalization 147 | 148 | Args: 149 | pretrained (bool): If True, returns a model pre-trained on ImageNet 150 | progress (bool): If True, displays a progress bar of the download to stderr 151 | """ 152 | return _vgg("vgg13_bn", "B", True, pretrained, progress, device, **kwargs) 153 | 154 | 155 | def vgg16_bn(pretrained=False, progress=True, device="cpu", **kwargs): 156 | """VGG 16-layer model (configuration "D") with batch normalization 157 | 158 | Args: 159 | pretrained (bool): If True, returns a model pre-trained on ImageNet 160 | progress (bool): If True, displays a progress bar of the download to stderr 161 | """ 162 | return _vgg("vgg16_bn", "D", True, pretrained, progress, device, **kwargs) 163 | 164 | 165 | def vgg19_bn(pretrained=False, progress=True, device="cpu", **kwargs): 166 | """VGG 19-layer model (configuration 'E') with batch normalization 167 | 168 | Args: 169 | pretrained (bool): If True, returns a model pre-trained on ImageNet 170 | progress (bool): If True, displays a progress bar of the download to stderr 171 | """ 172 | return _vgg("vgg19_bn", "E", True, pretrained, progress, device, **kwargs) 173 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import zipfile 3 | 4 | import pytorch_lightning as pl 5 | import requests 6 | from torch.utils.data import DataLoader 7 | from torchvision import transforms as T 8 | from torchvision.datasets import CIFAR10 9 | from tqdm import tqdm 10 | 11 | 12 | class CIFAR10Data(pl.LightningDataModule): 13 | def __init__(self, args): 14 | super().__init__() 15 | self.hparams = args 16 | self.mean = (0.4914, 0.4822, 0.4465) 17 | self.std = (0.2471, 0.2435, 0.2616) 18 | 19 | def download_weights(): 20 | url = ( 21 | "https://rutgers.box.com/shared/static/gkw08ecs797j2et1ksmbg1w5t3idf5r5.zip" 22 | ) 23 | 24 | # Streaming, so we can iterate over the response. 25 | r = requests.get(url, stream=True) 26 | 27 | # Total size in Mebibyte 28 | total_size = int(r.headers.get("content-length", 0)) 29 | block_size = 2 ** 20 # Mebibyte 30 | t = tqdm(total=total_size, unit="MiB", unit_scale=True) 31 | 32 | with open("state_dicts.zip", "wb") as f: 33 | for data in r.iter_content(block_size): 34 | t.update(len(data)) 35 | f.write(data) 36 | t.close() 37 | 38 | if total_size != 0 and t.n != total_size: 39 | raise Exception("Error, something went wrong") 40 | 41 | print("Download successful. Unzipping file...") 42 | path_to_zip_file = os.path.join(os.getcwd(), "state_dicts.zip") 43 | directory_to_extract_to = os.path.join(os.getcwd(), "cifar10_models") 44 | with zipfile.ZipFile(path_to_zip_file, "r") as zip_ref: 45 | zip_ref.extractall(directory_to_extract_to) 46 | print("Unzip file successful!") 47 | 48 | def train_dataloader(self): 49 | transform = T.Compose( 50 | [ 51 | T.RandomCrop(32, padding=4), 52 | T.RandomHorizontalFlip(), 53 | T.ToTensor(), 54 | T.Normalize(self.mean, self.std), 55 | ] 56 | ) 57 | dataset = CIFAR10(root=self.hparams.data_dir, train=True, transform=transform) 58 | dataloader = DataLoader( 59 | dataset, 60 | batch_size=self.hparams.batch_size, 61 | num_workers=self.hparams.num_workers, 62 | shuffle=True, 63 | drop_last=True, 64 | pin_memory=True, 65 | ) 66 | return dataloader 67 | 68 | def val_dataloader(self): 69 | transform = T.Compose( 70 | [ 71 | T.ToTensor(), 72 | T.Normalize(self.mean, self.std), 73 | ] 74 | ) 75 | dataset = CIFAR10(root=self.hparams.data_dir, train=False, transform=transform) 76 | dataloader = DataLoader( 77 | dataset, 78 | batch_size=self.hparams.batch_size, 79 | num_workers=self.hparams.num_workers, 80 | drop_last=True, 81 | pin_memory=True, 82 | ) 83 | return dataloader 84 | 85 | def test_dataloader(self): 86 | return self.val_dataloader() 87 | -------------------------------------------------------------------------------- /module.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | from pytorch_lightning.metrics import Accuracy 4 | 5 | from cifar10_models.densenet import densenet121, densenet161, densenet169 6 | from cifar10_models.googlenet import googlenet 7 | from cifar10_models.inception import inception_v3 8 | from cifar10_models.mobilenetv2 import mobilenet_v2 9 | from cifar10_models.resnet import resnet18, resnet34, resnet50 10 | from cifar10_models.vgg import vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn 11 | from schduler import WarmupCosineLR 12 | 13 | all_classifiers = { 14 | "vgg11_bn": vgg11_bn(), 15 | "vgg13_bn": vgg13_bn(), 16 | "vgg16_bn": vgg16_bn(), 17 | "vgg19_bn": vgg19_bn(), 18 | "resnet18": resnet18(), 19 | "resnet34": resnet34(), 20 | "resnet50": resnet50(), 21 | "densenet121": densenet121(), 22 | "densenet161": densenet161(), 23 | "densenet169": densenet169(), 24 | "mobilenet_v2": mobilenet_v2(), 25 | "googlenet": googlenet(), 26 | "inception_v3": inception_v3(), 27 | } 28 | 29 | 30 | class CIFAR10Module(pl.LightningModule): 31 | def __init__(self, hparams): 32 | super().__init__() 33 | self.hparams = hparams 34 | 35 | self.criterion = torch.nn.CrossEntropyLoss() 36 | self.accuracy = Accuracy() 37 | 38 | self.model = all_classifiers[self.hparams.classifier] 39 | 40 | def forward(self, batch): 41 | images, labels = batch 42 | predictions = self.model(images) 43 | loss = self.criterion(predictions, labels) 44 | accuracy = self.accuracy(predictions, labels) 45 | return loss, accuracy * 100 46 | 47 | def training_step(self, batch, batch_nb): 48 | loss, accuracy = self.forward(batch) 49 | self.log("loss/train", loss) 50 | self.log("acc/train", accuracy) 51 | return loss 52 | 53 | def validation_step(self, batch, batch_nb): 54 | loss, accuracy = self.forward(batch) 55 | self.log("loss/val", loss) 56 | self.log("acc/val", accuracy) 57 | 58 | def test_step(self, batch, batch_nb): 59 | loss, accuracy = self.forward(batch) 60 | self.log("acc/test", accuracy) 61 | 62 | def configure_optimizers(self): 63 | optimizer = torch.optim.SGD( 64 | self.model.parameters(), 65 | lr=self.hparams.learning_rate, 66 | weight_decay=self.hparams.weight_decay, 67 | momentum=0.9, 68 | nesterov=True, 69 | ) 70 | total_steps = self.hparams.max_epochs * len(self.train_dataloader()) 71 | scheduler = { 72 | "scheduler": WarmupCosineLR( 73 | optimizer, warmup_epochs=total_steps * 0.3, max_epochs=total_steps 74 | ), 75 | "interval": "step", 76 | "name": "learning_rate", 77 | } 78 | return [optimizer], [scheduler] 79 | -------------------------------------------------------------------------------- /schduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | from typing import List 4 | 5 | from torch.optim import Optimizer 6 | from torch.optim.lr_scheduler import _LRScheduler 7 | 8 | 9 | class WarmupCosineLR(_LRScheduler): 10 | """ 11 | Sets the learning rate of each parameter group to follow a linear warmup schedule 12 | between warmup_start_lr and base_lr followed by a cosine annealing schedule between 13 | base_lr and eta_min. 14 | .. warning:: 15 | It is recommended to call :func:`.step()` for :class:`LinearWarmupCosineAnnealingLR` 16 | after each iteration as calling it after each epoch will keep the starting lr at 17 | warmup_start_lr for the first epoch which is 0 in most cases. 18 | .. warning:: 19 | passing epoch to :func:`.step()` is being deprecated and comes with an EPOCH_DEPRECATION_WARNING. 20 | It calls the :func:`_get_closed_form_lr()` method for this scheduler instead of 21 | :func:`get_lr()`. Though this does not change the behavior of the scheduler, when passing 22 | epoch param to :func:`.step()`, the user should call the :func:`.step()` function before calling 23 | train and validation methods. 24 | Args: 25 | optimizer (Optimizer): Wrapped optimizer. 26 | warmup_epochs (int): Maximum number of iterations for linear warmup 27 | max_epochs (int): Maximum number of iterations 28 | warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0. 29 | eta_min (float): Minimum learning rate. Default: 0. 30 | last_epoch (int): The index of last epoch. Default: -1. 31 | Example: 32 | >>> layer = nn.Linear(10, 1) 33 | >>> optimizer = Adam(layer.parameters(), lr=0.02) 34 | >>> scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=10, max_epochs=40) 35 | >>> # 36 | >>> # the default case 37 | >>> for epoch in range(40): 38 | ... # train(...) 39 | ... # validate(...) 40 | ... scheduler.step() 41 | >>> # 42 | >>> # passing epoch param case 43 | >>> for epoch in range(40): 44 | ... scheduler.step(epoch) 45 | ... # train(...) 46 | ... # validate(...) 47 | """ 48 | 49 | def __init__( 50 | self, 51 | optimizer: Optimizer, 52 | warmup_epochs: int, 53 | max_epochs: int, 54 | warmup_start_lr: float = 1e-8, 55 | eta_min: float = 1e-8, 56 | last_epoch: int = -1, 57 | ) -> None: 58 | 59 | self.warmup_epochs = warmup_epochs 60 | self.max_epochs = max_epochs 61 | self.warmup_start_lr = warmup_start_lr 62 | self.eta_min = eta_min 63 | 64 | super(WarmupCosineLR, self).__init__(optimizer, last_epoch) 65 | 66 | def get_lr(self) -> List[float]: 67 | """ 68 | Compute learning rate using chainable form of the scheduler 69 | """ 70 | if not self._get_lr_called_within_step: 71 | warnings.warn( 72 | "To get the last learning rate computed by the scheduler, " 73 | "please use `get_last_lr()`.", 74 | UserWarning, 75 | ) 76 | 77 | if self.last_epoch == 0: 78 | return [self.warmup_start_lr] * len(self.base_lrs) 79 | elif self.last_epoch < self.warmup_epochs: 80 | return [ 81 | group["lr"] 82 | + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) 83 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 84 | ] 85 | elif self.last_epoch == self.warmup_epochs: 86 | return self.base_lrs 87 | elif (self.last_epoch - 1 - self.max_epochs) % ( 88 | 2 * (self.max_epochs - self.warmup_epochs) 89 | ) == 0: 90 | return [ 91 | group["lr"] 92 | + (base_lr - self.eta_min) 93 | * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) 94 | / 2 95 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 96 | ] 97 | 98 | return [ 99 | ( 100 | 1 101 | + math.cos( 102 | math.pi 103 | * (self.last_epoch - self.warmup_epochs) 104 | / (self.max_epochs - self.warmup_epochs) 105 | ) 106 | ) 107 | / ( 108 | 1 109 | + math.cos( 110 | math.pi 111 | * (self.last_epoch - self.warmup_epochs - 1) 112 | / (self.max_epochs - self.warmup_epochs) 113 | ) 114 | ) 115 | * (group["lr"] - self.eta_min) 116 | + self.eta_min 117 | for group in self.optimizer.param_groups 118 | ] 119 | 120 | def _get_closed_form_lr(self) -> List[float]: 121 | """ 122 | Called when epoch is passed as a param to the `step` function of the scheduler. 123 | """ 124 | if self.last_epoch < self.warmup_epochs: 125 | return [ 126 | self.warmup_start_lr 127 | + self.last_epoch 128 | * (base_lr - self.warmup_start_lr) 129 | / (self.warmup_epochs - 1) 130 | for base_lr in self.base_lrs 131 | ] 132 | 133 | return [ 134 | self.eta_min 135 | + 0.5 136 | * (base_lr - self.eta_min) 137 | * ( 138 | 1 139 | + math.cos( 140 | math.pi 141 | * (self.last_epoch - self.warmup_epochs) 142 | / (self.max_epochs - self.warmup_epochs) 143 | ) 144 | ) 145 | for base_lr in self.base_lrs 146 | ] 147 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | 4 | import torch 5 | from pytorch_lightning import Trainer, seed_everything 6 | from pytorch_lightning.callbacks import ModelCheckpoint 7 | from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger 8 | 9 | from data import CIFAR10Data 10 | from module import CIFAR10Module 11 | 12 | 13 | def main(args): 14 | 15 | if bool(args.download_weights): 16 | CIFAR10Data.download_weights() 17 | else: 18 | seed_everything(0) 19 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 20 | 21 | if args.logger == "wandb": 22 | logger = WandbLogger(name=args.classifier, project="cifar10") 23 | elif args.logger == "tensorboard": 24 | logger = TensorBoardLogger("cifar10", name=args.classifier) 25 | 26 | checkpoint = ModelCheckpoint(monitor="acc/val", mode="max", save_last=False) 27 | 28 | trainer = Trainer( 29 | fast_dev_run=bool(args.dev), 30 | logger=logger if not bool(args.dev + args.test_phase) else None, 31 | gpus=-1, 32 | deterministic=True, 33 | weights_summary=None, 34 | log_every_n_steps=1, 35 | max_epochs=args.max_epochs, 36 | checkpoint_callback=checkpoint, 37 | precision=args.precision, 38 | ) 39 | 40 | model = CIFAR10Module(args) 41 | data = CIFAR10Data(args) 42 | 43 | if bool(args.pretrained): 44 | state_dict = os.path.join( 45 | "cifar10_models", "state_dicts", args.classifier + ".pt" 46 | ) 47 | model.model.load_state_dict(torch.load(state_dict)) 48 | 49 | if bool(args.test_phase): 50 | trainer.test(model, data.test_dataloader()) 51 | else: 52 | trainer.fit(model, data) 53 | trainer.test() 54 | 55 | 56 | if __name__ == "__main__": 57 | parser = ArgumentParser() 58 | 59 | # PROGRAM level args 60 | parser.add_argument("--data_dir", type=str, default="/data/huy/cifar10") 61 | parser.add_argument("--download_weights", type=int, default=0, choices=[0, 1]) 62 | parser.add_argument("--test_phase", type=int, default=0, choices=[0, 1]) 63 | parser.add_argument("--dev", type=int, default=0, choices=[0, 1]) 64 | parser.add_argument( 65 | "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"] 66 | ) 67 | 68 | # TRAINER args 69 | parser.add_argument("--classifier", type=str, default="resnet18") 70 | parser.add_argument("--pretrained", type=int, default=0, choices=[0, 1]) 71 | 72 | parser.add_argument("--precision", type=int, default=32, choices=[16, 32]) 73 | parser.add_argument("--batch_size", type=int, default=256) 74 | parser.add_argument("--max_epochs", type=int, default=100) 75 | parser.add_argument("--num_workers", type=int, default=8) 76 | parser.add_argument("--gpu_id", type=str, default="3") 77 | 78 | parser.add_argument("--learning_rate", type=float, default=1e-2) 79 | parser.add_argument("--weight_decay", type=float, default=1e-2) 80 | 81 | args = parser.parse_args() 82 | main(args) 83 | --------------------------------------------------------------------------------