├── LICENSE
├── README.md
├── extension
└── extension.cpp
├── mpl
├── __init__.py
├── autograd
│ ├── __init__.py
│ └── functions.py
├── models
│ ├── __init__.py
│ ├── alexnet.py
│ ├── base_model.py
│ ├── densenet.py
│ ├── googlenet.py
│ ├── inception.py
│ ├── leaf.py
│ ├── lenet.py
│ ├── resnet.py
│ ├── squeezenet.py
│ ├── utils.py
│ └── vgg.py
├── nn
│ ├── __init__.py
│ ├── conv2d.py
│ └── linear.py
├── optim
│ ├── __init__.py
│ └── sgd.py
└── utils
│ ├── __init__.py
│ └── save_load.py
├── requirements.txt
└── setup.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Yuang Jiang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ModelPruningLibrary (Updated 3/3/2021)
2 | ## Plan for the Next Version
3 | We plan to further complete ModelPruningLibrary with the following:
4 | 1. c++ implementation conv2d with groups > 1 and depthwise conv2d, as well as missing models in `torchvision.models`.
5 | 2. more optimizers as in `torch.optim`.
6 | 3. well-known pruning algorithms such as SNIP [[1]](#1).
7 | 4. we also plan to implement tools for federated learning (e.g. well-known datasets for FL).
8 |
9 | Suggestions/comments are welcome!
10 |
11 | ## Description
12 | This is a PyTorch-based library that implements
13 | 1. model pruning: various magnitude-based pruning algorithms (by percentage, random pruning, etc.);
14 | 2. conv2d module with **sparse kernels** as well as fully-connected module implementations;
15 | 3. SGD optimizer designed for our sparse modules;
16 | 4. two types of save-load functionalities for sparse tensors, determined automatically according to tensor's density (fraction of non-zero entries). If density < 1/32, we save value-index pairs, and otherwise, we use bitmap to save sparse tensors.
17 |
18 | It is originally from the following paper:
19 | - Jiang, Y., Wang, S., Valls, V., Ko, B. J., Lee, W. H., Leung, K. K. & Tassiulas, L. (2019). [Model pruning enables efficient federated learning on edge devices](https://arxiv.org/pdf/1909.12326.pdf). arXiv preprint arXiv:1909.12326.
20 |
21 | When using this code for scientific publications, please kindly cite the above paper.
22 |
23 | The library consists of the following components:
24 | * **setup.py**: installs the c++ extension and `mpl` (model pruning library) module
25 | * **extension**: the `extension.cpp` c++ file extends the current PyTorch implementation with **sparse kernels** (the installed module is called `sparse_conv2d`). However, please note that we only extend PyTorch's slow, cpu version of conv2d forward/backward with no groups and dilation = 1 (see PyTorch's c++ code [here](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/ConvolutionMM2d.cpp)). In other words, we do not use acceleration packages such as MKL (which are not available on Raspberry Pis on which our paper experimented). Do not compare the speed of our implementation with the acceleration packages.
26 | * **autograd**: the `AddmmFunction` and `SparseConv2dFunction` functions provide the forward and backward functions to our customized modules.
27 | * **models**: this is similar to `torchvision`'s implementations ([link](https://github.com/pytorch/vision/tree/master/torchvision/models)). Note that we do not implement mnasnet, mobilenet and shufflenetv2 since they have groups > 1 in the models. We also implement popular models such as models in [leaf](https://github.com/TalwalkarLab/leaf/tree/master/models).
28 | * **nn**: `conv2d.py` and `linear.py` implement the prunable modules and their `to_sparse` functionalities.
29 | * **optim**: implements a compatible version of SGD optimizer.
30 |
31 | Our code has been validated on Ubuntu 20.04. Contact me if you encounter any issues!
32 |
33 | ## Examples
34 |
35 | ### Setup Library:
36 | ```shell
37 | sudo python3 setup.py install
38 | ```
39 |
40 |
41 |
42 | ### Importing and Using Model
43 | ```python3
44 | from mpl.models import conv2
45 |
46 | model = conv2()
47 | print(model)
48 | ```
49 |
50 | output:
51 | ```
52 | Conv2(
53 | (features): Sequential(
54 | (0): DenseConv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
55 | (1): ReLU(inplace=True)
56 | (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
57 | (3): DenseConv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
58 | (4): ReLU(inplace=True)
59 | (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
60 | )
61 | (classifier): Sequential(
62 | (0): DenseLinear(in_features=3136, out_features=2048, bias=True)
63 | (1): ReLU(inplace=True)
64 | (2): DenseLinear(in_features=2048, out_features=62, bias=True)
65 | )
66 | )
67 | ```
68 |
69 | ### Model Pruning:
70 | ```python3
71 | import mpl.models
72 |
73 | model = mpl.models.conv2()
74 | print("Before pruning:")
75 | model.calc_num_prunable_params(display=True)
76 |
77 | print("After pruning:")
78 | model.prune_by_pct([0.1, 0, None, 0.9])
79 | model.calc_num_prunable_params(display=True)
80 | ```
81 | output:
82 | ```
83 | Before pruning:
84 | Layer name: features.0. remaining/all: 832/832 = 1.0
85 | Layer name: features.3. remaining/all: 51264/51264 = 1.0
86 | Layer name: classifier.0. remaining/all: 6424576/6424576 = 1.0
87 | Layer name: classifier.2. remaining/all: 127038/127038 = 1.0
88 | Total: remaining/all: 6603710/6603710 = 1.0
89 | After pruning:
90 | Layer name: features.0. remaining/all: 752/832 = 0.9038461538461539
91 | Layer name: features.3. remaining/all: 51264/51264 = 1.0
92 | Layer name: classifier.0. remaining/all: 6424576/6424576 = 1.0
93 | Layer name: classifier.2. remaining/all: 12760/127038 = 0.10044238731718069
94 | Total: remaining/all: 6489352/6603710 = 0.9826827646883343
95 | ```
96 | ### Dense to Sparse Conversion:
97 | ```python3
98 | from mpl.models import conv2
99 |
100 | model = conv2()
101 | print(model.to_sparse())
102 | ```
103 | output:
104 | ```
105 | Conv2(
106 | (features): Sequential(
107 | (0): SparseConv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=True)
108 | (1): ReLU(inplace=True)
109 | (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
110 | (3): SparseConv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=True)
111 | (4): ReLU(inplace=True)
112 | (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
113 | )
114 | (classifier): Sequential(
115 | (0): SparseLinear(in_features=3136, out_features=2048, bias=True)
116 | (1): ReLU(inplace=True)
117 | (2): SparseLinear(in_features=2048, out_features=62, bias=True)
118 | )
119 | )
120 | ```
121 | Note that `DenseConv2d` and `DenseLinear` layers are converted to `SparseConv2d` and `SparseLinear` layers, respectively.
122 |
123 | ### SGD Training with a Sparse Model:
124 | ```python3
125 | from mpl.models import conv2
126 | from mpl.optim import SGD
127 | import torch
128 |
129 | inp = torch.rand(size=(10, 1, 28, 28))
130 | model = conv2().to_sparse()
131 | optimizer = SGD(model.parameters(), lr=0.01)
132 | optimizer.zero_grad()
133 | model(inp).sum().backward()
134 | optimizer.step()
135 | ```
136 |
137 | ### Save/Load a Tensor:
138 | ```python3
139 | from mpl.utils.save_load import save, load
140 | import torch
141 |
142 | torch.manual_seed(0)
143 | x = torch.randn(size=(1000, 1000))
144 | mask = torch.rand_like(x) <= 0.5
145 | x = (x * mask).to_sparse()
146 | save(x, "sparse_x.pt")
147 |
148 | x_loaded = load("sparse_x.pt")
149 | ```
150 | Using our implementation, the size of `sparse_x.pt` file is 2.1 MB, while the default `torch.save` results in a file size of 10 MB (4.8x).
151 |
152 | ## References
153 | [1]
154 | Lee, Namhoon, Thalaiyasingam Ajanthan, and Philip HS Torr. "Snip: Single-shot network pruning based on connection sensitivity." arXiv preprint arXiv:1810.02340 (2018).
155 |
--------------------------------------------------------------------------------
/extension/extension.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include
4 | #include
5 | #include
6 |
7 | #define Tensor torch::Tensor
8 | #define IntArrayRef at::IntArrayRef
9 |
10 | template
11 | static void unfolded2d_acc(
12 | scalar_t* finput_data,
13 | scalar_t* input_data,
14 | int64_t kH,
15 | int64_t kW,
16 | int64_t dH,
17 | int64_t dW,
18 | int64_t padH,
19 | int64_t padW,
20 | int64_t n_input_plane,
21 | int64_t input_height,
22 | int64_t input_width,
23 | int64_t output_height,
24 | int64_t output_width) {
25 | #pragma omp parallel for
26 | for (auto nip = 0; nip < n_input_plane; nip++) {
27 | int64_t kw, kh, y, x;
28 | int64_t ix, iy;
29 | for (kh = 0; kh < kH; kh++) {
30 | for (kw = 0; kw < kW; kw++) {
31 | scalar_t* src = finput_data +
32 | nip * ((size_t)kH * kW * output_height * output_width) +
33 | kh * ((size_t)kW * output_height * output_width) +
34 | kw * ((size_t)output_height * output_width);
35 | scalar_t* dst =
36 | input_data + nip * ((size_t)input_height * input_width);
37 | if (padW > 0 || padH > 0) {
38 | int64_t lpad, rpad;
39 | for (y = 0; y < output_height; y++) {
40 | iy = (int64_t)y * dH - padH + kh;
41 | if (iy < 0 || iy >= input_height) {
42 | } else {
43 | for (x = 0; x < output_width; x++) {
44 | ix = (int64_t)x * dW - padW + kw;
45 | if (ix < 0 || ix >= input_width) {
46 | } else {
47 | scalar_t* dst_slice = dst + (size_t)iy * input_width + ix;
48 | *dst_slice = *dst_slice + src[(size_t)y * output_width + x];
49 | }
50 | }
51 | }
52 | }
53 | } else {
54 | for (y = 0; y < output_height; y++) {
55 | iy = (int64_t)y * dH + kh;
56 | ix = 0 + kw;
57 | for (x = 0; x < output_width; x++) {
58 | scalar_t* dst_slice =
59 | dst + (size_t)iy * input_width + ix + x * dW;
60 | *dst_slice = *dst_slice + src[(size_t)y * output_width + x];
61 | }
62 | }
63 | }
64 | }
65 | }
66 | }
67 | }
68 |
69 | void unfolded2d_acc_kernel(
70 | Tensor& finput,
71 | Tensor& input,
72 | int64_t kH,
73 | int64_t kW,
74 | int64_t dH,
75 | int64_t dW,
76 | int64_t padH,
77 | int64_t padW,
78 | int64_t n_input_plane,
79 | int64_t input_height,
80 | int64_t input_width,
81 | int64_t output_height,
82 | int64_t output_width) {
83 | // This function assumes that
84 | // output_height*dH does not overflow a int64_t
85 | // output_width*dW does not overflow a int64_t
86 |
87 | auto input_data = (float*) input.data_ptr();
88 | auto finput_data =(float*) finput.data_ptr();
89 |
90 | unfolded2d_acc(
91 | finput_data,
92 | input_data,
93 | kH,
94 | kW,
95 | dH,
96 | dW,
97 | padH,
98 | padW,
99 | n_input_plane,
100 | input_height,
101 | input_width,
102 | output_height,
103 | output_width);
104 | }
105 |
106 | template
107 | static void unfolded2d_copy(
108 | scalar_t* input_data,
109 | scalar_t* finput_data,
110 | int64_t kH,
111 | int64_t kW,
112 | int64_t dH,
113 | int64_t dW,
114 | int64_t padH,
115 | int64_t padW,
116 | int64_t n_input_plane,
117 | int64_t input_height,
118 | int64_t input_width,
119 | int64_t output_height,
120 | int64_t output_width) {
121 |
122 | auto start = 0;
123 | auto end = (int64_t)n_input_plane * kH * kW;
124 | #pragma omp parallel for
125 | for (auto k = start; k < end; k++) {
126 | int64_t nip = k / (kH * kW);
127 | int64_t rest = k % (kH * kW);
128 | int64_t kh = rest / kW;
129 | int64_t kw = rest % kW;
130 | int64_t x, y;
131 | int64_t ix, iy;
132 | scalar_t* dst = finput_data +
133 | nip * ((size_t)kH * kW * output_height * output_width) +
134 | kh * ((size_t)kW * output_height * output_width) +
135 | kw * ((size_t)output_height * output_width);
136 | scalar_t* src =
137 | input_data + nip * ((size_t)input_height * input_width);
138 | if (padW > 0 || padH > 0) {
139 | int64_t lpad, rpad;
140 | for (y = 0; y < output_height; y++) {
141 | iy = (int64_t)y * dH - padH + kh;
142 | if (iy < 0 || iy >= input_height) {
143 | memset(
144 | dst + (size_t)y * output_width,
145 | 0,
146 | sizeof(scalar_t) * output_width);
147 | } else {
148 | if (dW == 1) {
149 | ix = 0 - padW + kw;
150 | lpad = std::max(0, padW - kw);
151 | rpad = std::max(0, padW - (kW - kw - 1));
152 | if (output_width - rpad - lpad <= 0) {
153 | memset(
154 | dst + (size_t)y * output_width,
155 | 0,
156 | sizeof(scalar_t) * output_width);
157 | } else {
158 | if (lpad > 0)
159 | memset(
160 | dst + (size_t)y * output_width,
161 | 0,
162 | sizeof(scalar_t) * lpad);
163 | memcpy(
164 | dst + (size_t)y * output_width + lpad,
165 | src + (size_t)iy * input_width + ix + lpad,
166 | sizeof(scalar_t) * (output_width - rpad - lpad));
167 | if (rpad > 0)
168 | memset(
169 | dst + (size_t)y * output_width + output_width - rpad,
170 | 0,
171 | sizeof(scalar_t) * rpad);
172 | }
173 | } else {
174 | for (x = 0; x < output_width; x++) {
175 | ix = (int64_t)x * dW - padW + kw;
176 | if (ix < 0 || ix >= input_width)
177 | memset(
178 | dst + (size_t)y * output_width + x,
179 | 0,
180 | sizeof(scalar_t) * 1);
181 | else
182 | memcpy(
183 | dst + (size_t)y * output_width + x,
184 | src + (size_t)iy * input_width + ix,
185 | sizeof(scalar_t) * (1));
186 | }
187 | }
188 | }
189 | }
190 | } else {
191 | for (y = 0; y < output_height; y++) {
192 | iy = (int64_t)y * dH + kh;
193 | ix = 0 + kw;
194 | if (dW == 1)
195 | memcpy(
196 | dst + (size_t)y * output_width,
197 | src + (size_t)iy * input_width + ix,
198 | sizeof(scalar_t) * output_width);
199 | else {
200 | for (x = 0; x < output_width; x++)
201 | memcpy(
202 | dst + (size_t)y * output_width + x,
203 | src + (size_t)iy * input_width + ix + (int64_t)x * dW,
204 | sizeof(scalar_t) * (1));
205 | }
206 | }
207 | }
208 | }
209 | }
210 |
211 | void unfolded2d_copy_kernel(
212 | Tensor& finput,
213 | Tensor& input,
214 | int64_t kH,
215 | int64_t kW,
216 | int64_t dH,
217 | int64_t dW,
218 | int64_t padH,
219 | int64_t padW,
220 | int64_t n_input_plane,
221 | int64_t input_height,
222 | int64_t input_width,
223 | int64_t output_height,
224 | int64_t output_width) {
225 |
226 | auto input_data = (float*) input.data_ptr();
227 | auto finput_data =(float*) finput.data_ptr();
228 |
229 | unfolded2d_copy(
230 | input_data,
231 | finput_data,
232 | kH,
233 | kW,
234 | dH,
235 | dW,
236 | padH,
237 | padW,
238 | n_input_plane,
239 | input_height,
240 | input_width,
241 | output_height,
242 | output_width);
243 | }
244 |
245 | static void slow_conv2d_update_output_frame(
246 | Tensor& input,
247 | Tensor& output,
248 | const Tensor& weight,
249 | const Tensor& bias,
250 | Tensor& finput,
251 | int64_t kernel_height,
252 | int64_t kernel_width,
253 | int64_t stride_height,
254 | int64_t stride_width,
255 | int64_t pad_height,
256 | int64_t pad_width,
257 | int64_t n_input_plane,
258 | int64_t input_height,
259 | int64_t input_width,
260 | int64_t n_output_plane,
261 | int64_t output_height,
262 | int64_t output_width) {
263 |
264 | unfolded2d_copy_kernel(
265 | finput,
266 | input,
267 | kernel_height,
268 | kernel_width,
269 | stride_height,
270 | stride_width,
271 | pad_height,
272 | pad_width,
273 | n_input_plane,
274 | input_height,
275 | input_width,
276 | output_height,
277 | output_width);
278 |
279 |
280 | auto output2d =
281 | output.reshape({n_output_plane, output_height * output_width});
282 | if (bias.defined()) {
283 | for (int64_t i = 0; i < n_output_plane; i++) {
284 | output[i].fill_(bias[i].item());
285 | }
286 | } else {
287 | output.zero_();
288 | }
289 | output2d.addmm_(weight, finput, 1, 1);
290 | }
291 |
292 | std::tuple slow_conv2d_forward_out_cpu(
293 | Tensor& output,
294 | Tensor& finput,
295 | Tensor& fgrad_input,
296 | const Tensor& self,
297 | const Tensor& weight_,
298 | IntArrayRef kernel_size,
299 | const Tensor& bias,
300 | IntArrayRef stride,
301 | IntArrayRef padding) {
302 | const int64_t kernel_height = kernel_size[0];
303 | const int64_t kernel_width = kernel_size[1];
304 | const int64_t pad_height = padding[0];
305 | const int64_t pad_width = padding[1];
306 | const int64_t stride_height = stride[0];
307 | const int64_t stride_width = stride[1];
308 |
309 |
310 | assert(weight_.dim()==2);
311 | const Tensor weight_2d = weight_;
312 |
313 |
314 | const Tensor input = self.contiguous();
315 | const int64_t ndim = input.dim();
316 | const int64_t dim_planes = 1;
317 | const int64_t dim_height = 2;
318 | const int64_t dim_width = 3;
319 |
320 | const int64_t n_input_plane = input.size(dim_planes);
321 | const int64_t input_height = input.size(dim_height);
322 | const int64_t input_width = input.size(dim_width);
323 | const int64_t n_output_plane = weight_2d.size(0);
324 | const int64_t output_height =
325 | (input_height + 2 * pad_height - kernel_height) / stride_height + 1;
326 | const int64_t output_width =
327 | (input_width + 2 * pad_width - kernel_width) / stride_width + 1;
328 |
329 | const int64_t batch_size = input.size(0);
330 |
331 |
332 | finput.resize_({batch_size,
333 | n_input_plane * kernel_height * kernel_width,
334 | output_height * output_width});
335 | output.resize_({batch_size, n_output_plane, output_height, output_width});
336 |
337 | at::NoGradGuard no_grad;
338 | at::AutoNonVariableTypeMode non_variable_type_mode(true);
339 |
340 | #pragma omp parallel for
341 | for (int64_t t = 0; t < batch_size; t++) {
342 | Tensor input_t = input[t];
343 | Tensor output_t = output[t];
344 | Tensor finput_t = finput[t];
345 | slow_conv2d_update_output_frame(
346 | input_t,
347 | output_t,
348 | weight_2d,
349 | bias,
350 | finput_t,
351 | kernel_height,
352 | kernel_width,
353 | stride_height,
354 | stride_width,
355 | pad_height,
356 | pad_width,
357 | n_input_plane,
358 | input_height,
359 | input_width,
360 | n_output_plane,
361 | output_height,
362 | output_width);
363 | }
364 |
365 | return std::tuple(output, finput, fgrad_input);
366 | }
367 |
368 |
369 | std::tuple slow_conv2d_forward_cpu(
370 | const Tensor& self,
371 | const Tensor& weight,
372 | IntArrayRef kernel_size,
373 | const Tensor& bias,
374 | IntArrayRef stride,
375 | IntArrayRef padding) {
376 |
377 | auto output = at::empty({0}, self.options());
378 | auto finput = at::empty({0}, self.options());
379 | auto fgrad_input = at::empty({0}, self.options());
380 |
381 | slow_conv2d_forward_out_cpu(
382 | output,
383 | finput,
384 | fgrad_input,
385 | self,
386 | weight,
387 | kernel_size,
388 | bias,
389 | stride,
390 | padding);
391 | return std::make_tuple(output, finput, fgrad_input);
392 | }
393 |
394 | void slow_conv2d_backward_update_grad_input_frame(
395 | Tensor& grad_input,
396 | const Tensor& grad_output,
397 | const Tensor& weight,
398 | Tensor& fgrad_input,
399 | int64_t kernel_height,
400 | int64_t kernel_width,
401 | int64_t stride_height,
402 | int64_t stride_width,
403 | int64_t pad_height,
404 | int64_t pad_width) {
405 | auto grad_output_2d = grad_output.reshape(
406 | {grad_output.size(0), grad_output.size(1) * grad_output.size(2)});
407 | fgrad_input.addmm_(weight, grad_output_2d, 0, 1);
408 | grad_input.zero_();
409 | unfolded2d_acc_kernel(
410 | fgrad_input,
411 | grad_input,
412 | kernel_height,
413 | kernel_width,
414 | stride_height,
415 | stride_width,
416 | pad_height,
417 | pad_width,
418 | grad_input.size(0),
419 | grad_input.size(1),
420 | grad_input.size(2),
421 | grad_output.size(1),
422 | grad_output.size(2));
423 | }
424 |
425 | void slow_conv2d_backward_out_cpu_template(
426 | Tensor& grad_input,
427 | const Tensor& grad_output_,
428 | const Tensor& input_,
429 | const Tensor& weight_,
430 | const Tensor& finput,
431 | Tensor& fgrad_input,
432 | IntArrayRef kernel_size,
433 | IntArrayRef stride,
434 | IntArrayRef padding) {
435 | const int64_t kernel_height = kernel_size[0];
436 | const int64_t kernel_width = kernel_size[1];
437 | const int64_t pad_height = padding[0];
438 | const int64_t pad_width = padding[1];
439 | const int64_t stride_height = stride[0];
440 | const int64_t stride_width = stride[1];
441 |
442 | assert(weight_.dim() == 2);
443 | const Tensor weight = weight_;
444 |
445 |
446 | const Tensor input = input_.contiguous();
447 | const Tensor grad_output = grad_output_.contiguous();
448 | grad_input.resize_as_(input);
449 | fgrad_input.resize_as_(finput);
450 | fgrad_input.zero_();
451 | Tensor tw = weight.transpose(0, 1);
452 | if(tw.is_sparse() && !tw.is_coalesced()){
453 | tw = tw.coalesce();
454 | }
455 | const Tensor tweight = tw;
456 | const int64_t batch_size = input.size(0);
457 | #pragma omp parallel for
458 | for (int64_t t = 0; t < batch_size; t++) {
459 | Tensor grad_input_t = grad_input[t];
460 | Tensor grad_output_t = grad_output[t];
461 | Tensor fgrad_input_t = fgrad_input[t];
462 | slow_conv2d_backward_update_grad_input_frame(
463 | grad_input_t,
464 | grad_output_t,
465 | tweight,
466 | fgrad_input_t,
467 | kernel_height,
468 | kernel_width,
469 | stride_height,
470 | stride_width,
471 | pad_height,
472 | pad_width);
473 | }
474 | }
475 |
476 | void slow_conv2d_backward_parameters_frame(
477 | Tensor& grad_weight,
478 | Tensor& grad_bias,
479 | Tensor& grad_output,
480 | const Tensor& finput) {
481 | auto grad_output_2d = grad_output.view(
482 | {grad_output.size(0), grad_output.size(1) * grad_output.size(2)});
483 | if (grad_weight.defined()) {
484 | const Tensor tfinput = finput.transpose(0, 1);
485 | grad_weight.addmm_(grad_output_2d, tfinput);
486 | }
487 |
488 | if (grad_bias.defined()) {
489 | AT_DISPATCH_FLOATING_TYPES_AND(
490 | at::ScalarType::BFloat16,
491 | grad_output.scalar_type(),
492 | "slow_conv2d_backward_parameters",
493 | [&] {
494 | auto grad_output_2d_acc = grad_output_2d.accessor();
495 | auto grad_bias_acc = grad_bias.accessor();
496 | const auto sz = grad_output_2d.size(1);
497 | for (int64_t i = 0; i < grad_bias.size(0); i++) {
498 | scalar_t sum = 0;
499 | for (int64_t k = 0; k < sz; k++) {
500 | sum = sum + grad_output_2d_acc[i][k];
501 | }
502 | grad_bias_acc[i] = grad_bias_acc[i] + sum;
503 | }
504 | });
505 | }
506 | }
507 |
508 | static void slow_conv2d_backward_parameters_out_cpu_template(
509 | Tensor& grad_weight,
510 | Tensor& grad_bias,
511 | const Tensor& input_,
512 | const Tensor& grad_output_,
513 | const Tensor& finput,
514 | Tensor fgrad_input,
515 | IntArrayRef kernel_size,
516 | IntArrayRef stride,
517 | IntArrayRef padding) {
518 |
519 | const int64_t kernel_height = kernel_size[0];
520 | const int64_t kernel_width = kernel_size[1];
521 | const int64_t pad_height = padding[0];
522 | const int64_t pad_width = padding[1];
523 | const int64_t stride_height = stride[0];
524 | const int64_t stride_width = stride[1];
525 |
526 | Tensor grad_weight_2d = grad_weight;
527 |
528 | auto input = input_.contiguous();
529 | auto grad_output = grad_output_.contiguous();
530 |
531 | const int64_t batch_size = input.size(0);
532 | for (int64_t t = 0; t < batch_size; t++) {
533 | Tensor grad_output_t = grad_output[t];
534 | Tensor finput_t;
535 | if (grad_weight_2d.defined()) {
536 | finput_t = finput[t];
537 | }
538 |
539 | slow_conv2d_backward_parameters_frame(
540 | grad_weight_2d, grad_bias, grad_output_t, finput_t);
541 | }
542 | }
543 |
544 | std::tuple slow_conv2d_backward_out_cpu(
545 | Tensor& grad_input,
546 | Tensor& grad_weight,
547 | Tensor& grad_bias,
548 | const Tensor& grad_output,
549 | const Tensor& self,
550 | const Tensor& weight,
551 | IntArrayRef kernel_size,
552 | IntArrayRef stride,
553 | IntArrayRef padding,
554 | const Tensor& finput,
555 | const Tensor& fgrad_input) {
556 | if (grad_input.defined()) {
557 | slow_conv2d_backward_out_cpu_template(
558 | grad_input,
559 | grad_output,
560 | self,
561 | weight,
562 | finput,
563 | const_cast(fgrad_input),
564 | kernel_size,
565 | stride,
566 | padding);
567 | }
568 |
569 | if (grad_weight.defined()) {
570 | grad_weight.resize_(weight.sizes());
571 | grad_weight.zero_();
572 | }
573 |
574 | if (grad_bias.defined()) {
575 | grad_bias.resize_({grad_output.size(1)});
576 | grad_bias.zero_();
577 | }
578 | if (grad_weight.defined() || grad_bias.defined()) {
579 | slow_conv2d_backward_parameters_out_cpu_template(
580 | grad_weight,
581 | grad_bias,
582 | self,
583 | grad_output,
584 | finput,
585 | fgrad_input,
586 | kernel_size,
587 | stride,
588 | padding);
589 | }
590 |
591 | return std::tuple(
592 | grad_input, grad_weight, grad_bias);
593 | }
594 |
595 | std::tuple slow_conv2d_backward_cpu(
596 | const Tensor& grad_output,
597 | const Tensor& self,
598 | const Tensor& weight,
599 | IntArrayRef kernel_size,
600 | IntArrayRef stride,
601 | IntArrayRef padding,
602 | const Tensor& finput,
603 | const Tensor& fgrad_input,
604 | std::array output_mask) {
605 | Tensor grad_input;
606 | Tensor grad_weight;
607 | Tensor grad_bias;
608 |
609 | if (output_mask[0]) {
610 | grad_input = at::empty({0}, grad_output.options());
611 | }
612 |
613 | if (output_mask[1]) {
614 | grad_weight = at::empty({0}, grad_output.options());
615 | }
616 |
617 | if (output_mask[2]) {
618 | grad_bias = at::empty({0}, grad_output.options());
619 | }
620 | slow_conv2d_backward_out_cpu(
621 | grad_input,
622 | grad_weight,
623 | grad_bias,
624 | grad_output,
625 | self,
626 | weight,
627 | kernel_size,
628 | stride,
629 | padding,
630 | finput,
631 | fgrad_input);
632 | return std::make_tuple(grad_input, grad_weight, grad_bias);
633 | }
634 |
635 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
636 | m.def("forward", &slow_conv2d_forward_cpu, "Conv Forward");
637 | m.def("backward", &slow_conv2d_backward_cpu, "Conv Backward");
638 | }
639 |
640 |
641 |
--------------------------------------------------------------------------------
/mpl/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils.save_load import save, load
2 | import mpl.autograd
3 | import mpl.models
4 | import mpl.nn
5 | import mpl.optim
6 |
--------------------------------------------------------------------------------
/mpl/autograd/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiangyuang/ModelPruningLibrary/9c8ba5a3c5d118f37768d5d42254711f48d88745/mpl/autograd/__init__.py
--------------------------------------------------------------------------------
/mpl/autograd/functions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.sparse as sparse
3 | from warnings import warn
4 |
5 | sparse_conv2d_imported = True
6 | try:
7 | import sparse_conv2d
8 | except ImportError:
9 | warn("The sparse_conv2d module is NOT imported. Using default conv2d functions for compatibility.")
10 | sparse_conv2d_imported = False
11 |
12 |
13 | class AddmmFunction(torch.autograd.Function):
14 | @staticmethod
15 | def forward(ctx, bias, weight: sparse.FloatTensor, dense_weight_placeholder, inp):
16 | if bias is None:
17 | out = sparse.mm(weight, inp)
18 | else:
19 | out = sparse.addmm(bias, weight, inp)
20 | ctx.save_for_backward(bias, weight, inp)
21 | return out
22 |
23 | @staticmethod
24 | def backward(ctx, grad_output):
25 | bias, weight, inp = ctx.saved_tensors
26 | grad_bias = grad_input = None
27 | if bias is not None:
28 | grad_bias = grad_output.sum(1).reshape((-1, 1))
29 | grad_weight = grad_output.mm(inp.t())
30 | if ctx.needs_input_grad[3]:
31 | grad_input = torch.mm(weight.t(), grad_output)
32 |
33 | return grad_bias, None, grad_weight, grad_input
34 |
35 |
36 | if sparse_conv2d_imported:
37 | class SparseConv2dFunction(torch.autograd.Function):
38 | @staticmethod
39 | def forward(ctx, inp, weight, dense_weight_placeholder, kernel_size, bias, stride, padding):
40 | out, f_input, fgrad_input = sparse_conv2d.forward(inp, weight, kernel_size, bias, stride, padding)
41 | ctx.save_for_backward(inp, weight, f_input, fgrad_input)
42 | ctx.kernel_size = kernel_size
43 | ctx.stride = stride
44 | ctx.padding = padding
45 | return out
46 |
47 | @staticmethod
48 | def backward(ctx, grad_output):
49 | grad_input, grad_weight, grad_bias = sparse_conv2d.backward(grad_output,
50 | ctx.saved_tensors[0],
51 | ctx.saved_tensors[1],
52 | ctx.kernel_size,
53 | ctx.stride,
54 | ctx.padding,
55 | ctx.saved_tensors[2],
56 | ctx.saved_tensors[3],
57 | (True, True, True))
58 | return grad_input, None, grad_weight, None, grad_bias, None, None
59 |
60 |
61 | class DenseConv2dFunction(torch.autograd.Function):
62 | @staticmethod
63 | def forward(ctx, inp, weight, kernel_size, bias, stride, padding):
64 | weight2d = weight.data.reshape((weight.size(0), -1))
65 | out, f_input, fgrad_input = sparse_conv2d.forward(inp, weight2d, kernel_size, bias, stride, padding)
66 | ctx.save_for_backward(inp, weight2d, f_input, fgrad_input, weight)
67 | ctx.kernel_size = kernel_size
68 | ctx.stride = stride
69 | ctx.padding = padding
70 | return out
71 |
72 | @staticmethod
73 | def backward(ctx, grad_output):
74 | grad_input, grad_weight2d, grad_bias = sparse_conv2d.backward(grad_output,
75 | ctx.saved_tensors[0],
76 | ctx.saved_tensors[1],
77 | ctx.kernel_size,
78 | ctx.stride,
79 | ctx.padding,
80 | ctx.saved_tensors[2],
81 | ctx.saved_tensors[3],
82 | (True, True, True))
83 | grad_weight = grad_weight2d.reshape_as(ctx.saved_tensors[4])
84 | return grad_input, grad_weight, None, grad_bias, None, None
85 |
86 | else:
87 | class SparseConv2dFunction(torch.autograd.Function):
88 | @staticmethod
89 | def apply(inp, weight, dense_weight_placeholder, kernel_size, bias, stride, padding):
90 | size_4d = (weight.size(0), -1, *kernel_size)
91 | with torch.no_grad():
92 | dense_weight_placeholder.zero_()
93 | dense_weight_placeholder.add_(weight.to_dense())
94 | return torch.nn.functional.conv2d(inp, dense_weight_placeholder.view(size_4d), bias, stride, padding)
95 |
96 |
97 | class DenseConv2dFunction(torch.autograd.Function):
98 | @staticmethod
99 | def apply(inp, weight, kernel_size, bias, stride, padding):
100 | size_4d = (weight.size(0), -1, *kernel_size)
101 | return torch.nn.functional.conv2d(inp, weight.reshape(size_4d), bias, stride, padding)
102 |
--------------------------------------------------------------------------------
/mpl/models/__init__.py:
--------------------------------------------------------------------------------
1 | # networks implemented by torchvision
2 | from .alexnet import *
3 | from .resnet import *
4 | from .vgg import *
5 | from .squeezenet import *
6 | from .inception import *
7 | from .densenet import *
8 | from .googlenet import *
9 | # from .mobilenet import *
10 | # from .mnasnet import *
11 | # from .shufflenetv2 import *
12 |
13 | # additional model implementations
14 | from .lenet import *
15 | from .leaf import *
16 |
--------------------------------------------------------------------------------
/mpl/models/alexnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.models
3 | from typing import Any
4 |
5 | from .base_model import BaseModel
6 |
7 | __all__ = ['AlexNet', 'alexnet']
8 |
9 |
10 | class AlexNet(BaseModel):
11 | def __init__(self, model: torchvision.models.AlexNet):
12 | super(AlexNet, self).__init__()
13 | self.clone_from_model(model)
14 | self.process_layers()
15 |
16 | def process_layers(self):
17 | self.collect_prunable_layers()
18 | self.convert_eligible_layers()
19 | self.collect_prunable_layers()
20 |
21 | def forward(self, x: torch.Tensor) -> torch.Tensor:
22 | x = self.features(x)
23 | x = self.avgpool(x)
24 | x = torch.flatten(x, 1)
25 | x = self.classifier(x)
26 | return x
27 |
28 |
29 | def alexnet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> AlexNet:
30 | return AlexNet(torchvision.models.alexnet(pretrained, progress, **kwargs))
31 |
--------------------------------------------------------------------------------
/mpl/models/base_model.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Union, Sized, List, Tuple
3 | from copy import deepcopy
4 |
5 | import torch
6 | from torch import nn as nn
7 |
8 | from ..nn.linear import DenseLinear
9 | from ..nn.conv2d import DenseConv2d
10 | from .utils import collect_leaf_modules, is_parameterized
11 |
12 |
13 | class BaseModel(nn.Module, ABC):
14 | def __init__(self):
15 | super(BaseModel, self).__init__()
16 |
17 | self.prunable_layers: list = []
18 | self.prunable_layer_prefixes: list = []
19 |
20 | def clone_from_model(self, original_model: nn.Module = None):
21 | # copying all submodules from original model
22 | for name, module in original_model._modules.items():
23 | self.add_module(name, deepcopy(module))
24 |
25 | def collect_prunable_layers(self) -> None:
26 | self.prunable_layers, self.prunable_layer_prefixes = self.find_layers(lambda x: is_parameterized(x))
27 |
28 | def convert_eligible_layers(self):
29 | # changing all conv2d and linear layers to customized ones
30 | for module_name, old_module in zip(self.prunable_layer_prefixes, self.prunable_layers):
31 | if isinstance(old_module, nn.Linear):
32 | self.set_module_by_name(module_name, DenseLinear.from_linear(old_module))
33 | elif isinstance(old_module, nn.Conv2d):
34 | self.set_module_by_name(module_name, DenseConv2d.from_conv2d(old_module))
35 |
36 | def find_layers(self, criterion) -> Tuple[List, List]:
37 | layers, names = [], []
38 | collect_leaf_modules(self, criterion, layers, names)
39 | return layers, names
40 |
41 | @abstractmethod
42 | def forward(self, inputs) -> torch.Tensor:
43 | pass
44 |
45 | def prune_by_threshold(self, thr_arg: Union[int, float, Sized]):
46 | prunable_layers = self.prunable_layers
47 | if isinstance(thr_arg, Sized):
48 | assert len(prunable_layers) == len(thr_arg)
49 | else:
50 | thr_arg = [thr_arg] * len(prunable_layers)
51 | for thr, layer in zip(thr_arg, prunable_layers):
52 | if thr is not None:
53 | layer.prune_by_threshold(thr)
54 |
55 | return self
56 |
57 | def prune_by_rank(self, rank_arg: Union[int, float, Sized]):
58 | prunable_layers = self.prunable_layers
59 | if isinstance(rank_arg, Sized):
60 | assert len(prunable_layers) == len(rank_arg)
61 | else:
62 | rank_arg = [rank_arg] * len(prunable_layers)
63 | for rank, layer in zip(rank_arg, prunable_layers):
64 | if rank is not None:
65 | layer.prune_by_rank(rank)
66 |
67 | return self
68 |
69 | def prune_by_pct(self, pct_arg: Union[int, float, Sized]):
70 | prunable_layers = self.prunable_layers
71 | if isinstance(pct_arg, Sized):
72 | assert len(prunable_layers) == len(pct_arg)
73 | else:
74 | pct_arg = [pct_arg] * len(prunable_layers)
75 | for pct, layer in zip(pct_arg, prunable_layers):
76 | if pct is not None:
77 | layer.prune_by_pct(pct)
78 |
79 | return self
80 |
81 | def random_prune_by_pct(self, pct_arg: Union[int, float, Sized]):
82 | prunable_layers = self.prunable_layers
83 | if isinstance(pct_arg, Sized):
84 | assert len(prunable_layers) == len(pct_arg)
85 | else:
86 | pct_arg = [pct_arg] * len(prunable_layers)
87 | for pct, layer in zip(pct_arg, prunable_layers):
88 | if pct is not None:
89 | layer.random_prune_by_pct(pct)
90 |
91 | return self
92 |
93 | def calc_num_prunable_params(self, count_bias=True, display=False):
94 | total_param_in_use = 0
95 | total_param = 0
96 | for layer, layer_prefix in zip(self.prunable_layers, self.prunable_layer_prefixes):
97 | num_bias = layer.bias.nelement() if layer.bias is not None and count_bias else 0
98 | num_weight = layer.num_weight
99 | num_params_in_use = num_weight + num_bias
100 | num_params = layer.weight.nelement() + num_bias
101 | total_param_in_use += num_params_in_use
102 | total_param += num_params
103 |
104 | if display:
105 | print("Layer name: {}. remaining/all: {}/{} = {}".format(layer_prefix, num_params_in_use, num_params,
106 | num_params_in_use / num_params))
107 | if display:
108 | print("Total: remaining/all: {}/{} = {}".format(total_param_in_use, total_param,
109 | total_param_in_use / total_param))
110 | return total_param_in_use, total_param
111 |
112 | def nnz(self, count_bias=True):
113 | # number of parameters in use in prunable layers
114 | return self.calc_num_prunable_params(count_bias=count_bias)[0]
115 |
116 | def nelement(self, count_bias=True):
117 | # number of all parameters in prunable layers
118 | return self.calc_num_prunable_params(count_bias=count_bias)[1]
119 |
120 | def density(self, count_bias=True):
121 | total_param_in_use, total_param = self.calc_num_prunable_params(count_bias=count_bias)
122 | return total_param_in_use / total_param
123 |
124 | def _get_module_by_list(self, module_names: List):
125 | module = self
126 | for name in module_names:
127 | module = getattr(module, name)
128 | return module
129 |
130 | def get_module_by_name(self, module_name: str):
131 | return self._get_module_by_list(module_name.split('.'))
132 |
133 | def set_module_by_name(self, module_name: str, new_module):
134 | splits = module_name.split('.')
135 | self._get_module_by_list(splits[:-1]).__setattr__(splits[-1], new_module)
136 |
137 | def get_mask_by_name(self, param_name: str):
138 | if param_name.endswith("bias"): # todo
139 | return None
140 | module = self._get_module_by_list(param_name.split('.')[:-1])
141 | return module.mask if hasattr(module, "mask") else None
142 |
143 | @torch.no_grad()
144 | def reinit_from_model(self, final_model):
145 | assert isinstance(final_model, self.__class__)
146 | for self_layer, layer in zip(self.prunable_layers, final_model.prunable_layers):
147 | self_layer.mask = layer.mask.clone().to(self_layer.mask.device)
148 |
149 | def to_sparse(self):
150 | self_copy = deepcopy(self)
151 | for module_name, old_module in zip(self.prunable_layer_prefixes, self.prunable_layers):
152 | self_copy.set_module_by_name(module_name, old_module.to_sparse())
153 | self.collect_prunable_layers()
154 | return self_copy
155 |
156 | def to(self, *args, **kwargs):
157 | device = torch._C._nn._parse_to(*args, **kwargs)[0]
158 | if device is not None:
159 | # move masks to device
160 | for m in self.prunable_layers:
161 | m.move_data(device)
162 | return super(BaseModel, self).to(*args, **kwargs)
163 |
--------------------------------------------------------------------------------
/mpl/models/densenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import Tensor
4 | import torchvision.models
5 | from typing import Any
6 |
7 | from .base_model import BaseModel
8 |
9 | __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']
10 |
11 |
12 | class DenseNet(BaseModel):
13 | def __init__(self, model: torchvision.models.DenseNet):
14 | super(DenseNet, self).__init__()
15 | self.clone_from_model(model)
16 | self.process_layers()
17 |
18 | def collect_prunable_layers(self) -> None:
19 | """
20 | removed transition layers from prunable layers
21 | """
22 | super(DenseNet, self).collect_prunable_layers()
23 | keep_indices = []
24 | for layer_idx, name in enumerate(self.prunable_layer_prefixes):
25 | if "transition" not in name:
26 | keep_indices.append(layer_idx)
27 |
28 | self.prunable_layer_prefixes = [self.prunable_layer_prefixes[idx] for idx in keep_indices]
29 | self.prunable_layers = [self.prunable_layers[idx] for idx in keep_indices]
30 |
31 | def process_layers(self):
32 | self.collect_prunable_layers()
33 | self.convert_eligible_layers()
34 | self.collect_prunable_layers()
35 |
36 | def forward(self, x: Tensor) -> Tensor:
37 | features = self.features(x)
38 | out = F.relu(features, inplace=True)
39 | out = F.adaptive_avg_pool2d(out, (1, 1))
40 | out = torch.flatten(out, 1)
41 | out = self.classifier(out)
42 | return out
43 |
44 |
45 | def densenet121(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet:
46 | return DenseNet(torchvision.models.densenet121(pretrained, progress, **kwargs))
47 |
48 |
49 | def densenet161(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet:
50 | return DenseNet(torchvision.models.densenet161(pretrained, progress, **kwargs))
51 |
52 |
53 | def densenet169(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet:
54 | return DenseNet(torchvision.models.densenet169(pretrained, progress, **kwargs))
55 |
56 |
57 | def densenet201(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet:
58 | return DenseNet(torchvision.models.densenet201(pretrained, progress, **kwargs))
59 |
--------------------------------------------------------------------------------
/mpl/models/googlenet.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | import torch
3 | import torch.nn as nn
4 | from torch import Tensor
5 | import torchvision.models
6 | from torchvision.models import GoogLeNetOutputs, _GoogLeNetOutputs
7 | from typing import Optional, Tuple, Any
8 |
9 | from .base_model import BaseModel
10 |
11 | __all__ = ['GoogLeNet', 'googlenet']
12 |
13 |
14 | class GoogLeNet(BaseModel):
15 | def __init__(self, model: torchvision.models.GoogLeNet):
16 | super(GoogLeNet, self).__init__()
17 | self.clone_from_model(model)
18 |
19 | self.aux_logits = model.aux_logits
20 | self.transform_input = model.transform_input
21 | if not hasattr(self, "aux1"):
22 | self.aux1 = model.aux1
23 | if not hasattr(self, "aux2"):
24 | self.aux1 = model.aux2
25 |
26 | self.process_layers()
27 |
28 | def process_layers(self):
29 | self.collect_prunable_layers()
30 | self.convert_eligible_layers()
31 | self.collect_prunable_layers()
32 |
33 | def _initialize_weights(self) -> None:
34 | for m in self.modules():
35 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
36 | import scipy.stats as stats
37 | X = stats.truncnorm(-2, 2, scale=0.01)
38 | values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype)
39 | values = values.view(m.weight.size())
40 | with torch.no_grad():
41 | m.weight.copy_(values)
42 | elif isinstance(m, nn.BatchNorm2d):
43 | nn.init.constant_(m.weight, 1)
44 | nn.init.constant_(m.bias, 0)
45 |
46 | def _transform_input(self, x: Tensor) -> Tensor:
47 | if self.transform_input:
48 | x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
49 | x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
50 | x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
51 | x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
52 | return x
53 |
54 | def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
55 | # N x 3 x 224 x 224
56 | x = self.conv1(x)
57 | # N x 64 x 112 x 112
58 | x = self.maxpool1(x)
59 | # N x 64 x 56 x 56
60 | x = self.conv2(x)
61 | # N x 64 x 56 x 56
62 | x = self.conv3(x)
63 | # N x 192 x 56 x 56
64 | x = self.maxpool2(x)
65 |
66 | # N x 192 x 28 x 28
67 | x = self.inception3a(x)
68 | # N x 256 x 28 x 28
69 | x = self.inception3b(x)
70 | # N x 480 x 28 x 28
71 | x = self.maxpool3(x)
72 | # N x 480 x 14 x 14
73 | x = self.inception4a(x)
74 | # N x 512 x 14 x 14
75 | aux1: Optional[Tensor] = None
76 | if self.aux1 is not None:
77 | if self.training:
78 | aux1 = self.aux1(x)
79 |
80 | x = self.inception4b(x)
81 | # N x 512 x 14 x 14
82 | x = self.inception4c(x)
83 | # N x 512 x 14 x 14
84 | x = self.inception4d(x)
85 | # N x 528 x 14 x 14
86 | aux2: Optional[Tensor] = None
87 | if self.aux2 is not None:
88 | if self.training:
89 | aux2 = self.aux2(x)
90 |
91 | x = self.inception4e(x)
92 | # N x 832 x 14 x 14
93 | x = self.maxpool4(x)
94 | # N x 832 x 7 x 7
95 | x = self.inception5a(x)
96 | # N x 832 x 7 x 7
97 | x = self.inception5b(x)
98 | # N x 1024 x 7 x 7
99 |
100 | x = self.avgpool(x)
101 | # N x 1024 x 1 x 1
102 | x = torch.flatten(x, 1)
103 | # N x 1024
104 | x = self.dropout(x)
105 | x = self.fc(x)
106 | # N x 1000 (num_classes)
107 | return x, aux2, aux1
108 |
109 | @torch.jit.unused
110 | def eager_outputs(self, x: Tensor, aux2: Tensor, aux1: Optional[Tensor]) -> GoogLeNetOutputs:
111 | if self.training and self.aux_logits:
112 | return _GoogLeNetOutputs(x, aux2, aux1)
113 | else:
114 | return x # type: ignore[return-value]
115 |
116 | def forward(self, x: Tensor) -> GoogLeNetOutputs:
117 | x = self._transform_input(x)
118 | x, aux1, aux2 = self._forward(x)
119 | aux_defined = self.training and self.aux_logits
120 | if torch.jit.is_scripting():
121 | if not aux_defined:
122 | warnings.warn("Scripted GoogleNet always returns GoogleNetOutputs Tuple")
123 | return GoogLeNetOutputs(x, aux2, aux1)
124 | else:
125 | return self.eager_outputs(x, aux2, aux1)
126 |
127 |
128 | def googlenet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> GoogLeNet:
129 | return GoogLeNet(torchvision.models.googlenet(pretrained, progress, **kwargs))
130 |
--------------------------------------------------------------------------------
/mpl/models/inception.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | import torch
3 | from torch import Tensor
4 | import torchvision.models
5 | from torchvision.models.inception import InceptionOutputs
6 | from typing import Any, Optional, Tuple
7 |
8 | from .base_model import BaseModel
9 |
10 | __all__ = ['Inception3', 'inception_v3']
11 |
12 |
13 | class Inception3(BaseModel):
14 | def __init__(self, model: torchvision.models.Inception3):
15 | super(Inception3, self).__init__()
16 | self.clone_from_model(model)
17 |
18 | self.aux_logits = model.aux_logits
19 | self.transform_input = model.transform_input
20 |
21 | self.process_layers()
22 |
23 | def process_layers(self):
24 | self.collect_prunable_layers()
25 | self.convert_eligible_layers()
26 | self.collect_prunable_layers()
27 |
28 | def _transform_input(self, x: Tensor) -> Tensor:
29 | if self.transform_input:
30 | x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
31 | x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
32 | x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
33 | x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
34 | return x
35 |
36 | def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
37 | # N x 3 x 299 x 299
38 | x = self.Conv2d_1a_3x3(x)
39 | # N x 32 x 149 x 149
40 | x = self.Conv2d_2a_3x3(x)
41 | # N x 32 x 147 x 147
42 | x = self.Conv2d_2b_3x3(x)
43 | # N x 64 x 147 x 147
44 | x = self.maxpool1(x)
45 | # N x 64 x 73 x 73
46 | x = self.Conv2d_3b_1x1(x)
47 | # N x 80 x 73 x 73
48 | x = self.Conv2d_4a_3x3(x)
49 | # N x 192 x 71 x 71
50 | x = self.maxpool2(x)
51 | # N x 192 x 35 x 35
52 | x = self.Mixed_5b(x)
53 | # N x 256 x 35 x 35
54 | x = self.Mixed_5c(x)
55 | # N x 288 x 35 x 35
56 | x = self.Mixed_5d(x)
57 | # N x 288 x 35 x 35
58 | x = self.Mixed_6a(x)
59 | # N x 768 x 17 x 17
60 | x = self.Mixed_6b(x)
61 | # N x 768 x 17 x 17
62 | x = self.Mixed_6c(x)
63 | # N x 768 x 17 x 17
64 | x = self.Mixed_6d(x)
65 | # N x 768 x 17 x 17
66 | x = self.Mixed_6e(x)
67 | # N x 768 x 17 x 17
68 | aux: Optional[Tensor] = None
69 | if self.AuxLogits is not None:
70 | if self.training:
71 | aux = self.AuxLogits(x)
72 | # N x 768 x 17 x 17
73 | x = self.Mixed_7a(x)
74 | # N x 1280 x 8 x 8
75 | x = self.Mixed_7b(x)
76 | # N x 2048 x 8 x 8
77 | x = self.Mixed_7c(x)
78 | # N x 2048 x 8 x 8
79 | # Adaptive average pooling
80 | x = self.avgpool(x)
81 | # N x 2048 x 1 x 1
82 | x = self.dropout(x)
83 | # N x 2048 x 1 x 1
84 | x = torch.flatten(x, 1)
85 | # N x 2048
86 | x = self.fc(x)
87 | # N x 1000 (num_classes)
88 | return x, aux
89 |
90 | @torch.jit.unused
91 | def eager_outputs(self, x: Tensor, aux: Optional[Tensor]) -> InceptionOutputs:
92 | if self.training and self.aux_logits:
93 | return InceptionOutputs(x, aux)
94 | else:
95 | return x # type: ignore[return-value]
96 |
97 | def forward(self, x: Tensor) -> InceptionOutputs:
98 | x = self._transform_input(x)
99 | x, aux = self._forward(x)
100 | aux_defined = self.training and self.aux_logits
101 | if torch.jit.is_scripting():
102 | if not aux_defined:
103 | warnings.warn("Scripted Inception3 always returns Inception3 Tuple")
104 | return InceptionOutputs(x, aux)
105 | else:
106 | return self.eager_outputs(x, aux)
107 |
108 |
109 | def inception_v3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> "Inception3":
110 | return Inception3(torchvision.models.inception_v3(pretrained, progress, **kwargs))
111 |
--------------------------------------------------------------------------------
/mpl/models/leaf.py:
--------------------------------------------------------------------------------
1 | from torch import nn as nn
2 |
3 | from .base_model import BaseModel
4 | from ..nn.conv2d import DenseConv2d
5 | from ..nn.linear import DenseLinear
6 |
7 | __all__ = ["Conv2", "conv2", "Conv4", "conv4"]
8 |
9 |
10 | class Conv2(BaseModel):
11 | def __init__(self):
12 | super(Conv2, self).__init__()
13 | self.features = nn.Sequential(DenseConv2d(1, 32, kernel_size=5, padding=2), # 32x28x28
14 | nn.ReLU(inplace=True),
15 | nn.MaxPool2d(2, stride=2), # 32x14x14
16 | DenseConv2d(32, 64, kernel_size=5, padding=2), # 64x14x14
17 | nn.ReLU(inplace=True),
18 | nn.MaxPool2d(2, stride=2)) # 64x7x7
19 |
20 | self.classifier = nn.Sequential(DenseLinear(64 * 7 * 7, 2048),
21 | nn.ReLU(inplace=True),
22 | DenseLinear(2048, 62))
23 | self.collect_prunable_layers()
24 |
25 | def forward(self, inp):
26 | out = self.features(inp)
27 | out = out.view(out.size(0), -1)
28 | out = self.classifier(out)
29 | return out
30 |
31 |
32 | class Conv4(BaseModel):
33 | def __init__(self):
34 | super(Conv4, self).__init__()
35 | self.features = nn.Sequential(DenseConv2d(3, 32, kernel_size=3, padding=1),
36 | nn.BatchNorm2d(32),
37 | nn.MaxPool2d(2),
38 | DenseConv2d(32, 32, kernel_size=3, padding=1),
39 | nn.BatchNorm2d(32),
40 | nn.MaxPool2d(2),
41 | DenseConv2d(32, 32, kernel_size=3, padding=2),
42 | nn.BatchNorm2d(32),
43 | nn.MaxPool2d(2),
44 | DenseConv2d(32, 32, kernel_size=3, padding=2),
45 | nn.BatchNorm2d(32),
46 | nn.MaxPool2d(2))
47 |
48 | self.classifier = DenseLinear(in_features=32 * 6 * 6, out_features=2)
49 |
50 | def forward(self, inp):
51 | out = self.features(inp)
52 | out = out.view(out.size(0), -1)
53 | out = self.classifier(out)
54 | return out
55 |
56 |
57 | def conv2() -> Conv2:
58 | return Conv2()
59 |
60 |
61 | def conv4() -> Conv4:
62 | return Conv4()
63 |
64 | # TODO: define pretrain etc.
65 |
--------------------------------------------------------------------------------
/mpl/models/lenet.py:
--------------------------------------------------------------------------------
1 | from torch import nn as nn
2 | from .base_model import BaseModel
3 |
4 | from ..nn.linear import DenseLinear
5 |
6 | __all__ = ["LeNet5", "lenet5"]
7 |
8 |
9 | class LeNet5(BaseModel):
10 | def __init__(self):
11 | super(LeNet5, self).__init__()
12 | self.classifier = nn.Sequential(DenseLinear(784, 300),
13 | nn.ReLU(inplace=True),
14 | DenseLinear(300, 100),
15 | nn.ReLU(inplace=True),
16 | DenseLinear(100, 10))
17 |
18 | self.collect_prunable_layers()
19 |
20 | def forward(self, inputs):
21 | return self.classifier(inputs)
22 |
23 |
24 | def lenet5() -> LeNet5:
25 | return LeNet5()
26 |
--------------------------------------------------------------------------------
/mpl/models/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | from torch import nn
4 | import torchvision.models
5 | from torchvision.models.resnet import conv1x1, BasicBlock, Bottleneck
6 | from typing import Type, Any, Union
7 |
8 | from .base_model import BaseModel
9 |
10 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
11 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
12 | 'wide_resnet50_2', 'wide_resnet101_2']
13 |
14 |
15 | class ResNet(BaseModel):
16 | def __init__(self, model: torchvision.models.ResNet):
17 | super(ResNet, self).__init__()
18 | self.clone_from_model(model)
19 | self.process_layers()
20 |
21 | def process_layers(self):
22 | self.collect_prunable_layers()
23 | self.convert_eligible_layers()
24 | self.collect_prunable_layers()
25 |
26 | def collect_prunable_layers(self) -> None:
27 | """
28 | removed transition layers from prunable layers
29 | """
30 | super(ResNet, self).collect_prunable_layers()
31 | keep_indices = []
32 | for layer_idx, name in enumerate(self.prunable_layer_prefixes):
33 | if "downsample" not in name:
34 | keep_indices.append(layer_idx)
35 |
36 | self.prunable_layer_prefixes = [self.prunable_layer_prefixes[idx] for idx in keep_indices]
37 | self.prunable_layers = [self.prunable_layers[idx] for idx in keep_indices]
38 |
39 | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
40 | stride: int = 1, dilate: bool = False) -> nn.Sequential:
41 | norm_layer = self._norm_layer
42 | downsample = None
43 | previous_dilation = self.dilation
44 | if dilate:
45 | self.dilation *= stride
46 | stride = 1
47 | if stride != 1 or self.inplanes != planes * block.expansion:
48 | downsample = nn.Sequential(
49 | conv1x1(self.inplanes, planes * block.expansion, stride),
50 | norm_layer(planes * block.expansion),
51 | )
52 |
53 | layers = []
54 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
55 | self.base_width, previous_dilation, norm_layer))
56 | self.inplanes = planes * block.expansion
57 | for _ in range(1, blocks):
58 | layers.append(block(self.inplanes, planes, groups=self.groups,
59 | base_width=self.base_width, dilation=self.dilation,
60 | norm_layer=norm_layer))
61 |
62 | return nn.Sequential(*layers)
63 |
64 | def _forward_impl(self, x: Tensor) -> Tensor:
65 | # See note [TorchScript super()]
66 | x = self.conv1(x)
67 | x = self.bn1(x)
68 | x = self.relu(x)
69 | x = self.maxpool(x)
70 |
71 | x = self.layer1(x)
72 | x = self.layer2(x)
73 | x = self.layer3(x)
74 | x = self.layer4(x)
75 |
76 | x = self.avgpool(x)
77 | x = torch.flatten(x, 1)
78 | x = self.fc(x)
79 |
80 | return x
81 |
82 | def forward(self, x: Tensor) -> Tensor:
83 | return self._forward_impl(x)
84 |
85 |
86 | def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
87 | return ResNet(torchvision.models.resnet18(pretrained, progress, **kwargs))
88 |
89 |
90 | def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
91 | return ResNet(torchvision.models.resnet34(pretrained, progress, **kwargs))
92 |
93 |
94 | def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
95 | return ResNet(torchvision.models.resnet50(pretrained, progress, **kwargs))
96 |
97 |
98 | def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
99 | return ResNet(torchvision.models.resnet101(pretrained, progress, **kwargs))
100 |
101 |
102 | def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
103 | return ResNet(torchvision.models.resnet152(pretrained, progress, **kwargs))
104 |
105 |
106 | def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
107 | return ResNet(torchvision.models.resnext50_32x4d(pretrained, progress, **kwargs))
108 |
109 |
110 | def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
111 | return ResNet(torchvision.models.resnext101_32x8d(pretrained, progress, **kwargs))
112 |
113 |
114 | def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
115 | return ResNet(torchvision.models.wide_resnet50_2(pretrained, progress, **kwargs))
116 |
117 |
118 | def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
119 | return ResNet(torchvision.models.wide_resnet101_2(pretrained, progress, **kwargs))
120 |
--------------------------------------------------------------------------------
/mpl/models/squeezenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.models
3 | from typing import Any
4 |
5 | from .base_model import BaseModel
6 |
7 |
8 | class SqueezeNet(BaseModel):
9 | def __init__(self, model: torchvision.models.SqueezeNet):
10 | super(SqueezeNet, self).__init__()
11 | self.clone_from_model(model)
12 | self.process_layers()
13 |
14 | def process_layers(self):
15 | self.collect_prunable_layers()
16 | self.convert_eligible_layers()
17 | self.collect_prunable_layers()
18 |
19 | def forward(self, x: torch.Tensor) -> torch.Tensor:
20 | x = self.features(x)
21 | x = self.classifier(x)
22 | return torch.flatten(x, 1)
23 |
24 |
25 | def squeezenet1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SqueezeNet:
26 | return SqueezeNet(torchvision.models.squeezenet1_0(pretrained, progress, **kwargs))
27 |
28 |
29 | def squeezenet1_1(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SqueezeNet:
30 | return SqueezeNet(torchvision.models.squeezenet1_1(pretrained, progress, **kwargs))
31 |
--------------------------------------------------------------------------------
/mpl/models/utils.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from ..nn.conv2d import DenseConv2d, SparseConv2d
3 | from ..nn.linear import DenseLinear, SparseLinear
4 |
5 | from typing import Callable
6 |
7 |
8 | def is_prunable_fc(layer):
9 | return isinstance(layer, DenseLinear) or isinstance(layer, SparseLinear)
10 |
11 |
12 | def is_prunable_conv(layer):
13 | return isinstance(layer, DenseConv2d) or isinstance(layer, SparseConv2d)
14 |
15 |
16 | def is_prunable(layer):
17 | return is_prunable_fc(layer) or is_prunable_conv(layer)
18 |
19 |
20 | def is_parameterized(layer):
21 | return is_prunable(layer) or isinstance(layer, nn.Linear) or isinstance(layer, nn.Conv2d)
22 |
23 |
24 | def collect_leaf_modules(module, criterion: Callable, layers: list, names: list, prefix: str = ""):
25 | for key, submodule in module._modules.items():
26 | new_prefix = prefix
27 | if prefix != "":
28 | new_prefix += '.'
29 | new_prefix += key
30 | # is leaf and satisfies criterion
31 | if submodule is not None:
32 | if len(submodule._modules.keys()) == 0 and criterion(submodule):
33 | layers.append(submodule)
34 | names.append(new_prefix)
35 | collect_leaf_modules(submodule, criterion, layers, names, prefix=new_prefix)
36 |
--------------------------------------------------------------------------------
/mpl/models/vgg.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision.models
4 | from typing import Any
5 |
6 | from .base_model import BaseModel
7 |
8 | __all__ = [
9 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
10 | 'vgg19_bn', 'vgg19',
11 | ]
12 |
13 |
14 | class VGG(BaseModel):
15 | def __init__(self, model: torchvision.models.VGG):
16 | super(VGG, self).__init__()
17 | self.clone_from_model(model)
18 | self.process_layers()
19 |
20 | def process_layers(self):
21 | self.collect_prunable_layers()
22 | self.convert_eligible_layers()
23 | self.collect_prunable_layers()
24 |
25 | def forward(self, x):
26 | x = self.features(x)
27 | x = self.avgpool(x)
28 | x = torch.flatten(x, 1)
29 | x = self.classifier(x)
30 | return x
31 |
32 |
33 | def vgg11(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
34 | return VGG(torchvision.models.vgg11(pretrained, progress, **kwargs))
35 |
36 |
37 | def vgg11_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
38 | return VGG(torchvision.models.vgg11_bn(pretrained, progress, **kwargs))
39 |
40 |
41 | def vgg13(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
42 | return VGG(torchvision.models.vgg13(pretrained, progress, **kwargs))
43 |
44 |
45 | def vgg13_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
46 | return VGG(torchvision.models.vgg13_bn(pretrained, progress, **kwargs))
47 |
48 |
49 | def vgg16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
50 | return VGG(torchvision.models.vgg16(pretrained, progress, **kwargs))
51 |
52 |
53 | def vgg16_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
54 | return VGG(torchvision.models.vgg16_bn(pretrained, progress, **kwargs))
55 |
56 |
57 | def vgg19(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
58 | return VGG(torchvision.models.vgg19(pretrained, progress, **kwargs))
59 |
60 |
61 | def vgg19_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
62 | return VGG(torchvision.models.vgg19_bn(pretrained, progress, **kwargs))
63 |
--------------------------------------------------------------------------------
/mpl/nn/__init__.py:
--------------------------------------------------------------------------------
1 | from .conv2d import *
2 | from .linear import *
3 |
--------------------------------------------------------------------------------
/mpl/nn/conv2d.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.modules.utils import _pair
4 | from ..autograd.functions import SparseConv2dFunction
5 |
6 | from typing import Union, Tuple
7 |
8 | __all__ = ["SparseConv2d", "DenseConv2d"]
9 |
10 |
11 | class SparseConv2d(nn.Module):
12 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, weight, bias, mask):
13 | super(SparseConv2d, self).__init__()
14 | kernel_size = _pair(kernel_size)
15 | stride = _pair(stride)
16 | padding = _pair(padding)
17 | self.in_channels = in_channels
18 | self.out_channels = out_channels
19 | self.kernel_size = kernel_size
20 | self.stride = stride
21 | self.padding = padding
22 |
23 | self.weight = nn.Parameter(weight.clone(), requires_grad=False)
24 | self.mask = mask.clone()
25 | self.dense_weight_placeholder = nn.Parameter(torch.empty(size=self.weight.size()))
26 | self.dense_weight_placeholder.is_placeholder = True
27 |
28 | self.weight.dense = self.dense_weight_placeholder
29 | self.weight.mask = self.mask
30 | self.weight.is_sparse_param = True
31 |
32 | if bias is None:
33 | self.bias = torch.zeros(size=(out_channels,))
34 | else:
35 | self.bias = nn.Parameter(bias.clone())
36 |
37 | def forward(self, inp):
38 | return SparseConv2dFunction.apply(inp, self.weight, self.dense_weight_placeholder, self.kernel_size,
39 | self.bias, self.stride, self.padding)
40 |
41 | def __repr__(self):
42 | return "SparseConv2d({}, {}, kernel_size={}, " \
43 | "stride={}, padding={}, bias={})".format(self.in_channels,
44 | self.out_channels,
45 | self.kernel_size,
46 | self.stride,
47 | self.padding,
48 | not torch.equal(self.bias, torch.zeros_like(self.bias)))
49 |
50 | def __str__(self):
51 | return self.__repr__()
52 |
53 |
54 | class DenseConv2d(nn.Conv2d):
55 | def __init__(self, in_channels, out_channels, kernel_size, stride: Union[int, Tuple] = 1,
56 | padding: Union[int, Tuple] = 0,
57 | dilation: Union[int, Tuple] = 1, groups=1, bias=True,
58 | padding_mode='zeros'):
59 | max_dilation = dilation if isinstance(dilation, int) else max(dilation)
60 | if max_dilation > 1:
61 | raise NotImplementedError("Dilation > 1 not implemented")
62 | if groups > 1:
63 | raise NotImplementedError("Groups > 1 not implemented")
64 | super(DenseConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding,
65 | dilation, groups, bias, padding_mode)
66 | self.mask = torch.ones_like(self.weight, dtype=torch.bool, device=self.weight.device)
67 |
68 | def forward(self, inp):
69 | return self._conv_forward(inp, self.weight * self.mask)
70 |
71 | def prune_by_threshold(self, thr):
72 | self.mask *= (torch.abs(self.weight) >= thr)
73 |
74 | def prune_by_rank(self, rank):
75 | if rank == 0:
76 | return
77 | weights_val = self.weight[self.mask == 1]
78 | sorted_abs_weights = torch.sort(torch.abs(weights_val))[0]
79 | thr = sorted_abs_weights[rank]
80 | self.prune_by_threshold(thr)
81 |
82 | def prune_by_pct(self, pct):
83 | if pct == 0:
84 | return
85 | prune_idx = int(self.num_weight * pct)
86 | self.prune_by_rank(prune_idx)
87 |
88 | def random_prune_by_pct(self, pct):
89 | prune_idx = int(self.num_weight * pct)
90 | rand = torch.rand_like(self.mask, device=self.mask.device)
91 | rand_val = rand[self.mask == 1]
92 | sorted_abs_rand = torch.sort(rand_val)[0]
93 | thr = sorted_abs_rand[prune_idx]
94 | self.mask *= (rand >= thr)
95 |
96 | @classmethod
97 | def from_conv2d(cls, conv2d_module: nn.Conv2d):
98 | new_conv2d = cls(conv2d_module.in_channels, conv2d_module.out_channels, conv2d_module.kernel_size,
99 | conv2d_module.stride, conv2d_module.padding, conv2d_module.dilation, conv2d_module.groups,
100 | bias=conv2d_module.bias is not None,
101 | padding_mode=conv2d_module.padding_mode)
102 |
103 | new_conv2d.weight = nn.Parameter(conv2d_module.weight.clone())
104 | if conv2d_module.bias is not None:
105 | new_conv2d.bias = nn.Parameter(conv2d_module.bias.clone())
106 |
107 | return new_conv2d
108 |
109 | # This method will always remove zero elements, even if you wish to keep zeros in the sparse form
110 | def to_sparse(self):
111 | masked_weight = self.weight * self.mask
112 | mask = (masked_weight != 0.).view(self.out_channels, -1)
113 | weight = masked_weight.view(self.out_channels, -1).to_sparse()
114 | return SparseConv2d(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.padding, weight,
115 | self.bias, mask)
116 |
117 | def move_data(self, device: torch.device):
118 | self.mask = self.mask.to(device)
119 |
120 | @property
121 | def num_weight(self):
122 | return torch.sum(self.mask).int().item()
123 |
--------------------------------------------------------------------------------
/mpl/nn/linear.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.sparse as sparse
4 | from ..autograd.functions import AddmmFunction
5 |
6 | __all__ = ["SparseLinear", "DenseLinear"]
7 |
8 |
9 | class SparseLinear(nn.Module):
10 | __constants__ = ['in_features', 'out_features']
11 |
12 | def __init__(self, weight: sparse.FloatTensor, bias, mask):
13 | super(SparseLinear, self).__init__()
14 | if not weight.is_sparse:
15 | raise ValueError("Weight must be sparse")
16 | elif weight._nnz() > 0 and not weight.is_coalesced():
17 | raise ValueError("Weight must be coalesced")
18 |
19 | self.in_features = weight.size(1)
20 | self.out_features = weight.size(0)
21 |
22 | # in order to add to optimizer
23 | self.weight = nn.Parameter(weight.data.clone(), requires_grad=False)
24 | self.mask = mask.clone()
25 | # Don't move after creation to make it a leaf
26 | self.dense_weight_placeholder = nn.Parameter(torch.empty(size=self.weight.size(), device=self.weight.device))
27 | self.dense_weight_placeholder.is_placeholder = True
28 |
29 | # create links
30 | self.weight.dense = self.dense_weight_placeholder
31 | self.weight.mask = self.mask
32 | self.weight.is_sparse_param = True
33 |
34 | if bias is None:
35 | self.register_parameter('bias', None)
36 | else:
37 | assert bias.size() == torch.Size((weight.size(0), 1))
38 | self.bias = nn.Parameter(bias.data.clone())
39 |
40 | def _sparse_masked_select_abs(self, sparse_tensor: sparse.FloatTensor, thr):
41 | indices = sparse_tensor._indices()
42 | values = sparse_tensor._values()
43 | prune_mask = torch.abs(values) >= thr
44 | return torch.sparse_coo_tensor(indices=indices.masked_select(prune_mask).reshape(2, -1),
45 | values=values.masked_select(prune_mask),
46 | size=[self.out_features, self.in_features]).coalesce()
47 |
48 | def prune_by_threshold(self, thr):
49 | self.weight = nn.Parameter(self._sparse_masked_select_abs(self.weight, thr))
50 |
51 | def prune_by_rank(self, rank):
52 | weight_val = self.weight._values()
53 | sorted_abs_weight = torch.sort(torch.abs(weight_val))[0]
54 | thr = sorted_abs_weight[rank]
55 | self.prune_by_threshold(thr)
56 |
57 | def prune_by_pct(self, pct):
58 | if pct == 0:
59 | return
60 | prune_idx = int(self.weight._nnz() * pct)
61 | self.prune_by_rank(prune_idx)
62 |
63 | def move_data(self, device: torch.device):
64 | self.weight = self.weight.to(device)
65 |
66 | def forward(self, inp: torch.Tensor):
67 | return AddmmFunction.apply(self.bias, self.weight, self.dense_weight_placeholder, inp.t()).t()
68 |
69 | @property
70 | def num_weight(self) -> int:
71 | return self.weight._nnz()
72 |
73 | def __repr__(self):
74 | return "SparseLinear(in_features={}, out_features={}, bias={})".format(self.in_features, self.out_features,
75 | self.bias is not None)
76 |
77 | def __str__(self):
78 | return self.__repr__()
79 |
80 |
81 | class DenseLinear(nn.Linear):
82 | def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
83 | super(DenseLinear, self).__init__(in_features, out_features, bias)
84 | self.mask = torch.ones_like(self.weight, dtype=torch.bool, device=self.weight.device)
85 |
86 | def forward(self, inp: torch.Tensor):
87 | return nn.functional.linear(inp, self.weight * self.mask, self.bias)
88 |
89 | def prune_by_threshold(self, thr):
90 | self.mask *= (self.weight.abs() >= thr)
91 |
92 | def prune_by_rank(self, rank):
93 | if rank == 0:
94 | return
95 | weight_val = self.weight[self.mask == 1.]
96 | sorted_abs_weight = weight_val.abs().sort()[0]
97 | thr = sorted_abs_weight[rank]
98 | self.prune_by_threshold(thr)
99 |
100 | def prune_by_pct(self, pct):
101 | prune_idx = int(self.num_weight * pct)
102 | self.prune_by_rank(prune_idx)
103 |
104 | def random_prune_by_pct(self, pct):
105 | prune_idx = int(self.num_weight * pct)
106 | rand = torch.rand(size=self.mask.size(), device=self.mask.device)
107 | rand_val = rand[self.mask == 1]
108 | sorted_abs_rand = rand_val.sort()[0]
109 | thr = sorted_abs_rand[prune_idx]
110 | self.mask *= (rand >= thr)
111 |
112 | @classmethod
113 | def from_linear(cls, linear_module: nn.Linear):
114 | new_linear = cls(linear_module.in_features, linear_module.out_features,
115 | bias=linear_module.bias is not None)
116 | new_linear.weight = nn.Parameter(linear_module.weight.clone())
117 | if linear_module.bias is not None:
118 | new_linear.bias = nn.Parameter(linear_module.bias.clone())
119 |
120 | return new_linear
121 |
122 | # This method will always remove zero elements, even if you wish to keep zeros in the sparse form
123 | def to_sparse(self) -> SparseLinear:
124 | sparse_bias = None if self.bias is None else self.bias.reshape((-1, 1))
125 | masked_weight = self.weight * self.mask
126 | mask = masked_weight != 0.
127 | return SparseLinear(masked_weight.to_sparse(), sparse_bias, mask)
128 |
129 | def move_data(self, device: torch.device):
130 | self.mask = self.mask.to(device)
131 |
132 | def to(self, *args, **kwargs):
133 | device = torch._C._nn._parse_to(*args, **kwargs)[0]
134 |
135 | if device is not None:
136 | self.move_data(device)
137 |
138 | return super(DenseLinear, self).to(*args, **kwargs)
139 |
140 | @property
141 | def num_weight(self) -> int:
142 | return self.mask.sum().item()
143 |
--------------------------------------------------------------------------------
/mpl/optim/__init__.py:
--------------------------------------------------------------------------------
1 | from .sgd import SGD
2 |
--------------------------------------------------------------------------------
/mpl/optim/sgd.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.optim as optim
3 |
4 |
5 | class SGD(optim.SGD):
6 | @torch.no_grad()
7 | def step(self, closure=None):
8 | """Performs a single optimization step.
9 |
10 | Arguments:
11 | closure (callable, optional): A closure that reevaluates the model
12 | and returns the loss.
13 | """
14 | loss = None
15 | if closure is not None:
16 | with torch.enable_grad():
17 | loss = closure()
18 |
19 | for group in self.param_groups:
20 | weight_decay = group['weight_decay']
21 | momentum = group['momentum']
22 | dampening = group['dampening']
23 | nesterov = group['nesterov']
24 |
25 | for p in group['params']:
26 | # exclude 1) dense param with None grad and 2) dense placeholders for sparse params, and
27 | # 3) sparse param with None grad
28 | if hasattr(p, "is_placeholder") or (
29 | p.grad is None and (not hasattr(p, "is_sparse_param") or p.dense.grad is None)):
30 | # dense placeholder
31 | continue
32 | # if p.grad is None:
33 | # if not hasattr(p, "is_sparse_param"):
34 | # # dense param with None grad
35 | # continue
36 | # elif p.dense.grad is None:
37 | # # sparse param with None grad
38 | # continue
39 |
40 | if hasattr(p, "is_sparse_param"):
41 | d_p = p.dense.grad.masked_select(p.mask)
42 | p = p._values()
43 | else:
44 | d_p = p.grad
45 |
46 | if weight_decay != 0:
47 | d_p = d_p.add(p, alpha=weight_decay)
48 | if momentum != 0:
49 | param_state = self.state[p]
50 | if 'momentum_buffer' not in param_state:
51 | buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
52 | else:
53 | buf = param_state['momentum_buffer']
54 | buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
55 | if nesterov:
56 | d_p = d_p.add(buf, alpha=momentum)
57 | else:
58 | d_p = buf
59 |
60 | p.add_(d_p, alpha=-group['lr'])
61 |
62 | return loss
63 |
--------------------------------------------------------------------------------
/mpl/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiangyuang/ModelPruningLibrary/9c8ba5a3c5d118f37768d5d42254711f48d88745/mpl/utils/__init__.py
--------------------------------------------------------------------------------
/mpl/utils/save_load.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import copyreg
3 | import io
4 | import torch
5 | import numpy as np
6 | from PIL import Image
7 |
8 | import warnings
9 |
10 | bytes_types = (bytes, bytearray)
11 |
12 |
13 | # mode 0: store indices
14 | # mode 1: store bitmap
15 | def get_mode(x) -> bool:
16 | return False if x._nnz() / x.nelement() < 1 / 32 else True
17 |
18 |
19 | def get_int_type(max_val: int):
20 | assert max_val >= 0
21 | max_uint8 = 1 << 8
22 | max_int16 = 1 << 15
23 | max_int32 = 1 << 31
24 | if max_val < max_uint8:
25 | return torch.uint8
26 | elif max_val < max_int16:
27 | return torch.int16
28 | elif max_val < max_int32:
29 | return torch.int32
30 | else:
31 | return torch.int64
32 |
33 |
34 | def sparse_coo_from_indices(indices, values, size):
35 | mask = torch.zeros(size=size, dtype=torch.bool)
36 | mask[indices.tolist()] = True
37 | tensor = torch.sparse_coo_tensor(indices.to(torch.long), values, size).coalesce()
38 | tensor.mask = mask
39 | return tensor
40 |
41 |
42 | def sparse_coo_from_values_bitmap(bitmap, values, size):
43 | mask = torch.from_numpy(np.array(bitmap, np.uint8, copy=False))
44 | indices = mask.nonzero().t()
45 | tensor = torch.sparse_coo_tensor(indices.to(torch.long), values, size).coalesce()
46 | tensor.mask = mask
47 | return tensor
48 |
49 |
50 | def rebuild_dispatcher(mode, arg0, arg1, arg2):
51 | if mode is False:
52 | return sparse_coo_from_indices(arg0, arg1, arg2)
53 | else:
54 | return sparse_coo_from_values_bitmap(arg0, arg1, arg2)
55 |
56 |
57 | def args_dispatcher(mode, x) -> tuple:
58 | # supports only 2 dimensional tensors
59 | if mode is False:
60 | int_type = get_int_type(torch.max(x._indices()).item())
61 | return mode, x._indices().to(int_type), x._values(), x.size()
62 | else:
63 | bitmap = torch.zeros(size=x.size(), dtype=torch.bool)
64 | bitmap[x._indices().tolist()] = True
65 | # print(bitmap.size(), bitmap)
66 | # print(np.uint8(bitmap.numpy()))
67 | bitmap = Image.fromarray(bitmap.numpy())
68 | assert bitmap.mode == "1"
69 | return mode, bitmap, x._values(), x.size()
70 |
71 |
72 | def reduce(x: torch.Tensor):
73 | if x.is_sparse:
74 | assert x.ndim == 2, "Only 2-dimensional tensors are supported"
75 | mode = get_mode(x)
76 | return rebuild_dispatcher, args_dispatcher(mode, x)
77 | else:
78 | return x.__reduce_ex__(pickle.DEFAULT_PROTOCOL)
79 |
80 |
81 | # register custom reduce function for sparse tensors
82 | copyreg.pickle(torch.Tensor, reduce)
83 |
84 |
85 | def dumps(obj):
86 | f = io.BytesIO()
87 | pickle.dump(obj, f)
88 | res = f.getvalue()
89 | assert isinstance(res, bytes_types)
90 | return res
91 |
92 |
93 | def loads(res):
94 | return pickle.loads(res)
95 |
96 |
97 | def save(obj, f):
98 | # disabling warnings from torch.Tensor's reduce function. See issue: https://github.com/pytorch/pytorch/issues/38597
99 | with warnings.catch_warnings():
100 | warnings.simplefilter("ignore")
101 | with open(f, "wb") as opened_f:
102 | pickle.dump(obj, opened_f)
103 |
104 |
105 | def load(f):
106 | with open(f, 'rb') as opened_f:
107 | return pickle.load(opened_f)
108 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch~=1.7.1
2 | numpy~=1.19.2
3 | Pillow~=7.2.0
4 | torchvision~=0.8.2
5 | setuptools~=41.2.0
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | from setuptools import setup, find_packages
3 | from torch.utils import cpp_extension
4 |
5 | # Set up sparse conv2d extension
6 | setup(name="sparse_conv2d",
7 | ext_modules=[cpp_extension.CppExtension("sparse_conv2d",
8 | [os.path.join("extension", "extension.cpp")],
9 | extra_compile_args=["-std=c++14", "-fopenmp"])],
10 | cmdclass={"build_ext": cpp_extension.BuildExtension})
11 |
12 | # Set up mpl (model pruning library)
13 | with open("README.md", "r", encoding="utf-8") as fh:
14 | long_description = fh.read()
15 |
16 | DEPENDENCIES = ['torch', 'torchvision']
17 |
18 | setup(name='mpl',
19 | version='0.0.1',
20 | description="Model Pruning Library",
21 | long_description=long_description,
22 | author="Yuang Jiang",
23 | author_email="yuang.jiang@yale.edu",
24 | url="https://github.com/jiangyuang/ModelPruningLibrary",
25 | packages=find_packages(),
26 | install_requires=DEPENDENCIES,
27 | )
28 |
--------------------------------------------------------------------------------