├── LICENSE ├── README.md ├── extension └── extension.cpp ├── mpl ├── __init__.py ├── autograd │ ├── __init__.py │ └── functions.py ├── models │ ├── __init__.py │ ├── alexnet.py │ ├── base_model.py │ ├── densenet.py │ ├── googlenet.py │ ├── inception.py │ ├── leaf.py │ ├── lenet.py │ ├── resnet.py │ ├── squeezenet.py │ ├── utils.py │ └── vgg.py ├── nn │ ├── __init__.py │ ├── conv2d.py │ └── linear.py ├── optim │ ├── __init__.py │ └── sgd.py └── utils │ ├── __init__.py │ └── save_load.py ├── requirements.txt └── setup.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Yuang Jiang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ModelPruningLibrary (Updated 3/3/2021) 2 | ## Plan for the Next Version 3 | We plan to further complete ModelPruningLibrary with the following: 4 | 1. c++ implementation conv2d with groups > 1 and depthwise conv2d, as well as missing models in `torchvision.models`. 5 | 2. more optimizers as in `torch.optim`. 6 | 3. well-known pruning algorithms such as SNIP [[1]](#1). 7 | 4. we also plan to implement tools for federated learning (e.g. well-known datasets for FL). 8 | 9 | Suggestions/comments are welcome! 10 | 11 | ## Description 12 | This is a PyTorch-based library that implements 13 | 1. model pruning: various magnitude-based pruning algorithms (by percentage, random pruning, etc.); 14 | 2. conv2d module with **sparse kernels** as well as fully-connected module implementations; 15 | 3. SGD optimizer designed for our sparse modules; 16 | 4. two types of save-load functionalities for sparse tensors, determined automatically according to tensor's density (fraction of non-zero entries). If density < 1/32, we save value-index pairs, and otherwise, we use bitmap to save sparse tensors. 17 | 18 | It is originally from the following paper: 19 | - Jiang, Y., Wang, S., Valls, V., Ko, B. J., Lee, W. H., Leung, K. K. & Tassiulas, L. (2019). [Model pruning enables efficient federated learning on edge devices](https://arxiv.org/pdf/1909.12326.pdf). arXiv preprint arXiv:1909.12326. 20 | 21 | When using this code for scientific publications, please kindly cite the above paper. 22 | 23 | The library consists of the following components: 24 | * **setup.py**: installs the c++ extension and `mpl` (model pruning library) module 25 | * **extension**: the `extension.cpp` c++ file extends the current PyTorch implementation with **sparse kernels** (the installed module is called `sparse_conv2d`). However, please note that we only extend PyTorch's slow, cpu version of conv2d forward/backward with no groups and dilation = 1 (see PyTorch's c++ code [here](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/ConvolutionMM2d.cpp)). In other words, we do not use acceleration packages such as MKL (which are not available on Raspberry Pis on which our paper experimented). Do not compare the speed of our implementation with the acceleration packages. 26 | * **autograd**: the `AddmmFunction` and `SparseConv2dFunction` functions provide the forward and backward functions to our customized modules. 27 | * **models**: this is similar to `torchvision`'s implementations ([link](https://github.com/pytorch/vision/tree/master/torchvision/models)). Note that we do not implement mnasnet, mobilenet and shufflenetv2 since they have groups > 1 in the models. We also implement popular models such as models in [leaf](https://github.com/TalwalkarLab/leaf/tree/master/models). 28 | * **nn**: `conv2d.py` and `linear.py` implement the prunable modules and their `to_sparse` functionalities. 29 | * **optim**: implements a compatible version of SGD optimizer. 30 | 31 | Our code has been validated on Ubuntu 20.04. Contact me if you encounter any issues! 32 | 33 | ## Examples 34 | 35 | ### Setup Library: 36 | ```shell 37 | sudo python3 setup.py install 38 | ``` 39 | 40 | 41 | 42 | ### Importing and Using Model 43 | ```python3 44 | from mpl.models import conv2 45 | 46 | model = conv2() 47 | print(model) 48 | ``` 49 | 50 | output: 51 | ``` 52 | Conv2( 53 | (features): Sequential( 54 | (0): DenseConv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) 55 | (1): ReLU(inplace=True) 56 | (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) 57 | (3): DenseConv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) 58 | (4): ReLU(inplace=True) 59 | (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) 60 | ) 61 | (classifier): Sequential( 62 | (0): DenseLinear(in_features=3136, out_features=2048, bias=True) 63 | (1): ReLU(inplace=True) 64 | (2): DenseLinear(in_features=2048, out_features=62, bias=True) 65 | ) 66 | ) 67 | ``` 68 | 69 | ### Model Pruning: 70 | ```python3 71 | import mpl.models 72 | 73 | model = mpl.models.conv2() 74 | print("Before pruning:") 75 | model.calc_num_prunable_params(display=True) 76 | 77 | print("After pruning:") 78 | model.prune_by_pct([0.1, 0, None, 0.9]) 79 | model.calc_num_prunable_params(display=True) 80 | ``` 81 | output: 82 | ``` 83 | Before pruning: 84 | Layer name: features.0. remaining/all: 832/832 = 1.0 85 | Layer name: features.3. remaining/all: 51264/51264 = 1.0 86 | Layer name: classifier.0. remaining/all: 6424576/6424576 = 1.0 87 | Layer name: classifier.2. remaining/all: 127038/127038 = 1.0 88 | Total: remaining/all: 6603710/6603710 = 1.0 89 | After pruning: 90 | Layer name: features.0. remaining/all: 752/832 = 0.9038461538461539 91 | Layer name: features.3. remaining/all: 51264/51264 = 1.0 92 | Layer name: classifier.0. remaining/all: 6424576/6424576 = 1.0 93 | Layer name: classifier.2. remaining/all: 12760/127038 = 0.10044238731718069 94 | Total: remaining/all: 6489352/6603710 = 0.9826827646883343 95 | ``` 96 | ### Dense to Sparse Conversion: 97 | ```python3 98 | from mpl.models import conv2 99 | 100 | model = conv2() 101 | print(model.to_sparse()) 102 | ``` 103 | output: 104 | ``` 105 | Conv2( 106 | (features): Sequential( 107 | (0): SparseConv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=True) 108 | (1): ReLU(inplace=True) 109 | (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) 110 | (3): SparseConv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=True) 111 | (4): ReLU(inplace=True) 112 | (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) 113 | ) 114 | (classifier): Sequential( 115 | (0): SparseLinear(in_features=3136, out_features=2048, bias=True) 116 | (1): ReLU(inplace=True) 117 | (2): SparseLinear(in_features=2048, out_features=62, bias=True) 118 | ) 119 | ) 120 | ``` 121 | Note that `DenseConv2d` and `DenseLinear` layers are converted to `SparseConv2d` and `SparseLinear` layers, respectively. 122 | 123 | ### SGD Training with a Sparse Model: 124 | ```python3 125 | from mpl.models import conv2 126 | from mpl.optim import SGD 127 | import torch 128 | 129 | inp = torch.rand(size=(10, 1, 28, 28)) 130 | model = conv2().to_sparse() 131 | optimizer = SGD(model.parameters(), lr=0.01) 132 | optimizer.zero_grad() 133 | model(inp).sum().backward() 134 | optimizer.step() 135 | ``` 136 | 137 | ### Save/Load a Tensor: 138 | ```python3 139 | from mpl.utils.save_load import save, load 140 | import torch 141 | 142 | torch.manual_seed(0) 143 | x = torch.randn(size=(1000, 1000)) 144 | mask = torch.rand_like(x) <= 0.5 145 | x = (x * mask).to_sparse() 146 | save(x, "sparse_x.pt") 147 | 148 | x_loaded = load("sparse_x.pt") 149 | ``` 150 | Using our implementation, the size of `sparse_x.pt` file is 2.1 MB, while the default `torch.save` results in a file size of 10 MB (4.8x). 151 | 152 | ## References 153 | [1] 154 | Lee, Namhoon, Thalaiyasingam Ajanthan, and Philip HS Torr. "Snip: Single-shot network pruning based on connection sensitivity." arXiv preprint arXiv:1810.02340 (2018). 155 | -------------------------------------------------------------------------------- /extension/extension.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #define Tensor torch::Tensor 8 | #define IntArrayRef at::IntArrayRef 9 | 10 | template 11 | static void unfolded2d_acc( 12 | scalar_t* finput_data, 13 | scalar_t* input_data, 14 | int64_t kH, 15 | int64_t kW, 16 | int64_t dH, 17 | int64_t dW, 18 | int64_t padH, 19 | int64_t padW, 20 | int64_t n_input_plane, 21 | int64_t input_height, 22 | int64_t input_width, 23 | int64_t output_height, 24 | int64_t output_width) { 25 | #pragma omp parallel for 26 | for (auto nip = 0; nip < n_input_plane; nip++) { 27 | int64_t kw, kh, y, x; 28 | int64_t ix, iy; 29 | for (kh = 0; kh < kH; kh++) { 30 | for (kw = 0; kw < kW; kw++) { 31 | scalar_t* src = finput_data + 32 | nip * ((size_t)kH * kW * output_height * output_width) + 33 | kh * ((size_t)kW * output_height * output_width) + 34 | kw * ((size_t)output_height * output_width); 35 | scalar_t* dst = 36 | input_data + nip * ((size_t)input_height * input_width); 37 | if (padW > 0 || padH > 0) { 38 | int64_t lpad, rpad; 39 | for (y = 0; y < output_height; y++) { 40 | iy = (int64_t)y * dH - padH + kh; 41 | if (iy < 0 || iy >= input_height) { 42 | } else { 43 | for (x = 0; x < output_width; x++) { 44 | ix = (int64_t)x * dW - padW + kw; 45 | if (ix < 0 || ix >= input_width) { 46 | } else { 47 | scalar_t* dst_slice = dst + (size_t)iy * input_width + ix; 48 | *dst_slice = *dst_slice + src[(size_t)y * output_width + x]; 49 | } 50 | } 51 | } 52 | } 53 | } else { 54 | for (y = 0; y < output_height; y++) { 55 | iy = (int64_t)y * dH + kh; 56 | ix = 0 + kw; 57 | for (x = 0; x < output_width; x++) { 58 | scalar_t* dst_slice = 59 | dst + (size_t)iy * input_width + ix + x * dW; 60 | *dst_slice = *dst_slice + src[(size_t)y * output_width + x]; 61 | } 62 | } 63 | } 64 | } 65 | } 66 | } 67 | } 68 | 69 | void unfolded2d_acc_kernel( 70 | Tensor& finput, 71 | Tensor& input, 72 | int64_t kH, 73 | int64_t kW, 74 | int64_t dH, 75 | int64_t dW, 76 | int64_t padH, 77 | int64_t padW, 78 | int64_t n_input_plane, 79 | int64_t input_height, 80 | int64_t input_width, 81 | int64_t output_height, 82 | int64_t output_width) { 83 | // This function assumes that 84 | // output_height*dH does not overflow a int64_t 85 | // output_width*dW does not overflow a int64_t 86 | 87 | auto input_data = (float*) input.data_ptr(); 88 | auto finput_data =(float*) finput.data_ptr(); 89 | 90 | unfolded2d_acc( 91 | finput_data, 92 | input_data, 93 | kH, 94 | kW, 95 | dH, 96 | dW, 97 | padH, 98 | padW, 99 | n_input_plane, 100 | input_height, 101 | input_width, 102 | output_height, 103 | output_width); 104 | } 105 | 106 | template 107 | static void unfolded2d_copy( 108 | scalar_t* input_data, 109 | scalar_t* finput_data, 110 | int64_t kH, 111 | int64_t kW, 112 | int64_t dH, 113 | int64_t dW, 114 | int64_t padH, 115 | int64_t padW, 116 | int64_t n_input_plane, 117 | int64_t input_height, 118 | int64_t input_width, 119 | int64_t output_height, 120 | int64_t output_width) { 121 | 122 | auto start = 0; 123 | auto end = (int64_t)n_input_plane * kH * kW; 124 | #pragma omp parallel for 125 | for (auto k = start; k < end; k++) { 126 | int64_t nip = k / (kH * kW); 127 | int64_t rest = k % (kH * kW); 128 | int64_t kh = rest / kW; 129 | int64_t kw = rest % kW; 130 | int64_t x, y; 131 | int64_t ix, iy; 132 | scalar_t* dst = finput_data + 133 | nip * ((size_t)kH * kW * output_height * output_width) + 134 | kh * ((size_t)kW * output_height * output_width) + 135 | kw * ((size_t)output_height * output_width); 136 | scalar_t* src = 137 | input_data + nip * ((size_t)input_height * input_width); 138 | if (padW > 0 || padH > 0) { 139 | int64_t lpad, rpad; 140 | for (y = 0; y < output_height; y++) { 141 | iy = (int64_t)y * dH - padH + kh; 142 | if (iy < 0 || iy >= input_height) { 143 | memset( 144 | dst + (size_t)y * output_width, 145 | 0, 146 | sizeof(scalar_t) * output_width); 147 | } else { 148 | if (dW == 1) { 149 | ix = 0 - padW + kw; 150 | lpad = std::max(0, padW - kw); 151 | rpad = std::max(0, padW - (kW - kw - 1)); 152 | if (output_width - rpad - lpad <= 0) { 153 | memset( 154 | dst + (size_t)y * output_width, 155 | 0, 156 | sizeof(scalar_t) * output_width); 157 | } else { 158 | if (lpad > 0) 159 | memset( 160 | dst + (size_t)y * output_width, 161 | 0, 162 | sizeof(scalar_t) * lpad); 163 | memcpy( 164 | dst + (size_t)y * output_width + lpad, 165 | src + (size_t)iy * input_width + ix + lpad, 166 | sizeof(scalar_t) * (output_width - rpad - lpad)); 167 | if (rpad > 0) 168 | memset( 169 | dst + (size_t)y * output_width + output_width - rpad, 170 | 0, 171 | sizeof(scalar_t) * rpad); 172 | } 173 | } else { 174 | for (x = 0; x < output_width; x++) { 175 | ix = (int64_t)x * dW - padW + kw; 176 | if (ix < 0 || ix >= input_width) 177 | memset( 178 | dst + (size_t)y * output_width + x, 179 | 0, 180 | sizeof(scalar_t) * 1); 181 | else 182 | memcpy( 183 | dst + (size_t)y * output_width + x, 184 | src + (size_t)iy * input_width + ix, 185 | sizeof(scalar_t) * (1)); 186 | } 187 | } 188 | } 189 | } 190 | } else { 191 | for (y = 0; y < output_height; y++) { 192 | iy = (int64_t)y * dH + kh; 193 | ix = 0 + kw; 194 | if (dW == 1) 195 | memcpy( 196 | dst + (size_t)y * output_width, 197 | src + (size_t)iy * input_width + ix, 198 | sizeof(scalar_t) * output_width); 199 | else { 200 | for (x = 0; x < output_width; x++) 201 | memcpy( 202 | dst + (size_t)y * output_width + x, 203 | src + (size_t)iy * input_width + ix + (int64_t)x * dW, 204 | sizeof(scalar_t) * (1)); 205 | } 206 | } 207 | } 208 | } 209 | } 210 | 211 | void unfolded2d_copy_kernel( 212 | Tensor& finput, 213 | Tensor& input, 214 | int64_t kH, 215 | int64_t kW, 216 | int64_t dH, 217 | int64_t dW, 218 | int64_t padH, 219 | int64_t padW, 220 | int64_t n_input_plane, 221 | int64_t input_height, 222 | int64_t input_width, 223 | int64_t output_height, 224 | int64_t output_width) { 225 | 226 | auto input_data = (float*) input.data_ptr(); 227 | auto finput_data =(float*) finput.data_ptr(); 228 | 229 | unfolded2d_copy( 230 | input_data, 231 | finput_data, 232 | kH, 233 | kW, 234 | dH, 235 | dW, 236 | padH, 237 | padW, 238 | n_input_plane, 239 | input_height, 240 | input_width, 241 | output_height, 242 | output_width); 243 | } 244 | 245 | static void slow_conv2d_update_output_frame( 246 | Tensor& input, 247 | Tensor& output, 248 | const Tensor& weight, 249 | const Tensor& bias, 250 | Tensor& finput, 251 | int64_t kernel_height, 252 | int64_t kernel_width, 253 | int64_t stride_height, 254 | int64_t stride_width, 255 | int64_t pad_height, 256 | int64_t pad_width, 257 | int64_t n_input_plane, 258 | int64_t input_height, 259 | int64_t input_width, 260 | int64_t n_output_plane, 261 | int64_t output_height, 262 | int64_t output_width) { 263 | 264 | unfolded2d_copy_kernel( 265 | finput, 266 | input, 267 | kernel_height, 268 | kernel_width, 269 | stride_height, 270 | stride_width, 271 | pad_height, 272 | pad_width, 273 | n_input_plane, 274 | input_height, 275 | input_width, 276 | output_height, 277 | output_width); 278 | 279 | 280 | auto output2d = 281 | output.reshape({n_output_plane, output_height * output_width}); 282 | if (bias.defined()) { 283 | for (int64_t i = 0; i < n_output_plane; i++) { 284 | output[i].fill_(bias[i].item()); 285 | } 286 | } else { 287 | output.zero_(); 288 | } 289 | output2d.addmm_(weight, finput, 1, 1); 290 | } 291 | 292 | std::tuple slow_conv2d_forward_out_cpu( 293 | Tensor& output, 294 | Tensor& finput, 295 | Tensor& fgrad_input, 296 | const Tensor& self, 297 | const Tensor& weight_, 298 | IntArrayRef kernel_size, 299 | const Tensor& bias, 300 | IntArrayRef stride, 301 | IntArrayRef padding) { 302 | const int64_t kernel_height = kernel_size[0]; 303 | const int64_t kernel_width = kernel_size[1]; 304 | const int64_t pad_height = padding[0]; 305 | const int64_t pad_width = padding[1]; 306 | const int64_t stride_height = stride[0]; 307 | const int64_t stride_width = stride[1]; 308 | 309 | 310 | assert(weight_.dim()==2); 311 | const Tensor weight_2d = weight_; 312 | 313 | 314 | const Tensor input = self.contiguous(); 315 | const int64_t ndim = input.dim(); 316 | const int64_t dim_planes = 1; 317 | const int64_t dim_height = 2; 318 | const int64_t dim_width = 3; 319 | 320 | const int64_t n_input_plane = input.size(dim_planes); 321 | const int64_t input_height = input.size(dim_height); 322 | const int64_t input_width = input.size(dim_width); 323 | const int64_t n_output_plane = weight_2d.size(0); 324 | const int64_t output_height = 325 | (input_height + 2 * pad_height - kernel_height) / stride_height + 1; 326 | const int64_t output_width = 327 | (input_width + 2 * pad_width - kernel_width) / stride_width + 1; 328 | 329 | const int64_t batch_size = input.size(0); 330 | 331 | 332 | finput.resize_({batch_size, 333 | n_input_plane * kernel_height * kernel_width, 334 | output_height * output_width}); 335 | output.resize_({batch_size, n_output_plane, output_height, output_width}); 336 | 337 | at::NoGradGuard no_grad; 338 | at::AutoNonVariableTypeMode non_variable_type_mode(true); 339 | 340 | #pragma omp parallel for 341 | for (int64_t t = 0; t < batch_size; t++) { 342 | Tensor input_t = input[t]; 343 | Tensor output_t = output[t]; 344 | Tensor finput_t = finput[t]; 345 | slow_conv2d_update_output_frame( 346 | input_t, 347 | output_t, 348 | weight_2d, 349 | bias, 350 | finput_t, 351 | kernel_height, 352 | kernel_width, 353 | stride_height, 354 | stride_width, 355 | pad_height, 356 | pad_width, 357 | n_input_plane, 358 | input_height, 359 | input_width, 360 | n_output_plane, 361 | output_height, 362 | output_width); 363 | } 364 | 365 | return std::tuple(output, finput, fgrad_input); 366 | } 367 | 368 | 369 | std::tuple slow_conv2d_forward_cpu( 370 | const Tensor& self, 371 | const Tensor& weight, 372 | IntArrayRef kernel_size, 373 | const Tensor& bias, 374 | IntArrayRef stride, 375 | IntArrayRef padding) { 376 | 377 | auto output = at::empty({0}, self.options()); 378 | auto finput = at::empty({0}, self.options()); 379 | auto fgrad_input = at::empty({0}, self.options()); 380 | 381 | slow_conv2d_forward_out_cpu( 382 | output, 383 | finput, 384 | fgrad_input, 385 | self, 386 | weight, 387 | kernel_size, 388 | bias, 389 | stride, 390 | padding); 391 | return std::make_tuple(output, finput, fgrad_input); 392 | } 393 | 394 | void slow_conv2d_backward_update_grad_input_frame( 395 | Tensor& grad_input, 396 | const Tensor& grad_output, 397 | const Tensor& weight, 398 | Tensor& fgrad_input, 399 | int64_t kernel_height, 400 | int64_t kernel_width, 401 | int64_t stride_height, 402 | int64_t stride_width, 403 | int64_t pad_height, 404 | int64_t pad_width) { 405 | auto grad_output_2d = grad_output.reshape( 406 | {grad_output.size(0), grad_output.size(1) * grad_output.size(2)}); 407 | fgrad_input.addmm_(weight, grad_output_2d, 0, 1); 408 | grad_input.zero_(); 409 | unfolded2d_acc_kernel( 410 | fgrad_input, 411 | grad_input, 412 | kernel_height, 413 | kernel_width, 414 | stride_height, 415 | stride_width, 416 | pad_height, 417 | pad_width, 418 | grad_input.size(0), 419 | grad_input.size(1), 420 | grad_input.size(2), 421 | grad_output.size(1), 422 | grad_output.size(2)); 423 | } 424 | 425 | void slow_conv2d_backward_out_cpu_template( 426 | Tensor& grad_input, 427 | const Tensor& grad_output_, 428 | const Tensor& input_, 429 | const Tensor& weight_, 430 | const Tensor& finput, 431 | Tensor& fgrad_input, 432 | IntArrayRef kernel_size, 433 | IntArrayRef stride, 434 | IntArrayRef padding) { 435 | const int64_t kernel_height = kernel_size[0]; 436 | const int64_t kernel_width = kernel_size[1]; 437 | const int64_t pad_height = padding[0]; 438 | const int64_t pad_width = padding[1]; 439 | const int64_t stride_height = stride[0]; 440 | const int64_t stride_width = stride[1]; 441 | 442 | assert(weight_.dim() == 2); 443 | const Tensor weight = weight_; 444 | 445 | 446 | const Tensor input = input_.contiguous(); 447 | const Tensor grad_output = grad_output_.contiguous(); 448 | grad_input.resize_as_(input); 449 | fgrad_input.resize_as_(finput); 450 | fgrad_input.zero_(); 451 | Tensor tw = weight.transpose(0, 1); 452 | if(tw.is_sparse() && !tw.is_coalesced()){ 453 | tw = tw.coalesce(); 454 | } 455 | const Tensor tweight = tw; 456 | const int64_t batch_size = input.size(0); 457 | #pragma omp parallel for 458 | for (int64_t t = 0; t < batch_size; t++) { 459 | Tensor grad_input_t = grad_input[t]; 460 | Tensor grad_output_t = grad_output[t]; 461 | Tensor fgrad_input_t = fgrad_input[t]; 462 | slow_conv2d_backward_update_grad_input_frame( 463 | grad_input_t, 464 | grad_output_t, 465 | tweight, 466 | fgrad_input_t, 467 | kernel_height, 468 | kernel_width, 469 | stride_height, 470 | stride_width, 471 | pad_height, 472 | pad_width); 473 | } 474 | } 475 | 476 | void slow_conv2d_backward_parameters_frame( 477 | Tensor& grad_weight, 478 | Tensor& grad_bias, 479 | Tensor& grad_output, 480 | const Tensor& finput) { 481 | auto grad_output_2d = grad_output.view( 482 | {grad_output.size(0), grad_output.size(1) * grad_output.size(2)}); 483 | if (grad_weight.defined()) { 484 | const Tensor tfinput = finput.transpose(0, 1); 485 | grad_weight.addmm_(grad_output_2d, tfinput); 486 | } 487 | 488 | if (grad_bias.defined()) { 489 | AT_DISPATCH_FLOATING_TYPES_AND( 490 | at::ScalarType::BFloat16, 491 | grad_output.scalar_type(), 492 | "slow_conv2d_backward_parameters", 493 | [&] { 494 | auto grad_output_2d_acc = grad_output_2d.accessor(); 495 | auto grad_bias_acc = grad_bias.accessor(); 496 | const auto sz = grad_output_2d.size(1); 497 | for (int64_t i = 0; i < grad_bias.size(0); i++) { 498 | scalar_t sum = 0; 499 | for (int64_t k = 0; k < sz; k++) { 500 | sum = sum + grad_output_2d_acc[i][k]; 501 | } 502 | grad_bias_acc[i] = grad_bias_acc[i] + sum; 503 | } 504 | }); 505 | } 506 | } 507 | 508 | static void slow_conv2d_backward_parameters_out_cpu_template( 509 | Tensor& grad_weight, 510 | Tensor& grad_bias, 511 | const Tensor& input_, 512 | const Tensor& grad_output_, 513 | const Tensor& finput, 514 | Tensor fgrad_input, 515 | IntArrayRef kernel_size, 516 | IntArrayRef stride, 517 | IntArrayRef padding) { 518 | 519 | const int64_t kernel_height = kernel_size[0]; 520 | const int64_t kernel_width = kernel_size[1]; 521 | const int64_t pad_height = padding[0]; 522 | const int64_t pad_width = padding[1]; 523 | const int64_t stride_height = stride[0]; 524 | const int64_t stride_width = stride[1]; 525 | 526 | Tensor grad_weight_2d = grad_weight; 527 | 528 | auto input = input_.contiguous(); 529 | auto grad_output = grad_output_.contiguous(); 530 | 531 | const int64_t batch_size = input.size(0); 532 | for (int64_t t = 0; t < batch_size; t++) { 533 | Tensor grad_output_t = grad_output[t]; 534 | Tensor finput_t; 535 | if (grad_weight_2d.defined()) { 536 | finput_t = finput[t]; 537 | } 538 | 539 | slow_conv2d_backward_parameters_frame( 540 | grad_weight_2d, grad_bias, grad_output_t, finput_t); 541 | } 542 | } 543 | 544 | std::tuple slow_conv2d_backward_out_cpu( 545 | Tensor& grad_input, 546 | Tensor& grad_weight, 547 | Tensor& grad_bias, 548 | const Tensor& grad_output, 549 | const Tensor& self, 550 | const Tensor& weight, 551 | IntArrayRef kernel_size, 552 | IntArrayRef stride, 553 | IntArrayRef padding, 554 | const Tensor& finput, 555 | const Tensor& fgrad_input) { 556 | if (grad_input.defined()) { 557 | slow_conv2d_backward_out_cpu_template( 558 | grad_input, 559 | grad_output, 560 | self, 561 | weight, 562 | finput, 563 | const_cast(fgrad_input), 564 | kernel_size, 565 | stride, 566 | padding); 567 | } 568 | 569 | if (grad_weight.defined()) { 570 | grad_weight.resize_(weight.sizes()); 571 | grad_weight.zero_(); 572 | } 573 | 574 | if (grad_bias.defined()) { 575 | grad_bias.resize_({grad_output.size(1)}); 576 | grad_bias.zero_(); 577 | } 578 | if (grad_weight.defined() || grad_bias.defined()) { 579 | slow_conv2d_backward_parameters_out_cpu_template( 580 | grad_weight, 581 | grad_bias, 582 | self, 583 | grad_output, 584 | finput, 585 | fgrad_input, 586 | kernel_size, 587 | stride, 588 | padding); 589 | } 590 | 591 | return std::tuple( 592 | grad_input, grad_weight, grad_bias); 593 | } 594 | 595 | std::tuple slow_conv2d_backward_cpu( 596 | const Tensor& grad_output, 597 | const Tensor& self, 598 | const Tensor& weight, 599 | IntArrayRef kernel_size, 600 | IntArrayRef stride, 601 | IntArrayRef padding, 602 | const Tensor& finput, 603 | const Tensor& fgrad_input, 604 | std::array output_mask) { 605 | Tensor grad_input; 606 | Tensor grad_weight; 607 | Tensor grad_bias; 608 | 609 | if (output_mask[0]) { 610 | grad_input = at::empty({0}, grad_output.options()); 611 | } 612 | 613 | if (output_mask[1]) { 614 | grad_weight = at::empty({0}, grad_output.options()); 615 | } 616 | 617 | if (output_mask[2]) { 618 | grad_bias = at::empty({0}, grad_output.options()); 619 | } 620 | slow_conv2d_backward_out_cpu( 621 | grad_input, 622 | grad_weight, 623 | grad_bias, 624 | grad_output, 625 | self, 626 | weight, 627 | kernel_size, 628 | stride, 629 | padding, 630 | finput, 631 | fgrad_input); 632 | return std::make_tuple(grad_input, grad_weight, grad_bias); 633 | } 634 | 635 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 636 | m.def("forward", &slow_conv2d_forward_cpu, "Conv Forward"); 637 | m.def("backward", &slow_conv2d_backward_cpu, "Conv Backward"); 638 | } 639 | 640 | 641 | -------------------------------------------------------------------------------- /mpl/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils.save_load import save, load 2 | import mpl.autograd 3 | import mpl.models 4 | import mpl.nn 5 | import mpl.optim 6 | -------------------------------------------------------------------------------- /mpl/autograd/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiangyuang/ModelPruningLibrary/9c8ba5a3c5d118f37768d5d42254711f48d88745/mpl/autograd/__init__.py -------------------------------------------------------------------------------- /mpl/autograd/functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.sparse as sparse 3 | from warnings import warn 4 | 5 | sparse_conv2d_imported = True 6 | try: 7 | import sparse_conv2d 8 | except ImportError: 9 | warn("The sparse_conv2d module is NOT imported. Using default conv2d functions for compatibility.") 10 | sparse_conv2d_imported = False 11 | 12 | 13 | class AddmmFunction(torch.autograd.Function): 14 | @staticmethod 15 | def forward(ctx, bias, weight: sparse.FloatTensor, dense_weight_placeholder, inp): 16 | if bias is None: 17 | out = sparse.mm(weight, inp) 18 | else: 19 | out = sparse.addmm(bias, weight, inp) 20 | ctx.save_for_backward(bias, weight, inp) 21 | return out 22 | 23 | @staticmethod 24 | def backward(ctx, grad_output): 25 | bias, weight, inp = ctx.saved_tensors 26 | grad_bias = grad_input = None 27 | if bias is not None: 28 | grad_bias = grad_output.sum(1).reshape((-1, 1)) 29 | grad_weight = grad_output.mm(inp.t()) 30 | if ctx.needs_input_grad[3]: 31 | grad_input = torch.mm(weight.t(), grad_output) 32 | 33 | return grad_bias, None, grad_weight, grad_input 34 | 35 | 36 | if sparse_conv2d_imported: 37 | class SparseConv2dFunction(torch.autograd.Function): 38 | @staticmethod 39 | def forward(ctx, inp, weight, dense_weight_placeholder, kernel_size, bias, stride, padding): 40 | out, f_input, fgrad_input = sparse_conv2d.forward(inp, weight, kernel_size, bias, stride, padding) 41 | ctx.save_for_backward(inp, weight, f_input, fgrad_input) 42 | ctx.kernel_size = kernel_size 43 | ctx.stride = stride 44 | ctx.padding = padding 45 | return out 46 | 47 | @staticmethod 48 | def backward(ctx, grad_output): 49 | grad_input, grad_weight, grad_bias = sparse_conv2d.backward(grad_output, 50 | ctx.saved_tensors[0], 51 | ctx.saved_tensors[1], 52 | ctx.kernel_size, 53 | ctx.stride, 54 | ctx.padding, 55 | ctx.saved_tensors[2], 56 | ctx.saved_tensors[3], 57 | (True, True, True)) 58 | return grad_input, None, grad_weight, None, grad_bias, None, None 59 | 60 | 61 | class DenseConv2dFunction(torch.autograd.Function): 62 | @staticmethod 63 | def forward(ctx, inp, weight, kernel_size, bias, stride, padding): 64 | weight2d = weight.data.reshape((weight.size(0), -1)) 65 | out, f_input, fgrad_input = sparse_conv2d.forward(inp, weight2d, kernel_size, bias, stride, padding) 66 | ctx.save_for_backward(inp, weight2d, f_input, fgrad_input, weight) 67 | ctx.kernel_size = kernel_size 68 | ctx.stride = stride 69 | ctx.padding = padding 70 | return out 71 | 72 | @staticmethod 73 | def backward(ctx, grad_output): 74 | grad_input, grad_weight2d, grad_bias = sparse_conv2d.backward(grad_output, 75 | ctx.saved_tensors[0], 76 | ctx.saved_tensors[1], 77 | ctx.kernel_size, 78 | ctx.stride, 79 | ctx.padding, 80 | ctx.saved_tensors[2], 81 | ctx.saved_tensors[3], 82 | (True, True, True)) 83 | grad_weight = grad_weight2d.reshape_as(ctx.saved_tensors[4]) 84 | return grad_input, grad_weight, None, grad_bias, None, None 85 | 86 | else: 87 | class SparseConv2dFunction(torch.autograd.Function): 88 | @staticmethod 89 | def apply(inp, weight, dense_weight_placeholder, kernel_size, bias, stride, padding): 90 | size_4d = (weight.size(0), -1, *kernel_size) 91 | with torch.no_grad(): 92 | dense_weight_placeholder.zero_() 93 | dense_weight_placeholder.add_(weight.to_dense()) 94 | return torch.nn.functional.conv2d(inp, dense_weight_placeholder.view(size_4d), bias, stride, padding) 95 | 96 | 97 | class DenseConv2dFunction(torch.autograd.Function): 98 | @staticmethod 99 | def apply(inp, weight, kernel_size, bias, stride, padding): 100 | size_4d = (weight.size(0), -1, *kernel_size) 101 | return torch.nn.functional.conv2d(inp, weight.reshape(size_4d), bias, stride, padding) 102 | -------------------------------------------------------------------------------- /mpl/models/__init__.py: -------------------------------------------------------------------------------- 1 | # networks implemented by torchvision 2 | from .alexnet import * 3 | from .resnet import * 4 | from .vgg import * 5 | from .squeezenet import * 6 | from .inception import * 7 | from .densenet import * 8 | from .googlenet import * 9 | # from .mobilenet import * 10 | # from .mnasnet import * 11 | # from .shufflenetv2 import * 12 | 13 | # additional model implementations 14 | from .lenet import * 15 | from .leaf import * 16 | -------------------------------------------------------------------------------- /mpl/models/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.models 3 | from typing import Any 4 | 5 | from .base_model import BaseModel 6 | 7 | __all__ = ['AlexNet', 'alexnet'] 8 | 9 | 10 | class AlexNet(BaseModel): 11 | def __init__(self, model: torchvision.models.AlexNet): 12 | super(AlexNet, self).__init__() 13 | self.clone_from_model(model) 14 | self.process_layers() 15 | 16 | def process_layers(self): 17 | self.collect_prunable_layers() 18 | self.convert_eligible_layers() 19 | self.collect_prunable_layers() 20 | 21 | def forward(self, x: torch.Tensor) -> torch.Tensor: 22 | x = self.features(x) 23 | x = self.avgpool(x) 24 | x = torch.flatten(x, 1) 25 | x = self.classifier(x) 26 | return x 27 | 28 | 29 | def alexnet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> AlexNet: 30 | return AlexNet(torchvision.models.alexnet(pretrained, progress, **kwargs)) 31 | -------------------------------------------------------------------------------- /mpl/models/base_model.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Union, Sized, List, Tuple 3 | from copy import deepcopy 4 | 5 | import torch 6 | from torch import nn as nn 7 | 8 | from ..nn.linear import DenseLinear 9 | from ..nn.conv2d import DenseConv2d 10 | from .utils import collect_leaf_modules, is_parameterized 11 | 12 | 13 | class BaseModel(nn.Module, ABC): 14 | def __init__(self): 15 | super(BaseModel, self).__init__() 16 | 17 | self.prunable_layers: list = [] 18 | self.prunable_layer_prefixes: list = [] 19 | 20 | def clone_from_model(self, original_model: nn.Module = None): 21 | # copying all submodules from original model 22 | for name, module in original_model._modules.items(): 23 | self.add_module(name, deepcopy(module)) 24 | 25 | def collect_prunable_layers(self) -> None: 26 | self.prunable_layers, self.prunable_layer_prefixes = self.find_layers(lambda x: is_parameterized(x)) 27 | 28 | def convert_eligible_layers(self): 29 | # changing all conv2d and linear layers to customized ones 30 | for module_name, old_module in zip(self.prunable_layer_prefixes, self.prunable_layers): 31 | if isinstance(old_module, nn.Linear): 32 | self.set_module_by_name(module_name, DenseLinear.from_linear(old_module)) 33 | elif isinstance(old_module, nn.Conv2d): 34 | self.set_module_by_name(module_name, DenseConv2d.from_conv2d(old_module)) 35 | 36 | def find_layers(self, criterion) -> Tuple[List, List]: 37 | layers, names = [], [] 38 | collect_leaf_modules(self, criterion, layers, names) 39 | return layers, names 40 | 41 | @abstractmethod 42 | def forward(self, inputs) -> torch.Tensor: 43 | pass 44 | 45 | def prune_by_threshold(self, thr_arg: Union[int, float, Sized]): 46 | prunable_layers = self.prunable_layers 47 | if isinstance(thr_arg, Sized): 48 | assert len(prunable_layers) == len(thr_arg) 49 | else: 50 | thr_arg = [thr_arg] * len(prunable_layers) 51 | for thr, layer in zip(thr_arg, prunable_layers): 52 | if thr is not None: 53 | layer.prune_by_threshold(thr) 54 | 55 | return self 56 | 57 | def prune_by_rank(self, rank_arg: Union[int, float, Sized]): 58 | prunable_layers = self.prunable_layers 59 | if isinstance(rank_arg, Sized): 60 | assert len(prunable_layers) == len(rank_arg) 61 | else: 62 | rank_arg = [rank_arg] * len(prunable_layers) 63 | for rank, layer in zip(rank_arg, prunable_layers): 64 | if rank is not None: 65 | layer.prune_by_rank(rank) 66 | 67 | return self 68 | 69 | def prune_by_pct(self, pct_arg: Union[int, float, Sized]): 70 | prunable_layers = self.prunable_layers 71 | if isinstance(pct_arg, Sized): 72 | assert len(prunable_layers) == len(pct_arg) 73 | else: 74 | pct_arg = [pct_arg] * len(prunable_layers) 75 | for pct, layer in zip(pct_arg, prunable_layers): 76 | if pct is not None: 77 | layer.prune_by_pct(pct) 78 | 79 | return self 80 | 81 | def random_prune_by_pct(self, pct_arg: Union[int, float, Sized]): 82 | prunable_layers = self.prunable_layers 83 | if isinstance(pct_arg, Sized): 84 | assert len(prunable_layers) == len(pct_arg) 85 | else: 86 | pct_arg = [pct_arg] * len(prunable_layers) 87 | for pct, layer in zip(pct_arg, prunable_layers): 88 | if pct is not None: 89 | layer.random_prune_by_pct(pct) 90 | 91 | return self 92 | 93 | def calc_num_prunable_params(self, count_bias=True, display=False): 94 | total_param_in_use = 0 95 | total_param = 0 96 | for layer, layer_prefix in zip(self.prunable_layers, self.prunable_layer_prefixes): 97 | num_bias = layer.bias.nelement() if layer.bias is not None and count_bias else 0 98 | num_weight = layer.num_weight 99 | num_params_in_use = num_weight + num_bias 100 | num_params = layer.weight.nelement() + num_bias 101 | total_param_in_use += num_params_in_use 102 | total_param += num_params 103 | 104 | if display: 105 | print("Layer name: {}. remaining/all: {}/{} = {}".format(layer_prefix, num_params_in_use, num_params, 106 | num_params_in_use / num_params)) 107 | if display: 108 | print("Total: remaining/all: {}/{} = {}".format(total_param_in_use, total_param, 109 | total_param_in_use / total_param)) 110 | return total_param_in_use, total_param 111 | 112 | def nnz(self, count_bias=True): 113 | # number of parameters in use in prunable layers 114 | return self.calc_num_prunable_params(count_bias=count_bias)[0] 115 | 116 | def nelement(self, count_bias=True): 117 | # number of all parameters in prunable layers 118 | return self.calc_num_prunable_params(count_bias=count_bias)[1] 119 | 120 | def density(self, count_bias=True): 121 | total_param_in_use, total_param = self.calc_num_prunable_params(count_bias=count_bias) 122 | return total_param_in_use / total_param 123 | 124 | def _get_module_by_list(self, module_names: List): 125 | module = self 126 | for name in module_names: 127 | module = getattr(module, name) 128 | return module 129 | 130 | def get_module_by_name(self, module_name: str): 131 | return self._get_module_by_list(module_name.split('.')) 132 | 133 | def set_module_by_name(self, module_name: str, new_module): 134 | splits = module_name.split('.') 135 | self._get_module_by_list(splits[:-1]).__setattr__(splits[-1], new_module) 136 | 137 | def get_mask_by_name(self, param_name: str): 138 | if param_name.endswith("bias"): # todo 139 | return None 140 | module = self._get_module_by_list(param_name.split('.')[:-1]) 141 | return module.mask if hasattr(module, "mask") else None 142 | 143 | @torch.no_grad() 144 | def reinit_from_model(self, final_model): 145 | assert isinstance(final_model, self.__class__) 146 | for self_layer, layer in zip(self.prunable_layers, final_model.prunable_layers): 147 | self_layer.mask = layer.mask.clone().to(self_layer.mask.device) 148 | 149 | def to_sparse(self): 150 | self_copy = deepcopy(self) 151 | for module_name, old_module in zip(self.prunable_layer_prefixes, self.prunable_layers): 152 | self_copy.set_module_by_name(module_name, old_module.to_sparse()) 153 | self.collect_prunable_layers() 154 | return self_copy 155 | 156 | def to(self, *args, **kwargs): 157 | device = torch._C._nn._parse_to(*args, **kwargs)[0] 158 | if device is not None: 159 | # move masks to device 160 | for m in self.prunable_layers: 161 | m.move_data(device) 162 | return super(BaseModel, self).to(*args, **kwargs) 163 | -------------------------------------------------------------------------------- /mpl/models/densenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import Tensor 4 | import torchvision.models 5 | from typing import Any 6 | 7 | from .base_model import BaseModel 8 | 9 | __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] 10 | 11 | 12 | class DenseNet(BaseModel): 13 | def __init__(self, model: torchvision.models.DenseNet): 14 | super(DenseNet, self).__init__() 15 | self.clone_from_model(model) 16 | self.process_layers() 17 | 18 | def collect_prunable_layers(self) -> None: 19 | """ 20 | removed transition layers from prunable layers 21 | """ 22 | super(DenseNet, self).collect_prunable_layers() 23 | keep_indices = [] 24 | for layer_idx, name in enumerate(self.prunable_layer_prefixes): 25 | if "transition" not in name: 26 | keep_indices.append(layer_idx) 27 | 28 | self.prunable_layer_prefixes = [self.prunable_layer_prefixes[idx] for idx in keep_indices] 29 | self.prunable_layers = [self.prunable_layers[idx] for idx in keep_indices] 30 | 31 | def process_layers(self): 32 | self.collect_prunable_layers() 33 | self.convert_eligible_layers() 34 | self.collect_prunable_layers() 35 | 36 | def forward(self, x: Tensor) -> Tensor: 37 | features = self.features(x) 38 | out = F.relu(features, inplace=True) 39 | out = F.adaptive_avg_pool2d(out, (1, 1)) 40 | out = torch.flatten(out, 1) 41 | out = self.classifier(out) 42 | return out 43 | 44 | 45 | def densenet121(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet: 46 | return DenseNet(torchvision.models.densenet121(pretrained, progress, **kwargs)) 47 | 48 | 49 | def densenet161(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet: 50 | return DenseNet(torchvision.models.densenet161(pretrained, progress, **kwargs)) 51 | 52 | 53 | def densenet169(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet: 54 | return DenseNet(torchvision.models.densenet169(pretrained, progress, **kwargs)) 55 | 56 | 57 | def densenet201(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet: 58 | return DenseNet(torchvision.models.densenet201(pretrained, progress, **kwargs)) 59 | -------------------------------------------------------------------------------- /mpl/models/googlenet.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import torch 3 | import torch.nn as nn 4 | from torch import Tensor 5 | import torchvision.models 6 | from torchvision.models import GoogLeNetOutputs, _GoogLeNetOutputs 7 | from typing import Optional, Tuple, Any 8 | 9 | from .base_model import BaseModel 10 | 11 | __all__ = ['GoogLeNet', 'googlenet'] 12 | 13 | 14 | class GoogLeNet(BaseModel): 15 | def __init__(self, model: torchvision.models.GoogLeNet): 16 | super(GoogLeNet, self).__init__() 17 | self.clone_from_model(model) 18 | 19 | self.aux_logits = model.aux_logits 20 | self.transform_input = model.transform_input 21 | if not hasattr(self, "aux1"): 22 | self.aux1 = model.aux1 23 | if not hasattr(self, "aux2"): 24 | self.aux1 = model.aux2 25 | 26 | self.process_layers() 27 | 28 | def process_layers(self): 29 | self.collect_prunable_layers() 30 | self.convert_eligible_layers() 31 | self.collect_prunable_layers() 32 | 33 | def _initialize_weights(self) -> None: 34 | for m in self.modules(): 35 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 36 | import scipy.stats as stats 37 | X = stats.truncnorm(-2, 2, scale=0.01) 38 | values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype) 39 | values = values.view(m.weight.size()) 40 | with torch.no_grad(): 41 | m.weight.copy_(values) 42 | elif isinstance(m, nn.BatchNorm2d): 43 | nn.init.constant_(m.weight, 1) 44 | nn.init.constant_(m.bias, 0) 45 | 46 | def _transform_input(self, x: Tensor) -> Tensor: 47 | if self.transform_input: 48 | x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 49 | x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 50 | x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 51 | x = torch.cat((x_ch0, x_ch1, x_ch2), 1) 52 | return x 53 | 54 | def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: 55 | # N x 3 x 224 x 224 56 | x = self.conv1(x) 57 | # N x 64 x 112 x 112 58 | x = self.maxpool1(x) 59 | # N x 64 x 56 x 56 60 | x = self.conv2(x) 61 | # N x 64 x 56 x 56 62 | x = self.conv3(x) 63 | # N x 192 x 56 x 56 64 | x = self.maxpool2(x) 65 | 66 | # N x 192 x 28 x 28 67 | x = self.inception3a(x) 68 | # N x 256 x 28 x 28 69 | x = self.inception3b(x) 70 | # N x 480 x 28 x 28 71 | x = self.maxpool3(x) 72 | # N x 480 x 14 x 14 73 | x = self.inception4a(x) 74 | # N x 512 x 14 x 14 75 | aux1: Optional[Tensor] = None 76 | if self.aux1 is not None: 77 | if self.training: 78 | aux1 = self.aux1(x) 79 | 80 | x = self.inception4b(x) 81 | # N x 512 x 14 x 14 82 | x = self.inception4c(x) 83 | # N x 512 x 14 x 14 84 | x = self.inception4d(x) 85 | # N x 528 x 14 x 14 86 | aux2: Optional[Tensor] = None 87 | if self.aux2 is not None: 88 | if self.training: 89 | aux2 = self.aux2(x) 90 | 91 | x = self.inception4e(x) 92 | # N x 832 x 14 x 14 93 | x = self.maxpool4(x) 94 | # N x 832 x 7 x 7 95 | x = self.inception5a(x) 96 | # N x 832 x 7 x 7 97 | x = self.inception5b(x) 98 | # N x 1024 x 7 x 7 99 | 100 | x = self.avgpool(x) 101 | # N x 1024 x 1 x 1 102 | x = torch.flatten(x, 1) 103 | # N x 1024 104 | x = self.dropout(x) 105 | x = self.fc(x) 106 | # N x 1000 (num_classes) 107 | return x, aux2, aux1 108 | 109 | @torch.jit.unused 110 | def eager_outputs(self, x: Tensor, aux2: Tensor, aux1: Optional[Tensor]) -> GoogLeNetOutputs: 111 | if self.training and self.aux_logits: 112 | return _GoogLeNetOutputs(x, aux2, aux1) 113 | else: 114 | return x # type: ignore[return-value] 115 | 116 | def forward(self, x: Tensor) -> GoogLeNetOutputs: 117 | x = self._transform_input(x) 118 | x, aux1, aux2 = self._forward(x) 119 | aux_defined = self.training and self.aux_logits 120 | if torch.jit.is_scripting(): 121 | if not aux_defined: 122 | warnings.warn("Scripted GoogleNet always returns GoogleNetOutputs Tuple") 123 | return GoogLeNetOutputs(x, aux2, aux1) 124 | else: 125 | return self.eager_outputs(x, aux2, aux1) 126 | 127 | 128 | def googlenet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> GoogLeNet: 129 | return GoogLeNet(torchvision.models.googlenet(pretrained, progress, **kwargs)) 130 | -------------------------------------------------------------------------------- /mpl/models/inception.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import torch 3 | from torch import Tensor 4 | import torchvision.models 5 | from torchvision.models.inception import InceptionOutputs 6 | from typing import Any, Optional, Tuple 7 | 8 | from .base_model import BaseModel 9 | 10 | __all__ = ['Inception3', 'inception_v3'] 11 | 12 | 13 | class Inception3(BaseModel): 14 | def __init__(self, model: torchvision.models.Inception3): 15 | super(Inception3, self).__init__() 16 | self.clone_from_model(model) 17 | 18 | self.aux_logits = model.aux_logits 19 | self.transform_input = model.transform_input 20 | 21 | self.process_layers() 22 | 23 | def process_layers(self): 24 | self.collect_prunable_layers() 25 | self.convert_eligible_layers() 26 | self.collect_prunable_layers() 27 | 28 | def _transform_input(self, x: Tensor) -> Tensor: 29 | if self.transform_input: 30 | x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 31 | x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 32 | x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 33 | x = torch.cat((x_ch0, x_ch1, x_ch2), 1) 34 | return x 35 | 36 | def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor]]: 37 | # N x 3 x 299 x 299 38 | x = self.Conv2d_1a_3x3(x) 39 | # N x 32 x 149 x 149 40 | x = self.Conv2d_2a_3x3(x) 41 | # N x 32 x 147 x 147 42 | x = self.Conv2d_2b_3x3(x) 43 | # N x 64 x 147 x 147 44 | x = self.maxpool1(x) 45 | # N x 64 x 73 x 73 46 | x = self.Conv2d_3b_1x1(x) 47 | # N x 80 x 73 x 73 48 | x = self.Conv2d_4a_3x3(x) 49 | # N x 192 x 71 x 71 50 | x = self.maxpool2(x) 51 | # N x 192 x 35 x 35 52 | x = self.Mixed_5b(x) 53 | # N x 256 x 35 x 35 54 | x = self.Mixed_5c(x) 55 | # N x 288 x 35 x 35 56 | x = self.Mixed_5d(x) 57 | # N x 288 x 35 x 35 58 | x = self.Mixed_6a(x) 59 | # N x 768 x 17 x 17 60 | x = self.Mixed_6b(x) 61 | # N x 768 x 17 x 17 62 | x = self.Mixed_6c(x) 63 | # N x 768 x 17 x 17 64 | x = self.Mixed_6d(x) 65 | # N x 768 x 17 x 17 66 | x = self.Mixed_6e(x) 67 | # N x 768 x 17 x 17 68 | aux: Optional[Tensor] = None 69 | if self.AuxLogits is not None: 70 | if self.training: 71 | aux = self.AuxLogits(x) 72 | # N x 768 x 17 x 17 73 | x = self.Mixed_7a(x) 74 | # N x 1280 x 8 x 8 75 | x = self.Mixed_7b(x) 76 | # N x 2048 x 8 x 8 77 | x = self.Mixed_7c(x) 78 | # N x 2048 x 8 x 8 79 | # Adaptive average pooling 80 | x = self.avgpool(x) 81 | # N x 2048 x 1 x 1 82 | x = self.dropout(x) 83 | # N x 2048 x 1 x 1 84 | x = torch.flatten(x, 1) 85 | # N x 2048 86 | x = self.fc(x) 87 | # N x 1000 (num_classes) 88 | return x, aux 89 | 90 | @torch.jit.unused 91 | def eager_outputs(self, x: Tensor, aux: Optional[Tensor]) -> InceptionOutputs: 92 | if self.training and self.aux_logits: 93 | return InceptionOutputs(x, aux) 94 | else: 95 | return x # type: ignore[return-value] 96 | 97 | def forward(self, x: Tensor) -> InceptionOutputs: 98 | x = self._transform_input(x) 99 | x, aux = self._forward(x) 100 | aux_defined = self.training and self.aux_logits 101 | if torch.jit.is_scripting(): 102 | if not aux_defined: 103 | warnings.warn("Scripted Inception3 always returns Inception3 Tuple") 104 | return InceptionOutputs(x, aux) 105 | else: 106 | return self.eager_outputs(x, aux) 107 | 108 | 109 | def inception_v3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> "Inception3": 110 | return Inception3(torchvision.models.inception_v3(pretrained, progress, **kwargs)) 111 | -------------------------------------------------------------------------------- /mpl/models/leaf.py: -------------------------------------------------------------------------------- 1 | from torch import nn as nn 2 | 3 | from .base_model import BaseModel 4 | from ..nn.conv2d import DenseConv2d 5 | from ..nn.linear import DenseLinear 6 | 7 | __all__ = ["Conv2", "conv2", "Conv4", "conv4"] 8 | 9 | 10 | class Conv2(BaseModel): 11 | def __init__(self): 12 | super(Conv2, self).__init__() 13 | self.features = nn.Sequential(DenseConv2d(1, 32, kernel_size=5, padding=2), # 32x28x28 14 | nn.ReLU(inplace=True), 15 | nn.MaxPool2d(2, stride=2), # 32x14x14 16 | DenseConv2d(32, 64, kernel_size=5, padding=2), # 64x14x14 17 | nn.ReLU(inplace=True), 18 | nn.MaxPool2d(2, stride=2)) # 64x7x7 19 | 20 | self.classifier = nn.Sequential(DenseLinear(64 * 7 * 7, 2048), 21 | nn.ReLU(inplace=True), 22 | DenseLinear(2048, 62)) 23 | self.collect_prunable_layers() 24 | 25 | def forward(self, inp): 26 | out = self.features(inp) 27 | out = out.view(out.size(0), -1) 28 | out = self.classifier(out) 29 | return out 30 | 31 | 32 | class Conv4(BaseModel): 33 | def __init__(self): 34 | super(Conv4, self).__init__() 35 | self.features = nn.Sequential(DenseConv2d(3, 32, kernel_size=3, padding=1), 36 | nn.BatchNorm2d(32), 37 | nn.MaxPool2d(2), 38 | DenseConv2d(32, 32, kernel_size=3, padding=1), 39 | nn.BatchNorm2d(32), 40 | nn.MaxPool2d(2), 41 | DenseConv2d(32, 32, kernel_size=3, padding=2), 42 | nn.BatchNorm2d(32), 43 | nn.MaxPool2d(2), 44 | DenseConv2d(32, 32, kernel_size=3, padding=2), 45 | nn.BatchNorm2d(32), 46 | nn.MaxPool2d(2)) 47 | 48 | self.classifier = DenseLinear(in_features=32 * 6 * 6, out_features=2) 49 | 50 | def forward(self, inp): 51 | out = self.features(inp) 52 | out = out.view(out.size(0), -1) 53 | out = self.classifier(out) 54 | return out 55 | 56 | 57 | def conv2() -> Conv2: 58 | return Conv2() 59 | 60 | 61 | def conv4() -> Conv4: 62 | return Conv4() 63 | 64 | # TODO: define pretrain etc. 65 | -------------------------------------------------------------------------------- /mpl/models/lenet.py: -------------------------------------------------------------------------------- 1 | from torch import nn as nn 2 | from .base_model import BaseModel 3 | 4 | from ..nn.linear import DenseLinear 5 | 6 | __all__ = ["LeNet5", "lenet5"] 7 | 8 | 9 | class LeNet5(BaseModel): 10 | def __init__(self): 11 | super(LeNet5, self).__init__() 12 | self.classifier = nn.Sequential(DenseLinear(784, 300), 13 | nn.ReLU(inplace=True), 14 | DenseLinear(300, 100), 15 | nn.ReLU(inplace=True), 16 | DenseLinear(100, 10)) 17 | 18 | self.collect_prunable_layers() 19 | 20 | def forward(self, inputs): 21 | return self.classifier(inputs) 22 | 23 | 24 | def lenet5() -> LeNet5: 25 | return LeNet5() 26 | -------------------------------------------------------------------------------- /mpl/models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch import nn 4 | import torchvision.models 5 | from torchvision.models.resnet import conv1x1, BasicBlock, Bottleneck 6 | from typing import Type, Any, Union 7 | 8 | from .base_model import BaseModel 9 | 10 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 11 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 12 | 'wide_resnet50_2', 'wide_resnet101_2'] 13 | 14 | 15 | class ResNet(BaseModel): 16 | def __init__(self, model: torchvision.models.ResNet): 17 | super(ResNet, self).__init__() 18 | self.clone_from_model(model) 19 | self.process_layers() 20 | 21 | def process_layers(self): 22 | self.collect_prunable_layers() 23 | self.convert_eligible_layers() 24 | self.collect_prunable_layers() 25 | 26 | def collect_prunable_layers(self) -> None: 27 | """ 28 | removed transition layers from prunable layers 29 | """ 30 | super(ResNet, self).collect_prunable_layers() 31 | keep_indices = [] 32 | for layer_idx, name in enumerate(self.prunable_layer_prefixes): 33 | if "downsample" not in name: 34 | keep_indices.append(layer_idx) 35 | 36 | self.prunable_layer_prefixes = [self.prunable_layer_prefixes[idx] for idx in keep_indices] 37 | self.prunable_layers = [self.prunable_layers[idx] for idx in keep_indices] 38 | 39 | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, 40 | stride: int = 1, dilate: bool = False) -> nn.Sequential: 41 | norm_layer = self._norm_layer 42 | downsample = None 43 | previous_dilation = self.dilation 44 | if dilate: 45 | self.dilation *= stride 46 | stride = 1 47 | if stride != 1 or self.inplanes != planes * block.expansion: 48 | downsample = nn.Sequential( 49 | conv1x1(self.inplanes, planes * block.expansion, stride), 50 | norm_layer(planes * block.expansion), 51 | ) 52 | 53 | layers = [] 54 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 55 | self.base_width, previous_dilation, norm_layer)) 56 | self.inplanes = planes * block.expansion 57 | for _ in range(1, blocks): 58 | layers.append(block(self.inplanes, planes, groups=self.groups, 59 | base_width=self.base_width, dilation=self.dilation, 60 | norm_layer=norm_layer)) 61 | 62 | return nn.Sequential(*layers) 63 | 64 | def _forward_impl(self, x: Tensor) -> Tensor: 65 | # See note [TorchScript super()] 66 | x = self.conv1(x) 67 | x = self.bn1(x) 68 | x = self.relu(x) 69 | x = self.maxpool(x) 70 | 71 | x = self.layer1(x) 72 | x = self.layer2(x) 73 | x = self.layer3(x) 74 | x = self.layer4(x) 75 | 76 | x = self.avgpool(x) 77 | x = torch.flatten(x, 1) 78 | x = self.fc(x) 79 | 80 | return x 81 | 82 | def forward(self, x: Tensor) -> Tensor: 83 | return self._forward_impl(x) 84 | 85 | 86 | def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 87 | return ResNet(torchvision.models.resnet18(pretrained, progress, **kwargs)) 88 | 89 | 90 | def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 91 | return ResNet(torchvision.models.resnet34(pretrained, progress, **kwargs)) 92 | 93 | 94 | def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 95 | return ResNet(torchvision.models.resnet50(pretrained, progress, **kwargs)) 96 | 97 | 98 | def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 99 | return ResNet(torchvision.models.resnet101(pretrained, progress, **kwargs)) 100 | 101 | 102 | def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 103 | return ResNet(torchvision.models.resnet152(pretrained, progress, **kwargs)) 104 | 105 | 106 | def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 107 | return ResNet(torchvision.models.resnext50_32x4d(pretrained, progress, **kwargs)) 108 | 109 | 110 | def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 111 | return ResNet(torchvision.models.resnext101_32x8d(pretrained, progress, **kwargs)) 112 | 113 | 114 | def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 115 | return ResNet(torchvision.models.wide_resnet50_2(pretrained, progress, **kwargs)) 116 | 117 | 118 | def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 119 | return ResNet(torchvision.models.wide_resnet101_2(pretrained, progress, **kwargs)) 120 | -------------------------------------------------------------------------------- /mpl/models/squeezenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.models 3 | from typing import Any 4 | 5 | from .base_model import BaseModel 6 | 7 | 8 | class SqueezeNet(BaseModel): 9 | def __init__(self, model: torchvision.models.SqueezeNet): 10 | super(SqueezeNet, self).__init__() 11 | self.clone_from_model(model) 12 | self.process_layers() 13 | 14 | def process_layers(self): 15 | self.collect_prunable_layers() 16 | self.convert_eligible_layers() 17 | self.collect_prunable_layers() 18 | 19 | def forward(self, x: torch.Tensor) -> torch.Tensor: 20 | x = self.features(x) 21 | x = self.classifier(x) 22 | return torch.flatten(x, 1) 23 | 24 | 25 | def squeezenet1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SqueezeNet: 26 | return SqueezeNet(torchvision.models.squeezenet1_0(pretrained, progress, **kwargs)) 27 | 28 | 29 | def squeezenet1_1(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SqueezeNet: 30 | return SqueezeNet(torchvision.models.squeezenet1_1(pretrained, progress, **kwargs)) 31 | -------------------------------------------------------------------------------- /mpl/models/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from ..nn.conv2d import DenseConv2d, SparseConv2d 3 | from ..nn.linear import DenseLinear, SparseLinear 4 | 5 | from typing import Callable 6 | 7 | 8 | def is_prunable_fc(layer): 9 | return isinstance(layer, DenseLinear) or isinstance(layer, SparseLinear) 10 | 11 | 12 | def is_prunable_conv(layer): 13 | return isinstance(layer, DenseConv2d) or isinstance(layer, SparseConv2d) 14 | 15 | 16 | def is_prunable(layer): 17 | return is_prunable_fc(layer) or is_prunable_conv(layer) 18 | 19 | 20 | def is_parameterized(layer): 21 | return is_prunable(layer) or isinstance(layer, nn.Linear) or isinstance(layer, nn.Conv2d) 22 | 23 | 24 | def collect_leaf_modules(module, criterion: Callable, layers: list, names: list, prefix: str = ""): 25 | for key, submodule in module._modules.items(): 26 | new_prefix = prefix 27 | if prefix != "": 28 | new_prefix += '.' 29 | new_prefix += key 30 | # is leaf and satisfies criterion 31 | if submodule is not None: 32 | if len(submodule._modules.keys()) == 0 and criterion(submodule): 33 | layers.append(submodule) 34 | names.append(new_prefix) 35 | collect_leaf_modules(submodule, criterion, layers, names, prefix=new_prefix) 36 | -------------------------------------------------------------------------------- /mpl/models/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models 4 | from typing import Any 5 | 6 | from .base_model import BaseModel 7 | 8 | __all__ = [ 9 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 10 | 'vgg19_bn', 'vgg19', 11 | ] 12 | 13 | 14 | class VGG(BaseModel): 15 | def __init__(self, model: torchvision.models.VGG): 16 | super(VGG, self).__init__() 17 | self.clone_from_model(model) 18 | self.process_layers() 19 | 20 | def process_layers(self): 21 | self.collect_prunable_layers() 22 | self.convert_eligible_layers() 23 | self.collect_prunable_layers() 24 | 25 | def forward(self, x): 26 | x = self.features(x) 27 | x = self.avgpool(x) 28 | x = torch.flatten(x, 1) 29 | x = self.classifier(x) 30 | return x 31 | 32 | 33 | def vgg11(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 34 | return VGG(torchvision.models.vgg11(pretrained, progress, **kwargs)) 35 | 36 | 37 | def vgg11_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 38 | return VGG(torchvision.models.vgg11_bn(pretrained, progress, **kwargs)) 39 | 40 | 41 | def vgg13(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 42 | return VGG(torchvision.models.vgg13(pretrained, progress, **kwargs)) 43 | 44 | 45 | def vgg13_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 46 | return VGG(torchvision.models.vgg13_bn(pretrained, progress, **kwargs)) 47 | 48 | 49 | def vgg16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 50 | return VGG(torchvision.models.vgg16(pretrained, progress, **kwargs)) 51 | 52 | 53 | def vgg16_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 54 | return VGG(torchvision.models.vgg16_bn(pretrained, progress, **kwargs)) 55 | 56 | 57 | def vgg19(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 58 | return VGG(torchvision.models.vgg19(pretrained, progress, **kwargs)) 59 | 60 | 61 | def vgg19_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 62 | return VGG(torchvision.models.vgg19_bn(pretrained, progress, **kwargs)) 63 | -------------------------------------------------------------------------------- /mpl/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .conv2d import * 2 | from .linear import * 3 | -------------------------------------------------------------------------------- /mpl/nn/conv2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.modules.utils import _pair 4 | from ..autograd.functions import SparseConv2dFunction 5 | 6 | from typing import Union, Tuple 7 | 8 | __all__ = ["SparseConv2d", "DenseConv2d"] 9 | 10 | 11 | class SparseConv2d(nn.Module): 12 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, weight, bias, mask): 13 | super(SparseConv2d, self).__init__() 14 | kernel_size = _pair(kernel_size) 15 | stride = _pair(stride) 16 | padding = _pair(padding) 17 | self.in_channels = in_channels 18 | self.out_channels = out_channels 19 | self.kernel_size = kernel_size 20 | self.stride = stride 21 | self.padding = padding 22 | 23 | self.weight = nn.Parameter(weight.clone(), requires_grad=False) 24 | self.mask = mask.clone() 25 | self.dense_weight_placeholder = nn.Parameter(torch.empty(size=self.weight.size())) 26 | self.dense_weight_placeholder.is_placeholder = True 27 | 28 | self.weight.dense = self.dense_weight_placeholder 29 | self.weight.mask = self.mask 30 | self.weight.is_sparse_param = True 31 | 32 | if bias is None: 33 | self.bias = torch.zeros(size=(out_channels,)) 34 | else: 35 | self.bias = nn.Parameter(bias.clone()) 36 | 37 | def forward(self, inp): 38 | return SparseConv2dFunction.apply(inp, self.weight, self.dense_weight_placeholder, self.kernel_size, 39 | self.bias, self.stride, self.padding) 40 | 41 | def __repr__(self): 42 | return "SparseConv2d({}, {}, kernel_size={}, " \ 43 | "stride={}, padding={}, bias={})".format(self.in_channels, 44 | self.out_channels, 45 | self.kernel_size, 46 | self.stride, 47 | self.padding, 48 | not torch.equal(self.bias, torch.zeros_like(self.bias))) 49 | 50 | def __str__(self): 51 | return self.__repr__() 52 | 53 | 54 | class DenseConv2d(nn.Conv2d): 55 | def __init__(self, in_channels, out_channels, kernel_size, stride: Union[int, Tuple] = 1, 56 | padding: Union[int, Tuple] = 0, 57 | dilation: Union[int, Tuple] = 1, groups=1, bias=True, 58 | padding_mode='zeros'): 59 | max_dilation = dilation if isinstance(dilation, int) else max(dilation) 60 | if max_dilation > 1: 61 | raise NotImplementedError("Dilation > 1 not implemented") 62 | if groups > 1: 63 | raise NotImplementedError("Groups > 1 not implemented") 64 | super(DenseConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, 65 | dilation, groups, bias, padding_mode) 66 | self.mask = torch.ones_like(self.weight, dtype=torch.bool, device=self.weight.device) 67 | 68 | def forward(self, inp): 69 | return self._conv_forward(inp, self.weight * self.mask) 70 | 71 | def prune_by_threshold(self, thr): 72 | self.mask *= (torch.abs(self.weight) >= thr) 73 | 74 | def prune_by_rank(self, rank): 75 | if rank == 0: 76 | return 77 | weights_val = self.weight[self.mask == 1] 78 | sorted_abs_weights = torch.sort(torch.abs(weights_val))[0] 79 | thr = sorted_abs_weights[rank] 80 | self.prune_by_threshold(thr) 81 | 82 | def prune_by_pct(self, pct): 83 | if pct == 0: 84 | return 85 | prune_idx = int(self.num_weight * pct) 86 | self.prune_by_rank(prune_idx) 87 | 88 | def random_prune_by_pct(self, pct): 89 | prune_idx = int(self.num_weight * pct) 90 | rand = torch.rand_like(self.mask, device=self.mask.device) 91 | rand_val = rand[self.mask == 1] 92 | sorted_abs_rand = torch.sort(rand_val)[0] 93 | thr = sorted_abs_rand[prune_idx] 94 | self.mask *= (rand >= thr) 95 | 96 | @classmethod 97 | def from_conv2d(cls, conv2d_module: nn.Conv2d): 98 | new_conv2d = cls(conv2d_module.in_channels, conv2d_module.out_channels, conv2d_module.kernel_size, 99 | conv2d_module.stride, conv2d_module.padding, conv2d_module.dilation, conv2d_module.groups, 100 | bias=conv2d_module.bias is not None, 101 | padding_mode=conv2d_module.padding_mode) 102 | 103 | new_conv2d.weight = nn.Parameter(conv2d_module.weight.clone()) 104 | if conv2d_module.bias is not None: 105 | new_conv2d.bias = nn.Parameter(conv2d_module.bias.clone()) 106 | 107 | return new_conv2d 108 | 109 | # This method will always remove zero elements, even if you wish to keep zeros in the sparse form 110 | def to_sparse(self): 111 | masked_weight = self.weight * self.mask 112 | mask = (masked_weight != 0.).view(self.out_channels, -1) 113 | weight = masked_weight.view(self.out_channels, -1).to_sparse() 114 | return SparseConv2d(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.padding, weight, 115 | self.bias, mask) 116 | 117 | def move_data(self, device: torch.device): 118 | self.mask = self.mask.to(device) 119 | 120 | @property 121 | def num_weight(self): 122 | return torch.sum(self.mask).int().item() 123 | -------------------------------------------------------------------------------- /mpl/nn/linear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.sparse as sparse 4 | from ..autograd.functions import AddmmFunction 5 | 6 | __all__ = ["SparseLinear", "DenseLinear"] 7 | 8 | 9 | class SparseLinear(nn.Module): 10 | __constants__ = ['in_features', 'out_features'] 11 | 12 | def __init__(self, weight: sparse.FloatTensor, bias, mask): 13 | super(SparseLinear, self).__init__() 14 | if not weight.is_sparse: 15 | raise ValueError("Weight must be sparse") 16 | elif weight._nnz() > 0 and not weight.is_coalesced(): 17 | raise ValueError("Weight must be coalesced") 18 | 19 | self.in_features = weight.size(1) 20 | self.out_features = weight.size(0) 21 | 22 | # in order to add to optimizer 23 | self.weight = nn.Parameter(weight.data.clone(), requires_grad=False) 24 | self.mask = mask.clone() 25 | # Don't move after creation to make it a leaf 26 | self.dense_weight_placeholder = nn.Parameter(torch.empty(size=self.weight.size(), device=self.weight.device)) 27 | self.dense_weight_placeholder.is_placeholder = True 28 | 29 | # create links 30 | self.weight.dense = self.dense_weight_placeholder 31 | self.weight.mask = self.mask 32 | self.weight.is_sparse_param = True 33 | 34 | if bias is None: 35 | self.register_parameter('bias', None) 36 | else: 37 | assert bias.size() == torch.Size((weight.size(0), 1)) 38 | self.bias = nn.Parameter(bias.data.clone()) 39 | 40 | def _sparse_masked_select_abs(self, sparse_tensor: sparse.FloatTensor, thr): 41 | indices = sparse_tensor._indices() 42 | values = sparse_tensor._values() 43 | prune_mask = torch.abs(values) >= thr 44 | return torch.sparse_coo_tensor(indices=indices.masked_select(prune_mask).reshape(2, -1), 45 | values=values.masked_select(prune_mask), 46 | size=[self.out_features, self.in_features]).coalesce() 47 | 48 | def prune_by_threshold(self, thr): 49 | self.weight = nn.Parameter(self._sparse_masked_select_abs(self.weight, thr)) 50 | 51 | def prune_by_rank(self, rank): 52 | weight_val = self.weight._values() 53 | sorted_abs_weight = torch.sort(torch.abs(weight_val))[0] 54 | thr = sorted_abs_weight[rank] 55 | self.prune_by_threshold(thr) 56 | 57 | def prune_by_pct(self, pct): 58 | if pct == 0: 59 | return 60 | prune_idx = int(self.weight._nnz() * pct) 61 | self.prune_by_rank(prune_idx) 62 | 63 | def move_data(self, device: torch.device): 64 | self.weight = self.weight.to(device) 65 | 66 | def forward(self, inp: torch.Tensor): 67 | return AddmmFunction.apply(self.bias, self.weight, self.dense_weight_placeholder, inp.t()).t() 68 | 69 | @property 70 | def num_weight(self) -> int: 71 | return self.weight._nnz() 72 | 73 | def __repr__(self): 74 | return "SparseLinear(in_features={}, out_features={}, bias={})".format(self.in_features, self.out_features, 75 | self.bias is not None) 76 | 77 | def __str__(self): 78 | return self.__repr__() 79 | 80 | 81 | class DenseLinear(nn.Linear): 82 | def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None: 83 | super(DenseLinear, self).__init__(in_features, out_features, bias) 84 | self.mask = torch.ones_like(self.weight, dtype=torch.bool, device=self.weight.device) 85 | 86 | def forward(self, inp: torch.Tensor): 87 | return nn.functional.linear(inp, self.weight * self.mask, self.bias) 88 | 89 | def prune_by_threshold(self, thr): 90 | self.mask *= (self.weight.abs() >= thr) 91 | 92 | def prune_by_rank(self, rank): 93 | if rank == 0: 94 | return 95 | weight_val = self.weight[self.mask == 1.] 96 | sorted_abs_weight = weight_val.abs().sort()[0] 97 | thr = sorted_abs_weight[rank] 98 | self.prune_by_threshold(thr) 99 | 100 | def prune_by_pct(self, pct): 101 | prune_idx = int(self.num_weight * pct) 102 | self.prune_by_rank(prune_idx) 103 | 104 | def random_prune_by_pct(self, pct): 105 | prune_idx = int(self.num_weight * pct) 106 | rand = torch.rand(size=self.mask.size(), device=self.mask.device) 107 | rand_val = rand[self.mask == 1] 108 | sorted_abs_rand = rand_val.sort()[0] 109 | thr = sorted_abs_rand[prune_idx] 110 | self.mask *= (rand >= thr) 111 | 112 | @classmethod 113 | def from_linear(cls, linear_module: nn.Linear): 114 | new_linear = cls(linear_module.in_features, linear_module.out_features, 115 | bias=linear_module.bias is not None) 116 | new_linear.weight = nn.Parameter(linear_module.weight.clone()) 117 | if linear_module.bias is not None: 118 | new_linear.bias = nn.Parameter(linear_module.bias.clone()) 119 | 120 | return new_linear 121 | 122 | # This method will always remove zero elements, even if you wish to keep zeros in the sparse form 123 | def to_sparse(self) -> SparseLinear: 124 | sparse_bias = None if self.bias is None else self.bias.reshape((-1, 1)) 125 | masked_weight = self.weight * self.mask 126 | mask = masked_weight != 0. 127 | return SparseLinear(masked_weight.to_sparse(), sparse_bias, mask) 128 | 129 | def move_data(self, device: torch.device): 130 | self.mask = self.mask.to(device) 131 | 132 | def to(self, *args, **kwargs): 133 | device = torch._C._nn._parse_to(*args, **kwargs)[0] 134 | 135 | if device is not None: 136 | self.move_data(device) 137 | 138 | return super(DenseLinear, self).to(*args, **kwargs) 139 | 140 | @property 141 | def num_weight(self) -> int: 142 | return self.mask.sum().item() 143 | -------------------------------------------------------------------------------- /mpl/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .sgd import SGD 2 | -------------------------------------------------------------------------------- /mpl/optim/sgd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | 4 | 5 | class SGD(optim.SGD): 6 | @torch.no_grad() 7 | def step(self, closure=None): 8 | """Performs a single optimization step. 9 | 10 | Arguments: 11 | closure (callable, optional): A closure that reevaluates the model 12 | and returns the loss. 13 | """ 14 | loss = None 15 | if closure is not None: 16 | with torch.enable_grad(): 17 | loss = closure() 18 | 19 | for group in self.param_groups: 20 | weight_decay = group['weight_decay'] 21 | momentum = group['momentum'] 22 | dampening = group['dampening'] 23 | nesterov = group['nesterov'] 24 | 25 | for p in group['params']: 26 | # exclude 1) dense param with None grad and 2) dense placeholders for sparse params, and 27 | # 3) sparse param with None grad 28 | if hasattr(p, "is_placeholder") or ( 29 | p.grad is None and (not hasattr(p, "is_sparse_param") or p.dense.grad is None)): 30 | # dense placeholder 31 | continue 32 | # if p.grad is None: 33 | # if not hasattr(p, "is_sparse_param"): 34 | # # dense param with None grad 35 | # continue 36 | # elif p.dense.grad is None: 37 | # # sparse param with None grad 38 | # continue 39 | 40 | if hasattr(p, "is_sparse_param"): 41 | d_p = p.dense.grad.masked_select(p.mask) 42 | p = p._values() 43 | else: 44 | d_p = p.grad 45 | 46 | if weight_decay != 0: 47 | d_p = d_p.add(p, alpha=weight_decay) 48 | if momentum != 0: 49 | param_state = self.state[p] 50 | if 'momentum_buffer' not in param_state: 51 | buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() 52 | else: 53 | buf = param_state['momentum_buffer'] 54 | buf.mul_(momentum).add_(d_p, alpha=1 - dampening) 55 | if nesterov: 56 | d_p = d_p.add(buf, alpha=momentum) 57 | else: 58 | d_p = buf 59 | 60 | p.add_(d_p, alpha=-group['lr']) 61 | 62 | return loss 63 | -------------------------------------------------------------------------------- /mpl/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiangyuang/ModelPruningLibrary/9c8ba5a3c5d118f37768d5d42254711f48d88745/mpl/utils/__init__.py -------------------------------------------------------------------------------- /mpl/utils/save_load.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import copyreg 3 | import io 4 | import torch 5 | import numpy as np 6 | from PIL import Image 7 | 8 | import warnings 9 | 10 | bytes_types = (bytes, bytearray) 11 | 12 | 13 | # mode 0: store indices 14 | # mode 1: store bitmap 15 | def get_mode(x) -> bool: 16 | return False if x._nnz() / x.nelement() < 1 / 32 else True 17 | 18 | 19 | def get_int_type(max_val: int): 20 | assert max_val >= 0 21 | max_uint8 = 1 << 8 22 | max_int16 = 1 << 15 23 | max_int32 = 1 << 31 24 | if max_val < max_uint8: 25 | return torch.uint8 26 | elif max_val < max_int16: 27 | return torch.int16 28 | elif max_val < max_int32: 29 | return torch.int32 30 | else: 31 | return torch.int64 32 | 33 | 34 | def sparse_coo_from_indices(indices, values, size): 35 | mask = torch.zeros(size=size, dtype=torch.bool) 36 | mask[indices.tolist()] = True 37 | tensor = torch.sparse_coo_tensor(indices.to(torch.long), values, size).coalesce() 38 | tensor.mask = mask 39 | return tensor 40 | 41 | 42 | def sparse_coo_from_values_bitmap(bitmap, values, size): 43 | mask = torch.from_numpy(np.array(bitmap, np.uint8, copy=False)) 44 | indices = mask.nonzero().t() 45 | tensor = torch.sparse_coo_tensor(indices.to(torch.long), values, size).coalesce() 46 | tensor.mask = mask 47 | return tensor 48 | 49 | 50 | def rebuild_dispatcher(mode, arg0, arg1, arg2): 51 | if mode is False: 52 | return sparse_coo_from_indices(arg0, arg1, arg2) 53 | else: 54 | return sparse_coo_from_values_bitmap(arg0, arg1, arg2) 55 | 56 | 57 | def args_dispatcher(mode, x) -> tuple: 58 | # supports only 2 dimensional tensors 59 | if mode is False: 60 | int_type = get_int_type(torch.max(x._indices()).item()) 61 | return mode, x._indices().to(int_type), x._values(), x.size() 62 | else: 63 | bitmap = torch.zeros(size=x.size(), dtype=torch.bool) 64 | bitmap[x._indices().tolist()] = True 65 | # print(bitmap.size(), bitmap) 66 | # print(np.uint8(bitmap.numpy())) 67 | bitmap = Image.fromarray(bitmap.numpy()) 68 | assert bitmap.mode == "1" 69 | return mode, bitmap, x._values(), x.size() 70 | 71 | 72 | def reduce(x: torch.Tensor): 73 | if x.is_sparse: 74 | assert x.ndim == 2, "Only 2-dimensional tensors are supported" 75 | mode = get_mode(x) 76 | return rebuild_dispatcher, args_dispatcher(mode, x) 77 | else: 78 | return x.__reduce_ex__(pickle.DEFAULT_PROTOCOL) 79 | 80 | 81 | # register custom reduce function for sparse tensors 82 | copyreg.pickle(torch.Tensor, reduce) 83 | 84 | 85 | def dumps(obj): 86 | f = io.BytesIO() 87 | pickle.dump(obj, f) 88 | res = f.getvalue() 89 | assert isinstance(res, bytes_types) 90 | return res 91 | 92 | 93 | def loads(res): 94 | return pickle.loads(res) 95 | 96 | 97 | def save(obj, f): 98 | # disabling warnings from torch.Tensor's reduce function. See issue: https://github.com/pytorch/pytorch/issues/38597 99 | with warnings.catch_warnings(): 100 | warnings.simplefilter("ignore") 101 | with open(f, "wb") as opened_f: 102 | pickle.dump(obj, opened_f) 103 | 104 | 105 | def load(f): 106 | with open(f, 'rb') as opened_f: 107 | return pickle.load(opened_f) 108 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch~=1.7.1 2 | numpy~=1.19.2 3 | Pillow~=7.2.0 4 | torchvision~=0.8.2 5 | setuptools~=41.2.0 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup, find_packages 3 | from torch.utils import cpp_extension 4 | 5 | # Set up sparse conv2d extension 6 | setup(name="sparse_conv2d", 7 | ext_modules=[cpp_extension.CppExtension("sparse_conv2d", 8 | [os.path.join("extension", "extension.cpp")], 9 | extra_compile_args=["-std=c++14", "-fopenmp"])], 10 | cmdclass={"build_ext": cpp_extension.BuildExtension}) 11 | 12 | # Set up mpl (model pruning library) 13 | with open("README.md", "r", encoding="utf-8") as fh: 14 | long_description = fh.read() 15 | 16 | DEPENDENCIES = ['torch', 'torchvision'] 17 | 18 | setup(name='mpl', 19 | version='0.0.1', 20 | description="Model Pruning Library", 21 | long_description=long_description, 22 | author="Yuang Jiang", 23 | author_email="yuang.jiang@yale.edu", 24 | url="https://github.com/jiangyuang/ModelPruningLibrary", 25 | packages=find_packages(), 26 | install_requires=DEPENDENCIES, 27 | ) 28 | --------------------------------------------------------------------------------