├── __init__.py ├── mnist_pretrained_lenet5.pkl ├── LICENSE ├── models.py ├── README.md ├── pt_models.py ├── utils.py ├── proj_utils.py ├── sparsity_proj_train.py ├── energy_proj_train.py └── sa_energy_model.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mnist_pretrained_lenet5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyang1990/model_based_energy_constrained_compression/HEAD/mnist_pretrained_lenet5.pkl -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Haichuan Yang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | 4 | import torch 5 | 6 | from sa_energy_model import FixHWConv2d, conv2d_out_dim, SparseConv2d 7 | import torch.nn as nn 8 | from pt_models import myalexnet 9 | from pt_models import mysqueezenet1_0 10 | from torchvision.models import alexnet, squeezenet1_0 11 | 12 | 13 | class MyLeNet5(nn.Module): 14 | def __init__(self, conv_class=FixHWConv2d): 15 | super(MyLeNet5, self).__init__() 16 | h = 32 17 | w = 32 18 | feature_layers = [] 19 | # conv 20 | feature_layers.append(conv_class(h, w, 1, 6, kernel_size=5)) 21 | h = conv2d_out_dim(h, kernel_size=5) 22 | w = conv2d_out_dim(w, kernel_size=5) 23 | feature_layers.append(nn.ReLU(inplace=True)) 24 | # pooling 25 | feature_layers.append(nn.MaxPool2d(kernel_size=2, stride=2)) 26 | h = conv2d_out_dim(h, kernel_size=2, stride=2) 27 | w = conv2d_out_dim(w, kernel_size=2, stride=2) 28 | # conv 29 | feature_layers.append(conv_class(h, w, 6, 16, kernel_size=5)) 30 | h = conv2d_out_dim(h, kernel_size=5) 31 | w = conv2d_out_dim(w, kernel_size=5) 32 | feature_layers.append(nn.ReLU(inplace=True)) 33 | # pooling 34 | feature_layers.append(nn.MaxPool2d(kernel_size=2, stride=2)) 35 | h = conv2d_out_dim(h, kernel_size=2, stride=2) 36 | w = conv2d_out_dim(w, kernel_size=2, stride=2) 37 | 38 | self.features = nn.Sequential(*feature_layers) 39 | 40 | self.classifier = nn.Sequential( 41 | nn.Linear(16 * 5 * 5, 120), 42 | nn.ReLU(inplace=True), 43 | nn.Linear(120, 84), 44 | nn.ReLU(inplace=True), 45 | nn.Linear(84, 10), 46 | ) 47 | 48 | def forward(self, x): 49 | x = self.features(x) 50 | x = x.view(x.size(0), 16 * 5 * 5) 51 | x = self.classifier(x) 52 | return x 53 | 54 | 55 | mnist_pretrained_lenet5_path = os.path.dirname(os.path.realpath(__file__)) + '/mnist_pretrained_lenet5.pkl' 56 | 57 | 58 | def get_net_model(net='alexnet', pretrained_dataset='imagenet', dropout=False, pretrained=True, input_mask=False): 59 | if input_mask: 60 | conv_class = SparseConv2d 61 | else: 62 | conv_class = FixHWConv2d 63 | if net == 'alexnet': 64 | model = myalexnet(pretrained=(pretrained_dataset == 'imagenet') and pretrained, dropout=dropout, conv_class=conv_class) 65 | teacher_model = alexnet(pretrained=(pretrained_dataset == 'imagenet')) 66 | elif net == 'squeezenet': 67 | model = mysqueezenet1_0(pretrained=(pretrained_dataset == 'imagenet') and pretrained, dropout=dropout, conv_class=conv_class) 68 | teacher_model = squeezenet1_0(pretrained=(pretrained_dataset == 'imagenet')) 69 | elif net == 'lenet-5': 70 | model = MyLeNet5(conv_class=conv_class) 71 | if pretrained and pretrained_dataset == 'mnist-32': 72 | model.load_state_dict(torch.load(mnist_pretrained_lenet5_path), strict=False) 73 | teacher_model = MyLeNet5() 74 | if os.path.isfile(mnist_pretrained_lenet5_path): 75 | teacher_model.load_state_dict(torch.load(mnist_pretrained_lenet5_path), strict=False) 76 | else: 77 | warnings.warn('failed to import teacher model!') 78 | else: 79 | raise NotImplementedError 80 | 81 | for p in teacher_model.parameters(): 82 | p.requires_grad = False 83 | teacher_model.eval() 84 | 85 | return model, teacher_model 86 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # model_based_energy_constrained_compression 2 | Code for paper "Energy-Constrained Compression for Deep Neural Networks via Weighted Sparse Projection and Layer Input Masking" (https://openreview.net/pdf?id=BylBr3C9K7) 3 | ``` 4 | @inproceedings{yang2018energy, 5 | title={Energy-Constrained Compression for Deep Neural Networks via Weighted Sparse Projection and Layer Input Masking}, 6 | author={Yang, Haichuan and Zhu, Yuhao and Liu, Ji}, 7 | booktitle={ICLR}, 8 | year={2019} 9 | } 10 | ``` 11 | ## Prerequisites 12 | 13 | 14 | ``` 15 | Python (3.6) 16 | PyTorch 1.0 17 | ``` 18 | 19 | To use the ImageNet dataset, download the dataset and move validation images to labeled subfolders (e.g., using https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh) 20 | 21 | ## Training and testing 22 | 23 | 24 | ### example 25 | 26 | To run the training with energy constraint on AlexNet, 27 | 28 | ``` 29 | python energy_proj_train.py --net alexnet --dataset imagenet --datadir [imagenet-folder with train and val folders] --batch_size 128 --lr 1e-3 --momentum 0.9 --l2wd 1e-4 --proj_int 10 --logdir ./log/path-of-log --num_workers 8 --exp_bdecay --epochs 30 --distill 0.5 --nodp --budget 0.2 30 | ``` 31 | 32 | ### usage 33 | 34 | ``` 35 | usage: energy_proj_train.py [-h] [--net NET] [--dataset DATASET] 36 | [--datadir DATADIR] [--batch_size BATCH_SIZE] 37 | [--val_batch_size VAL_BATCH_SIZE] 38 | [--num_workers NUM_WORKERS] [--epochs EPOCHS] 39 | [--lr LR] [--xlr XLR] [--l2wd L2WD] 40 | [--xl2wd XL2WD] [--momentum MOMENTUM] 41 | [--lr_decay LR_DECAY] [--lr_decay_e LR_DECAY_E] 42 | [--lr_decay_add] [--proj_int PROJ_INT] [--nodp] 43 | [--input_mask] [--randinit] [--pretrain PRETRAIN] 44 | [--eval] [--seed SEED] 45 | [--log_interval LOG_INTERVAL] 46 | [--test_interval TEST_INTERVAL] 47 | [--save_interval SAVE_INTERVAL] [--logdir LOGDIR] 48 | [--distill DISTILL] [--budget BUDGET] 49 | [--exp_bdecay] [--mgpu] [--skip1] 50 | 51 | Model-Based Energy Constrained Training 52 | 53 | optional arguments: 54 | -h, --help show this help message and exit 55 | --net NET network arch 56 | --dataset DATASET dataset used in the experiment 57 | --datadir DATADIR dataset dir in this machine 58 | --batch_size BATCH_SIZE 59 | batch size for training 60 | --val_batch_size VAL_BATCH_SIZE 61 | batch size for evaluation 62 | --num_workers NUM_WORKERS 63 | number of workers for training loader 64 | --epochs EPOCHS number of epochs to train 65 | --lr LR learning rate 66 | --xlr XLR learning rate for input mask 67 | --l2wd L2WD l2 weight decay 68 | --xl2wd XL2WD l2 weight decay (for input mask) 69 | --momentum MOMENTUM momentum 70 | --proj_int PROJ_INT how many batches for each projection 71 | --nodp turn off dropout 72 | --input_mask enable input mask 73 | --randinit use random init 74 | --pretrain PRETRAIN file to load pretrained model 75 | --eval evaluate testset in the begining 76 | --seed SEED random seed 77 | --log_interval LOG_INTERVAL 78 | how many batches to wait before logging training 79 | status 80 | --test_interval TEST_INTERVAL 81 | how many epochs to wait before another test 82 | --save_interval SAVE_INTERVAL 83 | how many epochs to wait before save a model 84 | --logdir LOGDIR folder to save to the log 85 | --distill DISTILL distill loss weight 86 | --budget BUDGET energy budget (relative) 87 | --exp_bdecay exponential budget decay 88 | --mgpu enable using multiple gpus 89 | --skip1 skip the first W update 90 | ``` -------------------------------------------------------------------------------- /pt_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.model_zoo as model_zoo 3 | from torch import nn as nn 4 | import torch.nn.init as init 5 | 6 | from sa_energy_model import FixHWConv2d, conv2d_out_dim 7 | 8 | model_urls = { 9 | 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth', 10 | 'squeezenet1_0': 'https://download.pytorch.org/models/squeezenet1_0-a815701f.pth', 11 | } 12 | 13 | ################################################################ 14 | ######################### Alex NET ########################## 15 | ################################################################ 16 | 17 | class MyAlexNet(nn.Module): 18 | def __init__(self, h=224, w=224, conv_class=FixHWConv2d, num_classes=1000, dropout=True): 19 | super(MyAlexNet, self).__init__() 20 | feature_layers = [] 21 | 22 | # conv 23 | feature_layers.append(conv_class(h, w, 3, 64, kernel_size=11, stride=4, padding=2)) 24 | h = conv2d_out_dim(h, kernel_size=11, stride=4, padding=2) 25 | w = conv2d_out_dim(w, kernel_size=11, stride=4, padding=2) 26 | feature_layers.append(nn.ReLU(inplace=True)) 27 | # pooling 28 | feature_layers.append(nn.MaxPool2d(kernel_size=3, stride=2)) 29 | h = conv2d_out_dim(h, kernel_size=3, stride=2) 30 | w = conv2d_out_dim(w, kernel_size=3, stride=2) 31 | 32 | # conv 33 | feature_layers.append(conv_class(h, w, 64, 192, kernel_size=5, padding=2)) 34 | h = conv2d_out_dim(h, kernel_size=5, padding=2) 35 | w = conv2d_out_dim(w, kernel_size=5, padding=2) 36 | feature_layers.append(nn.ReLU(inplace=True)) 37 | # pooling 38 | feature_layers.append(nn.MaxPool2d(kernel_size=3, stride=2)) 39 | h = conv2d_out_dim(h, kernel_size=3, stride=2) 40 | w = conv2d_out_dim(w, kernel_size=3, stride=2) 41 | 42 | # conv 43 | feature_layers.append(conv_class(h, w, 192, 384, kernel_size=3, padding=1)) 44 | h = conv2d_out_dim(h, kernel_size=3, padding=1) 45 | w = conv2d_out_dim(w, kernel_size=3, padding=1) 46 | feature_layers.append(nn.ReLU(inplace=True)) 47 | 48 | # conv 49 | feature_layers.append(conv_class(h, w, 384, 256, kernel_size=3, padding=1)) 50 | h = conv2d_out_dim(h, kernel_size=3, padding=1) 51 | w = conv2d_out_dim(w, kernel_size=3, padding=1) 52 | feature_layers.append(nn.ReLU(inplace=True)) 53 | 54 | # conv 55 | feature_layers.append(conv_class(h, w, 256, 256, kernel_size=3, padding=1)) 56 | h = conv2d_out_dim(h, kernel_size=3, padding=1) 57 | w = conv2d_out_dim(w, kernel_size=3, padding=1) 58 | feature_layers.append(nn.ReLU(inplace=True)) 59 | # pooling 60 | feature_layers.append(nn.MaxPool2d(kernel_size=3, stride=2)) 61 | h = conv2d_out_dim(h, kernel_size=3, stride=2) 62 | w = conv2d_out_dim(w, kernel_size=3, stride=2) 63 | 64 | self.features = nn.Sequential(*feature_layers) 65 | 66 | fc_layers = [nn.Dropout(p=0.5 if dropout else 0.0), 67 | nn.Linear(256 * 6 * 6, 4096), 68 | nn.ReLU(inplace=True), 69 | nn.Dropout(p=0.5 if dropout else 0.0), 70 | nn.Linear(4096, 4096), 71 | nn.ReLU(inplace=True), 72 | nn.Linear(4096, num_classes)] 73 | 74 | self.classifier = nn.Sequential(*fc_layers) 75 | 76 | def forward(self, x): 77 | x = self.features(x) 78 | x = x.view(x.size(0), 256 * 6 * 6) 79 | x = self.classifier(x) 80 | return x 81 | 82 | def get_inhw(self, x): 83 | res = [] 84 | for module in self.features._modules.values(): 85 | if isinstance(module, nn.Conv2d): 86 | res.append((x.size(2), x.size(3))) 87 | x = module(x) 88 | for module in self.classifier._modules.values(): 89 | if isinstance(module, nn.Linear): 90 | res.append((1, 1)) 91 | return res 92 | 93 | 94 | def myalexnet(pretrained=False, model_root=None, **kwargs): 95 | model = MyAlexNet(**kwargs) 96 | if pretrained: 97 | model.load_state_dict(model_zoo.load_url(model_urls['alexnet'], model_root), strict=False) 98 | return model 99 | 100 | 101 | ################################################################ 102 | ######################## Squeeze NET ######################## 103 | ################################################################ 104 | 105 | class MySqueezeNet(nn.Module): 106 | class MyFire(nn.Module): 107 | 108 | def __init__(self, inplanes, squeeze_planes, 109 | expand1x1_planes, expand3x3_planes, h_in, w_in, conv_class=FixHWConv2d): 110 | super(MySqueezeNet.MyFire, self).__init__() 111 | h = h_in 112 | w = w_in 113 | 114 | self.inplanes = inplanes 115 | self.squeeze = conv_class(h, w, inplanes, squeeze_planes, kernel_size=1) 116 | self.squeeze_activation = nn.ReLU(inplace=True) 117 | h = conv2d_out_dim(h, kernel_size=1) 118 | w = conv2d_out_dim(w, kernel_size=1) 119 | 120 | self.expand1x1 = conv_class(h, w, squeeze_planes, expand1x1_planes, kernel_size=1) 121 | self.expand1x1_activation = nn.ReLU(inplace=True) 122 | self.expand3x3 = conv_class(h, w, squeeze_planes, expand3x3_planes, kernel_size=3, padding=1) 123 | self.expand3x3_activation = nn.ReLU(inplace=True) 124 | h = conv2d_out_dim(h, kernel_size=3, padding=1) 125 | w = conv2d_out_dim(w, kernel_size=3, padding=1) 126 | 127 | def forward(self, x): 128 | x = self.squeeze_activation(self.squeeze(x)) 129 | return torch.cat([ 130 | self.expand1x1_activation(self.expand1x1(x)), 131 | self.expand3x3_activation(self.expand3x3(x)) 132 | ], 1) 133 | 134 | def __init__(self, version=1.0, h=224, w=224, conv_class=FixHWConv2d, num_classes=1000, dropout=True): 135 | MyFire = self.MyFire 136 | super(MySqueezeNet, self).__init__() 137 | if version not in [1.0]: 138 | raise ValueError("Unsupported SqueezeNet version {version}:" 139 | "1.0".format(version=version)) 140 | self.num_classes = num_classes 141 | 142 | feature_layers = [] 143 | # conv 144 | feature_layers.append(conv_class(h, w, 3, 96, kernel_size=7, stride=2)) 145 | h = conv2d_out_dim(h, kernel_size=7, stride=2) 146 | w = conv2d_out_dim(w, kernel_size=7, stride=2) 147 | feature_layers.append(nn.ReLU(inplace=True)) 148 | # pooling 149 | feature_layers.append(nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)) 150 | h = conv2d_out_dim(h, kernel_size=3, stride=2, ceil_mode=True) 151 | w = conv2d_out_dim(w, kernel_size=3, stride=2, ceil_mode=True) 152 | 153 | # fire block 154 | feature_layers.append(MyFire(96, 16, 64, 64, h_in=h, w_in=w, conv_class=conv_class)) 155 | feature_layers.append(MyFire(128, 16, 64, 64, h_in=h, w_in=w, conv_class=conv_class)) 156 | feature_layers.append(MyFire(128, 32, 128, 128, h_in=h, w_in=w, conv_class=conv_class)) 157 | feature_layers.append(nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)) 158 | h = conv2d_out_dim(h, kernel_size=3, stride=2, ceil_mode=True) 159 | w = conv2d_out_dim(w, kernel_size=3, stride=2, ceil_mode=True) 160 | 161 | feature_layers.append(MyFire(256, 32, 128, 128, h_in=h, w_in=w, conv_class=conv_class)) 162 | feature_layers.append(MyFire(256, 48, 192, 192, h_in=h, w_in=w, conv_class=conv_class)) 163 | feature_layers.append(MyFire(384, 48, 192, 192, h_in=h, w_in=w, conv_class=conv_class)) 164 | feature_layers.append(MyFire(384, 64, 256, 256, h_in=h, w_in=w, conv_class=conv_class)) 165 | feature_layers.append(nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)) 166 | h = conv2d_out_dim(h, kernel_size=3, stride=2, ceil_mode=True) 167 | w = conv2d_out_dim(w, kernel_size=3, stride=2, ceil_mode=True) 168 | 169 | feature_layers.append(MyFire(512, 64, 256, 256, h_in=h, w_in=w, conv_class=conv_class)) 170 | 171 | self.features = nn.Sequential(*feature_layers) 172 | # Final convolution is initialized differently form the rest 173 | final_conv = conv_class(h, w, 512, self.num_classes, kernel_size=1) 174 | self.classifier = nn.Sequential( 175 | nn.Dropout(p=0.5 if dropout else 0.0), 176 | final_conv, 177 | nn.ReLU(inplace=True), 178 | nn.AvgPool2d(13, stride=1) 179 | ) 180 | 181 | for m in self.modules(): 182 | if isinstance(m, nn.Conv2d): 183 | if m is final_conv: 184 | init.normal(m.weight.data, mean=0.0, std=0.01) 185 | else: 186 | init.kaiming_uniform(m.weight.data) 187 | if m.bias is not None: 188 | m.bias.data.zero_() 189 | 190 | def forward(self, x): 191 | x = self.features(x) 192 | x = self.classifier(x) 193 | return x.view(x.size(0), self.num_classes) 194 | 195 | 196 | def mysqueezenet1_0(pretrained=False, **kwargs): 197 | r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level 198 | accuracy with 50x fewer parameters and <0.5MB model size" 199 | `_ paper. 200 | 201 | Args: 202 | pretrained (bool): If True, returns a model pre-trained on ImageNet 203 | """ 204 | model = MySqueezeNet(version=1.0, **kwargs) 205 | if pretrained: 206 | model.load_state_dict(model_zoo.load_url(model_urls['squeezenet1_0']), strict=False) 207 | return model 208 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data.sampler import SubsetRandomSampler, Sampler 4 | from torchvision import datasets, transforms 5 | import torch.nn.functional as F 6 | 7 | 8 | class SubsetSequentialSampler(Sampler): 9 | r"""Samples elements sequentially from a given list of indices, without replacement. 10 | 11 | Arguments: 12 | indices (sequence): a sequence of indices 13 | """ 14 | 15 | def __init__(self, indices): 16 | self.indices = indices 17 | 18 | def __iter__(self): 19 | return (self.indices[i] for i in range(len(self.indices))) 20 | 21 | def __len__(self): 22 | return len(self.indices) 23 | 24 | 25 | def get_mnist32(batch_size, val_batch_size, data_root='./mnist_dataset', train=True, val=True, **kwargs): 26 | data_root = os.path.expanduser(os.path.join(data_root, 'mnist-data')) 27 | kwargs.pop('input_size', None) 28 | num_workers = kwargs.setdefault('num_workers', 1) 29 | print("Building MNIST data loader with {} workers".format(num_workers)) 30 | ds = [] 31 | if train: 32 | train_loader = torch.utils.data.DataLoader( 33 | datasets.MNIST(root=data_root, train=True, download=True, 34 | transform=transforms.Compose([ 35 | transforms.Resize(32), 36 | transforms.ToTensor(), 37 | transforms.Normalize((0.1307,), (0.3081,)) 38 | ])), 39 | batch_size=batch_size, shuffle=True, **kwargs) 40 | ds.append(train_loader) 41 | if val: 42 | test_loader = torch.utils.data.DataLoader( 43 | datasets.MNIST(root=data_root, train=False, download=True, 44 | transform=transforms.Compose([ 45 | transforms.Resize(32), 46 | transforms.ToTensor(), 47 | transforms.Normalize((0.1307,), (0.3081,)) 48 | ])), 49 | batch_size=val_batch_size, shuffle=False, **kwargs) 50 | ds.append(test_loader) 51 | 52 | train_loader4eval = torch.utils.data.DataLoader( 53 | datasets.MNIST(root=data_root, train=True, download=True, 54 | transform=transforms.Compose([ 55 | transforms.Resize(32), 56 | transforms.ToTensor(), 57 | transforms.Normalize((0.1307,), (0.3081,)) 58 | ])), 59 | batch_size=val_batch_size, shuffle=False, **kwargs) 60 | ds.append(train_loader4eval) 61 | ds = ds[0] if len(ds) == 1 else ds 62 | return ds 63 | 64 | 65 | def get_data_loaders(data_dir, dataset='imagenet', batch_size=32, val_batch_size=512, num_workers=0, nsubset=-1, 66 | normalize=None): 67 | if dataset == 'imagenet': 68 | traindir = os.path.join(data_dir, 'train') 69 | valdir = os.path.join(data_dir, 'val') 70 | if normalize is None: 71 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 72 | 73 | train_dataset = datasets.ImageFolder( 74 | traindir, 75 | transforms.Compose([ 76 | transforms.RandomResizedCrop(224), 77 | transforms.RandomHorizontalFlip(), 78 | transforms.ToTensor(), 79 | normalize, 80 | ])) 81 | 82 | if nsubset > 0: 83 | rand_idx = torch.randperm(len(train_dataset))[:nsubset] 84 | print('use a random subset of data:') 85 | print(rand_idx) 86 | train_sampler = SubsetRandomSampler(rand_idx) 87 | else: 88 | train_sampler = None 89 | 90 | train_loader = torch.utils.data.DataLoader( 91 | train_dataset, batch_size=batch_size, shuffle=(train_sampler is None), 92 | num_workers=num_workers, pin_memory=True, sampler=train_sampler) 93 | 94 | val_loader = torch.utils.data.DataLoader( 95 | datasets.ImageFolder(valdir, transforms.Compose([ 96 | transforms.Resize(256), 97 | transforms.CenterCrop(224), 98 | transforms.ToTensor(), 99 | normalize, 100 | ])), 101 | batch_size=val_batch_size, shuffle=False, 102 | num_workers=num_workers, pin_memory=True) 103 | 104 | # use 10K training data to see the training performance 105 | train_loader4eval = torch.utils.data.DataLoader( 106 | datasets.ImageFolder(traindir, transforms.Compose([ 107 | transforms.Resize(256), 108 | transforms.CenterCrop(224), 109 | transforms.ToTensor(), 110 | normalize, 111 | ])), 112 | batch_size=val_batch_size, shuffle=False, 113 | num_workers=num_workers, pin_memory=True, 114 | sampler=SubsetRandomSampler(torch.randperm(len(train_dataset))[:10000])) 115 | 116 | return train_loader, val_loader, train_loader4eval 117 | elif dataset == 'mnist-32': 118 | return get_mnist32(batch_size=batch_size, val_batch_size=val_batch_size, num_workers=num_workers) 119 | else: 120 | raise NotImplementedError 121 | 122 | 123 | def ncorrect(output, target, topk=(1,)): 124 | """Computes the numebr of correct@k for the specified values of k""" 125 | maxk = max(topk) 126 | 127 | _, pred = output.topk(maxk, 1, True, True) 128 | pred = pred.t() 129 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 130 | 131 | res = [] 132 | for k in topk: 133 | correct_k = correct[:k].float().sum().item() 134 | res.append(correct_k) 135 | return res 136 | 137 | 138 | def eval_loss_acc1_acc5(model, data_loader, loss_func=None, cuda=True, class_offset=0): 139 | val_loss = 0.0 140 | val_acc1 = 0.0 141 | val_acc5 = 0.0 142 | num_data = 0 143 | with torch.no_grad(): 144 | model.eval() 145 | for data, target in data_loader: 146 | num_data += target.size(0) 147 | target.data += class_offset 148 | if cuda: 149 | data, target = data.cuda(), target.cuda() 150 | output = model(data) 151 | if loss_func is not None: 152 | val_loss += loss_func(model, data, target).item() 153 | # val_loss += F.cross_entropy(output, target).item() 154 | nc1, nc5 = ncorrect(output.data, target.data, topk=(1, 5)) 155 | val_acc1 += nc1 156 | val_acc5 += nc5 157 | # print('acc:{}, {}'.format(nc1 / target.size(0), nc5 / target.size(0))) 158 | 159 | val_loss /= len(data_loader) 160 | val_acc1 /= num_data 161 | val_acc5 /= num_data 162 | 163 | return val_loss, val_acc1, val_acc5 164 | 165 | 166 | def cross_entropy(input, target, label_smoothing=0.0, size_average=True): 167 | """ Cross entropy that accepts soft targets 168 | Args: 169 | pred: predictions for neural network 170 | targets: targets (long tensor) 171 | size_average: if false, sum is returned instead of mean 172 | Examples:: 173 | input = torch.FloatTensor([[1.1, 2.8, 1.3], [1.1, 2.1, 4.8]]) 174 | input = torch.autograd.Variable(out, requires_grad=True) 175 | target = torch.FloatTensor([[0.05, 0.9, 0.05], [0.05, 0.05, 0.9]]) 176 | target = torch.autograd.Variable(y1) 177 | loss = cross_entropy(input, target) 178 | loss.backward() 179 | """ 180 | if label_smoothing <= 0.0: 181 | return F.cross_entropy(input, target) 182 | assert input.dim() == 2 and target.dim() == 1 183 | target_ = torch.unsqueeze(target, 1) 184 | one_hot = torch.zeros_like(input) 185 | one_hot.scatter_(1, target_, 1) 186 | one_hot = torch.clamp(one_hot, max=1.0-label_smoothing, min=label_smoothing/(one_hot.size(1) - 1.0)) 187 | 188 | if size_average: 189 | return torch.mean(torch.sum(-one_hot * F.log_softmax(input, dim=1), dim=1)) 190 | else: 191 | return torch.sum(torch.sum(-one_hot * F.log_softmax(input, dim=1), dim=1)) 192 | 193 | 194 | def joint_loss(model, data, target, teacher_model, distill, label_smoothing=0.0): 195 | criterion = lambda pred, y: cross_entropy(pred, y, label_smoothing=label_smoothing) 196 | output = model(data) 197 | if distill <= 0.0: 198 | return criterion(output, target) 199 | else: 200 | with torch.no_grad(): 201 | teacher_output = teacher_model(data).data 202 | distill_loss = torch.mean((output - teacher_output) ** 2) 203 | if distill >= 1.0: 204 | return distill_loss 205 | else: 206 | class_loss = criterion(output, target) 207 | # print("distill loss={:.4e}, class loss={:.4e}".format(distill_loss, class_loss)) 208 | return distill * distill_loss + (1.0 - distill) * class_loss 209 | 210 | 211 | def argmax(a): 212 | return max(range(len(a)), key=a.__getitem__) 213 | 214 | 215 | def expand_user(path): 216 | return os.path.abspath(os.path.expanduser(path)) 217 | 218 | 219 | def model_snapshot(model, new_file, old_file=None, verbose=False): 220 | from collections import OrderedDict 221 | import torch 222 | if isinstance(model, torch.nn.DataParallel): 223 | model = model.module 224 | if old_file and os.path.exists(expand_user(old_file)): 225 | if verbose: 226 | print("Removing old model {}".format(expand_user(old_file))) 227 | os.remove(expand_user(old_file)) 228 | if verbose: 229 | print("Saving model to {}".format(expand_user(new_file))) 230 | 231 | state_dict = OrderedDict() 232 | for k, v in model.state_dict().items(): 233 | if v.is_cuda: 234 | v = v.cpu() 235 | state_dict[k] = v 236 | torch.save(state_dict, expand_user(new_file)) 237 | -------------------------------------------------------------------------------- /proj_utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | def copy_model_weights(model, W_flat, W_shapes, param_name='weight'): 8 | offset = 0 9 | if isinstance(W_shapes, list): 10 | W_shapes = iter(W_shapes) 11 | for name, W in model.named_parameters(): 12 | if name.endswith(param_name): 13 | name_, shape = next(W_shapes) 14 | if shape is None: 15 | continue 16 | assert name_ == name 17 | numel = W.numel() 18 | W.data.copy_(W_flat[offset: offset + numel].view(shape)) 19 | offset += numel 20 | 21 | 22 | def reset_model_param(model): 23 | for M in model.modules(): 24 | if hasattr(M, 'reset_parameters'): 25 | M.reset_parameters() 26 | 27 | 28 | def idxproj(model, z_idx, W_shapes, param_name='weight'): 29 | assert type(z_idx) is torch.LongTensor or type(z_idx) is torch.cuda.LongTensor 30 | offset = 0 31 | i = 0 32 | for name, W in model.named_parameters(): 33 | if name.endswith(param_name): 34 | name_, shape = W_shapes[i] 35 | i += 1 36 | assert name_ == name 37 | if shape is None: 38 | continue 39 | mask = z_idx >= offset 40 | mask[z_idx >= (offset + W.numel())] = 0 41 | z_idx_sel = z_idx[mask] 42 | if len(z_idx_sel.shape) != 0: 43 | W.data.view(-1)[z_idx_sel - offset] = 0.0 44 | offset += W.numel() 45 | 46 | 47 | def getmask(model, param_name='weight'): 48 | mask_model = copy.deepcopy(model) 49 | for name, W in mask_model.named_parameters(): 50 | if name.endswith(param_name): 51 | W.data.copy_(W.data != 0.0) 52 | 53 | return mask_model 54 | 55 | 56 | def maskproj(model, mask_model, param_name='weight'): 57 | mask_model_param = mask_model.named_parameters() 58 | for name1, W in model.named_parameters(): 59 | name2, W_mask = next(mask_model_param) 60 | assert name1 == name2 61 | if name1.endswith(param_name) and W.dim() > 1: 62 | W.data.mul_(W_mask.data) 63 | 64 | 65 | def idx2mask(mask_model, z_idx, W_shapes, param_name='weight'): 66 | fill_model_weights(mask_model, 1.0, param_name=param_name) 67 | offset = 0 68 | i = 0 69 | for name, W in mask_model.named_parameters(): 70 | if name.endswith(param_name): 71 | name_, shape = W_shapes[i] 72 | assert name_ == name 73 | mask = z_idx >= offset 74 | mask[z_idx >= (offset + W.numel())] = 0 75 | z_idx_sel = z_idx[mask] 76 | if len(z_idx_sel.shape) != 0: 77 | W.data.view(-1)[z_idx_sel - offset] = 0.0 78 | i += 1 79 | offset += W.numel() 80 | 81 | return mask_model 82 | 83 | 84 | def model_mask(model, param_name='weight'): 85 | mask_model = copy.deepcopy(model) 86 | fill_model_weights(mask_model, 1.0, param_name=param_name) 87 | 88 | model2_param = model.named_parameters() 89 | for name1, p1 in mask_model.named_parameters(): 90 | name2, p2 = next(model2_param) 91 | assert name1 == name2 92 | if name1.endswith(param_name) and p1.dim() > 1: 93 | p1.data.copy_((p2.data != 0.0).float()) 94 | 95 | return mask_model 96 | 97 | 98 | def filtered_parameters(model, param_name, inverse=False): 99 | for name, param in model.named_parameters(): 100 | if inverse != (name.endswith(param_name)): 101 | yield param 102 | 103 | 104 | def l0proj(model, k, normalized=True, param_name='weight'): 105 | # get all the weights 106 | W_shapes = [] 107 | res = [] 108 | for name, W in model.named_parameters(): 109 | if name.endswith(param_name): 110 | if W.dim() == 1: 111 | W_shapes.append((name, None)) 112 | else: 113 | W_shapes.append((name, W.data.shape)) 114 | res.append(W.data.view(-1)) 115 | 116 | res = torch.cat(res, dim=0) 117 | if normalized: 118 | assert 0.0 <= k <= 1.0 119 | nnz = math.floor(res.shape[0] * k) 120 | else: 121 | assert k >= 1 and round(k) == k 122 | nnz = k 123 | if nnz == res.shape[0]: 124 | z_idx = [] 125 | else: 126 | _, z_idx = torch.topk(torch.abs(res), int(res.shape[0] - nnz), largest=False, sorted=False) 127 | res[z_idx] = 0.0 128 | copy_model_weights(model, res, W_shapes, param_name) 129 | return z_idx, W_shapes 130 | 131 | 132 | def threshold_proj(model, thresh, param_name='weight'): 133 | assert thresh > 0.0 134 | for name, W in model.named_parameters(): 135 | if name.endswith(param_name): 136 | if W.dim() > 1: 137 | W.data[W.data.abs() < thresh] = 0.0 138 | 139 | 140 | def print_model_weights(model, param_name='weight'): 141 | for name, W in model.named_parameters(): 142 | if name.endswith(param_name): 143 | print(name, W.data) 144 | 145 | 146 | def model_weights_diff(model1, model2, param_name='weight'): 147 | res = 0.0 148 | model2_param = model2.named_parameters() 149 | for name1, W1 in model1.named_parameters(): 150 | name2, W2 = next(model2_param) 151 | assert name1 == name2 152 | if name1.endswith(param_name): 153 | res += (W1.data - W2.data).abs().sum() 154 | 155 | return res 156 | 157 | 158 | def model_sparsity(model, normalized=True, param_name='weight'): 159 | nnz = 0 160 | numel = 0 161 | for name, W in model.named_parameters(): 162 | if name.endswith(param_name): 163 | W_nz = torch.nonzero(W.data) 164 | if W_nz.dim() > 0: 165 | nnz += W_nz.shape[0] 166 | numel += torch.numel(W.data) 167 | 168 | return float(nnz) / float(numel) if normalized else float(nnz) 169 | 170 | 171 | def model_sparsity_lb(model, param_name='weight'): 172 | numel = 0 173 | for name, W in model.named_parameters(): 174 | if name.endswith(param_name): 175 | numel += torch.numel(W.data) 176 | 177 | return 1.0 / float(numel) 178 | 179 | 180 | def layers_nnz(model, normalized=True, param_name='weight'): 181 | res = {} 182 | for name, W in model.named_parameters(): 183 | if name.endswith(param_name): 184 | layer_name = name[:-len(param_name)-1] 185 | W_nz = torch.nonzero(W.data) 186 | if W_nz.dim() > 0: 187 | if not normalized: 188 | res[layer_name] = W_nz.shape[0] 189 | else: 190 | # print("{} layer nnz:{}".format(name, torch.nonzero(W.data))) 191 | res[layer_name] = float(W_nz.shape[0]) / torch.numel(W) 192 | else: 193 | res[layer_name] = 0 194 | 195 | return res 196 | 197 | 198 | def layers_nnz_hw(model, param_name='weight'): 199 | """ 200 | Get a dict which contains each layer's nnz on the last two dimensions i.e. height and weight 201 | :param model: The model contains the layers 202 | :param param_name: The layers' parameter name, i.e. weight 203 | :return: Dict containing layer names and the nnz tensor 204 | """ 205 | res = {} 206 | for name, W in model.named_parameters(): 207 | if name.endswith(param_name): 208 | layer_name = name[:-len(param_name) - 1] 209 | if len(W.size()) < 3: 210 | res[layer_name] = (W.data != 0.0).float().sum().item() 211 | else: 212 | h_times_w = W.size()[-1] * W.size()[-2] 213 | W_nz = (W.data.view(*(W.size()[:-2]), h_times_w) != 0.0).float() 214 | res[layer_name] = torch.sum(W_nz, dim=-1) 215 | 216 | return res 217 | 218 | 219 | def layers_nz_mask(model, param_name='weight'): 220 | res = {} 221 | for name, W in model.named_parameters(): 222 | if name.endswith(param_name): 223 | layer_name = name[:-len(param_name) - 1] 224 | res[layer_name] = (W.data != 0.0).float() 225 | 226 | return res 227 | 228 | 229 | def layers_stat(model, param_names=('weight',), param_filter=lambda p: True): 230 | if isinstance(param_names, str): 231 | param_names = (param_names,) 232 | def match_endswith(name): 233 | for param_name in param_names: 234 | if name.endswith(param_name): 235 | return param_name 236 | return None 237 | res = "########### layer stat ###########\n" 238 | for name, W in model.named_parameters(): 239 | param_name = match_endswith(name) 240 | if param_name is not None: 241 | if param_filter(W): 242 | layer_name = name[:-len(param_name) - 1] 243 | W_nz = torch.nonzero(W.data) 244 | nnz = W_nz.shape[0] / W.data.numel() if W_nz.dim() > 0 else 0.0 245 | W_data_abs = W.data.abs() 246 | res += "{:>20}".format(layer_name) + 'abs(W): min={:.4e}, mean={:.4e}, max={:.4e}, nnz={:.4f}\n'\ 247 | .format(W_data_abs.min().item(), W_data_abs.mean().item(), W_data_abs.max().item(), nnz) 248 | 249 | res += "########### layer stat ###########" 250 | return res 251 | 252 | 253 | def layers_grad_stat(model, param_name='weight'): 254 | res = "########### layer grad stat ###########\n" 255 | for name, W in model.named_parameters(): 256 | if name.endswith(param_name): 257 | layer_name = name[:-len(param_name) - 1] 258 | W_nz = torch.nonzero(W.grad.data) 259 | nnz = W_nz.shape[0] / W.grad.data.numel() if W_nz.dim() > 0 else 0.0 260 | W_data_abs = W.grad.data.abs() 261 | res += "{:>20}".format(layer_name) + 'abs(W.grad): min={:.4e}, mean={:.4e}, max={:.4e}, nnz={:.4f}\n'.format(W_data_abs.min().item(), W_data_abs.mean().item(), W_data_abs.max().item(), nnz) 262 | 263 | res += "########### layer grad stat ###########" 264 | return res 265 | 266 | 267 | def fill_model_weights(model, val, param_name='weight'): 268 | for name, W in model.named_parameters(): 269 | if name.endswith(param_name): 270 | W.data.fill_(val) 271 | 272 | return model 273 | 274 | 275 | def clamp_model_weights(model, min=0.0, max=1.0, param_name='input_mask'): 276 | for name, W in model.named_parameters(): 277 | if name.endswith(param_name): 278 | W.data.clamp_(min=min, max=max) 279 | 280 | return model 281 | 282 | 283 | def round_model_weights(model, param_name='input_mask'): 284 | for name, W in model.named_parameters(): 285 | if name.endswith(param_name): 286 | W.data.round_() 287 | 288 | return model 289 | 290 | 291 | def model_support_set(model, param_name='weight'): 292 | res = copy.deepcopy(model) 293 | res_param = res.named_parameters() 294 | for name1, W1 in model.named_parameters(): 295 | name2, W2 = next(res_param) 296 | assert name1 == name2 297 | if name1.endswith(param_name): 298 | W2.data[:] = (W1.data != 0.0) 299 | 300 | return res 301 | 302 | 303 | def argmax(a): 304 | return max(range(len(a)), key=a.__getitem__) 305 | 306 | 307 | def num_dict_info(d): 308 | res = "{" 309 | for k in d: 310 | res += "{}: {:.4e}, ".format(k, d[k]) 311 | 312 | res += '}' 313 | return res 314 | 315 | 316 | if __name__ == '__main__': 317 | layers = [nn.Conv2d(in_channels=3, out_channels=1, kernel_size=3), nn.Linear(16, 10)] 318 | model = nn.Sequential(*layers) 319 | print_model_weights(model) 320 | model_ = copy.deepcopy(model) 321 | z_idx, W_shapes = l0proj(model_, 100) 322 | print_model_weights(model_) 323 | idxproj(model, z_idx, W_shapes) 324 | print_model_weights(model) 325 | 326 | print("diff={}".format(model_weights_diff(model, model_))) 327 | -------------------------------------------------------------------------------- /sparsity_proj_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import datetime 4 | import numpy as np 5 | import os 6 | import math 7 | import time 8 | import torch 9 | import random 10 | import sys 11 | import copy 12 | from energynet.model_based.models import get_net_model 13 | from energynet.model_based.proj_utils import fill_model_weights, layers_stat, model_sparsity, l0proj, model_sparsity_lb 14 | from energynet.model_based.sa_energy_model import build_energy_info, energy_eval2, energy_eval2_relax, \ 15 | reset_Xenergy_cache 16 | from energynet.model_free.utils import get_data_loaders, joint_loss, PlotData, eval_loss_acc1_acc5 17 | from utee import misc 18 | 19 | from torchvision import transforms 20 | from utee.misc import model_snapshot 21 | 22 | if __name__ == '__main__': 23 | parser = argparse.ArgumentParser(description='Sparsity Constrained Training (Magnitude based pruning variant)') 24 | parser.add_argument('--net', default='alexnet', help='network arch') 25 | 26 | parser.add_argument('--dataset', default='imagenet', help='dataset used in the experiment') 27 | parser.add_argument('--datadir', default='/home/hyang/ssd2/ILSVRC_CLS', help='dataset dir in this machine') 28 | 29 | parser.add_argument('--batch_size', type=int, default=128, help='batch size for training') 30 | parser.add_argument('--val_batch_size', type=int, default=512, help='batch size for evaluation') 31 | parser.add_argument('--num_workers', type=int, default=8, help='number of workers for train') 32 | 33 | parser.add_argument('--epochs', type=int, default=20, help='number of epochs to train') 34 | parser.add_argument('--lr', type=float, default=0.001, help='primal learning rate') 35 | parser.add_argument('--adlr', action='store_true', help='adaptive lr based on sparsity (lr = lr/sparsity)') 36 | parser.add_argument('--l2wd', type=float, default=0.0, help='l2 weight decay') 37 | parser.add_argument('--momentum', type=float, default=0.0, help='primal momentum') 38 | parser.add_argument('--lr_decay', type=float, default=1.0, help='learning rate (default: 1)') 39 | parser.add_argument('--lr_decay_s', type=int, default=10, help='learning rate decay start epoch (default: 10)') 40 | parser.add_argument('--lr_decay_i', type=int, default=10, help='learning rate decay epoch interval (default: 10)') 41 | parser.add_argument('--lr_decay_add', action='store_true', help='use additive lr decay (otherwise use multiplicative)') 42 | 43 | parser.add_argument('--proj_int', type=int, default=1, help='how many batches for each projection') 44 | parser.add_argument('--nodp', action='store_true', help='turn off dropout') 45 | 46 | parser.add_argument('--randinit', action='store_true', help='use random init') 47 | parser.add_argument('--pretrain', default=None, help='file to load pretrained model') 48 | parser.add_argument('--eval', action='store_true', help='eval mode') 49 | 50 | parser.add_argument('--seed', type=int, default=117, help='random seed (default: 117)') 51 | parser.add_argument('--log_interval', type=int, default=100, 52 | help='how many batches to wait before logging training status') 53 | parser.add_argument('--test_interval', type=int, default=1, help='how many epochs to wait before another test') 54 | parser.add_argument('--save_interval', type=int, default=-1, help='how many epochs to wait before save a model') 55 | parser.add_argument('--logdir', default=None, help='folder to save to the log') 56 | parser.add_argument('--distill', type=float, default=0.0, help='distill loss weight') 57 | parser.add_argument('--budget', type=float, default=0.0, help='energy budget') 58 | parser.add_argument('--exp_bdecay', action='store_true', help='budget decay exponential') 59 | parser.add_argument('--mgpu', action='store_true', help='enable using multiple gpus') 60 | 61 | args = parser.parse_args() 62 | args.cuda = torch.cuda.is_available() 63 | 64 | if args.logdir is None: 65 | args.logdir = 'log/' + sys.argv[0] + str(datetime.datetime.now().strftime("_%Y_%m_%d_AT_%H_%M_%S")) 66 | 67 | args.logdir = os.path.join(os.path.dirname(__file__), args.logdir) 68 | # rm old contents in dir 69 | print('remove old contents in {}'.format(args.logdir)) 70 | os.system('rm -rf ' + args.logdir) 71 | 72 | # create log file 73 | misc.logger.init(args.logdir, 'train_log') 74 | print = misc.logger.info 75 | 76 | # backup the src 77 | os.system('zip -q ' + os.path.join(args.logdir, 'src.zip') + ' {}/*.py'.format( 78 | os.path.dirname(os.path.realpath(__file__)))) 79 | 80 | print('command:\npython {}'.format(' '.join(sys.argv))) 81 | print("=================FLAGS==================") 82 | for k, v in args.__dict__.items(): 83 | print('{}: {}'.format(k, v)) 84 | print("========================================") 85 | 86 | # set up random seeds 87 | torch.manual_seed(args.seed) 88 | if args.cuda: 89 | torch.cuda.manual_seed(args.seed) 90 | np.random.seed(args.seed) 91 | random.seed(args.seed) 92 | 93 | # get training and validation data loaders 94 | normalize = None 95 | class_offset = 0 96 | if args.net == 'mobilenet-imagenet': 97 | normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 98 | class_offset = 1 99 | tr_loader, val_loader, train_loader4eval = get_data_loaders(data_dir=args.datadir, 100 | dataset=args.dataset, 101 | batch_size=args.batch_size, 102 | val_batch_size=args.val_batch_size, 103 | num_workers=args.num_workers, 104 | normalize=normalize) 105 | # get network model 106 | model, teacher_model = get_net_model(net=args.net, pretrained_dataset=args.dataset, dropout=(not args.nodp), 107 | pretrained=not args.randinit) 108 | 109 | # pretrained model 110 | if args.pretrain is not None and os.path.isfile(args.pretrain): 111 | print('load pretrained model:{}'.format(args.pretrain)) 112 | model.load_state_dict(torch.load(args.pretrain)) 113 | elif args.pretrain is not None: 114 | print('fail to load pretrained model: {}'.format(args.pretrain)) 115 | 116 | if args.mgpu: 117 | assert len(os.environ['CUDA_VISIBLE_DEVICES'].split(',')) > 1 118 | model = torch.nn.DataParallel(model) 119 | teacher_model = torch.nn.DataParallel(teacher_model) 120 | 121 | # for energy estimate 122 | print('================model energy summary================') 123 | energy_info = build_energy_info(model) 124 | energy_estimator = lambda m: sum(energy_eval2(m, energy_info, verbose=False).values()) 125 | energy_estimator_relaxed = lambda m: sum(energy_eval2_relax(m, energy_info, verbose=False).values()) 126 | 127 | reset_Xenergy_cache(energy_info) 128 | cur_energy = sum(energy_eval2(model, energy_info, verbose=True).values()) 129 | cur_energy_relaxed = energy_estimator_relaxed(model) 130 | 131 | dense_model = fill_model_weights(copy.deepcopy(model), 1.0) 132 | budget_ub = energy_estimator_relaxed(dense_model) 133 | zero_model = fill_model_weights(copy.deepcopy(model), 0.0) 134 | budget_lb = energy_estimator_relaxed(zero_model) 135 | 136 | del zero_model, dense_model 137 | 138 | proj_func = lambda m, budget: l0proj(m, budget, normalized=True) 139 | print('energy on dense DNN:{:.4e}, on zero DNN:{:.4e}, normalized_lb={:.4e}'.format(budget_ub, budget_lb, 140 | budget_lb / budget_ub)) 141 | print('energy on current DNN:{:.4e}, normalized={:.4e}'.format(cur_energy, cur_energy / budget_ub)) 142 | print('====================================================') 143 | print('current energy {:.4e}, relaxed: {:.4e}'.format(cur_energy, cur_energy_relaxed)) 144 | 145 | 146 | netl2wd = args.l2wd 147 | 148 | if args.cuda: 149 | if args.distill > 0.0: 150 | teacher_model.cuda() 151 | model.cuda() 152 | 153 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=netl2wd) 154 | 155 | loss_func = lambda m, x, y: joint_loss(model=m, data=x, target=y, teacher_model=teacher_model, distill=args.distill) 156 | 157 | plot_data = PlotData() 158 | if args.eval or args.dataset != 'imagenet': 159 | val_loss, val_acc1, val_acc5 = eval_loss_acc1_acc5(model, val_loader, loss_func, args.cuda, 160 | class_offset=class_offset) 161 | print('**Validation loss:{:.4e}, top-1 accuracy:{:.5f}, top-5 accuracy:{:.5f}'.format(val_loss, val_acc1, 162 | val_acc5)) 163 | # also evaluate training data 164 | tr_loss, tr_acc1, tr_acc5 = eval_loss_acc1_acc5(model, train_loader4eval, loss_func, args.cuda, 165 | class_offset=class_offset) 166 | print('###Training loss:{:.4e}, top-1 accuracy:{:.5f}, top-5 accuracy:{:.5f}'.format(tr_loss, tr_acc1, tr_acc5)) 167 | else: 168 | val_acc1 = 0.0 169 | print('For imagenet, skip the first validation evaluation.') 170 | 171 | best_acc = val_acc1 172 | 173 | old_file = None 174 | cur_sparsity = model_sparsity(model) 175 | sparsity_step = max(0.0, cur_sparsity - args.budget) / ((len(tr_loader) * args.epochs) / args.proj_int) 176 | 177 | sparsity_lb = model_sparsity_lb(model) 178 | sparsity_decay_factor = min(1.0, max(args.budget, sparsity_lb) / cur_sparsity) ** \ 179 | (1.0 / ((len(tr_loader) * args.epochs) / args.proj_int)) 180 | t_begin = time.time() 181 | log_tic = t_begin 182 | cur_budget = cur_sparsity 183 | lr = args.lr 184 | 185 | for epoch in range(args.epochs): 186 | # decay lr 187 | if epoch >= args.lr_decay_s and (epoch - args.lr_decay_s) % args.lr_decay_i == 0: 188 | lr *= args.lr_decay 189 | 190 | for batch_idx, (data, target) in enumerate(tr_loader): 191 | model.train() 192 | if args.adlr: 193 | optimizer.param_groups[0]['lr'] = lr / cur_sparsity 194 | else: 195 | optimizer.param_groups[0]['lr'] = lr 196 | if args.cuda: 197 | data, target = data.cuda(), target.cuda() 198 | 199 | loss = loss_func(model, data, target + class_offset) 200 | # update network weights 201 | optimizer.zero_grad() 202 | loss.backward() 203 | optimizer.step() 204 | 205 | if (batch_idx > 0 and batch_idx % args.proj_int == 0) or batch_idx == len(tr_loader) - 1: 206 | proj_func(model, cur_budget) 207 | if epoch == args.epochs - 1 and batch_idx >= len(tr_loader) - 1 - args.proj_int: 208 | cur_budget = args.budget 209 | else: 210 | if args.exp_bdecay: 211 | cur_budget = max(cur_budget * sparsity_decay_factor, args.budget) 212 | else: 213 | cur_budget = max(cur_budget - sparsity_step, args.budget) 214 | 215 | if batch_idx % args.log_interval == 0: 216 | print('======================================================') 217 | print('+-------------- epoch {}, batch {}/{} ----------------+'.format(epoch, batch_idx, 218 | len(tr_loader))) 219 | log_toc = time.time() 220 | print( 221 | 'primal update: net loss={:.4e}, lr={:.4e}, current normalized budget: {:.4e}, time_elapsed={:.3f}s'.format( 222 | loss.item(), optimizer.param_groups[0]['lr'], cur_budget, log_toc - log_tic)) 223 | log_tic = time.time() 224 | if batch_idx % args.proj_int == 0: 225 | cur_sparsity = model_sparsity(model) 226 | print('sparsity:{}'.format(model_sparsity(model))) 227 | print(layers_stat(model)) 228 | print('+-----------------------------------------------------+') 229 | 230 | cur_energy = energy_estimator(model) 231 | cur_energy_relaxed = energy_estimator_relaxed(model) 232 | if epoch % args.test_interval == 0: 233 | plot_data.append('energy', cur_energy) 234 | plot_data.append('normalized energy', cur_energy / budget_ub) 235 | plot_data.append('budget', cur_budget) 236 | 237 | val_loss, val_acc1, val_acc5 = eval_loss_acc1_acc5(model, val_loader, loss_func, args.cuda, 238 | class_offset=class_offset) 239 | plot_data.append('val_loss', val_loss) 240 | plot_data.append('val_acc1', val_acc1) 241 | plot_data.append('val_acc5', val_acc5) 242 | 243 | # also evaluate training data 244 | tr_loss, tr_acc1, tr_acc5 = eval_loss_acc1_acc5(model, train_loader4eval, loss_func, 245 | args.cuda, class_offset=class_offset) 246 | print('###Training loss:{:.4e}, top-1 accuracy:{:.5f}, top-5 accuracy:{:.5f}'.format(tr_loss, tr_acc1, 247 | tr_acc5)) 248 | plot_data.append('tr_loss', tr_loss) 249 | plot_data.append('tr_acc1', tr_acc1) 250 | plot_data.append('tr_acc5', tr_acc5) 251 | 252 | if val_acc1 > best_acc: 253 | pass 254 | # new_file = os.path.join(args.logdir, 'model_best-{}.pkl'.format(epoch)) 255 | # misc.model_snapshot(primal_model.net, new_file, old_file=old_file, verbose=True) 256 | # best_acc = val_acc1 257 | # old_file = new_file 258 | print( 259 | '***Validation loss:{:.4e}, top-1 accuracy:{:.5f}, top-5 accuracy:{:.5f}, current normalized energy:{:.4e}, {:.4e}(relaxed)'.format( 260 | val_loss, val_acc1, 261 | val_acc5, cur_energy / budget_ub, cur_energy_relaxed / budget_ub)) 262 | # save current model 263 | model_snapshot(model, os.path.join(args.logdir, 'primal_model_latest.pkl')) 264 | plot_data.dump(os.path.join(args.logdir, 'plot_data.pkl')) 265 | 266 | if args.save_interval > 0 and epoch % args.save_interval == 0: 267 | model_snapshot(model, os.path.join(args.logdir, 'primal_model_epoch{}.pkl'.format(epoch))) 268 | 269 | elapse_time = time.time() - t_begin 270 | speed_epoch = elapse_time / (1 + epoch) 271 | eta = speed_epoch * (args.epochs - epoch) 272 | print("Elapsed {:.2f}s, ets {:.2f}s".format(elapse_time, eta)) 273 | -------------------------------------------------------------------------------- /energy_proj_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import numpy as np 4 | import os 5 | import math 6 | import time 7 | import torch 8 | import random 9 | import sys 10 | import copy 11 | from models import get_net_model 12 | from proj_utils import fill_model_weights, layers_stat, model_sparsity, filtered_parameters, \ 13 | l0proj, round_model_weights, clamp_model_weights 14 | from sa_energy_model import build_energy_info, energy_eval2, energy_eval2_relax, energy_proj2, \ 15 | reset_Xenergy_cache 16 | from utils import get_data_loaders, joint_loss, eval_loss_acc1_acc5, model_snapshot 17 | 18 | if __name__ == '__main__': 19 | parser = argparse.ArgumentParser(description='Model-Based Energy Constrained Training') 20 | parser.add_argument('--net', default='alexnet', help='network arch') 21 | 22 | parser.add_argument('--dataset', default='imagenet', help='dataset used in the experiment') 23 | parser.add_argument('--datadir', default='./ILSVRC_CLS', help='dataset dir in this machine') 24 | 25 | parser.add_argument('--batch_size', type=int, default=128, help='batch size for training') 26 | parser.add_argument('--val_batch_size', type=int, default=512, help='batch size for evaluation') 27 | parser.add_argument('--num_workers', type=int, default=8, help='number of workers for training loader') 28 | 29 | parser.add_argument('--epochs', type=int, default=30, help='number of epochs to train') 30 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate') 31 | parser.add_argument('--xlr', type=float, default=1e-4, help='learning rate for input mask') 32 | 33 | parser.add_argument('--l2wd', type=float, default=1e-4, help='l2 weight decay') 34 | parser.add_argument('--xl2wd', type=float, default=1e-5, help='l2 weight decay (for input mask)') 35 | 36 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum') 37 | 38 | parser.add_argument('--proj_int', type=int, default=10, help='how many batches for each projection') 39 | parser.add_argument('--nodp', action='store_true', help='turn off dropout') 40 | parser.add_argument('--input_mask', action='store_true', help='enable input mask') 41 | 42 | parser.add_argument('--randinit', action='store_true', help='use random init') 43 | parser.add_argument('--pretrain', default=None, help='file to load pretrained model') 44 | parser.add_argument('--eval', action='store_true', help='evaluate testset in the begining') 45 | 46 | parser.add_argument('--seed', type=int, default=117, help='random seed') 47 | parser.add_argument('--log_interval', type=int, default=100, 48 | help='how many batches to wait before logging training status') 49 | parser.add_argument('--test_interval', type=int, default=1, help='how many epochs to wait before another test') 50 | parser.add_argument('--save_interval', type=int, default=-1, help='how many epochs to wait before save a model') 51 | parser.add_argument('--logdir', default=None, help='folder to save to the log') 52 | parser.add_argument('--distill', type=float, default=0.5, help='distill loss weight') 53 | parser.add_argument('--budget', type=float, default=0.2, help='energy budget (relative)') 54 | parser.add_argument('--exp_bdecay', action='store_true', help='exponential budget decay') 55 | parser.add_argument('--mgpu', action='store_true', help='enable using multiple gpus') 56 | parser.add_argument('--skip1', action='store_true', help='skip the first W update') 57 | 58 | args = parser.parse_args() 59 | args.cuda = torch.cuda.is_available() 60 | 61 | if args.logdir is None: 62 | args.logdir = 'log/' + sys.argv[0] + str(datetime.datetime.now().strftime("_%Y_%m_%d_AT_%H_%M_%S")) 63 | 64 | args.logdir = os.path.join(os.path.dirname(__file__), args.logdir) 65 | if not os.path.exists(args.logdir): 66 | os.makedirs(args.logdir) 67 | print('command:\npython {}'.format(' '.join(sys.argv))) 68 | print("=================FLAGS==================") 69 | for k, v in args.__dict__.items(): 70 | print('{}: {}'.format(k, v)) 71 | print("========================================") 72 | 73 | # set up random seeds 74 | torch.manual_seed(args.seed) 75 | if args.cuda: 76 | torch.cuda.manual_seed(args.seed) 77 | np.random.seed(args.seed) 78 | random.seed(args.seed) 79 | 80 | # get training and validation data loaders 81 | normalize = None 82 | tr_loader, val_loader, train_loader4eval = get_data_loaders(data_dir=args.datadir, 83 | dataset=args.dataset, 84 | batch_size=args.batch_size, 85 | val_batch_size=args.val_batch_size, 86 | num_workers=args.num_workers, 87 | normalize=normalize) 88 | # get network model 89 | model, teacher_model = get_net_model(net=args.net, pretrained_dataset=args.dataset, dropout=(not args.nodp), 90 | pretrained=not args.randinit, input_mask=args.input_mask) 91 | 92 | # pretrained model 93 | if args.pretrain is not None and os.path.isfile(args.pretrain): 94 | print('load pretrained model:{}'.format(args.pretrain)) 95 | model.load_state_dict(torch.load(args.pretrain)) 96 | elif args.pretrain is not None: 97 | print('fail to load pretrained model: {}'.format(args.pretrain)) 98 | 99 | # set up multi-gpus 100 | if args.mgpu: 101 | assert len(os.environ['CUDA_VISIBLE_DEVICES'].split(',')) > 1 102 | model = torch.nn.DataParallel(model) 103 | teacher_model = torch.nn.DataParallel(teacher_model) 104 | 105 | # for energy estimate 106 | print('================model energy summary================') 107 | energy_info = build_energy_info(model) 108 | energy_estimator = lambda m: sum(energy_eval2(m, energy_info, verbose=False).values()) 109 | energy_estimator_relaxed = lambda m: sum(energy_eval2_relax(m, energy_info, verbose=False).values()) 110 | 111 | reset_Xenergy_cache(energy_info) 112 | cur_energy = sum(energy_eval2(model, energy_info, verbose=True).values()) 113 | cur_energy_relaxed = energy_estimator_relaxed(model) 114 | 115 | dense_model = fill_model_weights(copy.deepcopy(model), 1.0) 116 | budget_ub = energy_estimator_relaxed(dense_model) 117 | zero_model = fill_model_weights(copy.deepcopy(model), 0.0) 118 | budget_lb = energy_estimator_relaxed(zero_model) 119 | 120 | del zero_model, dense_model 121 | args.budget = max(args.budget, budget_lb / budget_ub) 122 | 123 | proj_func = lambda m, budget, grad=False, in_place=True: energy_proj2(m, energy_info, budget, grad=grad, 124 | in_place=in_place, param_name='weight') 125 | print('energy on dense DNN:{:.4e}, on zero DNN:{:.4e}, normalized_lb={:.4e}'.format(budget_ub, budget_lb, 126 | budget_lb / budget_ub)) 127 | print('energy on current DNN:{:.4e}, normalized={:.4e}'.format(cur_energy, cur_energy / budget_ub)) 128 | print('====================================================') 129 | print('current energy {:.4e}, relaxed: {:.4e}'.format(cur_energy, cur_energy_relaxed)) 130 | 131 | netl2wd = args.l2wd 132 | 133 | if args.cuda: 134 | if args.distill > 0.0: 135 | teacher_model.cuda() 136 | model.cuda() 137 | 138 | loss_func = lambda m, x, y: joint_loss(model=m, data=x, target=y, teacher_model=teacher_model, distill=args.distill) 139 | 140 | if args.eval or args.dataset != 'imagenet': 141 | val_loss, val_acc1, val_acc5 = eval_loss_acc1_acc5(model, val_loader, loss_func, args.cuda) 142 | print('**Validation loss:{:.4e}, top-1 accuracy:{:.5f}, top-5 accuracy:{:.5f}'.format(val_loss, val_acc1, 143 | val_acc5)) 144 | # also evaluate training data 145 | tr_loss, tr_acc1, tr_acc5 = eval_loss_acc1_acc5(model, train_loader4eval, loss_func, args.cuda) 146 | print('###Training loss:{:.4e}, top-1 accuracy:{:.5f}, top-5 accuracy:{:.5f}'.format(tr_loss, tr_acc1, tr_acc5)) 147 | else: 148 | val_acc1 = 0.0 149 | print('For imagenet, skip the first validation evaluation.') 150 | 151 | old_file = None 152 | 153 | energy_step = math.ceil( 154 | max(0.0, cur_energy - args.budget * budget_ub) / ((len(tr_loader) * args.epochs) / args.proj_int)) 155 | 156 | energy_decay_factor = min(1.0, (args.budget * budget_ub) / cur_energy) ** \ 157 | (1.0 / ((len(tr_loader) * args.epochs) / args.proj_int)) 158 | 159 | optimizer = torch.optim.SGD(filtered_parameters(model, param_name='input_mask', inverse=True), lr=args.lr, momentum=args.momentum, weight_decay=netl2wd) 160 | if args.input_mask: 161 | Xoptimizer = torch.optim.Adam(filtered_parameters(model, param_name='input_mask', inverse=False), lr=args.xlr, weight_decay=args.xl2wd) 162 | 163 | cur_budget = cur_energy_relaxed 164 | lr = args.lr 165 | xlr = args.xlr 166 | cur_sparsity = model_sparsity(model) 167 | 168 | best_acc_pruned = None 169 | Xbudget = 0.9 170 | iter_idx = 0 171 | 172 | W_proj_time = 0.0 173 | W_proj_time_cnt = 1e-15 174 | while True: 175 | # update W 176 | if not (args.skip1 and iter_idx == 0): 177 | t_begin = time.time() 178 | log_tic = t_begin 179 | for epoch in range(args.epochs): 180 | for batch_idx, (data, target) in enumerate(tr_loader): 181 | model.train() 182 | if args.cuda: 183 | data, target = data.cuda(), target.cuda() 184 | 185 | loss = loss_func(model, data, target) 186 | # update network weights 187 | optimizer.zero_grad() 188 | loss.backward() 189 | optimizer.step() 190 | 191 | if args.proj_int == 1 or (batch_idx > 0 and batch_idx % args.proj_int == 0) or batch_idx == len(tr_loader) - 1: 192 | temp_tic = time.time() 193 | proj_func(model, cur_budget) 194 | W_proj_time += time.time() - temp_tic 195 | W_proj_time_cnt += 1 196 | if epoch == args.epochs - 1 and batch_idx >= len(tr_loader) - 1 - args.proj_int: 197 | cur_budget = args.budget * budget_ub 198 | else: 199 | if args.exp_bdecay: 200 | cur_budget = max(cur_budget * energy_decay_factor, args.budget * budget_ub) 201 | else: 202 | cur_budget = max(cur_budget - energy_step, args.budget * budget_ub) 203 | 204 | if batch_idx % args.log_interval == 0: 205 | print('======================================================') 206 | print('+-------------- epoch {}, batch {}/{} ----------------+'.format(epoch, batch_idx, 207 | len(tr_loader))) 208 | log_toc = time.time() 209 | print( 210 | 'primal update: net loss={:.4e}, lr={:.4e}, current normalized budget: {:.4e}, time_elapsed={:.3f}s, averaged projection_time {}'.format( 211 | loss.item(), optimizer.param_groups[0]['lr'], cur_budget / budget_ub, log_toc - log_tic, W_proj_time / W_proj_time_cnt)) 212 | log_tic = time.time() 213 | if batch_idx % args.proj_int == 0: 214 | cur_sparsity = model_sparsity(model) 215 | print('sparsity:{}'.format(cur_sparsity)) 216 | print(layers_stat(model, param_names='weight', param_filter=lambda p: p.dim() > 1)) 217 | print('+-----------------------------------------------------+') 218 | 219 | cur_energy = energy_estimator(model) 220 | cur_energy_relaxed = energy_estimator_relaxed(model) 221 | cur_sparsity = model_sparsity(model) 222 | if epoch % args.test_interval == 0: 223 | val_loss, val_acc1, val_acc5 = eval_loss_acc1_acc5(model, val_loader, loss_func, args.cuda) 224 | 225 | # also evaluate training data 226 | tr_loss, tr_acc1, tr_acc5 = eval_loss_acc1_acc5(model, train_loader4eval, loss_func, args.cuda) 227 | print('###Training loss:{:.4e}, top-1 accuracy:{:.5f}, top-5 accuracy:{:.5f}'.format(tr_loss, tr_acc1, 228 | tr_acc5)) 229 | 230 | print( 231 | '***Validation loss:{:.4e}, top-1 accuracy:{:.5f}, top-5 accuracy:{:.5f}, current normalized energy:{:.4e}, {:.4e}(relaxed), sparsity: {:.4e}'.format( 232 | val_loss, val_acc1, 233 | val_acc5, cur_energy / budget_ub, cur_energy_relaxed / budget_ub, cur_sparsity)) 234 | # save current model 235 | model_snapshot(model, os.path.join(args.logdir, 'primal_model_latest.pkl')) 236 | 237 | if args.save_interval > 0 and epoch % args.save_interval == 0: 238 | model_snapshot(model, os.path.join(args.logdir, 'Wprimal_model_epoch{}_{}.pkl'.format(iter_idx, epoch))) 239 | 240 | elapse_time = time.time() - t_begin 241 | speed_epoch = elapse_time / (1 + epoch) 242 | eta = speed_epoch * (args.epochs - epoch) 243 | print("Updating Weights, Elapsed {:.2f}s, ets {:.2f}s".format(elapse_time, eta)) 244 | 245 | if not args.input_mask: 246 | print("Complete weights training.") 247 | break 248 | else: 249 | print("Continue to train input mask.") 250 | 251 | if best_acc_pruned is not None and val_acc1 <= best_acc_pruned: 252 | print("Pruned accuracy does not improve, stop here!") 253 | break 254 | best_acc_pruned = val_acc1 255 | 256 | # update X 257 | t_begin = time.time() 258 | log_tic = t_begin 259 | for epoch in range(args.epochs): 260 | for batch_idx, (data, target) in enumerate(tr_loader): 261 | model.train() 262 | Xoptimizer.param_groups[0]['lr'] = xlr 263 | if args.cuda: 264 | data, target = data.cuda(), target.cuda() 265 | 266 | loss = loss_func(model, data, target) 267 | # update network weights 268 | Xoptimizer.zero_grad() 269 | loss.backward() 270 | Xoptimizer.step() 271 | clamp_model_weights(model, min=0.0, max=1.0, param_name='input_mask') 272 | 273 | if (batch_idx > 0 and batch_idx % args.proj_int == 0) or batch_idx == len(tr_loader) - 1: 274 | l0proj(model, Xbudget, param_name='input_mask') 275 | 276 | if batch_idx % args.log_interval == 0: 277 | print('======================================================') 278 | print('+-------------- epoch {}, batch {}/{} ----------------+'.format(epoch, batch_idx, 279 | len(tr_loader))) 280 | log_toc = time.time() 281 | print('primal update: net loss={:.4e}, xlr={:.4e}, time_elapsed={:.3f}s'.format( 282 | loss.item(), Xoptimizer.param_groups[0]['lr'], log_toc - log_tic)) 283 | log_tic = time.time() 284 | if batch_idx % args.proj_int == 0: 285 | cur_sparsity = model_sparsity(model, param_name='input_mask') 286 | print('sparsity:{}'.format(cur_sparsity)) 287 | print(layers_stat(model, param_names='input_mask')) 288 | print('+-----------------------------------------------------+') 289 | 290 | cur_energy = energy_estimator(model) 291 | cur_energy_relaxed = energy_estimator_relaxed(model) 292 | cur_sparsity = model_sparsity(model, param_name='input_mask') 293 | if epoch % args.test_interval == 0: 294 | 295 | val_loss, val_acc1, val_acc5 = eval_loss_acc1_acc5(model, val_loader, loss_func, args.cuda) 296 | 297 | # also evaluate training data 298 | tr_loss, tr_acc1, tr_acc5 = eval_loss_acc1_acc5(model, train_loader4eval, loss_func, args.cuda) 299 | print( 300 | '###Training loss:{:.4e}, top-1 accuracy:{:.5f}, top-5 accuracy:{:.5f}'.format(tr_loss, tr_acc1, 301 | tr_acc5)) 302 | 303 | print( 304 | '***Validation loss:{:.4e}, top-1 accuracy:{:.5f}, top-5 accuracy:{:.5f}, current normalized energy:{:.4e}, {:.4e}(relaxed), sparsity: {:.4e}'.format( 305 | val_loss, val_acc1, 306 | val_acc5, cur_energy / budget_ub, cur_energy_relaxed / budget_ub, cur_sparsity)) 307 | # save current model 308 | model_snapshot(model, os.path.join(args.logdir, 'primal_model_latest.pkl')) 309 | 310 | if args.save_interval > 0 and epoch % args.save_interval == 0: 311 | model_snapshot(model, os.path.join(args.logdir, 'Xprimal_model_epoch{}_{}.pkl'.format(iter_idx, epoch))) 312 | 313 | elapse_time = time.time() - t_begin 314 | speed_epoch = elapse_time / (1 + epoch) 315 | eta = speed_epoch * (args.epochs - epoch) 316 | print("Updating input mask, Elapsed {:.2f}s, ets {:.2f}s".format(elapse_time, eta)) 317 | 318 | round_model_weights(model, param_name='input_mask') 319 | # refresh X_energy_cache 320 | reset_Xenergy_cache(energy_info) 321 | cur_energy = energy_estimator(model) 322 | cur_energy_relaxed = energy_estimator_relaxed(model) 323 | 324 | iter_idx += 1 325 | Xbudget -= 0.1 326 | -------------------------------------------------------------------------------- /sa_energy_model.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch.nn.functional as F 4 | 5 | import torch 6 | from torch import nn as nn 7 | from torch.nn import Parameter 8 | 9 | from proj_utils import copy_model_weights, layers_nnz, fill_model_weights 10 | 11 | # using hardware parameters from Eyeriss 12 | 13 | default_s1 = int(100 * 1024 / 2) # input cache 14 | default_s2 = 1 * int(8 * 1024 / 2) # kernel cache 15 | default_m = 12 16 | default_n = 14 17 | 18 | # unit energy constants 19 | default_e_mac = 1.0 + 1.0 + 1.0 # including both read and write RF 20 | default_e_mem = 200.0 21 | default_e_cache = 6.0 22 | default_e_rf = 1.0 23 | 24 | 25 | class Layer_energy(object): 26 | def __init__(self, **kwargs): 27 | super(Layer_energy, self).__init__() 28 | self.h = kwargs['h'] if 'h' in kwargs else None 29 | self.w = kwargs['w'] if 'w' in kwargs else None 30 | self.c = kwargs['c'] if 'c' in kwargs else None 31 | self.d = kwargs['d'] if 'd' in kwargs else None 32 | self.xi = kwargs['xi'] if 'xi' in kwargs else None 33 | self.g = kwargs['g'] if 'g' in kwargs else None 34 | self.p = kwargs['p'] if 'p' in kwargs else None 35 | self.m = kwargs['m'] if 'm' in kwargs else None 36 | self.n = kwargs['n'] if 'n' in kwargs else None 37 | self.s1 = kwargs['s1'] if 's1' in kwargs else None 38 | self.s2 = kwargs['s2'] if 's2' in kwargs else None 39 | self.r = kwargs['r'] if 'r' in kwargs else None 40 | self.is_conv = True if self.r is not None else False 41 | 42 | if self.h is not None: 43 | self.h_ = max(0.0, math.floor((self.h + 2.0 * self.p - self.r) / float(self.xi)) + 1) 44 | if self.w is not None: 45 | self.w_ = max(0.0, math.floor((self.w + 2.0 * self.p - self.r) / float(self.xi)) + 1) 46 | 47 | self.cached_Xenergy = None 48 | 49 | def get_alpha(self, e_mem, e_cache, e_rf): 50 | if self.is_conv: 51 | return e_mem + \ 52 | (math.ceil((float(self.d) / self.g) / self.n) * (self.r ** 2) / float(self.xi ** 2)) * e_cache + \ 53 | ((float(self.d) / self.g) * (self.r ** 2) / (self.xi ** 2)) * e_rf 54 | else: 55 | if self.c <= default_s1: 56 | return e_mem + math.ceil(float(self.d) / self.n) * e_cache + float(self.d) * e_rf 57 | else: 58 | return math.ceil(float(self.d) / self.n) * e_mem + math.ceil(float(self.d) / self.n) * e_cache + float( 59 | self.d) * e_rf 60 | 61 | def get_beta(self, e_mem, e_cache, e_rf, in_cache=None): 62 | if self.is_conv: 63 | n = 1 if in_cache else math.ceil(self.h_ * self.w_ / self.m) 64 | return n * e_mem + math.ceil(self.h_ * self.w_ / self.m) * e_cache + \ 65 | (self.h_ * self.w_) * e_rf 66 | else: 67 | return e_mem + e_cache + e_rf 68 | 69 | def get_gamma(self, e_mem, k=None): 70 | if self.is_conv: 71 | rows_per_batch = math.floor(self.s1 / float(k)) 72 | assert rows_per_batch >= self.r 73 | # print(self.__dict__) 74 | # print('###########', rows_per_batch, self.s1, k) 75 | # print('conv input data energy (2):{:.2e}'.format(float(k) * (self.r - 1) * (math.ceil(float(self.h) / (rows_per_batch - self.r + 1)) - 1))) 76 | 77 | return (float(self.d) * self.h_ * self.w_) * e_mem + \ 78 | float(k) * (self.r - self.xi) * \ 79 | max(0.0, (math.ceil(float(self.h) / (rows_per_batch - self.r + self.xi)) - 1)) * e_mem 80 | else: 81 | return float(self.d) * e_mem 82 | 83 | def get_knapsack_weight_W(self, e_mac, e_mem, e_cache, e_rf, in_cache=None, crelax=False): 84 | if self.is_conv: 85 | if crelax: 86 | # use relaxed computation energy estimation (larger than the real computation energy) 87 | return self.get_beta(e_mem, e_cache, e_rf, in_cache) + e_mac * self.h_ * self.w_ 88 | else: 89 | # computation energy will be included in other place 90 | return self.get_beta(e_mem, e_cache, e_rf, in_cache) + e_mac * 0.0 91 | else: 92 | return self.get_beta(e_mem, e_cache, e_rf, in_cache) + e_mac 93 | 94 | def get_knapsack_bound_W(self, e_mem, e_cache, e_rf, X_nnz, k): 95 | if self.is_conv: 96 | return self.get_gamma(e_mem, k) + self.get_alpha(e_mem, e_cache, e_rf) * X_nnz 97 | else: 98 | return self.get_gamma(e_mem) + self.get_alpha(e_mem, e_cache, e_rf) * X_nnz 99 | 100 | 101 | def build_energy_info(model, m=default_m, n=default_n, s1=default_s1, s2=default_s2): 102 | res = {} 103 | for name, p in model.named_parameters(): 104 | if name.endswith('input_mask'): 105 | layer_name = name[:-len('input_mask') - 1] 106 | if layer_name in res: 107 | res[layer_name]['h'] = p.size()[1] 108 | res[layer_name]['w'] = p.size()[2] 109 | else: 110 | res[layer_name] = {'h': p.size()[1], 'w': p.size()[2]} 111 | elif name.endswith('.hw'): 112 | layer_name = name[:-len('hw') - 1] 113 | if layer_name in res: 114 | res[layer_name]['h'] = float(p.data[0]) 115 | res[layer_name]['w'] = float(p.data[1]) 116 | else: 117 | res[layer_name] = {'h': float(p.data[0]), 'w': float(p.data[1])} 118 | elif name.endswith('.xi'): 119 | layer_name = name[:-len('xi') - 1] 120 | if layer_name in res: 121 | res[layer_name]['xi'] = float(p.data[0]) 122 | else: 123 | res[layer_name] = {'xi': float(p.data[0])} 124 | elif name.endswith('.g'): 125 | layer_name = name[:-len('g') - 1] 126 | if layer_name in res: 127 | res[layer_name]['g'] = float(p.data[0]) 128 | else: 129 | res[layer_name] = {'g': float(p.data[0])} 130 | elif name.endswith('.p'): 131 | layer_name = name[:-len('p') - 1] 132 | if layer_name in res: 133 | res[layer_name]['p'] = float(p.data[0]) 134 | else: 135 | res[layer_name] = {'p': float(p.data[0])} 136 | elif name.endswith('weight'): 137 | if len(p.size()) == 2 or len(p.size()) == 4: 138 | layer_name = name[:-len('weight') - 1] 139 | if layer_name in res: 140 | res[layer_name]['d'] = p.size()[0] 141 | res[layer_name]['c'] = p.size()[1] 142 | else: 143 | res[layer_name] = {'d': p.size()[0], 'c': p.size()[1]} 144 | if p.dim() > 2: 145 | # (out_channels, in_channels, kernel_size[0], kernel_size[1]) 146 | assert p.dim() == 4 147 | res[layer_name]['r'] = p.size()[2] 148 | else: 149 | continue 150 | 151 | res[layer_name]['m'] = m 152 | res[layer_name]['n'] = n 153 | res[layer_name]['s1'] = s1 154 | res[layer_name]['s2'] = s2 155 | 156 | for layer_name in res: 157 | res[layer_name] = Layer_energy(**(res[layer_name])) 158 | if res[layer_name].g is not None and res[layer_name].g > 1: 159 | res[layer_name].c *= res[layer_name].g 160 | return res 161 | 162 | 163 | def reset_Xenergy_cache(energy_info): 164 | for layer_name in energy_info: 165 | energy_info[layer_name].cached_Xenergy = None 166 | return energy_info 167 | 168 | 169 | def conv_cache_overlap(X_supp, padding, kernel_size, stride, k_X): 170 | rs = X_supp.transpose(0, 1).contiguous().view(X_supp.size(1), -1).sum(dim=1).cpu() 171 | rs = torch.cat([torch.zeros(padding, dtype=rs.dtype, device=rs.device), 172 | rs, 173 | torch.zeros(padding, dtype=rs.dtype, device=rs.device)]) 174 | res = 0 175 | beg = 0 176 | end = None 177 | while beg + kernel_size - 1 < rs.size(0): 178 | if end is not None: 179 | if beg < end: 180 | res += rs[beg:end].sum().item() 181 | n_elements = 0 182 | for i in range(rs.size(0) - beg): 183 | if n_elements + rs[beg+i] <= k_X: 184 | n_elements += rs[beg+i] 185 | if beg + i == rs.size(0) - 1: 186 | end = rs.size(0) 187 | else: 188 | end = beg + i 189 | break 190 | assert end - beg >= kernel_size, 'can only hold {} rows with {} elements < {} rows in {}, cache size={}'.format(end - beg, n_elements, kernel_size, X_supp.size(), k_X) 191 | # print('map size={}. begin={}, end={}'.format(X_supp.size(), beg, end)) 192 | beg += (math.floor((end - beg - kernel_size) / stride) + 1) * stride 193 | return res 194 | 195 | 196 | def energy_eval2(model, energy_info, e_mac=default_e_mac, e_mem=default_e_mem, e_cache=default_e_cache, 197 | e_rf=default_e_rf, verbose=False, crelax=False): 198 | X_nnz_dict = layers_nnz(model, normalized=False, param_name='input_mask') 199 | 200 | W_nnz_dict = layers_nnz(model, normalized=False, param_name='weight') 201 | 202 | W_energy = [] 203 | C_energy = [] 204 | X_energy = [] 205 | X_supp_dict = {} 206 | for name, p in model.named_parameters(): 207 | if name.endswith('input_mask'): 208 | layer_name = name[:-len('input_mask') - 1] 209 | X_supp_dict[layer_name] = (p.data != 0.0).float() 210 | 211 | for name, p in model.named_parameters(): 212 | if name.endswith('weight'): 213 | if p is None or p.dim() == 1: 214 | continue 215 | layer_name = name[:-len('weight') - 1] 216 | einfo = energy_info[layer_name] 217 | 218 | if einfo.is_conv: 219 | X_nnz = einfo.h * einfo.w * einfo.c 220 | else: 221 | X_nnz = einfo.c 222 | if layer_name in X_nnz_dict: 223 | # this layer has sparse input 224 | X_nnz = X_nnz_dict[layer_name] 225 | 226 | if layer_name in X_supp_dict: 227 | X_supp = X_supp_dict[layer_name].unsqueeze(0) 228 | else: 229 | if einfo.is_conv: 230 | X_supp = torch.ones(1, int(einfo.c), int(einfo.h), int(einfo.w), dtype=p.dtype, device=p.device) 231 | else: 232 | X_supp = None 233 | 234 | unfoldedX = None 235 | 236 | # input data access energy 237 | if einfo.is_conv: 238 | h_, w_ = max(0.0, math.floor((einfo.h + 2 * einfo.p - einfo.r) / einfo.xi) + 1), max(0.0, math.floor((einfo.w + 2 * einfo.p - einfo.r) / einfo.xi) + 1) 239 | if verbose: 240 | print('Layer: {}, input shape: ({}, {}, {}), output shape: ({}, {}, {}), weight shape: {}' 241 | .format(layer_name, einfo.c, einfo.h, einfo.w, einfo.d, h_, w_, p.shape)) 242 | unfoldedX = F.unfold(X_supp, kernel_size=int(einfo.r), padding=int(einfo.p), stride=int(einfo.xi)).squeeze(0) 243 | assert unfoldedX.size(1) == h_ * w_, 'unfolded X size={}, but h_ * w_ = {}, W.size={}'.format(unfoldedX.size(), h_ * w_, p.size()) 244 | unfoldedX_nnz = (unfoldedX != 0.0).float().sum().item() 245 | 246 | X_energy_cache = unfoldedX_nnz * math.ceil((float(einfo.d) / einfo.g) / einfo.n) * e_cache 247 | X_energy_rf = unfoldedX_nnz * math.ceil(float(einfo.d) / einfo.g) * e_rf 248 | 249 | X_energy_mem = X_nnz * e_mem + \ 250 | conv_cache_overlap(X_supp.squeeze(0), int(einfo.p), int(einfo.r), int(einfo.xi), default_s1) * e_mem + \ 251 | unfoldedX.size(1) * einfo.d * e_mem 252 | X_energy_this = X_energy_mem + X_energy_rf + X_energy_cache 253 | else: 254 | X_energy_cache = math.ceil(float(einfo.d) / einfo.n) * e_cache * X_nnz 255 | X_energy_rf = float(einfo.d) * e_rf * X_nnz 256 | X_energy_mem = e_mem * (math.ceil(float(einfo.d) / einfo.n) * max(0.0, X_nnz - default_s1) 257 | + min(X_nnz, default_s1)) + e_mem * float(einfo.d) 258 | 259 | X_energy_this = X_energy_mem + X_energy_rf + X_energy_cache 260 | 261 | einfo.cached_Xenergy = X_energy_this 262 | X_energy.append(X_energy_this) 263 | 264 | # kernel weights data access energy 265 | if einfo.is_conv: 266 | output_hw = unfoldedX.size(1) 267 | W_energy_cache = math.ceil(output_hw / einfo.m) * W_nnz_dict[layer_name] * e_cache 268 | W_energy_rf = output_hw * W_nnz_dict[layer_name] * e_rf 269 | W_energy_mem = (math.ceil(output_hw / einfo.m) * max(0.0, W_nnz_dict[layer_name] - default_s2)\ 270 | + min(default_s2, W_nnz_dict[layer_name])) * e_mem 271 | W_energy_this = W_energy_cache + W_energy_rf + W_energy_mem 272 | else: 273 | W_energy_this = einfo.get_beta(e_mem, e_cache, e_rf, in_cache=None) * W_nnz_dict[layer_name] 274 | W_energy.append(W_energy_this) 275 | 276 | # computation enregy 277 | if einfo.is_conv: 278 | if crelax: 279 | N_mac = energy_info[layer_name].h_ * float(energy_info[layer_name].w_) * W_nnz_dict[layer_name] 280 | else: 281 | N_mac = torch.sum( 282 | F.conv2d(X_supp, (p.data != 0.0).float(), None, int(energy_info[layer_name].xi), 283 | int(energy_info[layer_name].p), 1, int(energy_info[layer_name].g))).item() 284 | C_energy_this = e_mac * N_mac 285 | else: 286 | C_energy_this = e_mac * (W_nnz_dict[layer_name]) 287 | 288 | C_energy.append(C_energy_this) 289 | 290 | if verbose: 291 | print("Layer: {}, W_energy={:.2e}, C_energy={:.2e}, X_energy={:.2e}".format(layer_name, 292 | W_energy[-1], 293 | C_energy[-1], 294 | X_energy[-1])) 295 | 296 | return {'W': sum(W_energy), 'C': sum(C_energy), 'X': sum(X_energy)} 297 | 298 | 299 | def energy_eval2_relax(model, energy_info, e_mac=default_e_mac, e_mem=default_e_mem, e_cache=default_e_cache, 300 | e_rf=default_e_rf, verbose=False): 301 | W_nnz_dict = layers_nnz(model, normalized=False, param_name='weight') 302 | 303 | W_energy = [] 304 | C_energy = [] 305 | X_energy = [] 306 | X_supp_dict = {} 307 | for name, p in model.named_parameters(): 308 | if name.endswith('input_mask'): 309 | layer_name = name[:-len('input_mask') - 1] 310 | X_supp_dict[layer_name] = (p.data != 0.0).float() 311 | 312 | for name, p in model.named_parameters(): 313 | if name.endswith('weight'): 314 | if p is None or p.dim() == 1: 315 | continue 316 | layer_name = name[:-len('weight') - 1] 317 | assert energy_info[layer_name].cached_Xenergy is not None 318 | X_energy.append(energy_info[layer_name].cached_Xenergy) 319 | assert X_energy[-1] > 0 320 | if not energy_info[layer_name].is_conv: 321 | # in_cache is not needed in fc layers 322 | in_cache = None 323 | W_energy.append( 324 | energy_info[layer_name].get_beta(e_mem, e_cache, e_rf, in_cache) * W_nnz_dict[layer_name]) 325 | C_energy.append(e_mac * (W_nnz_dict[layer_name])) 326 | if verbose: 327 | knapsack_weight1 = energy_info[layer_name].get_knapsack_weight_W(e_mac, e_mem, e_cache, e_rf, 328 | in_cache=None, crelax=True) 329 | if hasattr(knapsack_weight1, 'mean'): 330 | knapsack_weight1 = knapsack_weight1.mean() 331 | print(layer_name + " weight: {:.4e}".format(knapsack_weight1)) 332 | 333 | else: 334 | beta1 = energy_info[layer_name].get_beta(e_mem, e_cache, e_rf, in_cache=True) 335 | beta2 = energy_info[layer_name].get_beta(e_mem, e_cache, e_rf, in_cache=False) 336 | 337 | W_nnz = W_nnz_dict[layer_name] 338 | W_energy_this = beta1 * min(energy_info[layer_name].s2, W_nnz) + beta2 * max(0, W_nnz - energy_info[ 339 | layer_name].s2) 340 | W_energy.append(W_energy_this) 341 | C_energy.append(e_mac * energy_info[layer_name].h_ * float(energy_info[layer_name].w_) * W_nnz) 342 | 343 | if verbose: 344 | print("Layer: {}, W_energy={:.2e}, C_energy={:.2e}, X_energy={:.2e}".format(layer_name, 345 | W_energy[-1], 346 | C_energy[-1], 347 | X_energy[-1])) 348 | 349 | return {'W': sum(W_energy), 'C': sum(C_energy), 'X': sum(X_energy)} 350 | 351 | 352 | def energy_proj2(model, energy_info, budget, e_mac=default_e_mac, e_mem=default_e_mem, e_cache=default_e_cache, 353 | e_rf=default_e_rf, grad=False, in_place=True, preserve=0.0, param_name='weight'): 354 | knapsack_bound = budget 355 | param_flats = [] 356 | knapsack_weight_all = [] 357 | score_all = [] 358 | param_shapes = [] 359 | bound_bias = 0.0 360 | 361 | for name, p in model.named_parameters(): 362 | if name.endswith(param_name): 363 | if p is None or (param_name == 'weight' and p.dim() == 1): 364 | # skip batch_norm layer 365 | param_shapes.append((name, None)) 366 | continue 367 | else: 368 | param_shapes.append((name, p.data.shape)) 369 | 370 | layer_name = name[:-len(param_name) - 1] 371 | assert energy_info[layer_name].cached_Xenergy is not None 372 | if grad: 373 | p_flat = p.grad.data.view(-1) 374 | else: 375 | p_flat = p.data.view(-1) 376 | score = p_flat ** 2 377 | 378 | if param_name == 'weight': 379 | knapsack_weight = energy_info[layer_name].get_knapsack_weight_W(e_mac, e_mem, e_cache, e_rf, 380 | in_cache=True, crelax=True) 381 | if hasattr(knapsack_weight, 'view'): 382 | knapsack_weight = knapsack_weight.view(1, -1, 1, 1) 383 | knapsack_weight = torch.zeros_like(p.data).add_(knapsack_weight).view(-1) 384 | 385 | # preserve part of weights 386 | if preserve > 0.0: 387 | if preserve > 1: 388 | n_preserve = preserve 389 | else: 390 | n_preserve = round(p_flat.numel() * preserve) 391 | _, preserve_idx = torch.topk(score, k=n_preserve, largest=True, sorted=False) 392 | score[preserve_idx] = float('inf') 393 | 394 | if energy_info[layer_name].is_conv and p_flat.numel() > energy_info[layer_name].s2: 395 | delta = energy_info[layer_name].get_beta(e_mem, e_cache, e_rf, in_cache=False) \ 396 | - energy_info[layer_name].get_beta(e_mem, e_cache, e_rf, in_cache=True) 397 | assert delta >= 0 398 | _, out_cache_idx = torch.topk(score, k=p_flat.numel() - energy_info[layer_name].s2, largest=False, 399 | sorted=False) 400 | knapsack_weight[out_cache_idx] += delta 401 | 402 | bound_const = energy_info[layer_name].cached_Xenergy 403 | 404 | assert bound_const > 0 405 | bound_bias += bound_const 406 | knapsack_bound -= bound_const 407 | 408 | else: 409 | raise ValueError('not supported parameter name') 410 | 411 | score_all.append(score) 412 | knapsack_weight_all.append(knapsack_weight) 413 | # print(layer_name, X_nnz, knapsack_weight) 414 | param_flats.append(p_flat) 415 | 416 | param_flats = torch.cat(param_flats, dim=0) 417 | knapsack_weight_all = torch.cat(knapsack_weight_all, dim=0) 418 | score_all = torch.cat(score_all, dim=0) / knapsack_weight_all 419 | 420 | _, sorted_idx = torch.sort(score_all, descending=True) 421 | cumsum = torch.cumsum(knapsack_weight_all[sorted_idx], dim=0) 422 | res_nnz = torch.nonzero(cumsum <= knapsack_bound).max() 423 | z_idx = sorted_idx[-(param_flats.numel() - res_nnz):] 424 | 425 | if in_place: 426 | param_flats[z_idx] = 0.0 427 | copy_model_weights(model, param_flats, param_shapes, param_name) 428 | return z_idx, param_shapes 429 | 430 | 431 | class myConv2d(nn.Conv2d): 432 | def __init__(self, h_in, w_in, in_channels, out_channels, kernel_size, stride=1, 433 | padding=0, dilation=1, groups=1, bias=True): 434 | super(myConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, 435 | padding, dilation, groups, bias) 436 | self.h_in = h_in 437 | self.w_in = w_in 438 | self.xi = Parameter(torch.LongTensor(1), requires_grad=False) 439 | self.xi.data[0] = stride 440 | self.g = Parameter(torch.LongTensor(1), requires_grad=False) 441 | self.g.data[0] = groups 442 | self.p = Parameter(torch.LongTensor(1), requires_grad=False) 443 | self.p.data[0] = padding 444 | 445 | def __repr__(self): 446 | s = ('{name}({h_in}, {w_in}, {in_channels}, {out_channels}, kernel_size={kernel_size}' 447 | ', stride={stride}') 448 | if self.padding != (0,) * len(self.padding): 449 | s += ', padding={padding}' 450 | if self.dilation != (1,) * len(self.dilation): 451 | s += ', dilation={dilation}' 452 | if self.output_padding != (0,) * len(self.output_padding): 453 | s += ', output_padding={output_padding}' 454 | if self.groups != 1: 455 | s += ', groups={groups}' 456 | if self.bias is None: 457 | s += ', bias=False' 458 | s += ')' 459 | return s.format(name=self.__class__.__name__, **self.__dict__) 460 | 461 | 462 | class FixHWConv2d(myConv2d): 463 | def __init__(self, h_in, w_in, in_channels, out_channels, kernel_size, stride=1, 464 | padding=0, dilation=1, groups=1, bias=True): 465 | super(FixHWConv2d, self).__init__(h_in, w_in, in_channels, out_channels, kernel_size, stride, 466 | padding, dilation, groups, bias) 467 | 468 | self.hw = Parameter(torch.LongTensor(2), requires_grad=False) 469 | self.hw.data[0] = h_in 470 | self.hw.data[1] = w_in 471 | 472 | def forward(self, input): 473 | # Input: :math:`(N, C_{in}, H_{in}, W_{in})` 474 | assert input.size(2) == self.hw.data[0] and input.size(3) == self.hw.data[1], 'input_size={}, but hw={}'.format( 475 | input.size(), self.hw.data) 476 | return super(FixHWConv2d, self).forward(input) 477 | 478 | 479 | class SparseConv2d(myConv2d): 480 | def __init__(self, h_in, w_in, in_channels, out_channels, kernel_size, stride=1, 481 | padding=0, dilation=1, groups=1, bias=True): 482 | super(SparseConv2d, self).__init__(h_in, w_in, in_channels, out_channels, kernel_size, stride, 483 | padding, dilation, groups, bias) 484 | 485 | self.input_mask = Parameter(torch.Tensor(in_channels, h_in, w_in)) 486 | self.input_mask.data.fill_(1.0) 487 | 488 | def forward(self, input): 489 | # print("###{}, {}".format(input.size(), self.input_mask.size())) 490 | return super(SparseConv2d, self).forward(input * self.input_mask) 491 | 492 | 493 | def conv2d_out_dim(dim, kernel_size, padding=0, stride=1, dilation=1, ceil_mode=False): 494 | if ceil_mode: 495 | return int(math.ceil((dim + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1)) 496 | else: 497 | return int(math.floor((dim + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1)) 498 | 499 | --------------------------------------------------------------------------------