├── .gitignore ├── README.md ├── archs ├── cifar10 │ ├── AlexNet.py │ ├── LeNet5.py │ ├── densenet.py │ ├── fc1.py │ ├── resnet.py │ └── vgg.py ├── cifar100 │ ├── AlexNet.py │ ├── LeNet5.py │ ├── fc1.py │ ├── resnet.py │ └── vgg.py └── mnist │ ├── AlexNet.py │ ├── LeNet5.py │ ├── fc1.py │ ├── resnet.py │ └── vgg.py ├── combine_plots.py ├── main.py ├── requirements.txt └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | /data 2 | /dumps 3 | /plots 4 | /runs 5 | /saves -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Lottery Ticket Hypothesis in Pytorch 2 | [![Made With python 3.7](https://img.shields.io/badge/Made%20with-Python%203.7-brightgreen)]() [![Maintenance](https://img.shields.io/badge/Maintained%3F-no-red.svg)]() [![Open Source Love svg1](https://badges.frapsoft.com/os/v1/open-source.svg?v=103)]() 3 | 4 | This repository contains a **Pytorch** implementation of the paper [The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks](https://arxiv.org/abs/1803.03635) by [Jonathan Frankle](https://github.com/jfrankle) and [Michael Carbin](https://people.csail.mit.edu/mcarbin/) that can be **easily adapted to any model/dataset**. 5 | 6 | ## Requirements 7 | ``` 8 | pip3 install -r requirements.txt 9 | ``` 10 | ## How to run the code ? 11 | ### Using datasets/architectures included with this repository : 12 | ``` 13 | python3 main.py --prune_type=lt --arch_type=fc1 --dataset=mnist --prune_percent=10 --prune_iterations=35 14 | ``` 15 | - `--prune_type` : Type of pruning 16 | - Options : `lt` - Lottery Ticket Hypothesis, `reinit` - Random reinitialization 17 | - Default : `lt` 18 | - `--arch_type` : Type of architecture 19 | - Options : `fc1` - Simple fully connected network, `lenet5` - LeNet5, `AlexNet` - AlexNet, `resnet18` - Resnet18, `vgg16` - VGG16 20 | - Default : `fc1` 21 | - `--dataset` : Choice of dataset 22 | - Options : `mnist`, `fashionmnist`, `cifar10`, `cifar100` 23 | - Default : `mnist` 24 | - `--prune_percent` : Percentage of weight to be pruned after each cycle. 25 | - Default : `10` 26 | - `--prune_iterations` : Number of cycle of pruning that should be done. 27 | - Default : `35` 28 | - `--lr` : Learning rate 29 | - Default : `1.2e-3` 30 | - `--batch_size` : Batch size 31 | - Default : `60` 32 | - `--end_iter` : Number of Epochs 33 | - Default : `100` 34 | - `--print_freq` : Frequency for printing accuracy and loss 35 | - Default : `1` 36 | - `--valid_freq` : Frequency for Validation 37 | - Default : `1` 38 | - `--gpu` : Decide Which GPU the program should use 39 | - Default : `0` 40 | ### Using datasets/architectures that are not included with this repository : 41 | - Adding a new architecture : 42 | - For example, if you want to add an architecture named `new_model` with `mnist` dataset compatibility. 43 | - Go to `/archs/mnist/` directory and create a file `new_model.py`. 44 | - Now paste your **Pytorch compatible** model inside `new_model.py`. 45 | - **IMPORTANT** : Make sure the *input size*, *number of classes*, *number of channels*, *batch size* in your `new_model.py` matches with the corresponding dataset that you are adding (in this case, it is `mnist`). 46 | - Now open `main.py` and go to `line 36` and look for the comment `# Data Loader`. Now find your corresponding dataset (in this case, `mnist`) and add `new_model` at the end of the line `from archs.mnist import AlexNet, LeNet5, fc1, vgg, resnet`. 47 | - Now go to `line 82` and add the following to it : 48 | ``` 49 | elif args.arch_type == "new_model": 50 | model = new_model.new_model_name().to(device) 51 | ``` 52 | Here, `new_model_name()` is the name of the model that you have given inside `new_model.py`. 53 | - Adding a new dataset : 54 | - For example, if you want to add a dataset named `new_dataset` with `fc1` architecture compatibility. 55 | - Go to `/archs` and create a directory named `new_dataset`. 56 | - Now go to /archs/new_dataset/` and add a file named `fc1.py` or copy paste it from existing dataset folder. 57 | - **IMPORTANT** : Make sure the *input size*, *number of classes*, *number of channels*, *batch size* in your `new_model.py` matches with the corresponding dataset that you are adding (in this case, it is `new_dataset`). 58 | - Now open `main.py` and goto `line 58` and add the following to it : 59 | ``` 60 | elif args.dataset == "cifar100": 61 | traindataset = datasets.new_dataset('../data', train=True, download=True, transform=transform) 62 | testdataset = datasets.new_dataset('../data', train=False, transform=transform)from archs.new_dataset import fc1 63 | ``` 64 | **Note** that as of now, you can only add dataset that are [natively available in Pytorch](https://pytorch.org/docs/stable/torchvision/datasets.html). 65 | 66 | ## How to combine the plots of various `prune_type` ? 67 | - Go to `combine_plots.py` and add/remove the datasets/archs who's combined plot you want to generate (*Assuming that you have already executed the `main.py` code for those dataset/archs and produced the weights*). 68 | - Run `python3 combine_plots.py`. 69 | - Go to `/plots/lt/combined_plots/` to see the graphs. 70 | 71 | Kindly [raise an issue](https://github.com/rahulvigneswaran/Lottery-Ticket-Hypothesis-in-Pytorch/issues) if you have any problem with the instructions. 72 | 73 | 74 | ## Datasets and Architectures that were already tested 75 | 76 | | | fc1 | LeNet5 | AlexNet | VGG16 | Resnet18 | 77 | |--------------|:------------------:|:---------------------:|:----------------------:|:--------------------:|:------------------------:| 78 | | MNIST | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | 79 | | CIFAR10 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | 80 | | FashionMNIST | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | 81 | | CIFAR100 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | 82 | 83 | 84 | ## Repository Structure 85 | ``` 86 | Lottery-Ticket-Hypothesis-in-Pytorch 87 | ├── archs 88 | │   ├── cifar10 89 | │   │   ├── AlexNet.py 90 | │   │   ├── densenet.py 91 | │   │   ├── fc1.py 92 | │   │   ├── LeNet5.py 93 | │   │   ├── resnet.py 94 | │   │   └── vgg.py 95 | │   ├── cifar100 96 | │   │   ├── AlexNet.py 97 | │   │   ├── fc1.py 98 | │   │   ├── LeNet5.py 99 | │   │   ├── resnet.py 100 | │   │   └── vgg.py 101 | │   └── mnist 102 | │   ├── AlexNet.py 103 | │   ├── fc1.py 104 | │   ├── LeNet5.py 105 | │   ├── resnet.py 106 | │   └── vgg.py 107 | ├── combine_plots.py 108 | ├── dumps 109 | ├── main.py 110 | ├── plots 111 | ├── README.md 112 | ├── requirements.txt 113 | ├── saves 114 | └── utils.py 115 | 116 | ``` 117 | 118 | ## Interesting papers that are related to Lottery Ticket Hypothesis which I enjoyed 119 | - [Deconstructing Lottery Tickets: Zeros, Signs, and the Supermask](https://eng.uber.com/deconstructing-lottery-tickets/) 120 | 121 | ## Acknowledgement 122 | Parts of code were borrowed from [ktkth5](https://github.com/ktkth5/lottery-ticket-hyopothesis). 123 | 124 | ## Issue / Want to Contribute ? : 125 | Open a new issue or do a pull request incase you are facing any difficulty with the code base or if you want to contribute to it. 126 | 127 | [![forthebadge](https://forthebadge.com/images/badges/built-with-love.svg)](https://github.com/rahulvigneswaran/Lottery-Ticket-Hypothesis-in-Pytorch/issues) 128 | 129 | Buy Me A Coffee 130 | 131 | -------------------------------------------------------------------------------- /archs/cifar10/AlexNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | __all__ = ['AlexNet', 'alexnet'] 6 | 7 | 8 | model_urls = { 9 | 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth', 10 | } 11 | 12 | 13 | class AlexNet(nn.Module): 14 | 15 | def __init__(self, num_classes=10): 16 | super(AlexNet, self).__init__() 17 | self.features = nn.Sequential( 18 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=2), 19 | nn.ReLU(inplace=True), 20 | nn.MaxPool2d(kernel_size=3, stride=2), 21 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 22 | nn.ReLU(inplace=True), 23 | nn.MaxPool2d(kernel_size=3, stride=2), 24 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 25 | nn.ReLU(inplace=True), 26 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 27 | nn.ReLU(inplace=True), 28 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 29 | nn.ReLU(inplace=True), 30 | nn.MaxPool2d(kernel_size=3, stride=2), 31 | ) 32 | self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) 33 | self.classifier = nn.Sequential( 34 | nn.Dropout(), 35 | nn.Linear(256 * 6 * 6, 4096), 36 | nn.ReLU(inplace=True), 37 | nn.Dropout(), 38 | nn.Linear(4096, 4096), 39 | nn.ReLU(inplace=True), 40 | nn.Linear(4096, num_classes), 41 | ) 42 | 43 | def forward(self, x): 44 | x = self.features(x) 45 | x = self.avgpool(x) 46 | x = torch.flatten(x, 1) 47 | x = self.classifier(x) 48 | return x 49 | -------------------------------------------------------------------------------- /archs/cifar10/LeNet5.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as func 3 | 4 | 5 | class LeNet5(nn.Module): 6 | def __init__(self, num_classes=10): 7 | super(LeNet5, self).__init__() 8 | self.conv1 = nn.Conv2d(3, 6, kernel_size=5) 9 | self.conv2 = nn.Conv2d(6, 16, kernel_size=5) 10 | self.fc1 = nn.Linear(16*5*5, 120) 11 | self.fc2 = nn.Linear(120, 84) 12 | self.fc3 = nn.Linear(84, num_classes) 13 | 14 | def forward(self, x): 15 | x = func.relu(self.conv1(x)) 16 | x = func.max_pool2d(x, 2) 17 | x = func.relu(self.conv2(x)) 18 | x = func.max_pool2d(x, 2) 19 | x = x.view(x.size(0), -1) 20 | x = func.relu(self.fc1(x)) 21 | x = func.relu(self.fc2(x)) 22 | x = self.fc3(x) 23 | return x -------------------------------------------------------------------------------- /archs/cifar10/densenet.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.checkpoint as cp 6 | from collections import OrderedDict 7 | 8 | def _bn_function_factory(norm, relu, conv): 9 | def bn_function(*inputs): 10 | concated_features = torch.cat(inputs, 1) 11 | bottleneck_output = conv(relu(norm(concated_features))) 12 | return bottleneck_output 13 | 14 | return bn_function 15 | 16 | 17 | class _DenseLayer(nn.Sequential): 18 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient=False): 19 | super(_DenseLayer, self).__init__() 20 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)), 21 | self.add_module('relu1', nn.ReLU(inplace=True)), 22 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * 23 | growth_rate, kernel_size=1, stride=1, 24 | bias=False)), 25 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 26 | self.add_module('relu2', nn.ReLU(inplace=True)), 27 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 28 | kernel_size=3, stride=1, padding=1, 29 | bias=False)), 30 | self.drop_rate = drop_rate 31 | self.memory_efficient = memory_efficient 32 | 33 | def forward(self, *prev_features): 34 | bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1) 35 | if self.memory_efficient and any(prev_feature.requires_grad for prev_feature in prev_features): 36 | bottleneck_output = cp.checkpoint(bn_function, *prev_features) 37 | else: 38 | bottleneck_output = bn_function(*prev_features) 39 | new_features = self.conv2(self.relu2(self.norm2(bottleneck_output))) 40 | if self.drop_rate > 0: 41 | new_features = F.dropout(new_features, p=self.drop_rate, 42 | training=self.training) 43 | return new_features 44 | 45 | 46 | class _DenseBlock(nn.Module): 47 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, memory_efficient=False): 48 | super(_DenseBlock, self).__init__() 49 | for i in range(num_layers): 50 | layer = _DenseLayer( 51 | num_input_features + i * growth_rate, 52 | growth_rate=growth_rate, 53 | bn_size=bn_size, 54 | drop_rate=drop_rate, 55 | memory_efficient=memory_efficient, 56 | ) 57 | self.add_module('denselayer%d' % (i + 1), layer) 58 | 59 | def forward(self, init_features): 60 | features = [init_features] 61 | for name, layer in self.named_children(): 62 | new_features = layer(*features) 63 | features.append(new_features) 64 | return torch.cat(features, 1) 65 | 66 | 67 | class _Transition(nn.Sequential): 68 | def __init__(self, num_input_features, num_output_features): 69 | super(_Transition, self).__init__() 70 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 71 | self.add_module('relu', nn.ReLU(inplace=True)) 72 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 73 | kernel_size=1, stride=1, bias=False)) 74 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 75 | 76 | 77 | class DenseNet(nn.Module): 78 | r"""Densenet-BC model class, based on 79 | `"Densely Connected Convolutional Networks" `_ 80 | 81 | Args: 82 | growth_rate (int) - how many filters to add each layer (`k` in paper) 83 | block_config (list of 4 ints) - how many layers in each pooling block 84 | num_init_features (int) - the number of filters to learn in the first convolution layer 85 | bn_size (int) - multiplicative factor for number of bottle neck layers 86 | (i.e. bn_size * k features in the bottleneck layer) 87 | drop_rate (float) - dropout rate after each dense layer 88 | num_classes (int) - number of classification classes 89 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 90 | but slower. Default: *False*. See `"paper" `_ 91 | """ 92 | 93 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), 94 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=10, memory_efficient=False): 95 | 96 | super(DenseNet, self).__init__() 97 | 98 | # First convolution 99 | self.features = nn.Sequential(OrderedDict([ 100 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, 101 | padding=3, bias=False)), 102 | ('norm0', nn.BatchNorm2d(num_init_features)), 103 | ('relu0', nn.ReLU(inplace=True)), 104 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 105 | ])) 106 | 107 | # Each denseblock 108 | num_features = num_init_features 109 | for i, num_layers in enumerate(block_config): 110 | block = _DenseBlock( 111 | num_layers=num_layers, 112 | num_input_features=num_features, 113 | bn_size=bn_size, 114 | growth_rate=growth_rate, 115 | drop_rate=drop_rate, 116 | memory_efficient=memory_efficient 117 | ) 118 | self.features.add_module('denseblock%d' % (i + 1), block) 119 | num_features = num_features + num_layers * growth_rate 120 | if i != len(block_config) - 1: 121 | trans = _Transition(num_input_features=num_features, 122 | num_output_features=num_features // 2) 123 | self.features.add_module('transition%d' % (i + 1), trans) 124 | num_features = num_features // 2 125 | 126 | # Final batch norm 127 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 128 | 129 | # Linear layer 130 | self.classifier = nn.Linear(num_features, num_classes) 131 | 132 | # Official init from torch repo. 133 | for m in self.modules(): 134 | if isinstance(m, nn.Conv2d): 135 | nn.init.kaiming_normal_(m.weight) 136 | elif isinstance(m, nn.BatchNorm2d): 137 | nn.init.constant_(m.weight, 1) 138 | nn.init.constant_(m.bias, 0) 139 | elif isinstance(m, nn.Linear): 140 | nn.init.constant_(m.bias, 0) 141 | 142 | def forward(self, x): 143 | features = self.features(x) 144 | out = F.relu(features, inplace=True) 145 | out = F.adaptive_avg_pool2d(out, (1, 1)) 146 | out = torch.flatten(out, 1) 147 | out = self.classifier(out) 148 | return out 149 | 150 | 151 | def _load_state_dict(model, model_url, progress): 152 | # '.'s are no longer allowed in module names, but previous _DenseLayer 153 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 154 | # They are also in the checkpoints in model_urls. This pattern is used 155 | # to find such keys. 156 | pattern = re.compile( 157 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 158 | 159 | state_dict = load_state_dict_from_url(model_url, progress=progress) 160 | for key in list(state_dict.keys()): 161 | res = pattern.match(key) 162 | if res: 163 | new_key = res.group(1) + res.group(2) 164 | state_dict[new_key] = state_dict[key] 165 | del state_dict[key] 166 | model.load_state_dict(state_dict) 167 | 168 | 169 | def _densenet(arch, growth_rate, block_config, num_init_features, pretrained, progress, 170 | **kwargs): 171 | model = DenseNet(growth_rate, block_config, num_init_features, **kwargs) 172 | if pretrained: 173 | _load_state_dict(model, model_urls[arch], progress) 174 | return model 175 | 176 | 177 | def densenet121(pretrained=False, progress=True, **kwargs): 178 | r"""Densenet-121 model from 179 | `"Densely Connected Convolutional Networks" `_ 180 | 181 | Args: 182 | pretrained (bool): If True, returns a model pre-trained on ImageNet 183 | progress (bool): If True, displays a progress bar of the download to stderr 184 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 185 | but slower. Default: *False*. See `"paper" `_ 186 | """ 187 | return _densenet('densenet121', 32, (6, 12, 24, 16), 64, pretrained, progress, 188 | **kwargs) 189 | 190 | 191 | 192 | def densenet161(pretrained=False, progress=True, **kwargs): 193 | r"""Densenet-161 model from 194 | `"Densely Connected Convolutional Networks" `_ 195 | 196 | Args: 197 | pretrained (bool): If True, returns a model pre-trained on ImageNet 198 | progress (bool): If True, displays a progress bar of the download to stderr 199 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 200 | but slower. Default: *False*. See `"paper" `_ 201 | """ 202 | return _densenet('densenet161', 48, (6, 12, 36, 24), 96, pretrained, progress, 203 | **kwargs) 204 | 205 | 206 | 207 | def densenet169(pretrained=False, progress=True, **kwargs): 208 | r"""Densenet-169 model from 209 | `"Densely Connected Convolutional Networks" `_ 210 | 211 | Args: 212 | pretrained (bool): If True, returns a model pre-trained on ImageNet 213 | progress (bool): If True, displays a progress bar of the download to stderr 214 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 215 | but slower. Default: *False*. See `"paper" `_ 216 | """ 217 | return _densenet('densenet169', 32, (6, 12, 32, 32), 64, pretrained, progress, 218 | **kwargs) 219 | 220 | 221 | 222 | def densenet201(pretrained=False, progress=True, **kwargs): 223 | r"""Densenet-201 model from 224 | `"Densely Connected Convolutional Networks" `_ 225 | 226 | Args: 227 | pretrained (bool): If True, returns a model pre-trained on ImageNet 228 | progress (bool): If True, displays a progress bar of the download to stderr 229 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 230 | but slower. Default: *False*. See `"paper" `_ 231 | """ 232 | return _densenet('densenet201', 32, (6, 12, 48, 32), 64, pretrained, progress, 233 | **kwargs) 234 | -------------------------------------------------------------------------------- /archs/cifar10/fc1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class fc1(nn.Module): 5 | 6 | def __init__(self, num_classes=10): 7 | super(fc1, self).__init__() 8 | self.classifier = nn.Sequential( 9 | nn.Linear(3*32*32, 300), 10 | nn.ReLU(inplace=True), 11 | nn.Linear(300, 100), 12 | nn.ReLU(inplace=True), 13 | nn.Linear(100, num_classes), 14 | ) 15 | 16 | def forward(self, x): 17 | x = torch.flatten(x, 1) 18 | x = self.classifier(x) 19 | return x 20 | -------------------------------------------------------------------------------- /archs/cifar10/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | For Pre-activation ResNet, see 'preact_resnet.py'. 3 | Reference: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class BasicBlock(nn.Module): 13 | expansion = 1 14 | 15 | def __init__(self, in_planes, planes, stride=1): 16 | super(BasicBlock, self).__init__() 17 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | 22 | self.shortcut = nn.Sequential() 23 | if stride != 1 or in_planes != self.expansion*planes: 24 | self.shortcut = nn.Sequential( 25 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 26 | nn.BatchNorm2d(self.expansion*planes) 27 | ) 28 | 29 | def forward(self, x): 30 | out = F.relu(self.bn1(self.conv1(x))) 31 | out = self.bn2(self.conv2(out)) 32 | out += self.shortcut(x) 33 | out = F.relu(out) 34 | return out 35 | 36 | 37 | class Bottleneck(nn.Module): 38 | expansion = 4 39 | 40 | def __init__(self, in_planes, planes, stride=1): 41 | super(Bottleneck, self).__init__() 42 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 43 | self.bn1 = nn.BatchNorm2d(planes) 44 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 45 | self.bn2 = nn.BatchNorm2d(planes) 46 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 47 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 48 | 49 | self.shortcut = nn.Sequential() 50 | if stride != 1 or in_planes != self.expansion*planes: 51 | self.shortcut = nn.Sequential( 52 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 53 | nn.BatchNorm2d(self.expansion*planes) 54 | ) 55 | 56 | def forward(self, x): 57 | out = F.relu(self.bn1(self.conv1(x))) 58 | out = F.relu(self.bn2(self.conv2(out))) 59 | out = self.bn3(self.conv3(out)) 60 | out += self.shortcut(x) 61 | out = F.relu(out) 62 | return out 63 | 64 | 65 | class ResNet(nn.Module): 66 | def __init__(self, block, num_blocks, num_classes=10): 67 | super(ResNet, self).__init__() 68 | self.in_planes = 64 69 | 70 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 71 | self.bn1 = nn.BatchNorm2d(64) 72 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 73 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 74 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 75 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 76 | self.linear = nn.Linear(512*block.expansion, num_classes) 77 | 78 | def _make_layer(self, block, planes, num_blocks, stride): 79 | strides = [stride] + [1]*(num_blocks-1) 80 | layers = [] 81 | for stride in strides: 82 | layers.append(block(self.in_planes, planes, stride)) 83 | self.in_planes = planes * block.expansion 84 | return nn.Sequential(*layers) 85 | 86 | def forward(self, x): 87 | out = F.relu(self.bn1(self.conv1(x))) 88 | out = self.layer1(out) 89 | out = self.layer2(out) 90 | out = self.layer3(out) 91 | out = self.layer4(out) 92 | out = F.avg_pool2d(out, 4) 93 | out = out.view(out.size(0), -1) 94 | out = self.linear(out) 95 | return out 96 | 97 | 98 | def resnet18(): 99 | return ResNet(BasicBlock, [2,2,2,2]) 100 | 101 | def ResNet34(): 102 | return ResNet(BasicBlock, [3,4,6,3]) 103 | 104 | def ResNet50(): 105 | return ResNet(Bottleneck, [3,4,6,3]) 106 | 107 | def ResNet101(): 108 | return ResNet(Bottleneck, [3,4,23,3]) 109 | 110 | def ResNet152(): 111 | return ResNet(Bottleneck, [3,8,36,3]) 112 | 113 | 114 | def test(): 115 | net = ResNet18() 116 | y = net(torch.randn(1,3,32,32)) 117 | print(y.size()) 118 | 119 | # test() -------------------------------------------------------------------------------- /archs/cifar10/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | # 4 | # from torchvision.utils import load_state_dict_from_url 5 | 6 | 7 | __all__ = [ 8 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 9 | 'vgg19_bn', 'vgg19', 10 | ] 11 | 12 | 13 | model_urls = { 14 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 15 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 16 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 17 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 18 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 19 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', 20 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 21 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', 22 | } 23 | 24 | 25 | class VGG(nn.Module): 26 | #ANCHOR Change No. of Classes here. 27 | def __init__(self, features, num_classes=10, init_weights=True): 28 | super(VGG, self).__init__() 29 | self.features = features 30 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 31 | self.classifier = nn.Sequential( 32 | nn.Linear(512 * 7 * 7, 4096), 33 | nn.ReLU(True), 34 | nn.Dropout(), 35 | nn.Linear(4096, 4096), 36 | nn.ReLU(True), 37 | nn.Dropout(), 38 | nn.Linear(4096, num_classes), 39 | ) 40 | if init_weights: 41 | self._initialize_weights() 42 | 43 | def forward(self, x): 44 | x = self.features(x) 45 | x = self.avgpool(x) 46 | x = torch.flatten(x, 1) 47 | x = self.classifier(x) 48 | return x 49 | 50 | def _initialize_weights(self): 51 | for m in self.modules(): 52 | if isinstance(m, nn.Conv2d): 53 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 54 | if m.bias is not None: 55 | nn.init.constant_(m.bias, 0) 56 | elif isinstance(m, nn.BatchNorm2d): 57 | nn.init.constant_(m.weight, 1) 58 | nn.init.constant_(m.bias, 0) 59 | elif isinstance(m, nn.Linear): 60 | nn.init.normal_(m.weight, 0, 0.01) 61 | nn.init.constant_(m.bias, 0) 62 | 63 | 64 | def make_layers(cfg, batch_norm=False): 65 | layers = [] 66 | #ANCHOR Change No. of Input channels here. 67 | in_channels = 3 68 | for v in cfg: 69 | if v == 'M': 70 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 71 | else: 72 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 73 | if batch_norm: 74 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 75 | else: 76 | layers += [conv2d, nn.ReLU(inplace=True)] 77 | in_channels = v 78 | return nn.Sequential(*layers) 79 | 80 | 81 | cfgs = { 82 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 83 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 84 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 85 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 86 | } 87 | 88 | 89 | def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs): 90 | if pretrained: 91 | kwargs['init_weights'] = False 92 | model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) 93 | #if pretrained: 94 | #state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) 95 | #model.load_state_dict(state_dict) 96 | return model 97 | 98 | 99 | def vgg11(pretrained=False, progress=True, **kwargs): 100 | r"""VGG 11-layer model (configuration "A") from 101 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 102 | 103 | Args: 104 | pretrained (bool): If True, returns a model pre-trained on ImageNet 105 | progress (bool): If True, displays a progress bar of the download to stderr 106 | """ 107 | return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs) 108 | 109 | 110 | 111 | def vgg11_bn(pretrained=False, progress=True, **kwargs): 112 | r"""VGG 11-layer model (configuration "A") with batch normalization 113 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 114 | 115 | Args: 116 | pretrained (bool): If True, returns a model pre-trained on ImageNet 117 | progress (bool): If True, displays a progress bar of the download to stderr 118 | """ 119 | return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs) 120 | 121 | 122 | 123 | def vgg13(pretrained=False, progress=True, **kwargs): 124 | r"""VGG 13-layer model (configuration "B") 125 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 126 | 127 | Args: 128 | pretrained (bool): If True, returns a model pre-trained on ImageNet 129 | progress (bool): If True, displays a progress bar of the download to stderr 130 | """ 131 | return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs) 132 | 133 | 134 | 135 | def vgg13_bn(pretrained=False, progress=True, **kwargs): 136 | r"""VGG 13-layer model (configuration "B") with batch normalization 137 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 138 | 139 | Args: 140 | pretrained (bool): If True, returns a model pre-trained on ImageNet 141 | progress (bool): If True, displays a progress bar of the download to stderr 142 | """ 143 | return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs) 144 | 145 | 146 | 147 | def vgg16(pretrained=False, progress=True, **kwargs): 148 | r"""VGG 16-layer model (configuration "D") 149 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 150 | 151 | Args: 152 | pretrained (bool): If True, returns a model pre-trained on ImageNet 153 | progress (bool): If True, displays a progress bar of the download to stderr 154 | """ 155 | return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs) 156 | 157 | 158 | 159 | def vgg16_bn(pretrained=False, progress=True, **kwargs): 160 | r"""VGG 16-layer model (configuration "D") with batch normalization 161 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 162 | 163 | Args: 164 | pretrained (bool): If True, returns a model pre-trained on ImageNet 165 | progress (bool): If True, displays a progress bar of the download to stderr 166 | """ 167 | return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs) 168 | 169 | 170 | 171 | def vgg19(pretrained=False, progress=True, **kwargs): 172 | r"""VGG 19-layer model (configuration "E") 173 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 174 | 175 | Args: 176 | pretrained (bool): If True, returns a model pre-trained on ImageNet 177 | progress (bool): If True, displays a progress bar of the download to stderr 178 | """ 179 | return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs) 180 | 181 | 182 | 183 | def vgg19_bn(pretrained=False, progress=True, **kwargs): 184 | r"""VGG 19-layer model (configuration 'E') with batch normalization 185 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 186 | 187 | Args: 188 | pretrained (bool): If True, returns a model pre-trained on ImageNet 189 | progress (bool): If True, displays a progress bar of the download to stderr 190 | """ 191 | return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs) 192 | -------------------------------------------------------------------------------- /archs/cifar100/AlexNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | __all__ = ['AlexNet', 'alexnet'] 6 | 7 | 8 | model_urls = { 9 | 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth', 10 | } 11 | 12 | 13 | class AlexNet(nn.Module): 14 | 15 | def __init__(self, num_classes=100): 16 | super(AlexNet, self).__init__() 17 | self.features = nn.Sequential( 18 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=2), 19 | nn.ReLU(inplace=True), 20 | nn.MaxPool2d(kernel_size=3, stride=2), 21 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 22 | nn.ReLU(inplace=True), 23 | nn.MaxPool2d(kernel_size=3, stride=2), 24 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 25 | nn.ReLU(inplace=True), 26 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 27 | nn.ReLU(inplace=True), 28 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 29 | nn.ReLU(inplace=True), 30 | nn.MaxPool2d(kernel_size=3, stride=2), 31 | ) 32 | self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) 33 | self.classifier = nn.Sequential( 34 | nn.Dropout(), 35 | nn.Linear(256 * 6 * 6, 4096), 36 | nn.ReLU(inplace=True), 37 | nn.Dropout(), 38 | nn.Linear(4096, 4096), 39 | nn.ReLU(inplace=True), 40 | nn.Linear(4096, num_classes), 41 | ) 42 | 43 | def forward(self, x): 44 | x = self.features(x) 45 | x = self.avgpool(x) 46 | x = torch.flatten(x, 1) 47 | x = self.classifier(x) 48 | return x 49 | -------------------------------------------------------------------------------- /archs/cifar100/LeNet5.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as func 3 | 4 | 5 | class LeNet5(nn.Module): 6 | def __init__(self, num_classes=100): 7 | super(LeNet5, self).__init__() 8 | self.conv1 = nn.Conv2d(3, 6, kernel_size=5) 9 | self.conv2 = nn.Conv2d(6, 16, kernel_size=5) 10 | self.fc1 = nn.Linear(16*5*5, 120) 11 | self.fc2 = nn.Linear(120, 84) 12 | self.fc3 = nn.Linear(84, num_classes) 13 | 14 | def forward(self, x): 15 | x = func.relu(self.conv1(x)) 16 | x = func.max_pool2d(x, 2) 17 | x = func.relu(self.conv2(x)) 18 | x = func.max_pool2d(x, 2) 19 | x = x.view(x.size(0), -1) 20 | x = func.relu(self.fc1(x)) 21 | x = func.relu(self.fc2(x)) 22 | x = self.fc3(x) 23 | return x -------------------------------------------------------------------------------- /archs/cifar100/fc1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class fc1(nn.Module): 5 | 6 | def __init__(self, num_classes=100): 7 | super(fc1, self).__init__() 8 | self.classifier = nn.Sequential( 9 | nn.Linear(3*32*32, 300), 10 | nn.ReLU(inplace=True), 11 | nn.Linear(300, 100), 12 | nn.ReLU(inplace=True), 13 | nn.Linear(100, num_classes), 14 | ) 15 | 16 | def forward(self, x): 17 | x = torch.flatten(x, 1) 18 | x = self.classifier(x) 19 | return x 20 | -------------------------------------------------------------------------------- /archs/cifar100/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 5 | """3x3 convolution with padding""" 6 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 7 | padding=dilation, groups=groups, bias=False, dilation=dilation) 8 | 9 | 10 | def conv1x1(in_planes, out_planes, stride=1): 11 | """1x1 convolution""" 12 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 13 | 14 | 15 | class BasicBlock(nn.Module): 16 | expansion = 1 17 | __constants__ = ['downsample'] 18 | 19 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 20 | base_width=64, dilation=1, norm_layer=None): 21 | super(BasicBlock, self).__init__() 22 | if norm_layer is None: 23 | norm_layer = nn.BatchNorm2d 24 | if groups != 1 or base_width != 64: 25 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 26 | if dilation > 1: 27 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 28 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 29 | self.conv1 = conv3x3(inplanes, planes, stride) 30 | self.bn1 = norm_layer(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv2 = conv3x3(planes, planes) 33 | self.bn2 = norm_layer(planes) 34 | self.downsample = downsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | identity = x 39 | 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv2(out) 45 | out = self.bn2(out) 46 | 47 | if self.downsample is not None: 48 | identity = self.downsample(x) 49 | 50 | out += identity 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class Bottleneck(nn.Module): 57 | expansion = 4 58 | __constants__ = ['downsample'] 59 | 60 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 61 | base_width=64, dilation=1, norm_layer=None): 62 | super(Bottleneck, self).__init__() 63 | if norm_layer is None: 64 | norm_layer = nn.BatchNorm2d 65 | width = int(planes * (base_width / 64.)) * groups 66 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 67 | self.conv1 = conv1x1(inplanes, width) 68 | self.bn1 = norm_layer(width) 69 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 70 | self.bn2 = norm_layer(width) 71 | self.conv3 = conv1x1(width, planes * self.expansion) 72 | self.bn3 = norm_layer(planes * self.expansion) 73 | self.relu = nn.ReLU(inplace=True) 74 | self.downsample = downsample 75 | self.stride = stride 76 | 77 | def forward(self, x): 78 | identity = x 79 | 80 | out = self.conv1(x) 81 | out = self.bn1(out) 82 | out = self.relu(out) 83 | 84 | out = self.conv2(out) 85 | out = self.bn2(out) 86 | out = self.relu(out) 87 | 88 | out = self.conv3(out) 89 | out = self.bn3(out) 90 | 91 | if self.downsample is not None: 92 | identity = self.downsample(x) 93 | 94 | out += identity 95 | out = self.relu(out) 96 | 97 | return out 98 | 99 | 100 | class ResNet(nn.Module): 101 | 102 | def __init__(self, block, layers, num_classes=100, zero_init_residual=False, 103 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 104 | norm_layer=None): 105 | super(ResNet, self).__init__() 106 | if norm_layer is None: 107 | norm_layer = nn.BatchNorm2d 108 | self._norm_layer = norm_layer 109 | 110 | self.inplanes = 64 111 | self.dilation = 1 112 | if replace_stride_with_dilation is None: 113 | # each element in the tuple indicates if we should replace 114 | # the 2x2 stride with a dilated convolution instead 115 | replace_stride_with_dilation = [False, False, False] 116 | if len(replace_stride_with_dilation) != 3: 117 | raise ValueError("replace_stride_with_dilation should be None " 118 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 119 | self.groups = groups 120 | self.base_width = width_per_group 121 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 122 | bias=False) 123 | self.bn1 = norm_layer(self.inplanes) 124 | self.relu = nn.ReLU(inplace=True) 125 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 126 | self.layer1 = self._make_layer(block, 64, layers[0]) 127 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 128 | dilate=replace_stride_with_dilation[0]) 129 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 130 | dilate=replace_stride_with_dilation[1]) 131 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 132 | dilate=replace_stride_with_dilation[2]) 133 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 134 | self.fc = nn.Linear(512 * block.expansion, num_classes) 135 | 136 | for m in self.modules(): 137 | if isinstance(m, nn.Conv2d): 138 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 139 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 140 | nn.init.constant_(m.weight, 1) 141 | nn.init.constant_(m.bias, 0) 142 | 143 | # Zero-initialize the last BN in each residual branch, 144 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 145 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 146 | if zero_init_residual: 147 | for m in self.modules(): 148 | if isinstance(m, Bottleneck): 149 | nn.init.constant_(m.bn3.weight, 0) 150 | elif isinstance(m, BasicBlock): 151 | nn.init.constant_(m.bn2.weight, 0) 152 | 153 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 154 | norm_layer = self._norm_layer 155 | downsample = None 156 | previous_dilation = self.dilation 157 | if dilate: 158 | self.dilation *= stride 159 | stride = 1 160 | if stride != 1 or self.inplanes != planes * block.expansion: 161 | downsample = nn.Sequential( 162 | conv1x1(self.inplanes, planes * block.expansion, stride), 163 | norm_layer(planes * block.expansion), 164 | ) 165 | 166 | layers = [] 167 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 168 | self.base_width, previous_dilation, norm_layer)) 169 | self.inplanes = planes * block.expansion 170 | for _ in range(1, blocks): 171 | layers.append(block(self.inplanes, planes, groups=self.groups, 172 | base_width=self.base_width, dilation=self.dilation, 173 | norm_layer=norm_layer)) 174 | 175 | return nn.Sequential(*layers) 176 | 177 | def forward(self, x): 178 | x = self.conv1(x) 179 | x = self.bn1(x) 180 | x = self.relu(x) 181 | x = self.maxpool(x) 182 | 183 | x = self.layer1(x) 184 | x = self.layer2(x) 185 | x = self.layer3(x) 186 | x = self.layer4(x) 187 | 188 | x = self.avgpool(x) 189 | x = torch.flatten(x, 1) 190 | x = self.fc(x) 191 | 192 | return x 193 | 194 | 195 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 196 | model = ResNet(block, layers, **kwargs) 197 | return model 198 | 199 | 200 | def resnet18(pretrained=False, progress=True, **kwargs): 201 | r"""ResNet-18 model from 202 | `"Deep Residual Learning for Image Recognition" `_ 203 | Args: 204 | pretrained (bool): If True, returns a model pre-trained on ImageNet 205 | progress (bool): If True, displays a progress bar of the download to stderr 206 | """ 207 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 208 | **kwargs) 209 | 210 | 211 | def resnet34(pretrained=False, progress=True, **kwargs): 212 | r"""ResNet-34 model from 213 | `"Deep Residual Learning for Image Recognition" `_ 214 | Args: 215 | pretrained (bool): If True, returns a model pre-trained on ImageNet 216 | progress (bool): If True, displays a progress bar of the download to stderr 217 | """ 218 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 219 | **kwargs) 220 | 221 | 222 | def resnet50(pretrained=False, progress=True, **kwargs): 223 | r"""ResNet-50 model from 224 | `"Deep Residual Learning for Image Recognition" `_ 225 | Args: 226 | pretrained (bool): If True, returns a model pre-trained on ImageNet 227 | progress (bool): If True, displays a progress bar of the download to stderr 228 | """ 229 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 230 | **kwargs) 231 | 232 | 233 | def resnet101(pretrained=False, progress=True, **kwargs): 234 | r"""ResNet-101 model from 235 | `"Deep Residual Learning for Image Recognition" `_ 236 | Args: 237 | pretrained (bool): If True, returns a model pre-trained on ImageNet 238 | progress (bool): If True, displays a progress bar of the download to stderr 239 | """ 240 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 241 | **kwargs) 242 | 243 | 244 | def resnet152(pretrained=False, progress=True, **kwargs): 245 | r"""ResNet-152 model from 246 | `"Deep Residual Learning for Image Recognition" `_ 247 | Args: 248 | pretrained (bool): If True, returns a model pre-trained on ImageNet 249 | progress (bool): If True, displays a progress bar of the download to stderr 250 | """ 251 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 252 | **kwargs) 253 | 254 | 255 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 256 | r"""ResNeXt-50 32x4d model from 257 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 258 | Args: 259 | pretrained (bool): If True, returns a model pre-trained on ImageNet 260 | progress (bool): If True, displays a progress bar of the download to stderr 261 | """ 262 | kwargs['groups'] = 32 263 | kwargs['width_per_group'] = 4 264 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 265 | pretrained, progress, **kwargs) 266 | 267 | 268 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 269 | r"""ResNeXt-101 32x8d model from 270 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 271 | Args: 272 | pretrained (bool): If True, returns a model pre-trained on ImageNet 273 | progress (bool): If True, displays a progress bar of the download to stderr 274 | """ 275 | kwargs['groups'] = 32 276 | kwargs['width_per_group'] = 8 277 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 278 | pretrained, progress, **kwargs) 279 | 280 | 281 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 282 | r"""Wide ResNet-50-2 model from 283 | `"Wide Residual Networks" `_ 284 | The model is the same as ResNet except for the bottleneck number of channels 285 | which is twice larger in every block. The number of channels in outer 1x1 286 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 287 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 288 | Args: 289 | pretrained (bool): If True, returns a model pre-trained on ImageNet 290 | progress (bool): If True, displays a progress bar of the download to stderr 291 | """ 292 | kwargs['width_per_group'] = 64 * 2 293 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 294 | pretrained, progress, **kwargs) 295 | 296 | 297 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 298 | r"""Wide ResNet-101-2 model from 299 | `"Wide Residual Networks" `_ 300 | The model is the same as ResNet except for the bottleneck number of channels 301 | which is twice larger in every block. The number of channels in outer 1x1 302 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 303 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 304 | Args: 305 | pretrained (bool): If True, returns a model pre-trained on ImageNet 306 | progress (bool): If True, displays a progress bar of the download to stderr 307 | """ 308 | kwargs['width_per_group'] = 64 * 2 309 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 310 | pretrained, progress, **kwargs) -------------------------------------------------------------------------------- /archs/cifar100/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | # 4 | # from torchvision.utils import load_state_dict_from_url 5 | 6 | 7 | __all__ = [ 8 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 9 | 'vgg19_bn', 'vgg19', 10 | ] 11 | 12 | 13 | model_urls = { 14 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 15 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 16 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 17 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 18 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 19 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', 20 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 21 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', 22 | } 23 | 24 | 25 | class VGG(nn.Module): 26 | #ANCHOR Change No. of Classes here. 27 | def __init__(self, features, num_classes=100, init_weights=True): 28 | super(VGG, self).__init__() 29 | self.features = features 30 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 31 | self.classifier = nn.Sequential( 32 | nn.Linear(512 * 7 * 7, 4096), 33 | nn.ReLU(True), 34 | nn.Dropout(), 35 | nn.Linear(4096, 4096), 36 | nn.ReLU(True), 37 | nn.Dropout(), 38 | nn.Linear(4096, num_classes), 39 | ) 40 | if init_weights: 41 | self._initialize_weights() 42 | 43 | def forward(self, x): 44 | x = self.features(x) 45 | x = self.avgpool(x) 46 | x = torch.flatten(x, 1) 47 | x = self.classifier(x) 48 | return x 49 | 50 | def _initialize_weights(self): 51 | for m in self.modules(): 52 | if isinstance(m, nn.Conv2d): 53 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 54 | if m.bias is not None: 55 | nn.init.constant_(m.bias, 0) 56 | elif isinstance(m, nn.BatchNorm2d): 57 | nn.init.constant_(m.weight, 1) 58 | nn.init.constant_(m.bias, 0) 59 | elif isinstance(m, nn.Linear): 60 | nn.init.normal_(m.weight, 0, 0.01) 61 | nn.init.constant_(m.bias, 0) 62 | 63 | 64 | def make_layers(cfg, batch_norm=False): 65 | layers = [] 66 | #ANCHOR Change No. of Input channels here. 67 | in_channels = 3 68 | for v in cfg: 69 | if v == 'M': 70 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 71 | else: 72 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 73 | if batch_norm: 74 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 75 | else: 76 | layers += [conv2d, nn.ReLU(inplace=True)] 77 | in_channels = v 78 | return nn.Sequential(*layers) 79 | 80 | 81 | cfgs = { 82 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 83 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 84 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 85 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 86 | } 87 | 88 | 89 | def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs): 90 | if pretrained: 91 | kwargs['init_weights'] = False 92 | model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) 93 | #if pretrained: 94 | #state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) 95 | #model.load_state_dict(state_dict) 96 | return model 97 | 98 | 99 | def vgg11(pretrained=False, progress=True, **kwargs): 100 | r"""VGG 11-layer model (configuration "A") from 101 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 102 | 103 | Args: 104 | pretrained (bool): If True, returns a model pre-trained on ImageNet 105 | progress (bool): If True, displays a progress bar of the download to stderr 106 | """ 107 | return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs) 108 | 109 | 110 | 111 | def vgg11_bn(pretrained=False, progress=True, **kwargs): 112 | r"""VGG 11-layer model (configuration "A") with batch normalization 113 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 114 | 115 | Args: 116 | pretrained (bool): If True, returns a model pre-trained on ImageNet 117 | progress (bool): If True, displays a progress bar of the download to stderr 118 | """ 119 | return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs) 120 | 121 | 122 | 123 | def vgg13(pretrained=False, progress=True, **kwargs): 124 | r"""VGG 13-layer model (configuration "B") 125 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 126 | 127 | Args: 128 | pretrained (bool): If True, returns a model pre-trained on ImageNet 129 | progress (bool): If True, displays a progress bar of the download to stderr 130 | """ 131 | return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs) 132 | 133 | 134 | 135 | def vgg13_bn(pretrained=False, progress=True, **kwargs): 136 | r"""VGG 13-layer model (configuration "B") with batch normalization 137 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 138 | 139 | Args: 140 | pretrained (bool): If True, returns a model pre-trained on ImageNet 141 | progress (bool): If True, displays a progress bar of the download to stderr 142 | """ 143 | return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs) 144 | 145 | 146 | 147 | def vgg16(pretrained=False, progress=True, **kwargs): 148 | r"""VGG 16-layer model (configuration "D") 149 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 150 | 151 | Args: 152 | pretrained (bool): If True, returns a model pre-trained on ImageNet 153 | progress (bool): If True, displays a progress bar of the download to stderr 154 | """ 155 | return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs) 156 | 157 | 158 | 159 | def vgg16_bn(pretrained=False, progress=True, **kwargs): 160 | r"""VGG 16-layer model (configuration "D") with batch normalization 161 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 162 | 163 | Args: 164 | pretrained (bool): If True, returns a model pre-trained on ImageNet 165 | progress (bool): If True, displays a progress bar of the download to stderr 166 | """ 167 | return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs) 168 | 169 | 170 | 171 | def vgg19(pretrained=False, progress=True, **kwargs): 172 | r"""VGG 19-layer model (configuration "E") 173 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 174 | 175 | Args: 176 | pretrained (bool): If True, returns a model pre-trained on ImageNet 177 | progress (bool): If True, displays a progress bar of the download to stderr 178 | """ 179 | return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs) 180 | 181 | 182 | 183 | def vgg19_bn(pretrained=False, progress=True, **kwargs): 184 | r"""VGG 19-layer model (configuration 'E') with batch normalization 185 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 186 | 187 | Args: 188 | pretrained (bool): If True, returns a model pre-trained on ImageNet 189 | progress (bool): If True, displays a progress bar of the download to stderr 190 | """ 191 | return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs) 192 | -------------------------------------------------------------------------------- /archs/mnist/AlexNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | __all__ = ['AlexNet', 'alexnet'] 6 | 7 | 8 | model_urls = { 9 | 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth', 10 | } 11 | 12 | 13 | class AlexNet(nn.Module): 14 | 15 | def __init__(self, num_classes=10): 16 | super(AlexNet, self).__init__() 17 | self.features = nn.Sequential( 18 | nn.Conv2d(1, 64, kernel_size=3, stride=2, padding=2), 19 | nn.ReLU(inplace=True), 20 | nn.MaxPool2d(kernel_size=3, stride=2), 21 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 22 | nn.ReLU(inplace=True), 23 | nn.MaxPool2d(kernel_size=3, stride=2), 24 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 25 | nn.ReLU(inplace=True), 26 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 27 | nn.ReLU(inplace=True), 28 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 29 | nn.ReLU(inplace=True), 30 | nn.MaxPool2d(kernel_size=3, stride=2), 31 | ) 32 | self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) 33 | self.classifier = nn.Sequential( 34 | nn.Dropout(), 35 | nn.Linear(256 * 6 * 6, 4096), 36 | nn.ReLU(inplace=True), 37 | nn.Dropout(), 38 | nn.Linear(4096, 4096), 39 | nn.ReLU(inplace=True), 40 | nn.Linear(4096, num_classes), 41 | ) 42 | 43 | def forward(self, x): 44 | x = self.features(x) 45 | x = self.avgpool(x) 46 | x = torch.flatten(x, 1) 47 | x = self.classifier(x) 48 | return x 49 | -------------------------------------------------------------------------------- /archs/mnist/LeNet5.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class LeNet5(nn.Module): 5 | def __init__(self, num_classes=10): 6 | super(LeNet5, self).__init__() 7 | self.features = nn.Sequential( 8 | nn.Conv2d(1, 64, kernel_size=(3, 3), stride=1, padding=1), 9 | nn.ReLU(), 10 | nn.Conv2d(64, 64, kernel_size=(3, 3), stride=1, padding=1), 11 | nn.ReLU(), 12 | nn.MaxPool2d(kernel_size=2), 13 | ) 14 | self.classifier = nn.Sequential( 15 | nn.Linear(64*14*14, 256), 16 | nn.ReLU(inplace=True), 17 | nn.Linear(256, 256), 18 | nn.ReLU(inplace=True), 19 | nn.Linear(256, num_classes), 20 | ) 21 | 22 | def forward(self, x): 23 | x = self.features(x) 24 | x = torch.flatten(x, 1) 25 | x = self.classifier(x) 26 | return x 27 | -------------------------------------------------------------------------------- /archs/mnist/fc1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class fc1(nn.Module): 5 | 6 | def __init__(self, num_classes=10): 7 | super(fc1, self).__init__() 8 | self.classifier = nn.Sequential( 9 | nn.Linear(28*28, 300), 10 | nn.ReLU(inplace=True), 11 | nn.Linear(300, 100), 12 | nn.ReLU(inplace=True), 13 | nn.Linear(100, num_classes), 14 | ) 15 | 16 | def forward(self, x): 17 | x = torch.flatten(x, 1) 18 | x = self.classifier(x) 19 | return x 20 | 21 | -------------------------------------------------------------------------------- /archs/mnist/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 5 | """3x3 convolution with padding""" 6 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 7 | padding=dilation, groups=groups, bias=False, dilation=dilation) 8 | 9 | 10 | def conv1x1(in_planes, out_planes, stride=1): 11 | """1x1 convolution""" 12 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 13 | 14 | 15 | class BasicBlock(nn.Module): 16 | expansion = 1 17 | __constants__ = ['downsample'] 18 | 19 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 20 | base_width=64, dilation=1, norm_layer=None): 21 | super(BasicBlock, self).__init__() 22 | if norm_layer is None: 23 | norm_layer = nn.BatchNorm2d 24 | if groups != 1 or base_width != 64: 25 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 26 | if dilation > 1: 27 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 28 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 29 | self.conv1 = conv3x3(inplanes, planes, stride) 30 | self.bn1 = norm_layer(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv2 = conv3x3(planes, planes) 33 | self.bn2 = norm_layer(planes) 34 | self.downsample = downsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | identity = x 39 | 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv2(out) 45 | out = self.bn2(out) 46 | 47 | if self.downsample is not None: 48 | identity = self.downsample(x) 49 | 50 | out += identity 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class Bottleneck(nn.Module): 57 | expansion = 4 58 | __constants__ = ['downsample'] 59 | 60 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 61 | base_width=64, dilation=1, norm_layer=None): 62 | super(Bottleneck, self).__init__() 63 | if norm_layer is None: 64 | norm_layer = nn.BatchNorm2d 65 | width = int(planes * (base_width / 64.)) * groups 66 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 67 | self.conv1 = conv1x1(inplanes, width) 68 | self.bn1 = norm_layer(width) 69 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 70 | self.bn2 = norm_layer(width) 71 | self.conv3 = conv1x1(width, planes * self.expansion) 72 | self.bn3 = norm_layer(planes * self.expansion) 73 | self.relu = nn.ReLU(inplace=True) 74 | self.downsample = downsample 75 | self.stride = stride 76 | 77 | def forward(self, x): 78 | identity = x 79 | 80 | out = self.conv1(x) 81 | out = self.bn1(out) 82 | out = self.relu(out) 83 | 84 | out = self.conv2(out) 85 | out = self.bn2(out) 86 | out = self.relu(out) 87 | 88 | out = self.conv3(out) 89 | out = self.bn3(out) 90 | 91 | if self.downsample is not None: 92 | identity = self.downsample(x) 93 | 94 | out += identity 95 | out = self.relu(out) 96 | 97 | return out 98 | 99 | 100 | class ResNet(nn.Module): 101 | 102 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 103 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 104 | norm_layer=None): 105 | super(ResNet, self).__init__() 106 | if norm_layer is None: 107 | norm_layer = nn.BatchNorm2d 108 | self._norm_layer = norm_layer 109 | 110 | self.inplanes = 64 111 | self.dilation = 1 112 | if replace_stride_with_dilation is None: 113 | # each element in the tuple indicates if we should replace 114 | # the 2x2 stride with a dilated convolution instead 115 | replace_stride_with_dilation = [False, False, False] 116 | if len(replace_stride_with_dilation) != 3: 117 | raise ValueError("replace_stride_with_dilation should be None " 118 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 119 | self.groups = groups 120 | self.base_width = width_per_group 121 | self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=7, stride=2, padding=3, 122 | bias=False) 123 | self.bn1 = norm_layer(self.inplanes) 124 | self.relu = nn.ReLU(inplace=True) 125 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 126 | self.layer1 = self._make_layer(block, 64, layers[0]) 127 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 128 | dilate=replace_stride_with_dilation[0]) 129 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 130 | dilate=replace_stride_with_dilation[1]) 131 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 132 | dilate=replace_stride_with_dilation[2]) 133 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 134 | self.fc = nn.Linear(512 * block.expansion, num_classes) 135 | 136 | for m in self.modules(): 137 | if isinstance(m, nn.Conv2d): 138 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 139 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 140 | nn.init.constant_(m.weight, 1) 141 | nn.init.constant_(m.bias, 0) 142 | 143 | # Zero-initialize the last BN in each residual branch, 144 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 145 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 146 | if zero_init_residual: 147 | for m in self.modules(): 148 | if isinstance(m, Bottleneck): 149 | nn.init.constant_(m.bn3.weight, 0) 150 | elif isinstance(m, BasicBlock): 151 | nn.init.constant_(m.bn2.weight, 0) 152 | 153 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 154 | norm_layer = self._norm_layer 155 | downsample = None 156 | previous_dilation = self.dilation 157 | if dilate: 158 | self.dilation *= stride 159 | stride = 1 160 | if stride != 1 or self.inplanes != planes * block.expansion: 161 | downsample = nn.Sequential( 162 | conv1x1(self.inplanes, planes * block.expansion, stride), 163 | norm_layer(planes * block.expansion), 164 | ) 165 | 166 | layers = [] 167 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 168 | self.base_width, previous_dilation, norm_layer)) 169 | self.inplanes = planes * block.expansion 170 | for _ in range(1, blocks): 171 | layers.append(block(self.inplanes, planes, groups=self.groups, 172 | base_width=self.base_width, dilation=self.dilation, 173 | norm_layer=norm_layer)) 174 | 175 | return nn.Sequential(*layers) 176 | 177 | def forward(self, x): 178 | x = self.conv1(x) 179 | x = self.bn1(x) 180 | x = self.relu(x) 181 | x = self.maxpool(x) 182 | 183 | x = self.layer1(x) 184 | x = self.layer2(x) 185 | x = self.layer3(x) 186 | x = self.layer4(x) 187 | 188 | x = self.avgpool(x) 189 | x = torch.flatten(x, 1) 190 | x = self.fc(x) 191 | 192 | return x 193 | 194 | 195 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 196 | model = ResNet(block, layers, **kwargs) 197 | return model 198 | 199 | 200 | def resnet18(pretrained=False, progress=True, **kwargs): 201 | r"""ResNet-18 model from 202 | `"Deep Residual Learning for Image Recognition" `_ 203 | Args: 204 | pretrained (bool): If True, returns a model pre-trained on ImageNet 205 | progress (bool): If True, displays a progress bar of the download to stderr 206 | """ 207 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 208 | **kwargs) 209 | 210 | 211 | def resnet34(pretrained=False, progress=True, **kwargs): 212 | r"""ResNet-34 model from 213 | `"Deep Residual Learning for Image Recognition" `_ 214 | Args: 215 | pretrained (bool): If True, returns a model pre-trained on ImageNet 216 | progress (bool): If True, displays a progress bar of the download to stderr 217 | """ 218 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 219 | **kwargs) 220 | 221 | 222 | def resnet50(pretrained=False, progress=True, **kwargs): 223 | r"""ResNet-50 model from 224 | `"Deep Residual Learning for Image Recognition" `_ 225 | Args: 226 | pretrained (bool): If True, returns a model pre-trained on ImageNet 227 | progress (bool): If True, displays a progress bar of the download to stderr 228 | """ 229 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 230 | **kwargs) 231 | 232 | 233 | def resnet101(pretrained=False, progress=True, **kwargs): 234 | r"""ResNet-101 model from 235 | `"Deep Residual Learning for Image Recognition" `_ 236 | Args: 237 | pretrained (bool): If True, returns a model pre-trained on ImageNet 238 | progress (bool): If True, displays a progress bar of the download to stderr 239 | """ 240 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 241 | **kwargs) 242 | 243 | 244 | def resnet152(pretrained=False, progress=True, **kwargs): 245 | r"""ResNet-152 model from 246 | `"Deep Residual Learning for Image Recognition" `_ 247 | Args: 248 | pretrained (bool): If True, returns a model pre-trained on ImageNet 249 | progress (bool): If True, displays a progress bar of the download to stderr 250 | """ 251 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 252 | **kwargs) 253 | 254 | 255 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 256 | r"""ResNeXt-50 32x4d model from 257 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 258 | Args: 259 | pretrained (bool): If True, returns a model pre-trained on ImageNet 260 | progress (bool): If True, displays a progress bar of the download to stderr 261 | """ 262 | kwargs['groups'] = 32 263 | kwargs['width_per_group'] = 4 264 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 265 | pretrained, progress, **kwargs) 266 | 267 | 268 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 269 | r"""ResNeXt-101 32x8d model from 270 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 271 | Args: 272 | pretrained (bool): If True, returns a model pre-trained on ImageNet 273 | progress (bool): If True, displays a progress bar of the download to stderr 274 | """ 275 | kwargs['groups'] = 32 276 | kwargs['width_per_group'] = 8 277 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 278 | pretrained, progress, **kwargs) 279 | 280 | 281 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 282 | r"""Wide ResNet-50-2 model from 283 | `"Wide Residual Networks" `_ 284 | The model is the same as ResNet except for the bottleneck number of channels 285 | which is twice larger in every block. The number of channels in outer 1x1 286 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 287 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 288 | Args: 289 | pretrained (bool): If True, returns a model pre-trained on ImageNet 290 | progress (bool): If True, displays a progress bar of the download to stderr 291 | """ 292 | kwargs['width_per_group'] = 64 * 2 293 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 294 | pretrained, progress, **kwargs) 295 | 296 | 297 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 298 | r"""Wide ResNet-101-2 model from 299 | `"Wide Residual Networks" `_ 300 | The model is the same as ResNet except for the bottleneck number of channels 301 | which is twice larger in every block. The number of channels in outer 1x1 302 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 303 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 304 | Args: 305 | pretrained (bool): If True, returns a model pre-trained on ImageNet 306 | progress (bool): If True, displays a progress bar of the download to stderr 307 | """ 308 | kwargs['width_per_group'] = 64 * 2 309 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 310 | pretrained, progress, **kwargs) -------------------------------------------------------------------------------- /archs/mnist/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def vgg_block(num_convs, in_channels, num_channels): 6 | layers=[] 7 | for i in range(num_convs): 8 | layers+=[nn.Conv2d(in_channels=in_channels, out_channels=num_channels, kernel_size=3, padding=1)] 9 | in_channels=num_channels 10 | layers +=[nn.ReLU()] 11 | layers +=[nn.MaxPool2d(kernel_size=2, stride=2)] 12 | return nn.Sequential(*layers) 13 | 14 | class vgg16(nn.Module): 15 | def __init__(self, num_classes = 10): 16 | super(vgg16,self).__init__() 17 | self.conv_arch=((1,1,64),(1,64,128),(2,128,256),(2,256,512),(2,512,512)) 18 | layers=[] 19 | for (num_convs,in_channels,num_channels) in self.conv_arch: 20 | layers+=[vgg_block(num_convs,in_channels,num_channels)] 21 | self.features=nn.Sequential(*layers) 22 | self.dense1 = nn.Linear(512*7*7,4096) 23 | self.drop1 = nn.Dropout(0.5) 24 | self.dense2 = nn.Linear(4096, 4096) 25 | self.drop2 = nn.Dropout(0.5) 26 | self.dense3 = nn.Linear(4096, num_classes) 27 | 28 | def forward(self,x): 29 | x=self.features(x) 30 | x=x.view(-1,512*7*7) 31 | x=self.dense3(self.drop2(F.relu(self.dense2(self.drop1(F.relu(self.dense1(x))))))) 32 | return x 33 | -------------------------------------------------------------------------------- /combine_plots.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import seaborn as sns 3 | import numpy as np 4 | import os 5 | from tqdm import tqdm 6 | 7 | 8 | DPI = 1200 9 | prune_iterations = 35 10 | arch_types = ["fc1", "lenet5", "resnet18"] 11 | datasets = ["mnist", "fashionmnist", "cifar10", "cifar100"] 12 | 13 | 14 | for arch_type in tqdm(arch_types): 15 | for dataset in tqdm(datasets): 16 | d = np.load(f"{os.getcwd()}/dumps/lt/{arch_type}/{dataset}/lt_compression.dat", allow_pickle=True) 17 | b = np.load(f"{os.getcwd()}/dumps/lt/{arch_type}/{dataset}/lt_bestaccuracy.dat", allow_pickle=True) 18 | c = np.load(f"{os.getcwd()}/dumps/lt/{arch_type}/{dataset}/reinit_bestaccuracy.dat", allow_pickle=True) 19 | 20 | #plt.clf() 21 | #sns.set_style('darkgrid') 22 | #plt.style.use('seaborn-darkgrid') 23 | a = np.arange(prune_iterations) 24 | plt.plot(a, b, c="blue", label="Winning tickets") 25 | plt.plot(a, c, c="red", label="Random reinit") 26 | plt.title(f"Test Accuracy vs Weights % ({arch_type} | {dataset})") 27 | plt.xlabel("Weights %") 28 | plt.ylabel("Test accuracy") 29 | plt.xticks(a, d, rotation ="vertical") 30 | plt.ylim(0,100) 31 | plt.legend() 32 | plt.grid(color="gray") 33 | 34 | plt.savefig(f"{os.getcwd()}/plots/lt/combined_plots/combined_{arch_type}_{dataset}.png", dpi=DPI, bbox_inches='tight') 35 | plt.close() 36 | #print(f"\n combined_{arch_type}_{dataset} plotted!\n") 37 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Importing Libraries 2 | import argparse 3 | import copy 4 | import os 5 | import sys 6 | import numpy as np 7 | from tqdm import tqdm 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torchvision 12 | import torchvision.transforms as transforms 13 | import torchvision.datasets as datasets 14 | import matplotlib.pyplot as plt 15 | import os 16 | from tensorboardX import SummaryWriter 17 | import torchvision.utils as vutils 18 | import seaborn as sns 19 | import torch.nn.init as init 20 | import pickle 21 | 22 | # Custom Libraries 23 | import utils 24 | 25 | # Tensorboard initialization 26 | writer = SummaryWriter() 27 | 28 | # Plotting Style 29 | sns.set_style('darkgrid') 30 | 31 | # Main 32 | def main(args, ITE=0): 33 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 34 | reinit = True if args.prune_type=="reinit" else False 35 | 36 | # Data Loader 37 | transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))]) 38 | if args.dataset == "mnist": 39 | traindataset = datasets.MNIST('../data', train=True, download=True,transform=transform) 40 | testdataset = datasets.MNIST('../data', train=False, transform=transform) 41 | from archs.mnist import AlexNet, LeNet5, fc1, vgg, resnet 42 | 43 | elif args.dataset == "cifar10": 44 | traindataset = datasets.CIFAR10('../data', train=True, download=True,transform=transform) 45 | testdataset = datasets.CIFAR10('../data', train=False, transform=transform) 46 | from archs.cifar10 import AlexNet, LeNet5, fc1, vgg, resnet, densenet 47 | 48 | elif args.dataset == "fashionmnist": 49 | traindataset = datasets.FashionMNIST('../data', train=True, download=True,transform=transform) 50 | testdataset = datasets.FashionMNIST('../data', train=False, transform=transform) 51 | from archs.mnist import AlexNet, LeNet5, fc1, vgg, resnet 52 | 53 | elif args.dataset == "cifar100": 54 | traindataset = datasets.CIFAR100('../data', train=True, download=True,transform=transform) 55 | testdataset = datasets.CIFAR100('../data', train=False, transform=transform) 56 | from archs.cifar100 import AlexNet, fc1, LeNet5, vgg, resnet 57 | 58 | # If you want to add extra datasets paste here 59 | 60 | else: 61 | print("\nWrong Dataset choice \n") 62 | exit() 63 | 64 | train_loader = torch.utils.data.DataLoader(traindataset, batch_size=args.batch_size, shuffle=True, num_workers=0,drop_last=False) 65 | #train_loader = cycle(train_loader) 66 | test_loader = torch.utils.data.DataLoader(testdataset, batch_size=args.batch_size, shuffle=False, num_workers=0,drop_last=True) 67 | 68 | # Importing Network Architecture 69 | global model 70 | if args.arch_type == "fc1": 71 | model = fc1.fc1().to(device) 72 | elif args.arch_type == "lenet5": 73 | model = LeNet5.LeNet5().to(device) 74 | elif args.arch_type == "alexnet": 75 | model = AlexNet.AlexNet().to(device) 76 | elif args.arch_type == "vgg16": 77 | model = vgg.vgg16().to(device) 78 | elif args.arch_type == "resnet18": 79 | model = resnet.resnet18().to(device) 80 | elif args.arch_type == "densenet121": 81 | model = densenet.densenet121().to(device) 82 | # If you want to add extra model paste here 83 | else: 84 | print("\nWrong Model choice\n") 85 | exit() 86 | 87 | # Weight Initialization 88 | model.apply(weight_init) 89 | 90 | # Copying and Saving Initial State 91 | initial_state_dict = copy.deepcopy(model.state_dict()) 92 | utils.checkdir(f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/") 93 | torch.save(model, f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/initial_state_dict_{args.prune_type}.pth.tar") 94 | 95 | # Making Initial Mask 96 | make_mask(model) 97 | 98 | # Optimizer and Loss 99 | optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-4) 100 | criterion = nn.CrossEntropyLoss() # Default was F.nll_loss 101 | 102 | # Layer Looper 103 | for name, param in model.named_parameters(): 104 | print(name, param.size()) 105 | 106 | # Pruning 107 | # NOTE First Pruning Iteration is of No Compression 108 | bestacc = 0.0 109 | best_accuracy = 0 110 | ITERATION = args.prune_iterations 111 | comp = np.zeros(ITERATION,float) 112 | bestacc = np.zeros(ITERATION,float) 113 | step = 0 114 | all_loss = np.zeros(args.end_iter,float) 115 | all_accuracy = np.zeros(args.end_iter,float) 116 | 117 | 118 | for _ite in range(args.start_iter, ITERATION): 119 | if not _ite == 0: 120 | prune_by_percentile(args.prune_percent, resample=resample, reinit=reinit) 121 | if reinit: 122 | model.apply(weight_init) 123 | #if args.arch_type == "fc1": 124 | # model = fc1.fc1().to(device) 125 | #elif args.arch_type == "lenet5": 126 | # model = LeNet5.LeNet5().to(device) 127 | #elif args.arch_type == "alexnet": 128 | # model = AlexNet.AlexNet().to(device) 129 | #elif args.arch_type == "vgg16": 130 | # model = vgg.vgg16().to(device) 131 | #elif args.arch_type == "resnet18": 132 | # model = resnet.resnet18().to(device) 133 | #elif args.arch_type == "densenet121": 134 | # model = densenet.densenet121().to(device) 135 | #else: 136 | # print("\nWrong Model choice\n") 137 | # exit() 138 | step = 0 139 | for name, param in model.named_parameters(): 140 | if 'weight' in name: 141 | weight_dev = param.device 142 | param.data = torch.from_numpy(param.data.cpu().numpy() * mask[step]).to(weight_dev) 143 | step = step + 1 144 | step = 0 145 | else: 146 | original_initialization(mask, initial_state_dict) 147 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4) 148 | print(f"\n--- Pruning Level [{ITE}:{_ite}/{ITERATION}]: ---") 149 | 150 | # Print the table of Nonzeros in each layer 151 | comp1 = utils.print_nonzeros(model) 152 | comp[_ite] = comp1 153 | pbar = tqdm(range(args.end_iter)) 154 | 155 | for iter_ in pbar: 156 | 157 | # Frequency for Testing 158 | if iter_ % args.valid_freq == 0: 159 | accuracy = test(model, test_loader, criterion) 160 | 161 | # Save Weights 162 | if accuracy > best_accuracy: 163 | best_accuracy = accuracy 164 | utils.checkdir(f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/") 165 | torch.save(model,f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/{_ite}_model_{args.prune_type}.pth.tar") 166 | 167 | # Training 168 | loss = train(model, train_loader, optimizer, criterion) 169 | all_loss[iter_] = loss 170 | all_accuracy[iter_] = accuracy 171 | 172 | # Frequency for Printing Accuracy and Loss 173 | if iter_ % args.print_freq == 0: 174 | pbar.set_description( 175 | f'Train Epoch: {iter_}/{args.end_iter} Loss: {loss:.6f} Accuracy: {accuracy:.2f}% Best Accuracy: {best_accuracy:.2f}%') 176 | 177 | writer.add_scalar('Accuracy/test', best_accuracy, comp1) 178 | bestacc[_ite]=best_accuracy 179 | 180 | # Plotting Loss (Training), Accuracy (Testing), Iteration Curve 181 | #NOTE Loss is computed for every iteration while Accuracy is computed only for every {args.valid_freq} iterations. Therefore Accuracy saved is constant during the uncomputed iterations. 182 | #NOTE Normalized the accuracy to [0,100] for ease of plotting. 183 | plt.plot(np.arange(1,(args.end_iter)+1), 100*(all_loss - np.min(all_loss))/np.ptp(all_loss).astype(float), c="blue", label="Loss") 184 | plt.plot(np.arange(1,(args.end_iter)+1), all_accuracy, c="red", label="Accuracy") 185 | plt.title(f"Loss Vs Accuracy Vs Iterations ({args.dataset},{args.arch_type})") 186 | plt.xlabel("Iterations") 187 | plt.ylabel("Loss and Accuracy") 188 | plt.legend() 189 | plt.grid(color="gray") 190 | utils.checkdir(f"{os.getcwd()}/plots/lt/{args.arch_type}/{args.dataset}/") 191 | plt.savefig(f"{os.getcwd()}/plots/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_LossVsAccuracy_{comp1}.png", dpi=1200) 192 | plt.close() 193 | 194 | # Dump Plot values 195 | utils.checkdir(f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/") 196 | all_loss.dump(f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_all_loss_{comp1}.dat") 197 | all_accuracy.dump(f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_all_accuracy_{comp1}.dat") 198 | 199 | # Dumping mask 200 | utils.checkdir(f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/") 201 | with open(f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_mask_{comp1}.pkl", 'wb') as fp: 202 | pickle.dump(mask, fp) 203 | 204 | # Making variables into 0 205 | best_accuracy = 0 206 | all_loss = np.zeros(args.end_iter,float) 207 | all_accuracy = np.zeros(args.end_iter,float) 208 | 209 | # Dumping Values for Plotting 210 | utils.checkdir(f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/") 211 | comp.dump(f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_compression.dat") 212 | bestacc.dump(f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_bestaccuracy.dat") 213 | 214 | # Plotting 215 | a = np.arange(args.prune_iterations) 216 | plt.plot(a, bestacc, c="blue", label="Winning tickets") 217 | plt.title(f"Test Accuracy vs Unpruned Weights Percentage ({args.dataset},{args.arch_type})") 218 | plt.xlabel("Unpruned Weights Percentage") 219 | plt.ylabel("test accuracy") 220 | plt.xticks(a, comp, rotation ="vertical") 221 | plt.ylim(0,100) 222 | plt.legend() 223 | plt.grid(color="gray") 224 | utils.checkdir(f"{os.getcwd()}/plots/lt/{args.arch_type}/{args.dataset}/") 225 | plt.savefig(f"{os.getcwd()}/plots/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_AccuracyVsWeights.png", dpi=1200) 226 | plt.close() 227 | 228 | # Function for Training 229 | def train(model, train_loader, optimizer, criterion): 230 | EPS = 1e-6 231 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 232 | model.train() 233 | for batch_idx, (imgs, targets) in enumerate(train_loader): 234 | optimizer.zero_grad() 235 | #imgs, targets = next(train_loader) 236 | imgs, targets = imgs.to(device), targets.to(device) 237 | output = model(imgs) 238 | train_loss = criterion(output, targets) 239 | train_loss.backward() 240 | 241 | # Freezing Pruned weights by making their gradients Zero 242 | for name, p in model.named_parameters(): 243 | if 'weight' in name: 244 | tensor = p.data.cpu().numpy() 245 | grad_tensor = p.grad.data.cpu().numpy() 246 | grad_tensor = np.where(tensor < EPS, 0, grad_tensor) 247 | p.grad.data = torch.from_numpy(grad_tensor).to(device) 248 | optimizer.step() 249 | return train_loss.item() 250 | 251 | # Function for Testing 252 | def test(model, test_loader, criterion): 253 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 254 | model.eval() 255 | test_loss = 0 256 | correct = 0 257 | with torch.no_grad(): 258 | for data, target in test_loader: 259 | data, target = data.to(device), target.to(device) 260 | output = model(data) 261 | test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss 262 | pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability 263 | correct += pred.eq(target.data.view_as(pred)).sum().item() 264 | test_loss /= len(test_loader.dataset) 265 | accuracy = 100. * correct / len(test_loader.dataset) 266 | return accuracy 267 | 268 | # Prune by Percentile module 269 | def prune_by_percentile(percent, resample=False, reinit=False,**kwargs): 270 | global step 271 | global mask 272 | global model 273 | 274 | # Calculate percentile value 275 | step = 0 276 | for name, param in model.named_parameters(): 277 | 278 | # We do not prune bias term 279 | if 'weight' in name: 280 | tensor = param.data.cpu().numpy() 281 | alive = tensor[np.nonzero(tensor)] # flattened array of nonzero values 282 | percentile_value = np.percentile(abs(alive), percent) 283 | 284 | # Convert Tensors to numpy and calculate 285 | weight_dev = param.device 286 | new_mask = np.where(abs(tensor) < percentile_value, 0, mask[step]) 287 | 288 | # Apply new weight and mask 289 | param.data = torch.from_numpy(tensor * new_mask).to(weight_dev) 290 | mask[step] = new_mask 291 | step += 1 292 | step = 0 293 | 294 | # Function to make an empty mask of the same size as the model 295 | def make_mask(model): 296 | global step 297 | global mask 298 | step = 0 299 | for name, param in model.named_parameters(): 300 | if 'weight' in name: 301 | step = step + 1 302 | mask = [None]* step 303 | step = 0 304 | for name, param in model.named_parameters(): 305 | if 'weight' in name: 306 | tensor = param.data.cpu().numpy() 307 | mask[step] = np.ones_like(tensor) 308 | step = step + 1 309 | step = 0 310 | 311 | def original_initialization(mask_temp, initial_state_dict): 312 | global model 313 | 314 | step = 0 315 | for name, param in model.named_parameters(): 316 | if "weight" in name: 317 | weight_dev = param.device 318 | param.data = torch.from_numpy(mask_temp[step] * initial_state_dict[name].cpu().numpy()).to(weight_dev) 319 | step = step + 1 320 | if "bias" in name: 321 | param.data = initial_state_dict[name] 322 | step = 0 323 | 324 | # Function for Initialization 325 | def weight_init(m): 326 | ''' 327 | Usage: 328 | model = Model() 329 | model.apply(weight_init) 330 | ''' 331 | if isinstance(m, nn.Conv1d): 332 | init.normal_(m.weight.data) 333 | if m.bias is not None: 334 | init.normal_(m.bias.data) 335 | elif isinstance(m, nn.Conv2d): 336 | init.xavier_normal_(m.weight.data) 337 | if m.bias is not None: 338 | init.normal_(m.bias.data) 339 | elif isinstance(m, nn.Conv3d): 340 | init.xavier_normal_(m.weight.data) 341 | if m.bias is not None: 342 | init.normal_(m.bias.data) 343 | elif isinstance(m, nn.ConvTranspose1d): 344 | init.normal_(m.weight.data) 345 | if m.bias is not None: 346 | init.normal_(m.bias.data) 347 | elif isinstance(m, nn.ConvTranspose2d): 348 | init.xavier_normal_(m.weight.data) 349 | if m.bias is not None: 350 | init.normal_(m.bias.data) 351 | elif isinstance(m, nn.ConvTranspose3d): 352 | init.xavier_normal_(m.weight.data) 353 | if m.bias is not None: 354 | init.normal_(m.bias.data) 355 | elif isinstance(m, nn.BatchNorm1d): 356 | init.normal_(m.weight.data, mean=1, std=0.02) 357 | init.constant_(m.bias.data, 0) 358 | elif isinstance(m, nn.BatchNorm2d): 359 | init.normal_(m.weight.data, mean=1, std=0.02) 360 | init.constant_(m.bias.data, 0) 361 | elif isinstance(m, nn.BatchNorm3d): 362 | init.normal_(m.weight.data, mean=1, std=0.02) 363 | init.constant_(m.bias.data, 0) 364 | elif isinstance(m, nn.Linear): 365 | init.xavier_normal_(m.weight.data) 366 | init.normal_(m.bias.data) 367 | elif isinstance(m, nn.LSTM): 368 | for param in m.parameters(): 369 | if len(param.shape) >= 2: 370 | init.orthogonal_(param.data) 371 | else: 372 | init.normal_(param.data) 373 | elif isinstance(m, nn.LSTMCell): 374 | for param in m.parameters(): 375 | if len(param.shape) >= 2: 376 | init.orthogonal_(param.data) 377 | else: 378 | init.normal_(param.data) 379 | elif isinstance(m, nn.GRU): 380 | for param in m.parameters(): 381 | if len(param.shape) >= 2: 382 | init.orthogonal_(param.data) 383 | else: 384 | init.normal_(param.data) 385 | elif isinstance(m, nn.GRUCell): 386 | for param in m.parameters(): 387 | if len(param.shape) >= 2: 388 | init.orthogonal_(param.data) 389 | else: 390 | init.normal_(param.data) 391 | 392 | 393 | if __name__=="__main__": 394 | 395 | #from gooey import Gooey 396 | #@Gooey 397 | 398 | # Arguement Parser 399 | parser = argparse.ArgumentParser() 400 | parser.add_argument("--lr",default= 1.2e-3, type=float, help="Learning rate") 401 | parser.add_argument("--batch_size", default=60, type=int) 402 | parser.add_argument("--start_iter", default=0, type=int) 403 | parser.add_argument("--end_iter", default=100, type=int) 404 | parser.add_argument("--print_freq", default=1, type=int) 405 | parser.add_argument("--valid_freq", default=1, type=int) 406 | parser.add_argument("--resume", action="store_true") 407 | parser.add_argument("--prune_type", default="lt", type=str, help="lt | reinit") 408 | parser.add_argument("--gpu", default="0", type=str) 409 | parser.add_argument("--dataset", default="mnist", type=str, help="mnist | cifar10 | fashionmnist | cifar100") 410 | parser.add_argument("--arch_type", default="fc1", type=str, help="fc1 | lenet5 | alexnet | vgg16 | resnet18 | densenet121") 411 | parser.add_argument("--prune_percent", default=10, type=int, help="Pruning percent") 412 | parser.add_argument("--prune_iterations", default=35, type=int, help="Pruning iterations count") 413 | 414 | 415 | args = parser.parse_args() 416 | 417 | 418 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 419 | os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu 420 | 421 | 422 | #FIXME resample 423 | resample = False 424 | 425 | # Looping Entire process 426 | #for i in range(0, 5): 427 | main(args, ITE=1) 428 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cycler==0.10.0 2 | kiwisolver==1.1.0 3 | matplotlib==3.1.1 4 | numpy==1.17.2 5 | pandas==0.25.1 6 | Pillow==6.2.0 7 | protobuf==3.9.2 8 | pyparsing==2.4.2 9 | python-dateutil==2.8.0 10 | pytz==2019.2 11 | scipy==1.3.1 12 | seaborn==0.9.0 13 | six==1.12.0 14 | tensorboardX==1.8 15 | torch==1.2.0 16 | torchvision==0.4.0 17 | tqdm==4.36.1 18 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #ANCHOR Libraries 2 | import numpy as np 3 | import torch 4 | import os 5 | import seaborn as sns 6 | import matplotlib.pyplot as plt 7 | import copy 8 | 9 | #ANCHOR Print table of zeros and non-zeros count 10 | def print_nonzeros(model): 11 | nonzero = total = 0 12 | for name, p in model.named_parameters(): 13 | tensor = p.data.cpu().numpy() 14 | nz_count = np.count_nonzero(tensor) 15 | total_params = np.prod(tensor.shape) 16 | nonzero += nz_count 17 | total += total_params 18 | print(f'{name:20} | nonzeros = {nz_count:7} / {total_params:7} ({100 * nz_count / total_params:6.2f}%) | total_pruned = {total_params - nz_count :7} | shape = {tensor.shape}') 19 | print(f'alive: {nonzero}, pruned : {total - nonzero}, total: {total}, Compression rate : {total/nonzero:10.2f}x ({100 * (total-nonzero) / total:6.2f}% pruned)') 20 | return (round((nonzero/total)*100,1)) 21 | 22 | def original_initialization(mask_temp, initial_state_dict): 23 | global model 24 | 25 | step = 0 26 | for name, param in model.named_parameters(): 27 | if "weight" in name: 28 | weight_dev = param.device 29 | param.data = torch.from_numpy(mask_temp[step] * initial_state_dict[name].cpu().numpy()).to(weight_dev) 30 | step = step + 1 31 | if "bias" in name: 32 | param.data = initial_state_dict[name] 33 | step = 0 34 | 35 | 36 | 37 | 38 | #ANCHOR Checks of the directory exist and if not, creates a new directory 39 | def checkdir(directory): 40 | if not os.path.exists(directory): 41 | os.makedirs(directory) 42 | 43 | #FIXME 44 | def plot_train_test_stats(stats, 45 | epoch_num, 46 | key1='train', 47 | key2='test', 48 | key1_label=None, 49 | key2_label=None, 50 | xlabel=None, 51 | ylabel=None, 52 | title=None, 53 | yscale=None, 54 | ylim_bottom=None, 55 | ylim_top=None, 56 | savefig=None, 57 | sns_style='darkgrid' 58 | ): 59 | 60 | assert len(stats[key1]) == epoch_num, "len(stats['{}'])({}) != epoch_num({})".format(key1, len(stats[key1]), epoch_num) 61 | assert len(stats[key2]) == epoch_num, "len(stats['{}'])({}) != epoch_num({})".format(key2, len(stats[key2]), epoch_num) 62 | 63 | plt.clf() 64 | sns.set_style(sns_style) 65 | x_ticks = np.arange(epoch_num) 66 | 67 | plt.plot(x_ticks, stats[key1], label=key1_label) 68 | plt.plot(x_ticks, stats[key2], label=key2_label) 69 | 70 | if xlabel is not None: 71 | plt.xlabel(xlabel) 72 | if ylabel is not None: 73 | plt.ylabel(ylabel) 74 | 75 | if title is not None: 76 | plt.title(title) 77 | 78 | if yscale is not None: 79 | plt.yscale(yscale) 80 | 81 | if ylim_bottom is not None: 82 | plt.ylim(bottom=ylim_bottom) 83 | if ylim_top is not None: 84 | plt.ylim(top=ylim_top) 85 | 86 | plt.legend(bbox_to_anchor=(1.04,0.5), loc="center left", borderaxespad=0, fancybox=True) 87 | 88 | if savefig is not None: 89 | plt.savefig(savefig, bbox_inches='tight') 90 | else: 91 | plt.show() --------------------------------------------------------------------------------