├── .gitignore
├── README.md
├── archs
├── cifar10
│ ├── AlexNet.py
│ ├── LeNet5.py
│ ├── densenet.py
│ ├── fc1.py
│ ├── resnet.py
│ └── vgg.py
├── cifar100
│ ├── AlexNet.py
│ ├── LeNet5.py
│ ├── fc1.py
│ ├── resnet.py
│ └── vgg.py
└── mnist
│ ├── AlexNet.py
│ ├── LeNet5.py
│ ├── fc1.py
│ ├── resnet.py
│ └── vgg.py
├── combine_plots.py
├── main.py
├── requirements.txt
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | /data
2 | /dumps
3 | /plots
4 | /runs
5 | /saves
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Lottery Ticket Hypothesis in Pytorch
2 | []() []() []()
3 |
4 | This repository contains a **Pytorch** implementation of the paper [The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks](https://arxiv.org/abs/1803.03635) by [Jonathan Frankle](https://github.com/jfrankle) and [Michael Carbin](https://people.csail.mit.edu/mcarbin/) that can be **easily adapted to any model/dataset**.
5 |
6 | ## Requirements
7 | ```
8 | pip3 install -r requirements.txt
9 | ```
10 | ## How to run the code ?
11 | ### Using datasets/architectures included with this repository :
12 | ```
13 | python3 main.py --prune_type=lt --arch_type=fc1 --dataset=mnist --prune_percent=10 --prune_iterations=35
14 | ```
15 | - `--prune_type` : Type of pruning
16 | - Options : `lt` - Lottery Ticket Hypothesis, `reinit` - Random reinitialization
17 | - Default : `lt`
18 | - `--arch_type` : Type of architecture
19 | - Options : `fc1` - Simple fully connected network, `lenet5` - LeNet5, `AlexNet` - AlexNet, `resnet18` - Resnet18, `vgg16` - VGG16
20 | - Default : `fc1`
21 | - `--dataset` : Choice of dataset
22 | - Options : `mnist`, `fashionmnist`, `cifar10`, `cifar100`
23 | - Default : `mnist`
24 | - `--prune_percent` : Percentage of weight to be pruned after each cycle.
25 | - Default : `10`
26 | - `--prune_iterations` : Number of cycle of pruning that should be done.
27 | - Default : `35`
28 | - `--lr` : Learning rate
29 | - Default : `1.2e-3`
30 | - `--batch_size` : Batch size
31 | - Default : `60`
32 | - `--end_iter` : Number of Epochs
33 | - Default : `100`
34 | - `--print_freq` : Frequency for printing accuracy and loss
35 | - Default : `1`
36 | - `--valid_freq` : Frequency for Validation
37 | - Default : `1`
38 | - `--gpu` : Decide Which GPU the program should use
39 | - Default : `0`
40 | ### Using datasets/architectures that are not included with this repository :
41 | - Adding a new architecture :
42 | - For example, if you want to add an architecture named `new_model` with `mnist` dataset compatibility.
43 | - Go to `/archs/mnist/` directory and create a file `new_model.py`.
44 | - Now paste your **Pytorch compatible** model inside `new_model.py`.
45 | - **IMPORTANT** : Make sure the *input size*, *number of classes*, *number of channels*, *batch size* in your `new_model.py` matches with the corresponding dataset that you are adding (in this case, it is `mnist`).
46 | - Now open `main.py` and go to `line 36` and look for the comment `# Data Loader`. Now find your corresponding dataset (in this case, `mnist`) and add `new_model` at the end of the line `from archs.mnist import AlexNet, LeNet5, fc1, vgg, resnet`.
47 | - Now go to `line 82` and add the following to it :
48 | ```
49 | elif args.arch_type == "new_model":
50 | model = new_model.new_model_name().to(device)
51 | ```
52 | Here, `new_model_name()` is the name of the model that you have given inside `new_model.py`.
53 | - Adding a new dataset :
54 | - For example, if you want to add a dataset named `new_dataset` with `fc1` architecture compatibility.
55 | - Go to `/archs` and create a directory named `new_dataset`.
56 | - Now go to /archs/new_dataset/` and add a file named `fc1.py` or copy paste it from existing dataset folder.
57 | - **IMPORTANT** : Make sure the *input size*, *number of classes*, *number of channels*, *batch size* in your `new_model.py` matches with the corresponding dataset that you are adding (in this case, it is `new_dataset`).
58 | - Now open `main.py` and goto `line 58` and add the following to it :
59 | ```
60 | elif args.dataset == "cifar100":
61 | traindataset = datasets.new_dataset('../data', train=True, download=True, transform=transform)
62 | testdataset = datasets.new_dataset('../data', train=False, transform=transform)from archs.new_dataset import fc1
63 | ```
64 | **Note** that as of now, you can only add dataset that are [natively available in Pytorch](https://pytorch.org/docs/stable/torchvision/datasets.html).
65 |
66 | ## How to combine the plots of various `prune_type` ?
67 | - Go to `combine_plots.py` and add/remove the datasets/archs who's combined plot you want to generate (*Assuming that you have already executed the `main.py` code for those dataset/archs and produced the weights*).
68 | - Run `python3 combine_plots.py`.
69 | - Go to `/plots/lt/combined_plots/` to see the graphs.
70 |
71 | Kindly [raise an issue](https://github.com/rahulvigneswaran/Lottery-Ticket-Hypothesis-in-Pytorch/issues) if you have any problem with the instructions.
72 |
73 |
74 | ## Datasets and Architectures that were already tested
75 |
76 | | | fc1 | LeNet5 | AlexNet | VGG16 | Resnet18 |
77 | |--------------|:------------------:|:---------------------:|:----------------------:|:--------------------:|:------------------------:|
78 | | MNIST | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
79 | | CIFAR10 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
80 | | FashionMNIST | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
81 | | CIFAR100 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
82 |
83 |
84 | ## Repository Structure
85 | ```
86 | Lottery-Ticket-Hypothesis-in-Pytorch
87 | ├── archs
88 | │ ├── cifar10
89 | │ │ ├── AlexNet.py
90 | │ │ ├── densenet.py
91 | │ │ ├── fc1.py
92 | │ │ ├── LeNet5.py
93 | │ │ ├── resnet.py
94 | │ │ └── vgg.py
95 | │ ├── cifar100
96 | │ │ ├── AlexNet.py
97 | │ │ ├── fc1.py
98 | │ │ ├── LeNet5.py
99 | │ │ ├── resnet.py
100 | │ │ └── vgg.py
101 | │ └── mnist
102 | │ ├── AlexNet.py
103 | │ ├── fc1.py
104 | │ ├── LeNet5.py
105 | │ ├── resnet.py
106 | │ └── vgg.py
107 | ├── combine_plots.py
108 | ├── dumps
109 | ├── main.py
110 | ├── plots
111 | ├── README.md
112 | ├── requirements.txt
113 | ├── saves
114 | └── utils.py
115 |
116 | ```
117 |
118 | ## Interesting papers that are related to Lottery Ticket Hypothesis which I enjoyed
119 | - [Deconstructing Lottery Tickets: Zeros, Signs, and the Supermask](https://eng.uber.com/deconstructing-lottery-tickets/)
120 |
121 | ## Acknowledgement
122 | Parts of code were borrowed from [ktkth5](https://github.com/ktkth5/lottery-ticket-hyopothesis).
123 |
124 | ## Issue / Want to Contribute ? :
125 | Open a new issue or do a pull request incase you are facing any difficulty with the code base or if you want to contribute to it.
126 |
127 | [](https://github.com/rahulvigneswaran/Lottery-Ticket-Hypothesis-in-Pytorch/issues)
128 |
129 |
130 |
131 |
--------------------------------------------------------------------------------
/archs/cifar10/AlexNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | __all__ = ['AlexNet', 'alexnet']
6 |
7 |
8 | model_urls = {
9 | 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
10 | }
11 |
12 |
13 | class AlexNet(nn.Module):
14 |
15 | def __init__(self, num_classes=10):
16 | super(AlexNet, self).__init__()
17 | self.features = nn.Sequential(
18 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=2),
19 | nn.ReLU(inplace=True),
20 | nn.MaxPool2d(kernel_size=3, stride=2),
21 | nn.Conv2d(64, 192, kernel_size=5, padding=2),
22 | nn.ReLU(inplace=True),
23 | nn.MaxPool2d(kernel_size=3, stride=2),
24 | nn.Conv2d(192, 384, kernel_size=3, padding=1),
25 | nn.ReLU(inplace=True),
26 | nn.Conv2d(384, 256, kernel_size=3, padding=1),
27 | nn.ReLU(inplace=True),
28 | nn.Conv2d(256, 256, kernel_size=3, padding=1),
29 | nn.ReLU(inplace=True),
30 | nn.MaxPool2d(kernel_size=3, stride=2),
31 | )
32 | self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
33 | self.classifier = nn.Sequential(
34 | nn.Dropout(),
35 | nn.Linear(256 * 6 * 6, 4096),
36 | nn.ReLU(inplace=True),
37 | nn.Dropout(),
38 | nn.Linear(4096, 4096),
39 | nn.ReLU(inplace=True),
40 | nn.Linear(4096, num_classes),
41 | )
42 |
43 | def forward(self, x):
44 | x = self.features(x)
45 | x = self.avgpool(x)
46 | x = torch.flatten(x, 1)
47 | x = self.classifier(x)
48 | return x
49 |
--------------------------------------------------------------------------------
/archs/cifar10/LeNet5.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as func
3 |
4 |
5 | class LeNet5(nn.Module):
6 | def __init__(self, num_classes=10):
7 | super(LeNet5, self).__init__()
8 | self.conv1 = nn.Conv2d(3, 6, kernel_size=5)
9 | self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
10 | self.fc1 = nn.Linear(16*5*5, 120)
11 | self.fc2 = nn.Linear(120, 84)
12 | self.fc3 = nn.Linear(84, num_classes)
13 |
14 | def forward(self, x):
15 | x = func.relu(self.conv1(x))
16 | x = func.max_pool2d(x, 2)
17 | x = func.relu(self.conv2(x))
18 | x = func.max_pool2d(x, 2)
19 | x = x.view(x.size(0), -1)
20 | x = func.relu(self.fc1(x))
21 | x = func.relu(self.fc2(x))
22 | x = self.fc3(x)
23 | return x
--------------------------------------------------------------------------------
/archs/cifar10/densenet.py:
--------------------------------------------------------------------------------
1 | import re
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.utils.checkpoint as cp
6 | from collections import OrderedDict
7 |
8 | def _bn_function_factory(norm, relu, conv):
9 | def bn_function(*inputs):
10 | concated_features = torch.cat(inputs, 1)
11 | bottleneck_output = conv(relu(norm(concated_features)))
12 | return bottleneck_output
13 |
14 | return bn_function
15 |
16 |
17 | class _DenseLayer(nn.Sequential):
18 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient=False):
19 | super(_DenseLayer, self).__init__()
20 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
21 | self.add_module('relu1', nn.ReLU(inplace=True)),
22 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
23 | growth_rate, kernel_size=1, stride=1,
24 | bias=False)),
25 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
26 | self.add_module('relu2', nn.ReLU(inplace=True)),
27 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
28 | kernel_size=3, stride=1, padding=1,
29 | bias=False)),
30 | self.drop_rate = drop_rate
31 | self.memory_efficient = memory_efficient
32 |
33 | def forward(self, *prev_features):
34 | bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1)
35 | if self.memory_efficient and any(prev_feature.requires_grad for prev_feature in prev_features):
36 | bottleneck_output = cp.checkpoint(bn_function, *prev_features)
37 | else:
38 | bottleneck_output = bn_function(*prev_features)
39 | new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
40 | if self.drop_rate > 0:
41 | new_features = F.dropout(new_features, p=self.drop_rate,
42 | training=self.training)
43 | return new_features
44 |
45 |
46 | class _DenseBlock(nn.Module):
47 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, memory_efficient=False):
48 | super(_DenseBlock, self).__init__()
49 | for i in range(num_layers):
50 | layer = _DenseLayer(
51 | num_input_features + i * growth_rate,
52 | growth_rate=growth_rate,
53 | bn_size=bn_size,
54 | drop_rate=drop_rate,
55 | memory_efficient=memory_efficient,
56 | )
57 | self.add_module('denselayer%d' % (i + 1), layer)
58 |
59 | def forward(self, init_features):
60 | features = [init_features]
61 | for name, layer in self.named_children():
62 | new_features = layer(*features)
63 | features.append(new_features)
64 | return torch.cat(features, 1)
65 |
66 |
67 | class _Transition(nn.Sequential):
68 | def __init__(self, num_input_features, num_output_features):
69 | super(_Transition, self).__init__()
70 | self.add_module('norm', nn.BatchNorm2d(num_input_features))
71 | self.add_module('relu', nn.ReLU(inplace=True))
72 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
73 | kernel_size=1, stride=1, bias=False))
74 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
75 |
76 |
77 | class DenseNet(nn.Module):
78 | r"""Densenet-BC model class, based on
79 | `"Densely Connected Convolutional Networks" `_
80 |
81 | Args:
82 | growth_rate (int) - how many filters to add each layer (`k` in paper)
83 | block_config (list of 4 ints) - how many layers in each pooling block
84 | num_init_features (int) - the number of filters to learn in the first convolution layer
85 | bn_size (int) - multiplicative factor for number of bottle neck layers
86 | (i.e. bn_size * k features in the bottleneck layer)
87 | drop_rate (float) - dropout rate after each dense layer
88 | num_classes (int) - number of classification classes
89 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
90 | but slower. Default: *False*. See `"paper" `_
91 | """
92 |
93 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
94 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=10, memory_efficient=False):
95 |
96 | super(DenseNet, self).__init__()
97 |
98 | # First convolution
99 | self.features = nn.Sequential(OrderedDict([
100 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2,
101 | padding=3, bias=False)),
102 | ('norm0', nn.BatchNorm2d(num_init_features)),
103 | ('relu0', nn.ReLU(inplace=True)),
104 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
105 | ]))
106 |
107 | # Each denseblock
108 | num_features = num_init_features
109 | for i, num_layers in enumerate(block_config):
110 | block = _DenseBlock(
111 | num_layers=num_layers,
112 | num_input_features=num_features,
113 | bn_size=bn_size,
114 | growth_rate=growth_rate,
115 | drop_rate=drop_rate,
116 | memory_efficient=memory_efficient
117 | )
118 | self.features.add_module('denseblock%d' % (i + 1), block)
119 | num_features = num_features + num_layers * growth_rate
120 | if i != len(block_config) - 1:
121 | trans = _Transition(num_input_features=num_features,
122 | num_output_features=num_features // 2)
123 | self.features.add_module('transition%d' % (i + 1), trans)
124 | num_features = num_features // 2
125 |
126 | # Final batch norm
127 | self.features.add_module('norm5', nn.BatchNorm2d(num_features))
128 |
129 | # Linear layer
130 | self.classifier = nn.Linear(num_features, num_classes)
131 |
132 | # Official init from torch repo.
133 | for m in self.modules():
134 | if isinstance(m, nn.Conv2d):
135 | nn.init.kaiming_normal_(m.weight)
136 | elif isinstance(m, nn.BatchNorm2d):
137 | nn.init.constant_(m.weight, 1)
138 | nn.init.constant_(m.bias, 0)
139 | elif isinstance(m, nn.Linear):
140 | nn.init.constant_(m.bias, 0)
141 |
142 | def forward(self, x):
143 | features = self.features(x)
144 | out = F.relu(features, inplace=True)
145 | out = F.adaptive_avg_pool2d(out, (1, 1))
146 | out = torch.flatten(out, 1)
147 | out = self.classifier(out)
148 | return out
149 |
150 |
151 | def _load_state_dict(model, model_url, progress):
152 | # '.'s are no longer allowed in module names, but previous _DenseLayer
153 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
154 | # They are also in the checkpoints in model_urls. This pattern is used
155 | # to find such keys.
156 | pattern = re.compile(
157 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
158 |
159 | state_dict = load_state_dict_from_url(model_url, progress=progress)
160 | for key in list(state_dict.keys()):
161 | res = pattern.match(key)
162 | if res:
163 | new_key = res.group(1) + res.group(2)
164 | state_dict[new_key] = state_dict[key]
165 | del state_dict[key]
166 | model.load_state_dict(state_dict)
167 |
168 |
169 | def _densenet(arch, growth_rate, block_config, num_init_features, pretrained, progress,
170 | **kwargs):
171 | model = DenseNet(growth_rate, block_config, num_init_features, **kwargs)
172 | if pretrained:
173 | _load_state_dict(model, model_urls[arch], progress)
174 | return model
175 |
176 |
177 | def densenet121(pretrained=False, progress=True, **kwargs):
178 | r"""Densenet-121 model from
179 | `"Densely Connected Convolutional Networks" `_
180 |
181 | Args:
182 | pretrained (bool): If True, returns a model pre-trained on ImageNet
183 | progress (bool): If True, displays a progress bar of the download to stderr
184 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
185 | but slower. Default: *False*. See `"paper" `_
186 | """
187 | return _densenet('densenet121', 32, (6, 12, 24, 16), 64, pretrained, progress,
188 | **kwargs)
189 |
190 |
191 |
192 | def densenet161(pretrained=False, progress=True, **kwargs):
193 | r"""Densenet-161 model from
194 | `"Densely Connected Convolutional Networks" `_
195 |
196 | Args:
197 | pretrained (bool): If True, returns a model pre-trained on ImageNet
198 | progress (bool): If True, displays a progress bar of the download to stderr
199 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
200 | but slower. Default: *False*. See `"paper" `_
201 | """
202 | return _densenet('densenet161', 48, (6, 12, 36, 24), 96, pretrained, progress,
203 | **kwargs)
204 |
205 |
206 |
207 | def densenet169(pretrained=False, progress=True, **kwargs):
208 | r"""Densenet-169 model from
209 | `"Densely Connected Convolutional Networks" `_
210 |
211 | Args:
212 | pretrained (bool): If True, returns a model pre-trained on ImageNet
213 | progress (bool): If True, displays a progress bar of the download to stderr
214 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
215 | but slower. Default: *False*. See `"paper" `_
216 | """
217 | return _densenet('densenet169', 32, (6, 12, 32, 32), 64, pretrained, progress,
218 | **kwargs)
219 |
220 |
221 |
222 | def densenet201(pretrained=False, progress=True, **kwargs):
223 | r"""Densenet-201 model from
224 | `"Densely Connected Convolutional Networks" `_
225 |
226 | Args:
227 | pretrained (bool): If True, returns a model pre-trained on ImageNet
228 | progress (bool): If True, displays a progress bar of the download to stderr
229 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
230 | but slower. Default: *False*. See `"paper" `_
231 | """
232 | return _densenet('densenet201', 32, (6, 12, 48, 32), 64, pretrained, progress,
233 | **kwargs)
234 |
--------------------------------------------------------------------------------
/archs/cifar10/fc1.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class fc1(nn.Module):
5 |
6 | def __init__(self, num_classes=10):
7 | super(fc1, self).__init__()
8 | self.classifier = nn.Sequential(
9 | nn.Linear(3*32*32, 300),
10 | nn.ReLU(inplace=True),
11 | nn.Linear(300, 100),
12 | nn.ReLU(inplace=True),
13 | nn.Linear(100, num_classes),
14 | )
15 |
16 | def forward(self, x):
17 | x = torch.flatten(x, 1)
18 | x = self.classifier(x)
19 | return x
20 |
--------------------------------------------------------------------------------
/archs/cifar10/resnet.py:
--------------------------------------------------------------------------------
1 | '''ResNet in PyTorch.
2 | For Pre-activation ResNet, see 'preact_resnet.py'.
3 | Reference:
4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
5 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
6 | '''
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 |
12 | class BasicBlock(nn.Module):
13 | expansion = 1
14 |
15 | def __init__(self, in_planes, planes, stride=1):
16 | super(BasicBlock, self).__init__()
17 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
18 | self.bn1 = nn.BatchNorm2d(planes)
19 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
20 | self.bn2 = nn.BatchNorm2d(planes)
21 |
22 | self.shortcut = nn.Sequential()
23 | if stride != 1 or in_planes != self.expansion*planes:
24 | self.shortcut = nn.Sequential(
25 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
26 | nn.BatchNorm2d(self.expansion*planes)
27 | )
28 |
29 | def forward(self, x):
30 | out = F.relu(self.bn1(self.conv1(x)))
31 | out = self.bn2(self.conv2(out))
32 | out += self.shortcut(x)
33 | out = F.relu(out)
34 | return out
35 |
36 |
37 | class Bottleneck(nn.Module):
38 | expansion = 4
39 |
40 | def __init__(self, in_planes, planes, stride=1):
41 | super(Bottleneck, self).__init__()
42 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
43 | self.bn1 = nn.BatchNorm2d(planes)
44 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
45 | self.bn2 = nn.BatchNorm2d(planes)
46 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
47 | self.bn3 = nn.BatchNorm2d(self.expansion*planes)
48 |
49 | self.shortcut = nn.Sequential()
50 | if stride != 1 or in_planes != self.expansion*planes:
51 | self.shortcut = nn.Sequential(
52 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
53 | nn.BatchNorm2d(self.expansion*planes)
54 | )
55 |
56 | def forward(self, x):
57 | out = F.relu(self.bn1(self.conv1(x)))
58 | out = F.relu(self.bn2(self.conv2(out)))
59 | out = self.bn3(self.conv3(out))
60 | out += self.shortcut(x)
61 | out = F.relu(out)
62 | return out
63 |
64 |
65 | class ResNet(nn.Module):
66 | def __init__(self, block, num_blocks, num_classes=10):
67 | super(ResNet, self).__init__()
68 | self.in_planes = 64
69 |
70 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
71 | self.bn1 = nn.BatchNorm2d(64)
72 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
73 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
74 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
75 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
76 | self.linear = nn.Linear(512*block.expansion, num_classes)
77 |
78 | def _make_layer(self, block, planes, num_blocks, stride):
79 | strides = [stride] + [1]*(num_blocks-1)
80 | layers = []
81 | for stride in strides:
82 | layers.append(block(self.in_planes, planes, stride))
83 | self.in_planes = planes * block.expansion
84 | return nn.Sequential(*layers)
85 |
86 | def forward(self, x):
87 | out = F.relu(self.bn1(self.conv1(x)))
88 | out = self.layer1(out)
89 | out = self.layer2(out)
90 | out = self.layer3(out)
91 | out = self.layer4(out)
92 | out = F.avg_pool2d(out, 4)
93 | out = out.view(out.size(0), -1)
94 | out = self.linear(out)
95 | return out
96 |
97 |
98 | def resnet18():
99 | return ResNet(BasicBlock, [2,2,2,2])
100 |
101 | def ResNet34():
102 | return ResNet(BasicBlock, [3,4,6,3])
103 |
104 | def ResNet50():
105 | return ResNet(Bottleneck, [3,4,6,3])
106 |
107 | def ResNet101():
108 | return ResNet(Bottleneck, [3,4,23,3])
109 |
110 | def ResNet152():
111 | return ResNet(Bottleneck, [3,8,36,3])
112 |
113 |
114 | def test():
115 | net = ResNet18()
116 | y = net(torch.randn(1,3,32,32))
117 | print(y.size())
118 |
119 | # test()
--------------------------------------------------------------------------------
/archs/cifar10/vgg.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | #
4 | # from torchvision.utils import load_state_dict_from_url
5 |
6 |
7 | __all__ = [
8 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
9 | 'vgg19_bn', 'vgg19',
10 | ]
11 |
12 |
13 | model_urls = {
14 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
15 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
16 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
17 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
18 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
19 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
20 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
21 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
22 | }
23 |
24 |
25 | class VGG(nn.Module):
26 | #ANCHOR Change No. of Classes here.
27 | def __init__(self, features, num_classes=10, init_weights=True):
28 | super(VGG, self).__init__()
29 | self.features = features
30 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
31 | self.classifier = nn.Sequential(
32 | nn.Linear(512 * 7 * 7, 4096),
33 | nn.ReLU(True),
34 | nn.Dropout(),
35 | nn.Linear(4096, 4096),
36 | nn.ReLU(True),
37 | nn.Dropout(),
38 | nn.Linear(4096, num_classes),
39 | )
40 | if init_weights:
41 | self._initialize_weights()
42 |
43 | def forward(self, x):
44 | x = self.features(x)
45 | x = self.avgpool(x)
46 | x = torch.flatten(x, 1)
47 | x = self.classifier(x)
48 | return x
49 |
50 | def _initialize_weights(self):
51 | for m in self.modules():
52 | if isinstance(m, nn.Conv2d):
53 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
54 | if m.bias is not None:
55 | nn.init.constant_(m.bias, 0)
56 | elif isinstance(m, nn.BatchNorm2d):
57 | nn.init.constant_(m.weight, 1)
58 | nn.init.constant_(m.bias, 0)
59 | elif isinstance(m, nn.Linear):
60 | nn.init.normal_(m.weight, 0, 0.01)
61 | nn.init.constant_(m.bias, 0)
62 |
63 |
64 | def make_layers(cfg, batch_norm=False):
65 | layers = []
66 | #ANCHOR Change No. of Input channels here.
67 | in_channels = 3
68 | for v in cfg:
69 | if v == 'M':
70 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
71 | else:
72 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
73 | if batch_norm:
74 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
75 | else:
76 | layers += [conv2d, nn.ReLU(inplace=True)]
77 | in_channels = v
78 | return nn.Sequential(*layers)
79 |
80 |
81 | cfgs = {
82 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
83 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
84 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
85 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
86 | }
87 |
88 |
89 | def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs):
90 | if pretrained:
91 | kwargs['init_weights'] = False
92 | model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
93 | #if pretrained:
94 | #state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
95 | #model.load_state_dict(state_dict)
96 | return model
97 |
98 |
99 | def vgg11(pretrained=False, progress=True, **kwargs):
100 | r"""VGG 11-layer model (configuration "A") from
101 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
102 |
103 | Args:
104 | pretrained (bool): If True, returns a model pre-trained on ImageNet
105 | progress (bool): If True, displays a progress bar of the download to stderr
106 | """
107 | return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs)
108 |
109 |
110 |
111 | def vgg11_bn(pretrained=False, progress=True, **kwargs):
112 | r"""VGG 11-layer model (configuration "A") with batch normalization
113 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
114 |
115 | Args:
116 | pretrained (bool): If True, returns a model pre-trained on ImageNet
117 | progress (bool): If True, displays a progress bar of the download to stderr
118 | """
119 | return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs)
120 |
121 |
122 |
123 | def vgg13(pretrained=False, progress=True, **kwargs):
124 | r"""VGG 13-layer model (configuration "B")
125 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
126 |
127 | Args:
128 | pretrained (bool): If True, returns a model pre-trained on ImageNet
129 | progress (bool): If True, displays a progress bar of the download to stderr
130 | """
131 | return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs)
132 |
133 |
134 |
135 | def vgg13_bn(pretrained=False, progress=True, **kwargs):
136 | r"""VGG 13-layer model (configuration "B") with batch normalization
137 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
138 |
139 | Args:
140 | pretrained (bool): If True, returns a model pre-trained on ImageNet
141 | progress (bool): If True, displays a progress bar of the download to stderr
142 | """
143 | return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs)
144 |
145 |
146 |
147 | def vgg16(pretrained=False, progress=True, **kwargs):
148 | r"""VGG 16-layer model (configuration "D")
149 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
150 |
151 | Args:
152 | pretrained (bool): If True, returns a model pre-trained on ImageNet
153 | progress (bool): If True, displays a progress bar of the download to stderr
154 | """
155 | return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs)
156 |
157 |
158 |
159 | def vgg16_bn(pretrained=False, progress=True, **kwargs):
160 | r"""VGG 16-layer model (configuration "D") with batch normalization
161 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
162 |
163 | Args:
164 | pretrained (bool): If True, returns a model pre-trained on ImageNet
165 | progress (bool): If True, displays a progress bar of the download to stderr
166 | """
167 | return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs)
168 |
169 |
170 |
171 | def vgg19(pretrained=False, progress=True, **kwargs):
172 | r"""VGG 19-layer model (configuration "E")
173 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
174 |
175 | Args:
176 | pretrained (bool): If True, returns a model pre-trained on ImageNet
177 | progress (bool): If True, displays a progress bar of the download to stderr
178 | """
179 | return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs)
180 |
181 |
182 |
183 | def vgg19_bn(pretrained=False, progress=True, **kwargs):
184 | r"""VGG 19-layer model (configuration 'E') with batch normalization
185 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
186 |
187 | Args:
188 | pretrained (bool): If True, returns a model pre-trained on ImageNet
189 | progress (bool): If True, displays a progress bar of the download to stderr
190 | """
191 | return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs)
192 |
--------------------------------------------------------------------------------
/archs/cifar100/AlexNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | __all__ = ['AlexNet', 'alexnet']
6 |
7 |
8 | model_urls = {
9 | 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
10 | }
11 |
12 |
13 | class AlexNet(nn.Module):
14 |
15 | def __init__(self, num_classes=100):
16 | super(AlexNet, self).__init__()
17 | self.features = nn.Sequential(
18 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=2),
19 | nn.ReLU(inplace=True),
20 | nn.MaxPool2d(kernel_size=3, stride=2),
21 | nn.Conv2d(64, 192, kernel_size=5, padding=2),
22 | nn.ReLU(inplace=True),
23 | nn.MaxPool2d(kernel_size=3, stride=2),
24 | nn.Conv2d(192, 384, kernel_size=3, padding=1),
25 | nn.ReLU(inplace=True),
26 | nn.Conv2d(384, 256, kernel_size=3, padding=1),
27 | nn.ReLU(inplace=True),
28 | nn.Conv2d(256, 256, kernel_size=3, padding=1),
29 | nn.ReLU(inplace=True),
30 | nn.MaxPool2d(kernel_size=3, stride=2),
31 | )
32 | self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
33 | self.classifier = nn.Sequential(
34 | nn.Dropout(),
35 | nn.Linear(256 * 6 * 6, 4096),
36 | nn.ReLU(inplace=True),
37 | nn.Dropout(),
38 | nn.Linear(4096, 4096),
39 | nn.ReLU(inplace=True),
40 | nn.Linear(4096, num_classes),
41 | )
42 |
43 | def forward(self, x):
44 | x = self.features(x)
45 | x = self.avgpool(x)
46 | x = torch.flatten(x, 1)
47 | x = self.classifier(x)
48 | return x
49 |
--------------------------------------------------------------------------------
/archs/cifar100/LeNet5.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as func
3 |
4 |
5 | class LeNet5(nn.Module):
6 | def __init__(self, num_classes=100):
7 | super(LeNet5, self).__init__()
8 | self.conv1 = nn.Conv2d(3, 6, kernel_size=5)
9 | self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
10 | self.fc1 = nn.Linear(16*5*5, 120)
11 | self.fc2 = nn.Linear(120, 84)
12 | self.fc3 = nn.Linear(84, num_classes)
13 |
14 | def forward(self, x):
15 | x = func.relu(self.conv1(x))
16 | x = func.max_pool2d(x, 2)
17 | x = func.relu(self.conv2(x))
18 | x = func.max_pool2d(x, 2)
19 | x = x.view(x.size(0), -1)
20 | x = func.relu(self.fc1(x))
21 | x = func.relu(self.fc2(x))
22 | x = self.fc3(x)
23 | return x
--------------------------------------------------------------------------------
/archs/cifar100/fc1.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class fc1(nn.Module):
5 |
6 | def __init__(self, num_classes=100):
7 | super(fc1, self).__init__()
8 | self.classifier = nn.Sequential(
9 | nn.Linear(3*32*32, 300),
10 | nn.ReLU(inplace=True),
11 | nn.Linear(300, 100),
12 | nn.ReLU(inplace=True),
13 | nn.Linear(100, num_classes),
14 | )
15 |
16 | def forward(self, x):
17 | x = torch.flatten(x, 1)
18 | x = self.classifier(x)
19 | return x
20 |
--------------------------------------------------------------------------------
/archs/cifar100/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
5 | """3x3 convolution with padding"""
6 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
7 | padding=dilation, groups=groups, bias=False, dilation=dilation)
8 |
9 |
10 | def conv1x1(in_planes, out_planes, stride=1):
11 | """1x1 convolution"""
12 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
13 |
14 |
15 | class BasicBlock(nn.Module):
16 | expansion = 1
17 | __constants__ = ['downsample']
18 |
19 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
20 | base_width=64, dilation=1, norm_layer=None):
21 | super(BasicBlock, self).__init__()
22 | if norm_layer is None:
23 | norm_layer = nn.BatchNorm2d
24 | if groups != 1 or base_width != 64:
25 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
26 | if dilation > 1:
27 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
28 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
29 | self.conv1 = conv3x3(inplanes, planes, stride)
30 | self.bn1 = norm_layer(planes)
31 | self.relu = nn.ReLU(inplace=True)
32 | self.conv2 = conv3x3(planes, planes)
33 | self.bn2 = norm_layer(planes)
34 | self.downsample = downsample
35 | self.stride = stride
36 |
37 | def forward(self, x):
38 | identity = x
39 |
40 | out = self.conv1(x)
41 | out = self.bn1(out)
42 | out = self.relu(out)
43 |
44 | out = self.conv2(out)
45 | out = self.bn2(out)
46 |
47 | if self.downsample is not None:
48 | identity = self.downsample(x)
49 |
50 | out += identity
51 | out = self.relu(out)
52 |
53 | return out
54 |
55 |
56 | class Bottleneck(nn.Module):
57 | expansion = 4
58 | __constants__ = ['downsample']
59 |
60 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
61 | base_width=64, dilation=1, norm_layer=None):
62 | super(Bottleneck, self).__init__()
63 | if norm_layer is None:
64 | norm_layer = nn.BatchNorm2d
65 | width = int(planes * (base_width / 64.)) * groups
66 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
67 | self.conv1 = conv1x1(inplanes, width)
68 | self.bn1 = norm_layer(width)
69 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
70 | self.bn2 = norm_layer(width)
71 | self.conv3 = conv1x1(width, planes * self.expansion)
72 | self.bn3 = norm_layer(planes * self.expansion)
73 | self.relu = nn.ReLU(inplace=True)
74 | self.downsample = downsample
75 | self.stride = stride
76 |
77 | def forward(self, x):
78 | identity = x
79 |
80 | out = self.conv1(x)
81 | out = self.bn1(out)
82 | out = self.relu(out)
83 |
84 | out = self.conv2(out)
85 | out = self.bn2(out)
86 | out = self.relu(out)
87 |
88 | out = self.conv3(out)
89 | out = self.bn3(out)
90 |
91 | if self.downsample is not None:
92 | identity = self.downsample(x)
93 |
94 | out += identity
95 | out = self.relu(out)
96 |
97 | return out
98 |
99 |
100 | class ResNet(nn.Module):
101 |
102 | def __init__(self, block, layers, num_classes=100, zero_init_residual=False,
103 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
104 | norm_layer=None):
105 | super(ResNet, self).__init__()
106 | if norm_layer is None:
107 | norm_layer = nn.BatchNorm2d
108 | self._norm_layer = norm_layer
109 |
110 | self.inplanes = 64
111 | self.dilation = 1
112 | if replace_stride_with_dilation is None:
113 | # each element in the tuple indicates if we should replace
114 | # the 2x2 stride with a dilated convolution instead
115 | replace_stride_with_dilation = [False, False, False]
116 | if len(replace_stride_with_dilation) != 3:
117 | raise ValueError("replace_stride_with_dilation should be None "
118 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
119 | self.groups = groups
120 | self.base_width = width_per_group
121 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
122 | bias=False)
123 | self.bn1 = norm_layer(self.inplanes)
124 | self.relu = nn.ReLU(inplace=True)
125 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
126 | self.layer1 = self._make_layer(block, 64, layers[0])
127 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
128 | dilate=replace_stride_with_dilation[0])
129 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
130 | dilate=replace_stride_with_dilation[1])
131 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
132 | dilate=replace_stride_with_dilation[2])
133 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
134 | self.fc = nn.Linear(512 * block.expansion, num_classes)
135 |
136 | for m in self.modules():
137 | if isinstance(m, nn.Conv2d):
138 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
139 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
140 | nn.init.constant_(m.weight, 1)
141 | nn.init.constant_(m.bias, 0)
142 |
143 | # Zero-initialize the last BN in each residual branch,
144 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
145 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
146 | if zero_init_residual:
147 | for m in self.modules():
148 | if isinstance(m, Bottleneck):
149 | nn.init.constant_(m.bn3.weight, 0)
150 | elif isinstance(m, BasicBlock):
151 | nn.init.constant_(m.bn2.weight, 0)
152 |
153 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
154 | norm_layer = self._norm_layer
155 | downsample = None
156 | previous_dilation = self.dilation
157 | if dilate:
158 | self.dilation *= stride
159 | stride = 1
160 | if stride != 1 or self.inplanes != planes * block.expansion:
161 | downsample = nn.Sequential(
162 | conv1x1(self.inplanes, planes * block.expansion, stride),
163 | norm_layer(planes * block.expansion),
164 | )
165 |
166 | layers = []
167 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
168 | self.base_width, previous_dilation, norm_layer))
169 | self.inplanes = planes * block.expansion
170 | for _ in range(1, blocks):
171 | layers.append(block(self.inplanes, planes, groups=self.groups,
172 | base_width=self.base_width, dilation=self.dilation,
173 | norm_layer=norm_layer))
174 |
175 | return nn.Sequential(*layers)
176 |
177 | def forward(self, x):
178 | x = self.conv1(x)
179 | x = self.bn1(x)
180 | x = self.relu(x)
181 | x = self.maxpool(x)
182 |
183 | x = self.layer1(x)
184 | x = self.layer2(x)
185 | x = self.layer3(x)
186 | x = self.layer4(x)
187 |
188 | x = self.avgpool(x)
189 | x = torch.flatten(x, 1)
190 | x = self.fc(x)
191 |
192 | return x
193 |
194 |
195 | def _resnet(arch, block, layers, pretrained, progress, **kwargs):
196 | model = ResNet(block, layers, **kwargs)
197 | return model
198 |
199 |
200 | def resnet18(pretrained=False, progress=True, **kwargs):
201 | r"""ResNet-18 model from
202 | `"Deep Residual Learning for Image Recognition" `_
203 | Args:
204 | pretrained (bool): If True, returns a model pre-trained on ImageNet
205 | progress (bool): If True, displays a progress bar of the download to stderr
206 | """
207 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
208 | **kwargs)
209 |
210 |
211 | def resnet34(pretrained=False, progress=True, **kwargs):
212 | r"""ResNet-34 model from
213 | `"Deep Residual Learning for Image Recognition" `_
214 | Args:
215 | pretrained (bool): If True, returns a model pre-trained on ImageNet
216 | progress (bool): If True, displays a progress bar of the download to stderr
217 | """
218 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
219 | **kwargs)
220 |
221 |
222 | def resnet50(pretrained=False, progress=True, **kwargs):
223 | r"""ResNet-50 model from
224 | `"Deep Residual Learning for Image Recognition" `_
225 | Args:
226 | pretrained (bool): If True, returns a model pre-trained on ImageNet
227 | progress (bool): If True, displays a progress bar of the download to stderr
228 | """
229 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
230 | **kwargs)
231 |
232 |
233 | def resnet101(pretrained=False, progress=True, **kwargs):
234 | r"""ResNet-101 model from
235 | `"Deep Residual Learning for Image Recognition" `_
236 | Args:
237 | pretrained (bool): If True, returns a model pre-trained on ImageNet
238 | progress (bool): If True, displays a progress bar of the download to stderr
239 | """
240 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
241 | **kwargs)
242 |
243 |
244 | def resnet152(pretrained=False, progress=True, **kwargs):
245 | r"""ResNet-152 model from
246 | `"Deep Residual Learning for Image Recognition" `_
247 | Args:
248 | pretrained (bool): If True, returns a model pre-trained on ImageNet
249 | progress (bool): If True, displays a progress bar of the download to stderr
250 | """
251 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
252 | **kwargs)
253 |
254 |
255 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
256 | r"""ResNeXt-50 32x4d model from
257 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
258 | Args:
259 | pretrained (bool): If True, returns a model pre-trained on ImageNet
260 | progress (bool): If True, displays a progress bar of the download to stderr
261 | """
262 | kwargs['groups'] = 32
263 | kwargs['width_per_group'] = 4
264 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
265 | pretrained, progress, **kwargs)
266 |
267 |
268 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
269 | r"""ResNeXt-101 32x8d model from
270 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
271 | Args:
272 | pretrained (bool): If True, returns a model pre-trained on ImageNet
273 | progress (bool): If True, displays a progress bar of the download to stderr
274 | """
275 | kwargs['groups'] = 32
276 | kwargs['width_per_group'] = 8
277 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
278 | pretrained, progress, **kwargs)
279 |
280 |
281 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
282 | r"""Wide ResNet-50-2 model from
283 | `"Wide Residual Networks" `_
284 | The model is the same as ResNet except for the bottleneck number of channels
285 | which is twice larger in every block. The number of channels in outer 1x1
286 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
287 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
288 | Args:
289 | pretrained (bool): If True, returns a model pre-trained on ImageNet
290 | progress (bool): If True, displays a progress bar of the download to stderr
291 | """
292 | kwargs['width_per_group'] = 64 * 2
293 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
294 | pretrained, progress, **kwargs)
295 |
296 |
297 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
298 | r"""Wide ResNet-101-2 model from
299 | `"Wide Residual Networks" `_
300 | The model is the same as ResNet except for the bottleneck number of channels
301 | which is twice larger in every block. The number of channels in outer 1x1
302 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
303 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
304 | Args:
305 | pretrained (bool): If True, returns a model pre-trained on ImageNet
306 | progress (bool): If True, displays a progress bar of the download to stderr
307 | """
308 | kwargs['width_per_group'] = 64 * 2
309 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
310 | pretrained, progress, **kwargs)
--------------------------------------------------------------------------------
/archs/cifar100/vgg.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | #
4 | # from torchvision.utils import load_state_dict_from_url
5 |
6 |
7 | __all__ = [
8 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
9 | 'vgg19_bn', 'vgg19',
10 | ]
11 |
12 |
13 | model_urls = {
14 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
15 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
16 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
17 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
18 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
19 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
20 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
21 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
22 | }
23 |
24 |
25 | class VGG(nn.Module):
26 | #ANCHOR Change No. of Classes here.
27 | def __init__(self, features, num_classes=100, init_weights=True):
28 | super(VGG, self).__init__()
29 | self.features = features
30 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
31 | self.classifier = nn.Sequential(
32 | nn.Linear(512 * 7 * 7, 4096),
33 | nn.ReLU(True),
34 | nn.Dropout(),
35 | nn.Linear(4096, 4096),
36 | nn.ReLU(True),
37 | nn.Dropout(),
38 | nn.Linear(4096, num_classes),
39 | )
40 | if init_weights:
41 | self._initialize_weights()
42 |
43 | def forward(self, x):
44 | x = self.features(x)
45 | x = self.avgpool(x)
46 | x = torch.flatten(x, 1)
47 | x = self.classifier(x)
48 | return x
49 |
50 | def _initialize_weights(self):
51 | for m in self.modules():
52 | if isinstance(m, nn.Conv2d):
53 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
54 | if m.bias is not None:
55 | nn.init.constant_(m.bias, 0)
56 | elif isinstance(m, nn.BatchNorm2d):
57 | nn.init.constant_(m.weight, 1)
58 | nn.init.constant_(m.bias, 0)
59 | elif isinstance(m, nn.Linear):
60 | nn.init.normal_(m.weight, 0, 0.01)
61 | nn.init.constant_(m.bias, 0)
62 |
63 |
64 | def make_layers(cfg, batch_norm=False):
65 | layers = []
66 | #ANCHOR Change No. of Input channels here.
67 | in_channels = 3
68 | for v in cfg:
69 | if v == 'M':
70 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
71 | else:
72 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
73 | if batch_norm:
74 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
75 | else:
76 | layers += [conv2d, nn.ReLU(inplace=True)]
77 | in_channels = v
78 | return nn.Sequential(*layers)
79 |
80 |
81 | cfgs = {
82 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
83 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
84 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
85 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
86 | }
87 |
88 |
89 | def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs):
90 | if pretrained:
91 | kwargs['init_weights'] = False
92 | model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
93 | #if pretrained:
94 | #state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
95 | #model.load_state_dict(state_dict)
96 | return model
97 |
98 |
99 | def vgg11(pretrained=False, progress=True, **kwargs):
100 | r"""VGG 11-layer model (configuration "A") from
101 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
102 |
103 | Args:
104 | pretrained (bool): If True, returns a model pre-trained on ImageNet
105 | progress (bool): If True, displays a progress bar of the download to stderr
106 | """
107 | return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs)
108 |
109 |
110 |
111 | def vgg11_bn(pretrained=False, progress=True, **kwargs):
112 | r"""VGG 11-layer model (configuration "A") with batch normalization
113 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
114 |
115 | Args:
116 | pretrained (bool): If True, returns a model pre-trained on ImageNet
117 | progress (bool): If True, displays a progress bar of the download to stderr
118 | """
119 | return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs)
120 |
121 |
122 |
123 | def vgg13(pretrained=False, progress=True, **kwargs):
124 | r"""VGG 13-layer model (configuration "B")
125 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
126 |
127 | Args:
128 | pretrained (bool): If True, returns a model pre-trained on ImageNet
129 | progress (bool): If True, displays a progress bar of the download to stderr
130 | """
131 | return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs)
132 |
133 |
134 |
135 | def vgg13_bn(pretrained=False, progress=True, **kwargs):
136 | r"""VGG 13-layer model (configuration "B") with batch normalization
137 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
138 |
139 | Args:
140 | pretrained (bool): If True, returns a model pre-trained on ImageNet
141 | progress (bool): If True, displays a progress bar of the download to stderr
142 | """
143 | return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs)
144 |
145 |
146 |
147 | def vgg16(pretrained=False, progress=True, **kwargs):
148 | r"""VGG 16-layer model (configuration "D")
149 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
150 |
151 | Args:
152 | pretrained (bool): If True, returns a model pre-trained on ImageNet
153 | progress (bool): If True, displays a progress bar of the download to stderr
154 | """
155 | return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs)
156 |
157 |
158 |
159 | def vgg16_bn(pretrained=False, progress=True, **kwargs):
160 | r"""VGG 16-layer model (configuration "D") with batch normalization
161 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
162 |
163 | Args:
164 | pretrained (bool): If True, returns a model pre-trained on ImageNet
165 | progress (bool): If True, displays a progress bar of the download to stderr
166 | """
167 | return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs)
168 |
169 |
170 |
171 | def vgg19(pretrained=False, progress=True, **kwargs):
172 | r"""VGG 19-layer model (configuration "E")
173 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
174 |
175 | Args:
176 | pretrained (bool): If True, returns a model pre-trained on ImageNet
177 | progress (bool): If True, displays a progress bar of the download to stderr
178 | """
179 | return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs)
180 |
181 |
182 |
183 | def vgg19_bn(pretrained=False, progress=True, **kwargs):
184 | r"""VGG 19-layer model (configuration 'E') with batch normalization
185 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
186 |
187 | Args:
188 | pretrained (bool): If True, returns a model pre-trained on ImageNet
189 | progress (bool): If True, displays a progress bar of the download to stderr
190 | """
191 | return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs)
192 |
--------------------------------------------------------------------------------
/archs/mnist/AlexNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | __all__ = ['AlexNet', 'alexnet']
6 |
7 |
8 | model_urls = {
9 | 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
10 | }
11 |
12 |
13 | class AlexNet(nn.Module):
14 |
15 | def __init__(self, num_classes=10):
16 | super(AlexNet, self).__init__()
17 | self.features = nn.Sequential(
18 | nn.Conv2d(1, 64, kernel_size=3, stride=2, padding=2),
19 | nn.ReLU(inplace=True),
20 | nn.MaxPool2d(kernel_size=3, stride=2),
21 | nn.Conv2d(64, 192, kernel_size=5, padding=2),
22 | nn.ReLU(inplace=True),
23 | nn.MaxPool2d(kernel_size=3, stride=2),
24 | nn.Conv2d(192, 384, kernel_size=3, padding=1),
25 | nn.ReLU(inplace=True),
26 | nn.Conv2d(384, 256, kernel_size=3, padding=1),
27 | nn.ReLU(inplace=True),
28 | nn.Conv2d(256, 256, kernel_size=3, padding=1),
29 | nn.ReLU(inplace=True),
30 | nn.MaxPool2d(kernel_size=3, stride=2),
31 | )
32 | self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
33 | self.classifier = nn.Sequential(
34 | nn.Dropout(),
35 | nn.Linear(256 * 6 * 6, 4096),
36 | nn.ReLU(inplace=True),
37 | nn.Dropout(),
38 | nn.Linear(4096, 4096),
39 | nn.ReLU(inplace=True),
40 | nn.Linear(4096, num_classes),
41 | )
42 |
43 | def forward(self, x):
44 | x = self.features(x)
45 | x = self.avgpool(x)
46 | x = torch.flatten(x, 1)
47 | x = self.classifier(x)
48 | return x
49 |
--------------------------------------------------------------------------------
/archs/mnist/LeNet5.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class LeNet5(nn.Module):
5 | def __init__(self, num_classes=10):
6 | super(LeNet5, self).__init__()
7 | self.features = nn.Sequential(
8 | nn.Conv2d(1, 64, kernel_size=(3, 3), stride=1, padding=1),
9 | nn.ReLU(),
10 | nn.Conv2d(64, 64, kernel_size=(3, 3), stride=1, padding=1),
11 | nn.ReLU(),
12 | nn.MaxPool2d(kernel_size=2),
13 | )
14 | self.classifier = nn.Sequential(
15 | nn.Linear(64*14*14, 256),
16 | nn.ReLU(inplace=True),
17 | nn.Linear(256, 256),
18 | nn.ReLU(inplace=True),
19 | nn.Linear(256, num_classes),
20 | )
21 |
22 | def forward(self, x):
23 | x = self.features(x)
24 | x = torch.flatten(x, 1)
25 | x = self.classifier(x)
26 | return x
27 |
--------------------------------------------------------------------------------
/archs/mnist/fc1.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class fc1(nn.Module):
5 |
6 | def __init__(self, num_classes=10):
7 | super(fc1, self).__init__()
8 | self.classifier = nn.Sequential(
9 | nn.Linear(28*28, 300),
10 | nn.ReLU(inplace=True),
11 | nn.Linear(300, 100),
12 | nn.ReLU(inplace=True),
13 | nn.Linear(100, num_classes),
14 | )
15 |
16 | def forward(self, x):
17 | x = torch.flatten(x, 1)
18 | x = self.classifier(x)
19 | return x
20 |
21 |
--------------------------------------------------------------------------------
/archs/mnist/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
5 | """3x3 convolution with padding"""
6 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
7 | padding=dilation, groups=groups, bias=False, dilation=dilation)
8 |
9 |
10 | def conv1x1(in_planes, out_planes, stride=1):
11 | """1x1 convolution"""
12 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
13 |
14 |
15 | class BasicBlock(nn.Module):
16 | expansion = 1
17 | __constants__ = ['downsample']
18 |
19 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
20 | base_width=64, dilation=1, norm_layer=None):
21 | super(BasicBlock, self).__init__()
22 | if norm_layer is None:
23 | norm_layer = nn.BatchNorm2d
24 | if groups != 1 or base_width != 64:
25 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
26 | if dilation > 1:
27 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
28 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
29 | self.conv1 = conv3x3(inplanes, planes, stride)
30 | self.bn1 = norm_layer(planes)
31 | self.relu = nn.ReLU(inplace=True)
32 | self.conv2 = conv3x3(planes, planes)
33 | self.bn2 = norm_layer(planes)
34 | self.downsample = downsample
35 | self.stride = stride
36 |
37 | def forward(self, x):
38 | identity = x
39 |
40 | out = self.conv1(x)
41 | out = self.bn1(out)
42 | out = self.relu(out)
43 |
44 | out = self.conv2(out)
45 | out = self.bn2(out)
46 |
47 | if self.downsample is not None:
48 | identity = self.downsample(x)
49 |
50 | out += identity
51 | out = self.relu(out)
52 |
53 | return out
54 |
55 |
56 | class Bottleneck(nn.Module):
57 | expansion = 4
58 | __constants__ = ['downsample']
59 |
60 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
61 | base_width=64, dilation=1, norm_layer=None):
62 | super(Bottleneck, self).__init__()
63 | if norm_layer is None:
64 | norm_layer = nn.BatchNorm2d
65 | width = int(planes * (base_width / 64.)) * groups
66 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
67 | self.conv1 = conv1x1(inplanes, width)
68 | self.bn1 = norm_layer(width)
69 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
70 | self.bn2 = norm_layer(width)
71 | self.conv3 = conv1x1(width, planes * self.expansion)
72 | self.bn3 = norm_layer(planes * self.expansion)
73 | self.relu = nn.ReLU(inplace=True)
74 | self.downsample = downsample
75 | self.stride = stride
76 |
77 | def forward(self, x):
78 | identity = x
79 |
80 | out = self.conv1(x)
81 | out = self.bn1(out)
82 | out = self.relu(out)
83 |
84 | out = self.conv2(out)
85 | out = self.bn2(out)
86 | out = self.relu(out)
87 |
88 | out = self.conv3(out)
89 | out = self.bn3(out)
90 |
91 | if self.downsample is not None:
92 | identity = self.downsample(x)
93 |
94 | out += identity
95 | out = self.relu(out)
96 |
97 | return out
98 |
99 |
100 | class ResNet(nn.Module):
101 |
102 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
103 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
104 | norm_layer=None):
105 | super(ResNet, self).__init__()
106 | if norm_layer is None:
107 | norm_layer = nn.BatchNorm2d
108 | self._norm_layer = norm_layer
109 |
110 | self.inplanes = 64
111 | self.dilation = 1
112 | if replace_stride_with_dilation is None:
113 | # each element in the tuple indicates if we should replace
114 | # the 2x2 stride with a dilated convolution instead
115 | replace_stride_with_dilation = [False, False, False]
116 | if len(replace_stride_with_dilation) != 3:
117 | raise ValueError("replace_stride_with_dilation should be None "
118 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
119 | self.groups = groups
120 | self.base_width = width_per_group
121 | self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=7, stride=2, padding=3,
122 | bias=False)
123 | self.bn1 = norm_layer(self.inplanes)
124 | self.relu = nn.ReLU(inplace=True)
125 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
126 | self.layer1 = self._make_layer(block, 64, layers[0])
127 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
128 | dilate=replace_stride_with_dilation[0])
129 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
130 | dilate=replace_stride_with_dilation[1])
131 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
132 | dilate=replace_stride_with_dilation[2])
133 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
134 | self.fc = nn.Linear(512 * block.expansion, num_classes)
135 |
136 | for m in self.modules():
137 | if isinstance(m, nn.Conv2d):
138 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
139 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
140 | nn.init.constant_(m.weight, 1)
141 | nn.init.constant_(m.bias, 0)
142 |
143 | # Zero-initialize the last BN in each residual branch,
144 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
145 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
146 | if zero_init_residual:
147 | for m in self.modules():
148 | if isinstance(m, Bottleneck):
149 | nn.init.constant_(m.bn3.weight, 0)
150 | elif isinstance(m, BasicBlock):
151 | nn.init.constant_(m.bn2.weight, 0)
152 |
153 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
154 | norm_layer = self._norm_layer
155 | downsample = None
156 | previous_dilation = self.dilation
157 | if dilate:
158 | self.dilation *= stride
159 | stride = 1
160 | if stride != 1 or self.inplanes != planes * block.expansion:
161 | downsample = nn.Sequential(
162 | conv1x1(self.inplanes, planes * block.expansion, stride),
163 | norm_layer(planes * block.expansion),
164 | )
165 |
166 | layers = []
167 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
168 | self.base_width, previous_dilation, norm_layer))
169 | self.inplanes = planes * block.expansion
170 | for _ in range(1, blocks):
171 | layers.append(block(self.inplanes, planes, groups=self.groups,
172 | base_width=self.base_width, dilation=self.dilation,
173 | norm_layer=norm_layer))
174 |
175 | return nn.Sequential(*layers)
176 |
177 | def forward(self, x):
178 | x = self.conv1(x)
179 | x = self.bn1(x)
180 | x = self.relu(x)
181 | x = self.maxpool(x)
182 |
183 | x = self.layer1(x)
184 | x = self.layer2(x)
185 | x = self.layer3(x)
186 | x = self.layer4(x)
187 |
188 | x = self.avgpool(x)
189 | x = torch.flatten(x, 1)
190 | x = self.fc(x)
191 |
192 | return x
193 |
194 |
195 | def _resnet(arch, block, layers, pretrained, progress, **kwargs):
196 | model = ResNet(block, layers, **kwargs)
197 | return model
198 |
199 |
200 | def resnet18(pretrained=False, progress=True, **kwargs):
201 | r"""ResNet-18 model from
202 | `"Deep Residual Learning for Image Recognition" `_
203 | Args:
204 | pretrained (bool): If True, returns a model pre-trained on ImageNet
205 | progress (bool): If True, displays a progress bar of the download to stderr
206 | """
207 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
208 | **kwargs)
209 |
210 |
211 | def resnet34(pretrained=False, progress=True, **kwargs):
212 | r"""ResNet-34 model from
213 | `"Deep Residual Learning for Image Recognition" `_
214 | Args:
215 | pretrained (bool): If True, returns a model pre-trained on ImageNet
216 | progress (bool): If True, displays a progress bar of the download to stderr
217 | """
218 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
219 | **kwargs)
220 |
221 |
222 | def resnet50(pretrained=False, progress=True, **kwargs):
223 | r"""ResNet-50 model from
224 | `"Deep Residual Learning for Image Recognition" `_
225 | Args:
226 | pretrained (bool): If True, returns a model pre-trained on ImageNet
227 | progress (bool): If True, displays a progress bar of the download to stderr
228 | """
229 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
230 | **kwargs)
231 |
232 |
233 | def resnet101(pretrained=False, progress=True, **kwargs):
234 | r"""ResNet-101 model from
235 | `"Deep Residual Learning for Image Recognition" `_
236 | Args:
237 | pretrained (bool): If True, returns a model pre-trained on ImageNet
238 | progress (bool): If True, displays a progress bar of the download to stderr
239 | """
240 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
241 | **kwargs)
242 |
243 |
244 | def resnet152(pretrained=False, progress=True, **kwargs):
245 | r"""ResNet-152 model from
246 | `"Deep Residual Learning for Image Recognition" `_
247 | Args:
248 | pretrained (bool): If True, returns a model pre-trained on ImageNet
249 | progress (bool): If True, displays a progress bar of the download to stderr
250 | """
251 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
252 | **kwargs)
253 |
254 |
255 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
256 | r"""ResNeXt-50 32x4d model from
257 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
258 | Args:
259 | pretrained (bool): If True, returns a model pre-trained on ImageNet
260 | progress (bool): If True, displays a progress bar of the download to stderr
261 | """
262 | kwargs['groups'] = 32
263 | kwargs['width_per_group'] = 4
264 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
265 | pretrained, progress, **kwargs)
266 |
267 |
268 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
269 | r"""ResNeXt-101 32x8d model from
270 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
271 | Args:
272 | pretrained (bool): If True, returns a model pre-trained on ImageNet
273 | progress (bool): If True, displays a progress bar of the download to stderr
274 | """
275 | kwargs['groups'] = 32
276 | kwargs['width_per_group'] = 8
277 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
278 | pretrained, progress, **kwargs)
279 |
280 |
281 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
282 | r"""Wide ResNet-50-2 model from
283 | `"Wide Residual Networks" `_
284 | The model is the same as ResNet except for the bottleneck number of channels
285 | which is twice larger in every block. The number of channels in outer 1x1
286 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
287 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
288 | Args:
289 | pretrained (bool): If True, returns a model pre-trained on ImageNet
290 | progress (bool): If True, displays a progress bar of the download to stderr
291 | """
292 | kwargs['width_per_group'] = 64 * 2
293 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
294 | pretrained, progress, **kwargs)
295 |
296 |
297 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
298 | r"""Wide ResNet-101-2 model from
299 | `"Wide Residual Networks" `_
300 | The model is the same as ResNet except for the bottleneck number of channels
301 | which is twice larger in every block. The number of channels in outer 1x1
302 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
303 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
304 | Args:
305 | pretrained (bool): If True, returns a model pre-trained on ImageNet
306 | progress (bool): If True, displays a progress bar of the download to stderr
307 | """
308 | kwargs['width_per_group'] = 64 * 2
309 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
310 | pretrained, progress, **kwargs)
--------------------------------------------------------------------------------
/archs/mnist/vgg.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def vgg_block(num_convs, in_channels, num_channels):
6 | layers=[]
7 | for i in range(num_convs):
8 | layers+=[nn.Conv2d(in_channels=in_channels, out_channels=num_channels, kernel_size=3, padding=1)]
9 | in_channels=num_channels
10 | layers +=[nn.ReLU()]
11 | layers +=[nn.MaxPool2d(kernel_size=2, stride=2)]
12 | return nn.Sequential(*layers)
13 |
14 | class vgg16(nn.Module):
15 | def __init__(self, num_classes = 10):
16 | super(vgg16,self).__init__()
17 | self.conv_arch=((1,1,64),(1,64,128),(2,128,256),(2,256,512),(2,512,512))
18 | layers=[]
19 | for (num_convs,in_channels,num_channels) in self.conv_arch:
20 | layers+=[vgg_block(num_convs,in_channels,num_channels)]
21 | self.features=nn.Sequential(*layers)
22 | self.dense1 = nn.Linear(512*7*7,4096)
23 | self.drop1 = nn.Dropout(0.5)
24 | self.dense2 = nn.Linear(4096, 4096)
25 | self.drop2 = nn.Dropout(0.5)
26 | self.dense3 = nn.Linear(4096, num_classes)
27 |
28 | def forward(self,x):
29 | x=self.features(x)
30 | x=x.view(-1,512*7*7)
31 | x=self.dense3(self.drop2(F.relu(self.dense2(self.drop1(F.relu(self.dense1(x)))))))
32 | return x
33 |
--------------------------------------------------------------------------------
/combine_plots.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import seaborn as sns
3 | import numpy as np
4 | import os
5 | from tqdm import tqdm
6 |
7 |
8 | DPI = 1200
9 | prune_iterations = 35
10 | arch_types = ["fc1", "lenet5", "resnet18"]
11 | datasets = ["mnist", "fashionmnist", "cifar10", "cifar100"]
12 |
13 |
14 | for arch_type in tqdm(arch_types):
15 | for dataset in tqdm(datasets):
16 | d = np.load(f"{os.getcwd()}/dumps/lt/{arch_type}/{dataset}/lt_compression.dat", allow_pickle=True)
17 | b = np.load(f"{os.getcwd()}/dumps/lt/{arch_type}/{dataset}/lt_bestaccuracy.dat", allow_pickle=True)
18 | c = np.load(f"{os.getcwd()}/dumps/lt/{arch_type}/{dataset}/reinit_bestaccuracy.dat", allow_pickle=True)
19 |
20 | #plt.clf()
21 | #sns.set_style('darkgrid')
22 | #plt.style.use('seaborn-darkgrid')
23 | a = np.arange(prune_iterations)
24 | plt.plot(a, b, c="blue", label="Winning tickets")
25 | plt.plot(a, c, c="red", label="Random reinit")
26 | plt.title(f"Test Accuracy vs Weights % ({arch_type} | {dataset})")
27 | plt.xlabel("Weights %")
28 | plt.ylabel("Test accuracy")
29 | plt.xticks(a, d, rotation ="vertical")
30 | plt.ylim(0,100)
31 | plt.legend()
32 | plt.grid(color="gray")
33 |
34 | plt.savefig(f"{os.getcwd()}/plots/lt/combined_plots/combined_{arch_type}_{dataset}.png", dpi=DPI, bbox_inches='tight')
35 | plt.close()
36 | #print(f"\n combined_{arch_type}_{dataset} plotted!\n")
37 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | # Importing Libraries
2 | import argparse
3 | import copy
4 | import os
5 | import sys
6 | import numpy as np
7 | from tqdm import tqdm
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import torchvision
12 | import torchvision.transforms as transforms
13 | import torchvision.datasets as datasets
14 | import matplotlib.pyplot as plt
15 | import os
16 | from tensorboardX import SummaryWriter
17 | import torchvision.utils as vutils
18 | import seaborn as sns
19 | import torch.nn.init as init
20 | import pickle
21 |
22 | # Custom Libraries
23 | import utils
24 |
25 | # Tensorboard initialization
26 | writer = SummaryWriter()
27 |
28 | # Plotting Style
29 | sns.set_style('darkgrid')
30 |
31 | # Main
32 | def main(args, ITE=0):
33 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34 | reinit = True if args.prune_type=="reinit" else False
35 |
36 | # Data Loader
37 | transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
38 | if args.dataset == "mnist":
39 | traindataset = datasets.MNIST('../data', train=True, download=True,transform=transform)
40 | testdataset = datasets.MNIST('../data', train=False, transform=transform)
41 | from archs.mnist import AlexNet, LeNet5, fc1, vgg, resnet
42 |
43 | elif args.dataset == "cifar10":
44 | traindataset = datasets.CIFAR10('../data', train=True, download=True,transform=transform)
45 | testdataset = datasets.CIFAR10('../data', train=False, transform=transform)
46 | from archs.cifar10 import AlexNet, LeNet5, fc1, vgg, resnet, densenet
47 |
48 | elif args.dataset == "fashionmnist":
49 | traindataset = datasets.FashionMNIST('../data', train=True, download=True,transform=transform)
50 | testdataset = datasets.FashionMNIST('../data', train=False, transform=transform)
51 | from archs.mnist import AlexNet, LeNet5, fc1, vgg, resnet
52 |
53 | elif args.dataset == "cifar100":
54 | traindataset = datasets.CIFAR100('../data', train=True, download=True,transform=transform)
55 | testdataset = datasets.CIFAR100('../data', train=False, transform=transform)
56 | from archs.cifar100 import AlexNet, fc1, LeNet5, vgg, resnet
57 |
58 | # If you want to add extra datasets paste here
59 |
60 | else:
61 | print("\nWrong Dataset choice \n")
62 | exit()
63 |
64 | train_loader = torch.utils.data.DataLoader(traindataset, batch_size=args.batch_size, shuffle=True, num_workers=0,drop_last=False)
65 | #train_loader = cycle(train_loader)
66 | test_loader = torch.utils.data.DataLoader(testdataset, batch_size=args.batch_size, shuffle=False, num_workers=0,drop_last=True)
67 |
68 | # Importing Network Architecture
69 | global model
70 | if args.arch_type == "fc1":
71 | model = fc1.fc1().to(device)
72 | elif args.arch_type == "lenet5":
73 | model = LeNet5.LeNet5().to(device)
74 | elif args.arch_type == "alexnet":
75 | model = AlexNet.AlexNet().to(device)
76 | elif args.arch_type == "vgg16":
77 | model = vgg.vgg16().to(device)
78 | elif args.arch_type == "resnet18":
79 | model = resnet.resnet18().to(device)
80 | elif args.arch_type == "densenet121":
81 | model = densenet.densenet121().to(device)
82 | # If you want to add extra model paste here
83 | else:
84 | print("\nWrong Model choice\n")
85 | exit()
86 |
87 | # Weight Initialization
88 | model.apply(weight_init)
89 |
90 | # Copying and Saving Initial State
91 | initial_state_dict = copy.deepcopy(model.state_dict())
92 | utils.checkdir(f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/")
93 | torch.save(model, f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/initial_state_dict_{args.prune_type}.pth.tar")
94 |
95 | # Making Initial Mask
96 | make_mask(model)
97 |
98 | # Optimizer and Loss
99 | optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-4)
100 | criterion = nn.CrossEntropyLoss() # Default was F.nll_loss
101 |
102 | # Layer Looper
103 | for name, param in model.named_parameters():
104 | print(name, param.size())
105 |
106 | # Pruning
107 | # NOTE First Pruning Iteration is of No Compression
108 | bestacc = 0.0
109 | best_accuracy = 0
110 | ITERATION = args.prune_iterations
111 | comp = np.zeros(ITERATION,float)
112 | bestacc = np.zeros(ITERATION,float)
113 | step = 0
114 | all_loss = np.zeros(args.end_iter,float)
115 | all_accuracy = np.zeros(args.end_iter,float)
116 |
117 |
118 | for _ite in range(args.start_iter, ITERATION):
119 | if not _ite == 0:
120 | prune_by_percentile(args.prune_percent, resample=resample, reinit=reinit)
121 | if reinit:
122 | model.apply(weight_init)
123 | #if args.arch_type == "fc1":
124 | # model = fc1.fc1().to(device)
125 | #elif args.arch_type == "lenet5":
126 | # model = LeNet5.LeNet5().to(device)
127 | #elif args.arch_type == "alexnet":
128 | # model = AlexNet.AlexNet().to(device)
129 | #elif args.arch_type == "vgg16":
130 | # model = vgg.vgg16().to(device)
131 | #elif args.arch_type == "resnet18":
132 | # model = resnet.resnet18().to(device)
133 | #elif args.arch_type == "densenet121":
134 | # model = densenet.densenet121().to(device)
135 | #else:
136 | # print("\nWrong Model choice\n")
137 | # exit()
138 | step = 0
139 | for name, param in model.named_parameters():
140 | if 'weight' in name:
141 | weight_dev = param.device
142 | param.data = torch.from_numpy(param.data.cpu().numpy() * mask[step]).to(weight_dev)
143 | step = step + 1
144 | step = 0
145 | else:
146 | original_initialization(mask, initial_state_dict)
147 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4)
148 | print(f"\n--- Pruning Level [{ITE}:{_ite}/{ITERATION}]: ---")
149 |
150 | # Print the table of Nonzeros in each layer
151 | comp1 = utils.print_nonzeros(model)
152 | comp[_ite] = comp1
153 | pbar = tqdm(range(args.end_iter))
154 |
155 | for iter_ in pbar:
156 |
157 | # Frequency for Testing
158 | if iter_ % args.valid_freq == 0:
159 | accuracy = test(model, test_loader, criterion)
160 |
161 | # Save Weights
162 | if accuracy > best_accuracy:
163 | best_accuracy = accuracy
164 | utils.checkdir(f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/")
165 | torch.save(model,f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/{_ite}_model_{args.prune_type}.pth.tar")
166 |
167 | # Training
168 | loss = train(model, train_loader, optimizer, criterion)
169 | all_loss[iter_] = loss
170 | all_accuracy[iter_] = accuracy
171 |
172 | # Frequency for Printing Accuracy and Loss
173 | if iter_ % args.print_freq == 0:
174 | pbar.set_description(
175 | f'Train Epoch: {iter_}/{args.end_iter} Loss: {loss:.6f} Accuracy: {accuracy:.2f}% Best Accuracy: {best_accuracy:.2f}%')
176 |
177 | writer.add_scalar('Accuracy/test', best_accuracy, comp1)
178 | bestacc[_ite]=best_accuracy
179 |
180 | # Plotting Loss (Training), Accuracy (Testing), Iteration Curve
181 | #NOTE Loss is computed for every iteration while Accuracy is computed only for every {args.valid_freq} iterations. Therefore Accuracy saved is constant during the uncomputed iterations.
182 | #NOTE Normalized the accuracy to [0,100] for ease of plotting.
183 | plt.plot(np.arange(1,(args.end_iter)+1), 100*(all_loss - np.min(all_loss))/np.ptp(all_loss).astype(float), c="blue", label="Loss")
184 | plt.plot(np.arange(1,(args.end_iter)+1), all_accuracy, c="red", label="Accuracy")
185 | plt.title(f"Loss Vs Accuracy Vs Iterations ({args.dataset},{args.arch_type})")
186 | plt.xlabel("Iterations")
187 | plt.ylabel("Loss and Accuracy")
188 | plt.legend()
189 | plt.grid(color="gray")
190 | utils.checkdir(f"{os.getcwd()}/plots/lt/{args.arch_type}/{args.dataset}/")
191 | plt.savefig(f"{os.getcwd()}/plots/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_LossVsAccuracy_{comp1}.png", dpi=1200)
192 | plt.close()
193 |
194 | # Dump Plot values
195 | utils.checkdir(f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/")
196 | all_loss.dump(f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_all_loss_{comp1}.dat")
197 | all_accuracy.dump(f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_all_accuracy_{comp1}.dat")
198 |
199 | # Dumping mask
200 | utils.checkdir(f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/")
201 | with open(f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_mask_{comp1}.pkl", 'wb') as fp:
202 | pickle.dump(mask, fp)
203 |
204 | # Making variables into 0
205 | best_accuracy = 0
206 | all_loss = np.zeros(args.end_iter,float)
207 | all_accuracy = np.zeros(args.end_iter,float)
208 |
209 | # Dumping Values for Plotting
210 | utils.checkdir(f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/")
211 | comp.dump(f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_compression.dat")
212 | bestacc.dump(f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_bestaccuracy.dat")
213 |
214 | # Plotting
215 | a = np.arange(args.prune_iterations)
216 | plt.plot(a, bestacc, c="blue", label="Winning tickets")
217 | plt.title(f"Test Accuracy vs Unpruned Weights Percentage ({args.dataset},{args.arch_type})")
218 | plt.xlabel("Unpruned Weights Percentage")
219 | plt.ylabel("test accuracy")
220 | plt.xticks(a, comp, rotation ="vertical")
221 | plt.ylim(0,100)
222 | plt.legend()
223 | plt.grid(color="gray")
224 | utils.checkdir(f"{os.getcwd()}/plots/lt/{args.arch_type}/{args.dataset}/")
225 | plt.savefig(f"{os.getcwd()}/plots/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_AccuracyVsWeights.png", dpi=1200)
226 | plt.close()
227 |
228 | # Function for Training
229 | def train(model, train_loader, optimizer, criterion):
230 | EPS = 1e-6
231 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
232 | model.train()
233 | for batch_idx, (imgs, targets) in enumerate(train_loader):
234 | optimizer.zero_grad()
235 | #imgs, targets = next(train_loader)
236 | imgs, targets = imgs.to(device), targets.to(device)
237 | output = model(imgs)
238 | train_loss = criterion(output, targets)
239 | train_loss.backward()
240 |
241 | # Freezing Pruned weights by making their gradients Zero
242 | for name, p in model.named_parameters():
243 | if 'weight' in name:
244 | tensor = p.data.cpu().numpy()
245 | grad_tensor = p.grad.data.cpu().numpy()
246 | grad_tensor = np.where(tensor < EPS, 0, grad_tensor)
247 | p.grad.data = torch.from_numpy(grad_tensor).to(device)
248 | optimizer.step()
249 | return train_loss.item()
250 |
251 | # Function for Testing
252 | def test(model, test_loader, criterion):
253 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
254 | model.eval()
255 | test_loss = 0
256 | correct = 0
257 | with torch.no_grad():
258 | for data, target in test_loader:
259 | data, target = data.to(device), target.to(device)
260 | output = model(data)
261 | test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
262 | pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
263 | correct += pred.eq(target.data.view_as(pred)).sum().item()
264 | test_loss /= len(test_loader.dataset)
265 | accuracy = 100. * correct / len(test_loader.dataset)
266 | return accuracy
267 |
268 | # Prune by Percentile module
269 | def prune_by_percentile(percent, resample=False, reinit=False,**kwargs):
270 | global step
271 | global mask
272 | global model
273 |
274 | # Calculate percentile value
275 | step = 0
276 | for name, param in model.named_parameters():
277 |
278 | # We do not prune bias term
279 | if 'weight' in name:
280 | tensor = param.data.cpu().numpy()
281 | alive = tensor[np.nonzero(tensor)] # flattened array of nonzero values
282 | percentile_value = np.percentile(abs(alive), percent)
283 |
284 | # Convert Tensors to numpy and calculate
285 | weight_dev = param.device
286 | new_mask = np.where(abs(tensor) < percentile_value, 0, mask[step])
287 |
288 | # Apply new weight and mask
289 | param.data = torch.from_numpy(tensor * new_mask).to(weight_dev)
290 | mask[step] = new_mask
291 | step += 1
292 | step = 0
293 |
294 | # Function to make an empty mask of the same size as the model
295 | def make_mask(model):
296 | global step
297 | global mask
298 | step = 0
299 | for name, param in model.named_parameters():
300 | if 'weight' in name:
301 | step = step + 1
302 | mask = [None]* step
303 | step = 0
304 | for name, param in model.named_parameters():
305 | if 'weight' in name:
306 | tensor = param.data.cpu().numpy()
307 | mask[step] = np.ones_like(tensor)
308 | step = step + 1
309 | step = 0
310 |
311 | def original_initialization(mask_temp, initial_state_dict):
312 | global model
313 |
314 | step = 0
315 | for name, param in model.named_parameters():
316 | if "weight" in name:
317 | weight_dev = param.device
318 | param.data = torch.from_numpy(mask_temp[step] * initial_state_dict[name].cpu().numpy()).to(weight_dev)
319 | step = step + 1
320 | if "bias" in name:
321 | param.data = initial_state_dict[name]
322 | step = 0
323 |
324 | # Function for Initialization
325 | def weight_init(m):
326 | '''
327 | Usage:
328 | model = Model()
329 | model.apply(weight_init)
330 | '''
331 | if isinstance(m, nn.Conv1d):
332 | init.normal_(m.weight.data)
333 | if m.bias is not None:
334 | init.normal_(m.bias.data)
335 | elif isinstance(m, nn.Conv2d):
336 | init.xavier_normal_(m.weight.data)
337 | if m.bias is not None:
338 | init.normal_(m.bias.data)
339 | elif isinstance(m, nn.Conv3d):
340 | init.xavier_normal_(m.weight.data)
341 | if m.bias is not None:
342 | init.normal_(m.bias.data)
343 | elif isinstance(m, nn.ConvTranspose1d):
344 | init.normal_(m.weight.data)
345 | if m.bias is not None:
346 | init.normal_(m.bias.data)
347 | elif isinstance(m, nn.ConvTranspose2d):
348 | init.xavier_normal_(m.weight.data)
349 | if m.bias is not None:
350 | init.normal_(m.bias.data)
351 | elif isinstance(m, nn.ConvTranspose3d):
352 | init.xavier_normal_(m.weight.data)
353 | if m.bias is not None:
354 | init.normal_(m.bias.data)
355 | elif isinstance(m, nn.BatchNorm1d):
356 | init.normal_(m.weight.data, mean=1, std=0.02)
357 | init.constant_(m.bias.data, 0)
358 | elif isinstance(m, nn.BatchNorm2d):
359 | init.normal_(m.weight.data, mean=1, std=0.02)
360 | init.constant_(m.bias.data, 0)
361 | elif isinstance(m, nn.BatchNorm3d):
362 | init.normal_(m.weight.data, mean=1, std=0.02)
363 | init.constant_(m.bias.data, 0)
364 | elif isinstance(m, nn.Linear):
365 | init.xavier_normal_(m.weight.data)
366 | init.normal_(m.bias.data)
367 | elif isinstance(m, nn.LSTM):
368 | for param in m.parameters():
369 | if len(param.shape) >= 2:
370 | init.orthogonal_(param.data)
371 | else:
372 | init.normal_(param.data)
373 | elif isinstance(m, nn.LSTMCell):
374 | for param in m.parameters():
375 | if len(param.shape) >= 2:
376 | init.orthogonal_(param.data)
377 | else:
378 | init.normal_(param.data)
379 | elif isinstance(m, nn.GRU):
380 | for param in m.parameters():
381 | if len(param.shape) >= 2:
382 | init.orthogonal_(param.data)
383 | else:
384 | init.normal_(param.data)
385 | elif isinstance(m, nn.GRUCell):
386 | for param in m.parameters():
387 | if len(param.shape) >= 2:
388 | init.orthogonal_(param.data)
389 | else:
390 | init.normal_(param.data)
391 |
392 |
393 | if __name__=="__main__":
394 |
395 | #from gooey import Gooey
396 | #@Gooey
397 |
398 | # Arguement Parser
399 | parser = argparse.ArgumentParser()
400 | parser.add_argument("--lr",default= 1.2e-3, type=float, help="Learning rate")
401 | parser.add_argument("--batch_size", default=60, type=int)
402 | parser.add_argument("--start_iter", default=0, type=int)
403 | parser.add_argument("--end_iter", default=100, type=int)
404 | parser.add_argument("--print_freq", default=1, type=int)
405 | parser.add_argument("--valid_freq", default=1, type=int)
406 | parser.add_argument("--resume", action="store_true")
407 | parser.add_argument("--prune_type", default="lt", type=str, help="lt | reinit")
408 | parser.add_argument("--gpu", default="0", type=str)
409 | parser.add_argument("--dataset", default="mnist", type=str, help="mnist | cifar10 | fashionmnist | cifar100")
410 | parser.add_argument("--arch_type", default="fc1", type=str, help="fc1 | lenet5 | alexnet | vgg16 | resnet18 | densenet121")
411 | parser.add_argument("--prune_percent", default=10, type=int, help="Pruning percent")
412 | parser.add_argument("--prune_iterations", default=35, type=int, help="Pruning iterations count")
413 |
414 |
415 | args = parser.parse_args()
416 |
417 |
418 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
419 | os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu
420 |
421 |
422 | #FIXME resample
423 | resample = False
424 |
425 | # Looping Entire process
426 | #for i in range(0, 5):
427 | main(args, ITE=1)
428 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | cycler==0.10.0
2 | kiwisolver==1.1.0
3 | matplotlib==3.1.1
4 | numpy==1.17.2
5 | pandas==0.25.1
6 | Pillow==6.2.0
7 | protobuf==3.9.2
8 | pyparsing==2.4.2
9 | python-dateutil==2.8.0
10 | pytz==2019.2
11 | scipy==1.3.1
12 | seaborn==0.9.0
13 | six==1.12.0
14 | tensorboardX==1.8
15 | torch==1.2.0
16 | torchvision==0.4.0
17 | tqdm==4.36.1
18 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | #ANCHOR Libraries
2 | import numpy as np
3 | import torch
4 | import os
5 | import seaborn as sns
6 | import matplotlib.pyplot as plt
7 | import copy
8 |
9 | #ANCHOR Print table of zeros and non-zeros count
10 | def print_nonzeros(model):
11 | nonzero = total = 0
12 | for name, p in model.named_parameters():
13 | tensor = p.data.cpu().numpy()
14 | nz_count = np.count_nonzero(tensor)
15 | total_params = np.prod(tensor.shape)
16 | nonzero += nz_count
17 | total += total_params
18 | print(f'{name:20} | nonzeros = {nz_count:7} / {total_params:7} ({100 * nz_count / total_params:6.2f}%) | total_pruned = {total_params - nz_count :7} | shape = {tensor.shape}')
19 | print(f'alive: {nonzero}, pruned : {total - nonzero}, total: {total}, Compression rate : {total/nonzero:10.2f}x ({100 * (total-nonzero) / total:6.2f}% pruned)')
20 | return (round((nonzero/total)*100,1))
21 |
22 | def original_initialization(mask_temp, initial_state_dict):
23 | global model
24 |
25 | step = 0
26 | for name, param in model.named_parameters():
27 | if "weight" in name:
28 | weight_dev = param.device
29 | param.data = torch.from_numpy(mask_temp[step] * initial_state_dict[name].cpu().numpy()).to(weight_dev)
30 | step = step + 1
31 | if "bias" in name:
32 | param.data = initial_state_dict[name]
33 | step = 0
34 |
35 |
36 |
37 |
38 | #ANCHOR Checks of the directory exist and if not, creates a new directory
39 | def checkdir(directory):
40 | if not os.path.exists(directory):
41 | os.makedirs(directory)
42 |
43 | #FIXME
44 | def plot_train_test_stats(stats,
45 | epoch_num,
46 | key1='train',
47 | key2='test',
48 | key1_label=None,
49 | key2_label=None,
50 | xlabel=None,
51 | ylabel=None,
52 | title=None,
53 | yscale=None,
54 | ylim_bottom=None,
55 | ylim_top=None,
56 | savefig=None,
57 | sns_style='darkgrid'
58 | ):
59 |
60 | assert len(stats[key1]) == epoch_num, "len(stats['{}'])({}) != epoch_num({})".format(key1, len(stats[key1]), epoch_num)
61 | assert len(stats[key2]) == epoch_num, "len(stats['{}'])({}) != epoch_num({})".format(key2, len(stats[key2]), epoch_num)
62 |
63 | plt.clf()
64 | sns.set_style(sns_style)
65 | x_ticks = np.arange(epoch_num)
66 |
67 | plt.plot(x_ticks, stats[key1], label=key1_label)
68 | plt.plot(x_ticks, stats[key2], label=key2_label)
69 |
70 | if xlabel is not None:
71 | plt.xlabel(xlabel)
72 | if ylabel is not None:
73 | plt.ylabel(ylabel)
74 |
75 | if title is not None:
76 | plt.title(title)
77 |
78 | if yscale is not None:
79 | plt.yscale(yscale)
80 |
81 | if ylim_bottom is not None:
82 | plt.ylim(bottom=ylim_bottom)
83 | if ylim_top is not None:
84 | plt.ylim(top=ylim_top)
85 |
86 | plt.legend(bbox_to_anchor=(1.04,0.5), loc="center left", borderaxespad=0, fancybox=True)
87 |
88 | if savefig is not None:
89 | plt.savefig(savefig, bbox_inches='tight')
90 | else:
91 | plt.show()
--------------------------------------------------------------------------------