├── .gitignore ├── README.md ├── datasets.py ├── flop_counter.py ├── gate.py ├── misc.py ├── models ├── __init__.py ├── mobilenet.py ├── resnet.py └── vgg.py └── script ├── learn_gates.py ├── learn_gates_imagenet.py ├── prepare_imagenet_list.py ├── prune_model.py ├── prune_model_imagenet.py ├── train_pruned.py └── train_pruned_imagenet.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.npy 3 | *.npz 4 | *.pkl 5 | *.ckpt 6 | .idea/ 7 | .ipynb_checkpoints/ 8 | __pycache__/ 9 | .DS_Store 10 | data/ 11 | logs/ 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pruning from Scratch 2 | 3 | official implementation of the paper [Pruning from Scratch](https://www.aaai.org/Papers/AAAI/2020GB/AAAI-WangY.403.pdf) 4 | 5 | ## Requirements 6 | - pytorch == 1.1.0 7 | - torchvision == 0.2.2 8 | - [apex](https://github.com/NVIDIA/apex) @ commit: 574fe24 9 | 10 | ## CIFAR10 11 | - learning channel importance gates from randomly initialized weights 12 | ```bash 13 | python script/learn_gates.py -a ARCH --gpu GPU_ID --seed SEED -s SPARSITY -e EXPANSION 14 | ``` 15 | where `ARCH` is network architecture type, 16 | `SPARSITY` is the sparsity ratio $r$ in regularization term, 17 | `EXPANSION` is expansion channel number of initial conv layer. 18 | - pruning based on channel gates 19 | ```bash 20 | python script/prune_model.py -a ARCH --gpu GPU_ID --seed SEED -s SPARSITY -e EXPANSION -p RATIO 21 | ``` 22 | where `RATIO` is the pruned model MACs reduction ratio, larger ratio indicates more compact model. 23 | - training pruned model from scratch 24 | ```bash 25 | python script/train_pruned.py -a ARCH --gpu GPU_ID --seed SEED -s SPARSITY -e EXPANSION -p RATIO --budget_train 26 | ``` 27 | where `--budget_train` activates the budget training scheme (Scratch-B) proposed in 28 | [Rethinking the Value of Network Pruning](https://arxiv.org/abs/1810.05270), 29 | which trains the pruned model for the same amount of computation bud- get with the full model. 30 | Empirically, this training scheme is crucial for improving the pruned model performance. 31 | 32 | ## ImageNet 33 | - prepare imagenet dataset following the instructions in 34 | [link](https://github.com/pytorch/examples/tree/master/imagenet), 35 | which results in an imagenet folder with train and val sub-folders. 36 | - generate image index by 37 | ```bash 38 | python script/prepare_imagenet_list.py --data_dir IMAGENET_DATA_DIR/train --dump_path data/train_images_list.pkl 39 | python scrtpt/prepare_imagenet_list.py --data_dir IMAGENET_DATA_DIR/val --dump_path data/val_images_list.pkl 40 | ``` 41 | - learning channel importance gates from randomly initialized weights 42 | ```bash 43 | python script/learn_gates_imagenet.py -a ARCH --gpu GPU_ID -s SPARSITY -e EXPANSION -m MULTIPLIER 44 | ``` 45 | where `MULTIPLIER` is used to control the expansion of channel number on the backbone outputs, 46 | while `EXPANSION` is used to enlarge the intermediate channel numbers in InvertedResidual and Bottleneck blocks. 47 | - pruning based on channel gates 48 | ```bash 49 | python script/prune_model_imagenet.py -a ARCH --gpu GPU_ID -s SPARSITY -e EXPANSION -m MULTIPLIER -p RATIO 50 | ``` 51 | - training pruned model from scratch (single node multiple gpus) 52 | ```bash 53 | python -m torch.distributed.launch --nproc_per_node=NUM_GPU script/train_pruned_imagenet.py \ 54 | -a ARCH -e EXPANSION -s SPARSITY -p RATIO -m MULTIPLIER \ 55 | -b TRAIN_BATCH_SIZE --lr LR --wd WD --lr_scheduler SCHEDULER \ 56 | --budget_train --label_smooth 57 | ``` 58 | where `SCHEDULER` is learning rate scheduler type, 'multistep' for ResNet50, 'cos' for MobileNets. 59 | 60 | ## Citation 61 | ```bibtex 62 | @inproceedings{wang2020pruning, 63 | title={Pruning from Scratch}, 64 | author={Wang, Yulong and Zhang, Xiaolu and Xie, Lingxi and Zhou, Jun and Su, Hang and Zhang, Bo and Hu, Xiaolin}, 65 | booktitle={Proceedings of the 29th International Joint Conference on Artificial Intelligence}, 66 | year={2020}, 67 | publisher={AAAI Press}, 68 | address={New York, USA} 69 | } 70 | ``` 71 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from PIL import Image 3 | import os 4 | import os.path 5 | import numpy as np 6 | import sys 7 | if sys.version_info[0] == 2: 8 | import cPickle as pickle 9 | else: 10 | import pickle 11 | 12 | import torch.utils.data as data 13 | import misc 14 | import torchvision.transforms as tfm 15 | import torchvision.datasets as ds 16 | import torch 17 | 18 | 19 | class CIFAR10(data.Dataset): 20 | """`CIFAR10 `_ Dataset. 21 | 22 | Args: 23 | root (string): Root directory of dataset where directory 24 | ``cifar-10-batches-py`` exists or will be saved to if download is set to True. 25 | train (bool, optional): If True, creates dataset from training set, otherwise 26 | creates from test set. 27 | transform (callable, optional): A function/transform that takes in an PIL image 28 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 29 | target_transform (callable, optional): A function/transform that takes in the 30 | target and transforms it. 31 | download (bool, optional): If true, downloads the dataset from the internet and 32 | puts it in root directory. If dataset is already downloaded, it is not 33 | downloaded again. 34 | 35 | """ 36 | base_folder = 'cifar-10-batches-py' 37 | url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" 38 | filename = "cifar-10-python.tar.gz" 39 | tgz_md5 = 'c58f30108f718f92721af3b95e74349a' 40 | train_list = [ 41 | ['data_batch_1', 'c99cafc152244af753f735de768cd75f'], 42 | ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], 43 | ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], 44 | ['data_batch_4', '634d18415352ddfa80567beed471001a'], 45 | ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'], 46 | ] 47 | 48 | test_list = [ 49 | ['test_batch', '40351d587109b95175f43aff81a1287e'], 50 | ] 51 | 52 | def __init__(self, root, type='train', 53 | transform=None, target_transform=None): 54 | self.root = os.path.expanduser(root) 55 | self.transform = transform 56 | self.target_transform = target_transform 57 | self.type = type 58 | 59 | # now load the picked numpy arrays 60 | train_data = [] 61 | train_labels = [] 62 | for fentry in self.train_list: 63 | f = fentry[0] 64 | file = os.path.join(self.root, self.base_folder, f) 65 | fo = open(file, 'rb') 66 | if sys.version_info[0] == 2: 67 | entry = pickle.load(fo) 68 | else: 69 | entry = pickle.load(fo, encoding='latin1') 70 | train_data.append(entry['data']) 71 | if 'labels' in entry: 72 | train_labels += entry['labels'] 73 | else: 74 | train_labels += entry['fine_labels'] 75 | fo.close() 76 | 77 | train_data = np.concatenate(train_data) 78 | train_data = train_data.reshape((50000, 3, 32, 32)) 79 | train_data = train_data.transpose((0, 2, 3, 1)) # convert to HWC 80 | 81 | if self.type == 'train': 82 | self.data = train_data[:45000] 83 | self.labels = train_labels[:45000] 84 | elif self.type == 'val': 85 | self.data = train_data[45000:] 86 | self.labels = train_labels[45000:] 87 | elif self.type == 'train+val': 88 | self.data = train_data 89 | self.labels = train_labels 90 | else: 91 | f = self.test_list[0][0] 92 | file = os.path.join(self.root, self.base_folder, f) 93 | fo = open(file, 'rb') 94 | if sys.version_info[0] == 2: 95 | entry = pickle.load(fo) 96 | else: 97 | entry = pickle.load(fo, encoding='latin1') 98 | self.data = entry['data'] 99 | if 'labels' in entry: 100 | self.labels = entry['labels'] 101 | else: 102 | self.labels = entry['fine_labels'] 103 | fo.close() 104 | self.data = self.data.reshape((10000, 3, 32, 32)) 105 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC 106 | 107 | 108 | def __getitem__(self, index): 109 | """ 110 | Args: 111 | index (int): Index 112 | 113 | Returns: 114 | tuple: (image, target) where target is index of the target class. 115 | """ 116 | img, target = self.data[index], self.labels[index] 117 | 118 | # doing this so that it is consistent with all other datasets 119 | # to return a PIL Image 120 | img = Image.fromarray(img) 121 | 122 | if self.transform is not None: 123 | img = self.transform(img) 124 | 125 | if self.target_transform is not None: 126 | target = self.target_transform(target) 127 | 128 | return img, target 129 | 130 | def __len__(self): 131 | return len(self.data) 132 | 133 | 134 | class ImageNet(data.Dataset): 135 | def __init__(self, root, type='train', transform=None): 136 | self.root = os.path.expanduser(root) 137 | self.transform = transform 138 | self.type = type 139 | all_train_image_list = misc.load_pickle(os.path.join(self.root, 'train_img_list.pkl')) 140 | all_test_image_list = misc.load_pickle(os.path.join(self.root, 'val_img_list.pkl')) 141 | self.train_image_list = [] 142 | self.train_labels = [] 143 | self.val_image_list = [] 144 | self.val_labels = [] 145 | self.test_image_list = [] 146 | self.test_labels = [] 147 | for i in range(1000): 148 | self.train_image_list += all_train_image_list[i][:-50] 149 | self.train_labels += [i] * len(all_train_image_list[i][:-50]) 150 | self.val_image_list += all_train_image_list[i][-50:] 151 | self.val_labels += [i] * 50 152 | self.test_image_list += all_test_image_list[i] 153 | self.test_labels += [i] * 50 154 | 155 | if self.type == 'train': 156 | self.data = self.train_image_list 157 | self.labels = self.train_labels 158 | elif self.type == 'val': 159 | self.data = self.val_image_list 160 | self.labels = self.val_labels 161 | elif self.type == 'train+val': 162 | self.data = self.train_image_list + self.val_image_list 163 | self.labels = self.train_labels + self.val_labels 164 | elif self.type == 'test': 165 | self.data = self.test_image_list 166 | self.labels = self.test_labels 167 | 168 | def __len__(self): 169 | return len(self.data) 170 | 171 | def __getitem__(self, item): 172 | img_path = self.data[item] 173 | target = self.labels[item] 174 | img = misc.pil_loader(img_path) 175 | if self.transform is not None: 176 | img = self.transform(img) 177 | 178 | return img, target 179 | 180 | 181 | imagenet_pca = { 182 | 'eigval': np.asarray([0.2175, 0.0188, 0.0045]), 183 | 'eigvec': np.asarray([ 184 | [-0.5675, 0.7192, 0.4009], 185 | [-0.5808, -0.0045, -0.8140], 186 | [-0.5836, -0.6948, 0.4203], 187 | ]) 188 | } 189 | 190 | 191 | class Lighting(object): 192 | def __init__(self, alphastd, 193 | eigval=imagenet_pca['eigval'], 194 | eigvec=imagenet_pca['eigvec']): 195 | self.alphastd = alphastd 196 | assert eigval.shape == (3,) 197 | assert eigvec.shape == (3, 3) 198 | self.eigval = eigval 199 | self.eigvec = eigvec 200 | 201 | def __call__(self, img): 202 | if self.alphastd == 0.: 203 | return img 204 | rnd = np.random.randn(3) * self.alphastd 205 | rnd = rnd.astype('float32') 206 | v = rnd 207 | old_dtype = np.asarray(img).dtype 208 | v = v * self.eigval 209 | v = v.reshape((3, 1)) 210 | inc = np.dot(self.eigvec, v).reshape((3,)) 211 | img = np.add(img, inc) 212 | if old_dtype == np.uint8: 213 | img = np.clip(img, 0, 255) 214 | img = Image.fromarray(img.astype(old_dtype), 'RGB') 215 | return img 216 | 217 | def __repr__(self): 218 | return self.__class__.__name__ + '()' 219 | 220 | 221 | def fast_collate(batch): 222 | imgs = [img[0] for img in batch] 223 | targets = torch.tensor([target[1] for target in batch], dtype=torch.int64) 224 | w = imgs[0].size[0] 225 | h = imgs[0].size[1] 226 | tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8 ) 227 | for i, img in enumerate(imgs): 228 | nump_array = np.asarray(img, dtype=np.uint8) 229 | if(nump_array.ndim < 3): 230 | nump_array = np.expand_dims(nump_array, axis=-1) 231 | nump_array = np.rollaxis(nump_array, 2) 232 | 233 | tensor[i] += torch.from_numpy(nump_array) 234 | 235 | return tensor, targets 236 | 237 | 238 | def get_imagenet_loader(root, batch_size, type='train', mobile_setting=True): 239 | crop_scale = 0.25 if mobile_setting else 0.08 240 | jitter_param = 0.4 241 | lighting_param = 0.1 242 | if type == 'train': 243 | transform = tfm.Compose([ 244 | tfm.RandomResizedCrop(224, scale=(crop_scale, 1.0)), 245 | tfm.ColorJitter( 246 | brightness=jitter_param, contrast=jitter_param, 247 | saturation=jitter_param), 248 | Lighting(lighting_param), 249 | tfm.RandomHorizontalFlip(), 250 | ]) 251 | 252 | elif type == 'test': 253 | transform = tfm.Compose([ 254 | tfm.Resize(256), 255 | tfm.CenterCrop(224), 256 | ]) 257 | 258 | dataset = ds.ImageFolder(root, transform) 259 | sampler = data.distributed.DistributedSampler(dataset) 260 | data_loader = data.DataLoader( 261 | dataset, batch_size=batch_size, shuffle=False, 262 | num_workers=4, pin_memory=True, sampler=sampler, collate_fn=fast_collate 263 | ) 264 | if type == 'train': 265 | return data_loader, sampler 266 | 267 | elif type == 'test': 268 | return data_loader 269 | 270 | 271 | class DataPrefetcher(): 272 | def __init__(self, loader): 273 | self.loader = iter(loader) 274 | self.stream = torch.cuda.Stream() 275 | self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1) 276 | self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1) 277 | self.preload() 278 | 279 | def preload(self): 280 | try: 281 | self.next_input, self.next_target = next(self.loader) 282 | except StopIteration: 283 | self.next_input = None 284 | self.next_target = None 285 | return 286 | with torch.cuda.stream(self.stream): 287 | self.next_input = self.next_input.cuda(async=True) 288 | self.next_target = self.next_target.cuda(async=True) 289 | self.next_input = self.next_input.float() 290 | self.next_input = self.next_input.sub_(self.mean).div_(self.std) 291 | 292 | def next(self): 293 | torch.cuda.current_stream().wait_stream(self.stream) 294 | input = self.next_input 295 | target = self.next_target 296 | self.preload() 297 | return input, target -------------------------------------------------------------------------------- /flop_counter.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/Lyken17/pytorch-OpCounter 2 | 3 | import torch.nn as nn 4 | import torch 5 | import numpy as np 6 | 7 | def get_model_complexity_info(model, input_res, print_per_layer_stat=True, as_strings=True): 8 | assert type(input_res) is tuple 9 | assert len(input_res) == 2 10 | batch = torch.FloatTensor(1, 3, *input_res) 11 | flops_model = add_flops_counting_methods(model) 12 | flops_model.eval().start_flops_count() 13 | out = flops_model(batch) 14 | 15 | if print_per_layer_stat: 16 | print_model_with_flops(flops_model) 17 | flops_count = flops_model.compute_average_flops_cost() 18 | params_count = get_model_parameters_number(flops_model) 19 | flops_model.stop_flops_count() 20 | 21 | if as_strings: 22 | return flops_to_string(flops_count), params_to_string(params_count) 23 | 24 | return flops_count, params_count 25 | 26 | def flops_to_string(flops, units='GMac', precision=2): 27 | if units is None: 28 | if flops // 10**9 > 0: 29 | return str(round(flops / 10.**9, precision)) + ' GMac' 30 | elif flops // 10**6 > 0: 31 | return str(round(flops / 10.**6, precision)) + ' MMac' 32 | elif flops // 10**3 > 0: 33 | return str(round(flops / 10.**3, precision)) + ' KMac' 34 | else: 35 | return str(flops) + ' Mac' 36 | else: 37 | if units == 'GMac': 38 | return str(round(flops / 10.**9, precision)) + ' ' + units 39 | elif units == 'MMac': 40 | return str(round(flops / 10.**6, precision)) + ' ' + units 41 | elif units == 'KMac': 42 | return str(round(flops / 10.**3, precision)) + ' ' + units 43 | else: 44 | return str(flops) + ' Mac' 45 | 46 | def params_to_string(params_num): 47 | if params_num // 10 ** 6 > 0: 48 | return str(round(params_num / 10 ** 6, 2)) + ' M' 49 | elif params_num // 10 ** 3: 50 | return str(round(params_num / 10 ** 3, 2)) + ' k' 51 | 52 | def print_model_with_flops(model, units='GMac', precision=3): 53 | total_flops = model.compute_average_flops_cost() 54 | 55 | def accumulate_flops(self): 56 | if is_supported_instance(self): 57 | return self.__flops__ / model.__batch_counter__ 58 | else: 59 | sum = 0 60 | for m in self.children(): 61 | sum += m.accumulate_flops() 62 | return sum 63 | 64 | def flops_repr(self): 65 | accumulated_flops_cost = self.accumulate_flops() 66 | return ', '.join([flops_to_string(accumulated_flops_cost, units=units, precision=precision), 67 | '{:.3%} MACs'.format(accumulated_flops_cost / total_flops), 68 | self.original_extra_repr()]) 69 | 70 | def add_extra_repr(m): 71 | m.accumulate_flops = accumulate_flops.__get__(m) 72 | flops_extra_repr = flops_repr.__get__(m) 73 | if m.extra_repr != flops_extra_repr: 74 | m.original_extra_repr = m.extra_repr 75 | m.extra_repr = flops_extra_repr 76 | assert m.extra_repr != m.original_extra_repr 77 | 78 | def del_extra_repr(m): 79 | if hasattr(m, 'original_extra_repr'): 80 | m.extra_repr = m.original_extra_repr 81 | del m.original_extra_repr 82 | if hasattr(m, 'accumulate_flops'): 83 | del m.accumulate_flops 84 | 85 | model.apply(add_extra_repr) 86 | print(model) 87 | model.apply(del_extra_repr) 88 | 89 | def get_model_parameters_number(model): 90 | params_num = sum(p.numel() for p in model.parameters() if p.requires_grad) 91 | return params_num 92 | 93 | def add_flops_counting_methods(net_main_module): 94 | # adding additional methods to the existing module object, 95 | # this is done this way so that each function has access to self object 96 | net_main_module.start_flops_count = start_flops_count.__get__(net_main_module) 97 | net_main_module.stop_flops_count = stop_flops_count.__get__(net_main_module) 98 | net_main_module.reset_flops_count = reset_flops_count.__get__(net_main_module) 99 | net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__(net_main_module) 100 | 101 | net_main_module.reset_flops_count() 102 | 103 | # Adding variables necessary for masked flops computation 104 | net_main_module.apply(add_flops_mask_variable_or_reset) 105 | 106 | return net_main_module 107 | 108 | 109 | def compute_average_flops_cost(self): 110 | """ 111 | A method that will be available after add_flops_counting_methods() is called 112 | on a desired net object. 113 | 114 | Returns current mean flops consumption per image. 115 | 116 | """ 117 | 118 | batches_count = self.__batch_counter__ 119 | flops_sum = 0 120 | for module in self.modules(): 121 | if is_supported_instance(module): 122 | flops_sum += module.__flops__ 123 | 124 | return flops_sum / batches_count 125 | 126 | 127 | def start_flops_count(self): 128 | """ 129 | A method that will be available after add_flops_counting_methods() is called 130 | on a desired net object. 131 | 132 | Activates the computation of mean flops consumption per image. 133 | Call it before you run the network. 134 | 135 | """ 136 | add_batch_counter_hook_function(self) 137 | self.apply(add_flops_counter_hook_function) 138 | 139 | 140 | def stop_flops_count(self): 141 | """ 142 | A method that will be available after add_flops_counting_methods() is called 143 | on a desired net object. 144 | 145 | Stops computing the mean flops consumption per image. 146 | Call whenever you want to pause the computation. 147 | 148 | """ 149 | remove_batch_counter_hook_function(self) 150 | self.apply(remove_flops_counter_hook_function) 151 | 152 | 153 | def reset_flops_count(self): 154 | """ 155 | A method that will be available after add_flops_counting_methods() is called 156 | on a desired net object. 157 | 158 | Resets statistics computed so far. 159 | 160 | """ 161 | add_batch_counter_variables_or_reset(self) 162 | self.apply(add_flops_counter_variable_or_reset) 163 | 164 | 165 | def add_flops_mask(module, mask): 166 | def add_flops_mask_func(module): 167 | if isinstance(module, torch.nn.Conv2d): 168 | module.__mask__ = mask 169 | module.apply(add_flops_mask_func) 170 | 171 | 172 | def remove_flops_mask(module): 173 | module.apply(add_flops_mask_variable_or_reset) 174 | 175 | 176 | # ---- Internal functions 177 | def is_supported_instance(module): 178 | if isinstance(module, (torch.nn.Conv2d,\ 179 | torch.nn.Linear, \ 180 | torch.nn.AvgPool2d,\ 181 | nn.AdaptiveAvgPool2d)): 182 | return True 183 | 184 | return False 185 | 186 | 187 | def empty_flops_counter_hook(module, input, output): 188 | module.__flops__ += 0 189 | 190 | 191 | def linear_flops_counter_hook(module, input, output): 192 | input = input[0] 193 | batch_size = input.shape[0] 194 | module.__flops__ += batch_size * input.shape[1] * output.shape[1] 195 | 196 | 197 | def pool_flops_counter_hook(module, input, output): 198 | input = input[0] 199 | module.__flops__ += np.prod(input.shape) 200 | 201 | 202 | def conv_flops_counter_hook(conv_module, input, output): 203 | # Can have multiple inputs, getting the first one 204 | input = input[0] 205 | 206 | batch_size = input.shape[0] 207 | output_height, output_width = output.shape[2:] 208 | 209 | kernel_height, kernel_width = conv_module.kernel_size 210 | in_channels = conv_module.in_channels 211 | out_channels = conv_module.out_channels 212 | groups = conv_module.groups 213 | 214 | filters_per_channel = out_channels // groups 215 | conv_per_position_flops = kernel_height * kernel_width * in_channels * filters_per_channel 216 | 217 | active_elements_count = batch_size * output_height * output_width 218 | 219 | if conv_module.__mask__ is not None: 220 | # (b, 1, h, w) 221 | flops_mask = conv_module.__mask__.expand(batch_size, 1, output_height, output_width) 222 | active_elements_count = flops_mask.sum() 223 | 224 | overall_conv_flops = conv_per_position_flops * active_elements_count 225 | 226 | bias_flops = 0 227 | 228 | if conv_module.bias is not None: 229 | 230 | bias_flops = out_channels * active_elements_count 231 | 232 | overall_flops = overall_conv_flops + bias_flops 233 | 234 | conv_module.__flops__ += overall_flops 235 | 236 | 237 | def batch_counter_hook(module, input, output): 238 | # Can have multiple inputs, getting the first one 239 | input = input[0] 240 | batch_size = input.shape[0] 241 | module.__batch_counter__ += batch_size 242 | 243 | 244 | def add_batch_counter_variables_or_reset(module): 245 | 246 | module.__batch_counter__ = 0 247 | 248 | 249 | def add_batch_counter_hook_function(module): 250 | if hasattr(module, '__batch_counter_handle__'): 251 | return 252 | 253 | handle = module.register_forward_hook(batch_counter_hook) 254 | module.__batch_counter_handle__ = handle 255 | 256 | 257 | def remove_batch_counter_hook_function(module): 258 | if hasattr(module, '__batch_counter_handle__'): 259 | module.__batch_counter_handle__.remove() 260 | del module.__batch_counter_handle__ 261 | 262 | 263 | def add_flops_counter_variable_or_reset(module): 264 | if is_supported_instance(module): 265 | module.__flops__ = 0 266 | 267 | 268 | def add_flops_counter_hook_function(module): 269 | if is_supported_instance(module): 270 | if hasattr(module, '__flops_handle__'): 271 | return 272 | 273 | if isinstance(module, torch.nn.Conv2d): 274 | handle = module.register_forward_hook(conv_flops_counter_hook) 275 | elif isinstance(module, torch.nn.Linear): 276 | handle = module.register_forward_hook(linear_flops_counter_hook) 277 | elif isinstance(module, (torch.nn.AvgPool2d, \ 278 | nn.AdaptiveAvgPool2d)): 279 | handle = module.register_forward_hook(pool_flops_counter_hook) 280 | else: 281 | handle = module.register_forward_hook(empty_flops_counter_hook) 282 | module.__flops_handle__ = handle 283 | 284 | 285 | def remove_flops_counter_hook_function(module): 286 | if is_supported_instance(module): 287 | if hasattr(module, '__flops_handle__'): 288 | module.__flops_handle__.remove() 289 | del module.__flops_handle__ 290 | # --- Masked flops counting 291 | 292 | 293 | # Also being run in the initialization 294 | def add_flops_mask_variable_or_reset(module): 295 | if is_supported_instance(module): 296 | module.__mask__ = None 297 | -------------------------------------------------------------------------------- /gate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import types 4 | 5 | 6 | class TorchGraph(object): 7 | def __init__(self): 8 | self._graph = {} 9 | self.persistence = {} 10 | 11 | def add_tensor_list(self, name, persist=False): 12 | self._graph[name] = [] 13 | self.persistence[name] = persist 14 | 15 | def append_tensor(self, name, val): 16 | self._graph[name].append(val) 17 | 18 | def clear_tensor_list(self, name): 19 | self._graph[name].clear() 20 | 21 | def get_tensor_list(self, name): 22 | return self._graph[name] 23 | 24 | def clear_all_tensors(self): 25 | for k in self._graph.keys(): 26 | if not self.persistence[k]: 27 | self.clear_tensor_list(k) 28 | 29 | 30 | default_graph = TorchGraph() 31 | default_graph.add_tensor_list('gates_params', True) 32 | default_graph.add_tensor_list('selected_idx') 33 | 34 | def apply_func(model, module_type, func, **kwargs): 35 | for m in model.modules(): 36 | if m.__class__.__name__ == module_type: 37 | func(m, **kwargs) 38 | 39 | 40 | def replace_func(model, module_type, func): 41 | for m in model.modules(): 42 | if m.__class__.__name__ == module_type: 43 | m.forward = types.MethodType(func, m) 44 | 45 | 46 | def collect_convbn_gates(m): 47 | default_graph.append_tensor('gates_params', m.gates) 48 | 49 | 50 | def init_convbn_gates(m): 51 | m.gates = nn.Parameter(torch.ones(m.conv.out_channels)) 52 | 53 | 54 | def new_convbn_forward(self, x): 55 | out = self.conv(x) 56 | out = self.bn(out) 57 | out = self.gates.view(1, -1, 1, 1) * out 58 | out = self.relu(out) 59 | return out 60 | 61 | 62 | def collect_basicblock_gates(m): 63 | default_graph.append_tensor('gates_params', m.gates) 64 | 65 | 66 | def init_basicblock_gates(m): 67 | m.gates = nn.Parameter(torch.ones(m.conv1.out_channels)) 68 | 69 | 70 | def new_basicblock_forward(self, x): 71 | out = self.bn1(self.conv1(x)) 72 | out = self.gates.view(1, -1, 1, 1) * out 73 | out = self.bn2(self.conv2(self.relu1(out))) 74 | out += self.shortcut(x) 75 | out = self.relu2(out) 76 | return out 77 | 78 | 79 | def init_conv_depthwise_gates(m): 80 | m.gates = nn.Parameter(torch.ones(m.conv2.out_channels)) 81 | 82 | 83 | def collect_conv_depthwise_gates(m): 84 | default_graph.append_tensor('gates_params', m.gates) 85 | 86 | 87 | def new_conv_depthwise_forward(self, x): 88 | x = self.conv1(x) 89 | x = self.bn1(x) 90 | x = self.relu1(x) 91 | 92 | x = self.conv2(x) 93 | x = self.bn2(x) 94 | x = self.gates.view(1, -1, 1, 1) * x 95 | x = self.relu2(x) 96 | 97 | return x 98 | 99 | 100 | def init_inverted_block_gates(m): 101 | m.gates = nn.Parameter(torch.ones(m.hid)) 102 | 103 | 104 | def collect_inverted_block_gates(m): 105 | default_graph.append_tensor('gates_params', m.gates) 106 | 107 | 108 | def new_inverted_block_forward(self, x): 109 | x = self.conv1(x) 110 | x = self.bn1(x) 111 | x = self.gates.view(1, -1, 1, 1) * x 112 | x = self.relu1(x) 113 | 114 | x = self.conv2(x) 115 | x = self.bn2(x) 116 | x = self.relu2(x) 117 | x = self.conv3(x) 118 | x = self.bn3(x) 119 | return x 120 | 121 | 122 | def collect_bottleneck_gates(m): 123 | default_graph.append_tensor('gates_params', m.gates1) 124 | default_graph.append_tensor('gates_params', m.gates2) 125 | 126 | 127 | def init_bottleneck_gates(m): 128 | m.gates1 = nn.Parameter(torch.ones(m.conv1.out_channels)) 129 | m.gates2 = nn.Parameter(torch.ones(m.conv2.out_channels)) 130 | 131 | 132 | def new_bottleneck_forward(self, x): 133 | out = self.bn1(self.conv1(x)) 134 | out = self.gates1.view(1, -1, 1, 1) * out 135 | out = self.bn2(self.conv2(self.relu1(out))) 136 | out = self.gates2.view(1, -1, 1, 1) * out 137 | out = self.bn3(self.conv3(self.relu2(out))) 138 | out += self.shortcut(x) 139 | out = self.relu3(out) 140 | return out 141 | -------------------------------------------------------------------------------- /misc.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import os 3 | import shutil 4 | import pickle as pkl 5 | import time 6 | from datetime import datetime 7 | import numpy as np 8 | import torch 9 | import random 10 | 11 | 12 | def pil_loader(path): 13 | with open(path, 'rb') as f: 14 | with Image.open(f) as img: 15 | return img.convert('RGB') 16 | 17 | 18 | class Logger(object): 19 | def __init__(self): 20 | self._logger = None 21 | 22 | def init(self, logdir, name='log'): 23 | if self._logger is None: 24 | import logging 25 | if not os.path.exists(logdir): 26 | os.makedirs(logdir) 27 | log_file = os.path.join(logdir, name) 28 | if os.path.exists(log_file): 29 | os.remove(log_file) 30 | self._logger = logging.getLogger() 31 | self._logger.setLevel('INFO') 32 | fh = logging.FileHandler(log_file) 33 | ch = logging.StreamHandler() 34 | self._logger.addHandler(fh) 35 | self._logger.addHandler(ch) 36 | 37 | def info(self, str_info): 38 | now = datetime.now() 39 | display_now = str(now).split(' ')[1][:-3] 40 | self.init(os.path.expanduser('~/tmp_log'), 'tmp.log') 41 | self._logger.info('[' + display_now + ']' + ' ' + str_info) 42 | 43 | logger = Logger() 44 | 45 | 46 | def ensure_dir(path, erase=False): 47 | if os.path.exists(path) and erase: 48 | print("Removing old folder {}".format(path)) 49 | shutil.rmtree(path) 50 | if not os.path.exists(path): 51 | print("Creating folder {}".format(path)) 52 | os.makedirs(path) 53 | 54 | 55 | def load_pickle(path, verbose=True): 56 | begin_st = time.time() 57 | with open(path, 'rb') as f: 58 | if verbose: 59 | print("Loading pickle object from {}".format(path)) 60 | v = pkl.load(f) 61 | if verbose: 62 | print("=> Done ({:.4f} s)".format(time.time() - begin_st)) 63 | return v 64 | 65 | 66 | def dump_pickle(obj, path): 67 | with open(path, 'wb') as f: 68 | print("Dumping pickle object to {}".format(path)) 69 | pkl.dump(obj, f, protocol=pkl.HIGHEST_PROTOCOL) 70 | 71 | 72 | def prepare_logging(args): 73 | args.logdir = os.path.join('./logs', args.logdir) 74 | 75 | logger.init(args.logdir, 'log') 76 | 77 | ensure_dir(args.logdir) 78 | logger.info("=================FLAGS==================") 79 | for k, v in args.__dict__.items(): 80 | logger.info('{}: {}'.format(k, v)) 81 | logger.info("========================================") 82 | 83 | 84 | class AverageMeter(object): 85 | """Computes and stores the average and current value""" 86 | def __init__(self): 87 | self.reset() 88 | 89 | def reset(self): 90 | self.val = 0 91 | self.avg = 0 92 | self.sum = 0 93 | self.count = 0 94 | 95 | def update(self, val, n=1): 96 | self.val = val 97 | self.sum += val * n 98 | self.count += n 99 | self.avg = self.sum / self.count 100 | 101 | 102 | def accuracy(output, target, topk=(1,)): 103 | """Computes the precision@k for the specified values of k""" 104 | maxk = max(topk) 105 | batch_size = target.size(0) 106 | 107 | _, pred = output.topk(maxk, 1, True, True) 108 | pred = pred.t() 109 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 110 | 111 | res = [] 112 | for k in topk: 113 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 114 | res.append(correct_k.mul_(100.0 / batch_size)) 115 | return res 116 | 117 | 118 | def set_seed(seed=None): 119 | if seed is None: 120 | seed = random.randint(0, 9999) 121 | np.random.seed(seed) 122 | torch.manual_seed(seed) 123 | torch.cuda.manual_seed_all(seed) 124 | torch.backends.cudnn.deterministic = True 125 | return seed -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .vgg import vgg16_bn, vgg19_bn 2 | from .resnet import resnet20, resnet32, \ 3 | resnet44, resnet56, resnet110, resnet50 4 | from .mobilenet import mobilenet_v1, mobilenet_v2 5 | 6 | def expanded_cfg(c): 7 | v1_cfg = [c, 2*c, 4*c, 4*c, 8*c, 8*c, 16*c, 16*c, 16*c, 16*c, 16*c, 16*c, 32*c, 32*c] 8 | v2_cfg = [32, 96, 144, 144, 192, 192, 192, 384, 384, 384, 384, 576, 576, 576, 960, 960, 960, 1280] 9 | multiplier = c / 32 10 | if c != 32: 11 | v2_cfg = [int(v * multiplier) for v in v2_cfg] 12 | 13 | defaultcfg = { 14 | 'vgg11_bn' : [c, 'M', c*2, 'M', c*4, c*4, 'M', c*8, c*8, 'M', c*8, c*8], 15 | 'vgg13_bn' : [c, c, 'M', c*2, c*2, 'M', c*4, c*4, 'M', c*8, c*8, 'M', c*8, c*8], 16 | 'vgg16_bn' : [c, c, 'M', c*2, c*2, 'M', c*4, c*4, c*4, 'M', c*8, c*8, c*8, 'M', c*8, c*8, c*8], 17 | 'vgg19_bn' : [c, c, 'M', c*2, c*2, 'M', c*4, c*4, c*4, c*4, 'M', c*8, c*8, c*8, c*8, 'M', c*8, c*8, c*8, c*8], 18 | 'resnet20': [c] * 3 + [c*2] * 3 + [c*4] * 3, 19 | 'resnet32': [c] * 5 + [c*2] * 5 + [c*4] * 5, 20 | 'resnet44': [c] * 7 + [c*2] * 7 + [c*4] * 7, 21 | 'resnet56': [c] * 9 + [c*2] * 9 + [c*4] * 9, 22 | 'resnet110': [c] * 18 + [c*2] * 18 + [c*4] * 18, 23 | 'mobilenet_v1': v1_cfg, 24 | 'mobilenet_v2': v2_cfg, 25 | 'resnet50': [c, c, c, c, c, c, 2*c, 2*c, 2*c, 2*c, 2*c, 2*c, 2*c, 2*c, 26 | 4*c, 4*c, 4*c, 4*c, 4*c, 4*c, 4*c, 4*c, 4*c, 4*c, 4*c, 4*c, 27 | 8*c, 8*c, 8*c, 8*c, 8*c, 8*c] 28 | } 29 | return defaultcfg 30 | -------------------------------------------------------------------------------- /models/mobilenet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | 4 | 5 | def expanded_cfg(c): 6 | v1_cfg = [c, 2*c, 4*c, 4*c, 8*c, 8*c, 16*c, 16*c, 16*c, 16*c, 16*c, 16*c, 32*c, 32*c] 7 | v2_cfg = [32, 96, 144, 144, 192, 192, 192, 384, 384, 384, 384, 576, 576, 576, 960, 960, 960, 1280] 8 | multiplier = c / 32 9 | if c != 32: 10 | v2_cfg = [int(v * multiplier) for v in v2_cfg] 11 | 12 | cfg = { 13 | 'mobilenet_v1': v1_cfg, 14 | 'mobilenet_v2': v2_cfg 15 | } 16 | return cfg 17 | 18 | 19 | class ConvBNReLU(nn.Module): 20 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding): 21 | super(ConvBNReLU, self).__init__() 22 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False) 23 | self.bn = nn.BatchNorm2d(out_channels) 24 | self.relu = nn.ReLU(inplace=True) 25 | 26 | def forward(self, x): 27 | x = self.conv(x) 28 | x = self.bn(x) 29 | x = self.relu(x) 30 | return x 31 | 32 | 33 | class ConvDepthWise(nn.Module): 34 | 35 | def __init__(self, inp, oup, stride): 36 | super(ConvDepthWise, self).__init__() 37 | self.conv1 = nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False) 38 | self.bn1 = nn.BatchNorm2d(inp) 39 | self.relu1 = nn.ReLU() 40 | 41 | self.conv2 = nn.Conv2d(inp, oup, 1, 1, 0, bias=False) 42 | self.bn2 = nn.BatchNorm2d(oup) 43 | self.relu2 = nn.ReLU() 44 | 45 | def forward(self, x): 46 | x = self.conv1(x) 47 | x = self.bn1(x) 48 | x = self.relu1(x) 49 | 50 | x = self.conv2(x) 51 | x = self.bn2(x) 52 | x = self.relu2(x) 53 | 54 | return x 55 | 56 | 57 | 58 | class MobileNet(nn.Module): 59 | def __init__(self, n_class, in_channel=32, multiplier=1.0, cfg=None): 60 | super(MobileNet, self).__init__() 61 | # original 62 | if cfg is None: 63 | cfg = expanded_cfg(in_channel)['mobilenet_v1'] 64 | 65 | self.conv1 = ConvBNReLU(3, cfg[0], 3, 2, 1) 66 | self.features = self._make_layers(cfg[0], cfg[1:], ConvDepthWise) 67 | self.pool = nn.AvgPool2d(7) 68 | self.classifier = nn.Sequential( 69 | nn.Linear(cfg[-1], n_class), 70 | ) 71 | 72 | self._initialize_weights() 73 | 74 | def forward(self, x): 75 | x = self.conv1(x) 76 | x = self.features(x) 77 | x = self.pool(x) # global average pooling 78 | x = x.view(x.size(0), -1) 79 | 80 | x = self.classifier(x) 81 | return x 82 | 83 | def _make_layers(self, in_planes, cfg, layer): 84 | layers = [] 85 | for i, x in enumerate(cfg): 86 | out_planes = x 87 | stride = 2 if i in [1, 3, 5, 11] else 1 88 | layers.append(layer(in_planes, out_planes, stride)) 89 | in_planes = out_planes 90 | return nn.Sequential(*layers) 91 | 92 | def _initialize_weights(self): 93 | for m in self.modules(): 94 | if isinstance(m, nn.Conv2d): 95 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 96 | m.weight.data.normal_(0, math.sqrt(2. / n)) 97 | if m.bias is not None: 98 | m.bias.data.zero_() 99 | elif isinstance(m, nn.BatchNorm2d): 100 | m.weight.data.fill_(1) 101 | m.bias.data.zero_() 102 | elif isinstance(m, nn.Linear): 103 | n = m.weight.size(1) 104 | m.weight.data.normal_(0, 0.01) 105 | m.bias.data.zero_() 106 | 107 | 108 | def mobilenet_v1(num_class, in_channel=32, multiplier=1.0, cfg=None): 109 | return MobileNet(num_class, in_channel, multiplier, cfg) 110 | 111 | 112 | class InvertedBlock(nn.Module): 113 | def __init__(self, inp, oup, hid, stride): 114 | super(InvertedBlock, self).__init__() 115 | self.hid = hid 116 | # pw 117 | self.conv1 = nn.Conv2d(inp, hid, 1, 1, 0, bias=False) 118 | self.bn1 = nn.BatchNorm2d(hid) 119 | self.relu1 = nn.ReLU6(inplace=True) 120 | # dw 121 | self.conv2 = nn.Conv2d(hid, hid, 3, stride, 1, groups=hid, bias=False) 122 | self.bn2 = nn.BatchNorm2d(hid) 123 | self.relu2 = nn.ReLU6(inplace=True) 124 | # pw-linear 125 | self.conv3 = nn.Conv2d(hid, oup, 1, 1, 0, bias=False) 126 | self.bn3 = nn.BatchNorm2d(oup) 127 | 128 | def forward(self, x): 129 | x = self.conv1(x) 130 | x = self.bn1(x) 131 | x = self.relu1(x) 132 | x = self.conv2(x) 133 | x = self.bn2(x) 134 | x = self.relu2(x) 135 | x = self.conv3(x) 136 | x = self.bn3(x) 137 | return x 138 | 139 | 140 | class InvertedResidual(nn.Module): 141 | def __init__(self, inp, oup, hid, stride): 142 | super(InvertedResidual, self).__init__() 143 | self.stride = stride 144 | assert stride in [1, 2] 145 | 146 | self.use_res_connect = self.stride == 1 and inp == oup 147 | 148 | if hid == inp: 149 | self.conv = nn.Sequential( 150 | # dw 151 | nn.Conv2d(hid, hid, 3, stride, 1, groups=hid, bias=False), 152 | nn.BatchNorm2d(hid), 153 | nn.ReLU6(inplace=True), 154 | # pw-linear 155 | nn.Conv2d(hid, oup, 1, 1, 0, bias=False), 156 | nn.BatchNorm2d(oup), 157 | ) 158 | else: 159 | self.conv = InvertedBlock(inp, oup, hid, stride) 160 | 161 | def forward(self, x): 162 | if self.use_res_connect: 163 | return x + self.conv(x) 164 | else: 165 | return self.conv(x) 166 | 167 | 168 | class MobileNetV2(nn.Module): 169 | def __init__(self, n_class=1000, in_channel=32, multiplier=1.0, cfg=None): 170 | super(MobileNetV2, self).__init__() 171 | output_channels = [ 172 | 16, 24, 24, 32, 32, 32, 64, 64, 64, 64, 173 | 96, 96, 96, 160, 160, 160, 320 174 | ] 175 | for i in range(len(output_channels)): 176 | output_channels[i] = int(multiplier * output_channels[i]) 177 | 178 | if cfg is None: 179 | cfg = expanded_cfg(in_channel)['mobilenet_v2'] 180 | 181 | self.features = [ConvBNReLU(3, cfg[0], kernel_size=3, stride=2, padding=1)] 182 | # building inverted residual blocks 183 | inp = cfg[0] 184 | for j, (hid, oup) in enumerate(zip(cfg[:-1], output_channels)): 185 | if j in [1, 3, 6, 13]: 186 | stride = 2 187 | else: 188 | stride = 1 189 | self.features.append(InvertedResidual(inp, oup, hid, stride)) 190 | inp = oup 191 | 192 | # building last several layers 193 | self.features.append(ConvBNReLU(inp, cfg[-1], kernel_size=1, stride=1, padding=0)) 194 | self.features.append(nn.AvgPool2d(7)) 195 | # make it nn.Sequential 196 | self.features = nn.Sequential(*self.features) 197 | 198 | # building classifier 199 | self.classifier = nn.Sequential( 200 | nn.Linear(cfg[-1], n_class), 201 | ) 202 | 203 | self._initialize_weights() 204 | 205 | def forward(self, x): 206 | x = self.features(x) 207 | x = x.view(x.size(0), -1) 208 | x = self.classifier(x) 209 | return x 210 | 211 | def _initialize_weights(self): 212 | for m in self.modules(): 213 | if isinstance(m, nn.Conv2d): 214 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 215 | m.weight.data.normal_(0, math.sqrt(2. / n)) 216 | if m.bias is not None: 217 | m.bias.data.zero_() 218 | elif isinstance(m, nn.BatchNorm2d): 219 | m.weight.data.fill_(1) 220 | m.bias.data.zero_() 221 | elif isinstance(m, nn.Linear): 222 | n = m.weight.size(1) 223 | m.weight.data.normal_(0, 0.01) 224 | m.bias.data.zero_() 225 | 226 | def mobilenet_v2(num_class, in_channel=32, multiplier=1.0, cfg=None): 227 | return MobileNetV2(num_class, in_channel, multiplier, cfg) 228 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | 4 | 5 | def expanded_cfg(c): 6 | defaultcfg = { 7 | 'resnet20': [c] * 3 + [c*2] * 3 + [c*4] * 3, 8 | 'resnet32': [c] * 5 + [c*2] * 5 + [c*4] * 5, 9 | 'resnet44': [c] * 7 + [c*2] * 7 + [c*4] * 7, 10 | 'resnet56': [c] * 9 + [c*2] * 9 + [c*4] * 9, 11 | 'resnet110': [c] * 18 + [c*2] * 18 + [c*4] * 18, 12 | 'resnet50': [c, c, c, c, c, c, 2*c, 2*c, 2*c, 2*c, 2*c, 2*c, 2*c, 2*c, 13 | 4*c, 4*c, 4*c, 4*c, 4*c, 4*c, 4*c, 4*c, 4*c, 4*c, 4*c, 4*c, 14 | 8*c, 8*c, 8*c, 8*c, 8*c, 8*c] 15 | } 16 | return defaultcfg 17 | 18 | class BasicBlock(nn.Module): 19 | expansion = 1 20 | 21 | def __init__(self, in_planes, mid_planes, out_planes, stride=1): 22 | super(BasicBlock, self).__init__() 23 | self.conv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=3, stride=stride, padding=1, bias=False) 24 | self.bn1 = nn.BatchNorm2d(mid_planes) 25 | self.relu1 = nn.ReLU() 26 | self.conv2 = nn.Conv2d(mid_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False) 27 | self.bn2 = nn.BatchNorm2d(out_planes) 28 | self.relu2 = nn.ReLU() 29 | 30 | self.shortcut = nn.Sequential() 31 | if stride != 1 or in_planes != self.expansion*out_planes: 32 | self.shortcut = nn.Sequential( 33 | nn.Conv2d(in_planes, self.expansion*out_planes, kernel_size=1, stride=stride, bias=False), 34 | nn.BatchNorm2d(self.expansion*out_planes) 35 | ) 36 | 37 | def forward(self, x): 38 | out = self.relu1(self.bn1(self.conv1(x))) 39 | out = self.bn2(self.conv2(out)) 40 | out += self.shortcut(x) 41 | out = self.relu2(out) 42 | return out 43 | 44 | 45 | class CifarResNet(nn.Module): 46 | def __init__(self, depth, expanded_inchannel, num_classes=10, cfg=None): 47 | super(CifarResNet, self).__init__() 48 | if cfg is None: 49 | defaultcfg = expanded_cfg(expanded_inchannel) 50 | cfg = defaultcfg['resnet' + str(depth)] 51 | n_blocks = (depth - 2) // 6 52 | 53 | self.in_planes = expanded_inchannel 54 | self.conv1 = nn.Conv2d(3, expanded_inchannel, kernel_size=3, padding=1, bias=False) 55 | self.bn1 = nn.BatchNorm2d(expanded_inchannel) 56 | self.relu = nn.ReLU(inplace=True) 57 | self.layer1 = self._make_layer(BasicBlock, expanded_inchannel, cfg[0:n_blocks], stride=1) 58 | self.layer2 = self._make_layer(BasicBlock, expanded_inchannel*2, cfg[n_blocks:2*n_blocks], stride=2) 59 | self.layer3 = self._make_layer(BasicBlock, expanded_inchannel*4, cfg[2*n_blocks:], stride=2) 60 | self.avgpool = nn.AvgPool2d(8) 61 | self.linear = nn.Linear(expanded_inchannel*4, num_classes) 62 | 63 | for m in self.modules(): 64 | if isinstance(m, nn.Conv2d): 65 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 66 | m.weight.data.normal_(0, math.sqrt(2. / n)) 67 | elif isinstance(m, nn.BatchNorm2d): 68 | m.weight.data.fill_(1) 69 | m.bias.data.zero_() 70 | 71 | def _make_layer(self, block, channel, cfg, stride): 72 | layers = [] 73 | for i in range(0, len(cfg)): 74 | layers.append(block(self.in_planes, cfg[i], channel, stride)) 75 | stride = 1 76 | self.in_planes = channel 77 | return nn.Sequential(*layers) 78 | 79 | def forward(self, x): 80 | out = self.relu(self.bn1(self.conv1(x))) 81 | out = self.layer1(out) 82 | out = self.layer2(out) 83 | out = self.layer3(out) 84 | out = self.avgpool(out) 85 | out = out.view(out.size(0), -1) 86 | out = self.linear(out) 87 | return out 88 | 89 | 90 | def resnet20(num_classes, expanded_inchannel=16, cfg=None): 91 | return CifarResNet(20, expanded_inchannel, num_classes, cfg) 92 | 93 | def resnet32(num_classes, expanded_inchannel=16, cfg=None): 94 | return CifarResNet(32, expanded_inchannel, num_classes, cfg) 95 | 96 | def resnet44(num_classes, expanded_inchannel=16, cfg=None): 97 | return CifarResNet(44, expanded_inchannel, num_classes, cfg) 98 | 99 | def resnet56(num_classes, expanded_inchannel=16, cfg=None): 100 | return CifarResNet(56, expanded_inchannel, num_classes, cfg) 101 | 102 | def resnet110(num_classes, expanded_inchannel=16, cfg=None): 103 | return CifarResNet(110, expanded_inchannel, num_classes, cfg) 104 | 105 | 106 | class Bottleneck(nn.Module): 107 | def __init__(self, in_planes, mid_planes1, mid_planes2, out_planes, stride=1): 108 | super(Bottleneck, self).__init__() 109 | self.conv1 = nn.Conv2d(in_planes, mid_planes1, 1, 1, 0, bias=False) 110 | self.bn1 = nn.BatchNorm2d(mid_planes1) 111 | self.relu1 = nn.ReLU(inplace=True) 112 | 113 | self.conv2 = nn.Conv2d(mid_planes1, mid_planes2, 3, stride, 1, bias=False) 114 | self.bn2 = nn.BatchNorm2d(mid_planes2) 115 | self.relu2 = nn.ReLU(inplace=True) 116 | 117 | self.conv3 = nn.Conv2d(mid_planes2, out_planes, 1, 1, 0, bias=False) 118 | self.bn3 = nn.BatchNorm2d(out_planes) 119 | 120 | self.shortcut = nn.Sequential() 121 | if stride != 1 or in_planes != out_planes: 122 | self.shortcut = nn.Sequential( 123 | nn.Conv2d(in_planes, out_planes, 1, stride, 0, bias=False), 124 | nn.BatchNorm2d(out_planes) 125 | ) 126 | self.relu3 = nn.ReLU(inplace=True) 127 | 128 | def forward(self, x): 129 | out = self.relu1(self.bn1(self.conv1(x))) 130 | out = self.relu2(self.bn2(self.conv2(out))) 131 | out = self.bn3(self.conv3(out)) 132 | out += self.shortcut(x) 133 | out = self.relu3(out) 134 | return out 135 | 136 | 137 | class ResNet50(nn.Module): 138 | def __init__(self, num_classes, expanded_inchannel=64, multiplier=1.0, cfg=None): 139 | super(ResNet50, self).__init__() 140 | if cfg is None: 141 | c = expanded_inchannel 142 | cfg = [c, c, c, c, c, c, 2*c, 2*c, 2*c, 2*c, 2*c, 2*c, 2*c, 2*c, 143 | 4*c, 4*c, 4*c, 4*c, 4*c, 4*c, 4*c, 4*c, 4*c, 4*c, 4*c, 4*c, 144 | 8*c, 8*c, 8*c, 8*c, 8*c, 8*c] 145 | 146 | output_channels = [64, 256, 256, 256, 512, 512, 512, 512, 147 | 1024, 1024, 1024, 1024, 1024, 1024, 2048, 2048, 2048] 148 | 149 | for i in range(len(output_channels)): 150 | output_channels[i] = int(multiplier * output_channels[i]) 151 | 152 | self.in_planes = output_channels[0] 153 | self.conv1 = nn.Conv2d(3, self.in_planes, 7, 2, 3, bias=False) 154 | self.bn1 = nn.BatchNorm2d(self.in_planes) 155 | self.relu = nn.ReLU(inplace=True) 156 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 157 | self.layer1 = self._make_layer(cfg[:6], output_channels[1:4], stride=1) 158 | self.layer2 = self._make_layer(cfg[6:14], output_channels[4:8], stride=2) 159 | self.layer3 = self._make_layer(cfg[14:26], output_channels[8:14], stride=2) 160 | self.layer4 = self._make_layer(cfg[26:], output_channels[14:], stride=2) 161 | self.avgpool = nn.AvgPool2d(7) 162 | self.fc = nn.Linear(output_channels[-1], num_classes) 163 | 164 | for m in self.modules(): 165 | if isinstance(m, nn.Conv2d): 166 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 167 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 168 | nn.init.constant_(m.weight, 1) 169 | nn.init.constant_(m.bias, 0) 170 | 171 | def _make_layer(self, cfg, output_channels, stride): 172 | layers = [Bottleneck(self.in_planes, cfg[0], cfg[1], output_channels[0], stride)] 173 | self.in_planes = output_channels[0] 174 | for i in range(1, len(output_channels)): 175 | layers.append(Bottleneck(self.in_planes, cfg[2*i], cfg[2*i+1], output_channels[i])) 176 | self.in_planes = output_channels[i] 177 | return nn.Sequential(*layers) 178 | 179 | def forward(self, x): 180 | x = self.conv1(x) 181 | x = self.bn1(x) 182 | x = self.relu(x) 183 | x = self.maxpool(x) 184 | 185 | x = self.layer1(x) 186 | x = self.layer2(x) 187 | x = self.layer3(x) 188 | x = self.layer4(x) 189 | 190 | x = self.avgpool(x) 191 | x = x.view(x.size(0), -1) 192 | x = self.fc(x) 193 | 194 | return x 195 | 196 | def resnet50(num_classes, in_channel=64, multiplier=1.0, cfg=None): 197 | return ResNet50(num_classes, in_channel, multiplier, cfg) -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | 4 | 5 | def expanded_cfg(c): 6 | defaultcfg = { 7 | 'vgg11_bn' : [c, 'M', c*2, 'M', c*4, c*4, 'M', c*8, c*8, 'M', c*8, c*8], 8 | 'vgg13_bn' : [c, c, 'M', c*2, c*2, 'M', c*4, c*4, 'M', c*8, c*8, 'M', c*8, c*8], 9 | 'vgg16_bn' : [c, c, 'M', c*2, c*2, 'M', c*4, c*4, c*4, 'M', c*8, c*8, c*8, 'M', c*8, c*8, c*8], 10 | 'vgg19_bn' : [c, c, 'M', c*2, c*2, 'M', c*4, c*4, c*4, c*4, 'M', c*8, c*8, c*8, c*8, 'M', c*8, c*8, c*8, c*8], 11 | } 12 | return defaultcfg 13 | 14 | 15 | class ConvBNReLU(nn.Module): 16 | def __init__(self, in_channels, out_channels, kernel_size, padding=0, bias=False): 17 | super(ConvBNReLU, self).__init__() 18 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=bias) 19 | self.bn = nn.BatchNorm2d(out_channels) 20 | self.relu = nn.ReLU(inplace=True) 21 | 22 | def forward(self, x): 23 | x = self.conv(x) 24 | x = self.bn(x) 25 | x = self.relu(x) 26 | return x 27 | 28 | 29 | class CifarVGG(nn.Module): 30 | def __init__(self, depth, expanded_inchannel, num_classes=10, cfg=None): 31 | super(CifarVGG, self).__init__() 32 | if cfg is None: 33 | defaultcfg = expanded_cfg(expanded_inchannel) 34 | cfg = defaultcfg['vgg' + str(depth) + '_bn'] 35 | 36 | self.feature = self.make_layers(cfg) 37 | self.num_classes = num_classes 38 | self.classifier = nn.Linear(cfg[-1], num_classes) 39 | self._initialize_weights() 40 | 41 | def make_layers(self, cfg): 42 | layers = [] 43 | in_channels = 3 44 | for v in cfg: 45 | if v == 'M': 46 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 47 | else: 48 | layers += [ConvBNReLU(in_channels, v, kernel_size=3, padding=1, bias=False)] 49 | in_channels = v 50 | return nn.Sequential(*layers) 51 | 52 | def forward(self, x): 53 | x = self.feature(x) 54 | x = nn.AvgPool2d(2)(x) 55 | x = x.view(x.size(0), -1) 56 | y = self.classifier(x) 57 | return y 58 | 59 | def _initialize_weights(self): 60 | for m in self.modules(): 61 | if isinstance(m, nn.Conv2d): 62 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 63 | m.weight.data.normal_(0, math.sqrt(2. / n)) 64 | if m.bias is not None: 65 | m.bias.data.zero_() 66 | elif isinstance(m, nn.BatchNorm2d): 67 | m.weight.data.fill_(0.5) 68 | m.bias.data.zero_() 69 | elif isinstance(m, nn.Linear): 70 | m.weight.data.normal_(0, 0.01) 71 | m.bias.data.zero_() 72 | 73 | 74 | def vgg16_bn(num_classes, expanded_inchannel=64, cfg=None): 75 | return CifarVGG(16, expanded_inchannel, num_classes, cfg) 76 | 77 | def vgg19_bn(num_classes, expanded_inchannel=64, cfg=None): 78 | return CifarVGG(19, expanded_inchannel, num_classes, cfg) 79 | -------------------------------------------------------------------------------- /script/learn_gates.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | import datasets 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import torch 6 | import argparse 7 | import os 8 | 9 | from gate import default_graph, apply_func, replace_func 10 | import models 11 | import misc 12 | 13 | print = misc.logger.info 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--gpu', default='0', type=str) 17 | parser.add_argument('--dataset', default='cifar10', type=str) 18 | parser.add_argument('--arch', '-a', default='vgg16_bn', type=str) 19 | parser.add_argument('--sparsity_level', '-s', default=0.2, type=float) 20 | parser.add_argument('--lr', default=0.01, type=float) 21 | parser.add_argument('--lambd', default=0.5, type=float) 22 | parser.add_argument('--epochs', default=10, type=int) 23 | parser.add_argument('--log_interval', default=100, type=int) 24 | parser.add_argument('--train_batch_size', default=128, type=int) 25 | parser.add_argument('--expanded_inchannel', '-e', default=80, type=int) 26 | parser.add_argument('--seed', default=None, type=int) 27 | 28 | args = parser.parse_args() 29 | args.seed = misc.set_seed(args.seed) 30 | args.num_classes = 10 31 | 32 | args.device = 'cuda' 33 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 34 | 35 | args.logdir = 'seed-%d/%s-%s/channel-%d-sparsity-%.2f' % ( 36 | args.seed, args.dataset, args.arch, args.expanded_inchannel, args.sparsity_level 37 | ) 38 | 39 | misc.prepare_logging(args) 40 | 41 | print('==> Preparing data..') 42 | 43 | transform_train = transforms.Compose([ 44 | transforms.RandomCrop(32, padding=4), 45 | transforms.RandomHorizontalFlip(), 46 | transforms.ToTensor(), 47 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 48 | ]) 49 | 50 | transform_val = transforms.Compose([ 51 | transforms.ToTensor(), 52 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 53 | ]) 54 | 55 | trainset = datasets.CIFAR10(root='./data/cifar10', type='train', transform=transform_train) 56 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.train_batch_size, shuffle=True, num_workers=2) 57 | 58 | valset = datasets.CIFAR10(root='./data/cifar10', type='val', transform=transform_val) 59 | valloader = torch.utils.data.DataLoader(valset, batch_size=100, shuffle=False, num_workers=2) 60 | 61 | print('==> Initializing model...') 62 | model = models.__dict__[args.arch](args.num_classes, args.expanded_inchannel) 63 | 64 | if args.arch.find('vgg') != -1: 65 | from gate import init_convbn_gates, new_convbn_forward, collect_convbn_gates 66 | init_func = init_convbn_gates 67 | new_forward = new_convbn_forward 68 | collect_gates = collect_convbn_gates 69 | module_type = 'ConvBNReLU' 70 | 71 | elif args.arch.find('resnet') != -1: 72 | from gate import init_basicblock_gates, new_basicblock_forward, collect_basicblock_gates 73 | init_func = init_basicblock_gates 74 | new_forward = new_basicblock_forward 75 | collect_gates = collect_basicblock_gates 76 | module_type = 'BasicBlock' 77 | 78 | elif args.arch.find('densenet') != -1: 79 | from gate import init_channel_selection_gates, new_channel_selection_forward, collect_channel_selection_gates 80 | init_func = init_channel_selection_gates 81 | new_forward = new_channel_selection_forward 82 | collect_gates = collect_channel_selection_gates 83 | module_type = 'ChannelSelection' 84 | 85 | else: 86 | raise NotImplementedError 87 | 88 | 89 | print('==> Transforming model...') 90 | 91 | apply_func(model, module_type, init_func) 92 | apply_func(model, module_type, collect_gates) 93 | replace_func(model, module_type, new_forward) 94 | 95 | model = model.to(args.device) 96 | 97 | gates_params = default_graph.get_tensor_list('gates_params') 98 | optimizer = torch.optim.Adam(gates_params, lr=args.lr) 99 | 100 | def train(epoch): 101 | model.train() 102 | for i, (data, target) in enumerate(trainloader): 103 | data = data.to(args.device) 104 | target = target.to(args.device) 105 | 106 | optimizer.zero_grad() 107 | output = model(data) 108 | loss_ce = F.cross_entropy(output, target) 109 | loss_reg = args.lambd * (torch.cat(gates_params).abs().mean() - args.sparsity_level) ** 2 110 | loss = loss_ce + loss_reg 111 | 112 | loss.backward() 113 | optimizer.step() 114 | 115 | for p in gates_params: 116 | p.data.clamp_(0, 1) 117 | 118 | if i % args.log_interval == 0: 119 | concat_channels = torch.cat(gates_params) 120 | sparsity = (concat_channels != 0).float().mean() 121 | mean_gate = concat_channels.mean() 122 | acc = (output.max(1)[1] == target).float().mean() 123 | 124 | print('Train Epoch: %d [%d/%d]\tLoss: %.4f, Loss_CE: %.4f, Loss_REG: %.4f, ' 125 | 'Sparsity: %.4f, Mean gate: %.4f, Accuracy: %.4f' % ( 126 | epoch, i, len(trainloader), loss.item(), loss_ce.item(), loss_reg.item(), 127 | sparsity.item(), mean_gate.item(), acc.item() 128 | )) 129 | 130 | 131 | def test(): 132 | model.eval() 133 | test_loss_ce = [] 134 | correct = 0 135 | with torch.no_grad(): 136 | for data, target in valloader: 137 | default_graph.clear_all_tensors() 138 | 139 | data, target = data.to(args.device), target.to(args.device) 140 | output = model(data) 141 | 142 | test_loss_ce.append(F.cross_entropy(output, target).item()) 143 | 144 | pred = output.max(1)[1] 145 | correct += (pred == target).float().sum().item() 146 | 147 | test_sparsity = (torch.cat(gates_params) != 0).float().mean() 148 | acc = correct / len(valloader.dataset) 149 | print('Test set: Loss_CE: %.4f, ' 150 | 'Sparsity: %.4f, Accuracy: %.4f\n' % ( 151 | np.mean(test_loss_ce), 152 | test_sparsity.item(), acc 153 | )) 154 | return acc, test_sparsity 155 | 156 | best_acc = 0 157 | for epoch in range(args.epochs): 158 | train(epoch) 159 | acc, test_sparsity = test() 160 | if test_sparsity <= args.sparsity_level and acc > best_acc: 161 | best_acc = acc 162 | torch.save(model.state_dict(), os.path.join(args.logdir, 'checkpoint.pth')) 163 | 164 | temp_params = [] 165 | for i in range(len(gates_params)): 166 | temp_params.append(gates_params[i].data.clone().cpu()) 167 | 168 | misc.dump_pickle(temp_params, os.path.join(args.logdir, 'channel_gates.pkl')) 169 | 170 | if best_acc == 0: 171 | torch.save(model.state_dict(), os.path.join(args.logdir, 'checkpoint.pth')) 172 | 173 | temp_params = [] 174 | for i in range(len(gates_params)): 175 | temp_params.append(gates_params[i].data.clone().cpu()) 176 | 177 | misc.dump_pickle(temp_params, os.path.join(args.logdir, 'channel_gates.pkl')) 178 | -------------------------------------------------------------------------------- /script/learn_gates_imagenet.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | import datasets 3 | import torch 4 | import argparse 5 | import os 6 | 7 | from gate import default_graph, apply_func, replace_func 8 | from gate import init_convbn_gates, collect_convbn_gates, new_convbn_forward 9 | import models 10 | import misc 11 | 12 | print = misc.logger.info 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--gpu', default='0', type=str) 16 | parser.add_argument('--data', default='data/imagenet', type=str) 17 | parser.add_argument('--arch', '-a', default='mobilenet_v1', type=str) 18 | parser.add_argument('--sparsity_level', '-s', default=0.5, type=float) 19 | parser.add_argument('--lr', default=0.01, type=float) 20 | parser.add_argument('--lambd', default=0.05, type=float) 21 | parser.add_argument('--log_interval', default=100, type=int) 22 | parser.add_argument('--eval_interval', default=500, type=int) 23 | parser.add_argument('--train_batch_size', default=100, type=int) 24 | parser.add_argument('--expanded_inchannel', '-e', default=40, type=int) 25 | parser.add_argument('--multiplier', '-m', default=1.0, type=float) 26 | parser.add_argument('--seed', default=None, type=int) 27 | 28 | args = parser.parse_args() 29 | args.seed = misc.set_seed(args.seed) 30 | args.num_classes = 1000 31 | 32 | args.device = 'cuda' 33 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 34 | 35 | args.logdir = 'imagenet-%s/channel-%d-sparsity-%.2f' % ( 36 | args.arch, args.expanded_inchannel, args.sparsity_level 37 | ) 38 | 39 | misc.prepare_logging(args) 40 | 41 | print('==> Preparing data..') 42 | 43 | transform_train = transforms.Compose([ 44 | transforms.RandomResizedCrop(224, scale=(0.25, 1.0)), 45 | transforms.RandomHorizontalFlip(), 46 | ]) 47 | 48 | transform_test = transforms.Compose([ 49 | transforms.Resize(256), 50 | transforms.CenterCrop(224), 51 | ]) 52 | 53 | train_loader = torch.utils.data.DataLoader( 54 | datasets.ImageNet(args.data, 'train', transform_train), 55 | batch_size=args.train_batch_size, shuffle=True, num_workers=32, 56 | pin_memory=True, collate_fn=datasets.fast_collate 57 | ) 58 | test_loader = torch.utils.data.DataLoader( 59 | datasets.ImageNet(args.data, 'val', transform_test), 60 | batch_size=50, shuffle=False, num_workers=32, 61 | pin_memory=True, collate_fn=datasets.fast_collate 62 | ) 63 | print('==> Initializing model...') 64 | model = models.__dict__[args.arch](args.num_classes, args.expanded_inchannel, args.multiplier) 65 | 66 | if args.arch == 'mobilenet_v1': 67 | from gate import init_conv_depthwise_gates, new_conv_depthwise_forward, collect_conv_depthwise_gates 68 | init_func = init_conv_depthwise_gates 69 | new_forward = new_conv_depthwise_forward 70 | collect_gates = collect_conv_depthwise_gates 71 | module_type = 'ConvDepthWise' 72 | elif args.arch == 'mobilenet_v2': 73 | from gate import init_inverted_block_gates, new_inverted_block_forward, collect_inverted_block_gates 74 | init_func = init_inverted_block_gates 75 | new_forward = new_inverted_block_forward 76 | collect_gates = collect_inverted_block_gates 77 | module_type = 'InvertedBlock' 78 | elif args.arch == 'resnet50': 79 | from gate import init_bottleneck_gates, new_bottleneck_forward, collect_bottleneck_gates 80 | init_func = init_bottleneck_gates 81 | new_forward = new_bottleneck_forward 82 | collect_gates = collect_bottleneck_gates 83 | module_type = 'Bottleneck' 84 | else: 85 | raise NotImplementedError 86 | 87 | 88 | print('==> Transforming model...') 89 | model_params = [] 90 | for params in model.parameters(): 91 | ps = list(params.size()) 92 | if len(ps) == 4 and ps[1] != 1: 93 | weight_decay = 1e-4 94 | elif len(ps) == 2: 95 | weight_decay = 1e-4 96 | else: 97 | weight_decay = 0 98 | item = {'params': params, 'weight_decay': weight_decay, 99 | 'lr': 0.045, 'momentum': 0.9, 100 | 'nesterov': True} 101 | model_params.append(item) 102 | 103 | apply_func(model, 'ConvBNReLU', init_convbn_gates) 104 | apply_func(model, module_type, init_func) 105 | apply_func(model, 'ConvBNReLU', collect_convbn_gates) 106 | apply_func(model, module_type, collect_gates) 107 | replace_func(model, 'ConvBNReLU', new_convbn_forward) 108 | replace_func(model, module_type, new_forward) 109 | 110 | model = model.to(args.device) 111 | criterion = torch.nn.CrossEntropyLoss().to(args.device) 112 | 113 | gates_params = default_graph.get_tensor_list('gates_params') 114 | optimizer = torch.optim.Adam(gates_params, lr=args.lr) 115 | 116 | def test(): 117 | test_losses = misc.AverageMeter() 118 | test_top1 = misc.AverageMeter() 119 | test_top5 = misc.AverageMeter() 120 | 121 | model.eval() 122 | prefetcher = datasets.DataPrefetcher(test_loader) 123 | with torch.no_grad(): 124 | data, target = prefetcher.next() 125 | while data is not None: 126 | default_graph.clear_all_tensors() 127 | 128 | data, target = data.to(args.device), target.to(args.device) 129 | output = model(data) 130 | 131 | loss = criterion(output, target) 132 | prec1, prec5 = misc.accuracy(output, target, topk=(1, 5)) 133 | test_losses.update(loss.item(), data.size(0)) 134 | test_top1.update(prec1.item(), data.size(0)) 135 | test_top5.update(prec5.item(), data.size(0)) 136 | 137 | data, target = prefetcher.next() 138 | 139 | test_sparsity = (torch.cat(gates_params) != 0).float().mean().item() 140 | print(' * Test set: Loss_CE: %.4f, ' 141 | 'Sparsity: %.4f, Top1 acc: %.4f, Top5 acc: %.4f\n' % ( 142 | test_losses.avg, test_sparsity, test_top1.avg, test_top5.avg 143 | )) 144 | return test_top1.avg, test_sparsity 145 | 146 | best_acc = 0 147 | top1 = misc.AverageMeter() 148 | top5 = misc.AverageMeter() 149 | 150 | prefetcher = datasets.DataPrefetcher(train_loader) 151 | data, target = prefetcher.next() 152 | i = -1 153 | while data is not None: 154 | i += 1 155 | 156 | model.train() 157 | optimizer.zero_grad() 158 | output = model(data) 159 | loss_ce = criterion(output, target) 160 | loss_reg = args.lambd * (torch.cat(gates_params).abs().mean() - args.sparsity_level) ** 2 161 | loss = loss_ce + loss_reg 162 | 163 | loss.backward() 164 | optimizer.step() 165 | 166 | for p in gates_params: 167 | p.data.clamp_(0, 1) 168 | 169 | if i % args.log_interval == 0: 170 | concat_channels = torch.cat(gates_params) 171 | sparsity = (concat_channels != 0).float().mean() 172 | mean_gate = concat_channels.mean() 173 | prec1, prec5 = misc.accuracy(output, target, topk=(1, 5)) 174 | top1.update(prec1.item(), data.size(0)) 175 | top5.update(prec5.item(), data.size(0)) 176 | 177 | print('Train Iter [%d/%d]\tLoss: %.4f, Loss_CE: %.4f, Loss_REG: %.4f, ' 178 | 'Sparsity: %.4f, Mean gate: %.4f, Top1 acc: %.4f, Top5 acc: %.4f' % ( 179 | i, len(train_loader), loss.item(), loss_ce.item(), loss_reg.item(), 180 | sparsity.item(), mean_gate.item(), top1.avg, top5.avg 181 | )) 182 | 183 | if i % args.eval_interval == 0 and i > 0: 184 | acc, test_sparsity = test() 185 | if test_sparsity <= args.sparsity_level and acc > best_acc: 186 | best_acc = acc 187 | torch.save(model.state_dict(), os.path.join(args.logdir, 'checkpoint.pth')) 188 | 189 | temp_params = [] 190 | for j in range(len(gates_params)): 191 | temp_params.append(gates_params[j].data.clone().cpu()) 192 | 193 | misc.dump_pickle(temp_params, os.path.join(args.logdir, 'channel_gates.pkl')) 194 | 195 | data, target = prefetcher.next() 196 | 197 | if best_acc == 0: 198 | torch.save(model.state_dict(), os.path.join(args.logdir, 'checkpoint.pth')) 199 | 200 | temp_params = [] 201 | for j in range(len(gates_params)): 202 | temp_params.append(gates_params[j].data.clone().cpu()) 203 | 204 | misc.dump_pickle(temp_params, os.path.join(args.logdir, 'channel_gates.pkl')) 205 | -------------------------------------------------------------------------------- /script/prepare_imagenet_list.py: -------------------------------------------------------------------------------- 1 | import os 2 | from misc import dump_pickle 3 | import argparse 4 | 5 | IMG_EXTENSIONS = [ 6 | '.jpg', '.JPG', '.jpeg', '.JPEG', 7 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 8 | ] 9 | 10 | 11 | def is_image_file(filename): 12 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 13 | 14 | def find_classes(dir): 15 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 16 | classes.sort() 17 | return classes 18 | 19 | 20 | def prepare_images_list(data_dir, dump_path): 21 | classes = find_classes(data_dir) 22 | data_images_list = [] 23 | for i, class_name in enumerate(classes): 24 | print('processing %d-th class: %s' % (i, class_name)) 25 | temp = [] 26 | class_dir = os.path.join(data_dir, class_name) 27 | filenames = os.listdir(class_dir) 28 | for filename in filenames: 29 | if is_image_file(filename): 30 | temp.append(os.path.join(class_dir, filename)) 31 | 32 | data_images_list.append(temp) 33 | 34 | dump_pickle(data_images_list, dump_path) 35 | 36 | 37 | if __name__ == "__main__": 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument('--data_dir', type=str) 40 | parser.add_argument('--dump_path', type=str) 41 | args = parser.parse_args() 42 | prepare_images_list(args.data_dir, args.dump_path) -------------------------------------------------------------------------------- /script/prune_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import misc 3 | import argparse 4 | import os 5 | import flop_counter 6 | import models 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--gpu', default='0', type=str) 10 | parser.add_argument('--dataset', default='cifar10', type=str) 11 | parser.add_argument('--arch', '-a', default='vgg16_bn', type=str) 12 | parser.add_argument('--sparsity_level', '-s', default=0.2, type=float) 13 | parser.add_argument('--pruned_ratio', '-p', default=0.5, type=float) 14 | parser.add_argument('--max_iter', default=10, type=int) 15 | parser.add_argument('--expanded_inchannel', '-e', default=80, type=int) 16 | parser.add_argument('--seed', default=None, type=int) 17 | 18 | args = parser.parse_args() 19 | args.seed = misc.set_seed(args.seed) 20 | 21 | args.device = 'cuda' 22 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 23 | 24 | args.eps = 0.001 25 | args.num_classes = 10 26 | 27 | args.logdir = 'logs/seed-%d/%s-%s/channel-%d-sparsity-%.2f' % ( 28 | args.seed, args.dataset, args.arch, args.expanded_inchannel, args.sparsity_level 29 | ) 30 | 31 | gates_params = misc.load_pickle(os.path.join(args.logdir, 'channel_gates.pkl')) 32 | 33 | def calculate_flops(model, input_size=(1, 3, 32, 32)): 34 | model = flop_counter.add_flops_counting_methods(model) 35 | model.eval().start_flops_count() 36 | inp = torch.randn(*input_size) 37 | out = model(inp) 38 | flops = model.compute_average_flops_cost() 39 | return flops 40 | 41 | print('==> Initializing full model...') 42 | model = models.__dict__[args.arch](args.num_classes) 43 | 44 | full_flops = calculate_flops(model) 45 | print('Full model FLOPS = %.4f (M)' % (full_flops / 1e6)) 46 | 47 | all_gates = torch.cat(gates_params) 48 | gates_lens = [len(p) for p in gates_params] 49 | 50 | start_pruned_ratio = 0 51 | end_pruned_ratio = 1 52 | 53 | pruned_cfg = models.expanded_cfg(args.expanded_inchannel)[args.arch] 54 | 55 | for j in range(args.max_iter): 56 | cur_pruned_ratio = (start_pruned_ratio + end_pruned_ratio) / 2 57 | reserved_channel_num = round(len(all_gates) * (1 - cur_pruned_ratio)) 58 | reserved_index = all_gates.topk(reserved_channel_num)[1] 59 | mask = torch.zeros(len(all_gates)) 60 | mask[reserved_index] = 1 61 | masks = torch.split_with_sizes(mask, gates_lens) 62 | 63 | counter = 0 64 | for i in range(len(pruned_cfg)): 65 | if pruned_cfg[i] == 'M': 66 | continue 67 | else: 68 | pruned_cfg[i] = masks[counter].sum().long().item() 69 | counter += 1 70 | 71 | model = models.__dict__[args.arch](args.num_classes, args.expanded_inchannel, pruned_cfg) 72 | 73 | pruned_flops = calculate_flops(model) 74 | actual_pruned_ratio = 1 - pruned_flops / full_flops 75 | print('Iter %d, start %.2f, end %.2f, pruned ratio = %.4f' % ( 76 | j, start_pruned_ratio, end_pruned_ratio, actual_pruned_ratio 77 | )) 78 | 79 | if abs(actual_pruned_ratio - args.pruned_ratio) / args.pruned_ratio <= args.eps: 80 | print('Successfully reach the target pruned ratio with FLOPS = %.4f (M)' % ( 81 | pruned_flops / 1e6 82 | )) 83 | break 84 | 85 | if actual_pruned_ratio > args.pruned_ratio: 86 | end_pruned_ratio = cur_pruned_ratio 87 | else: 88 | start_pruned_ratio = cur_pruned_ratio 89 | 90 | misc.dump_pickle(pruned_cfg, os.path.join(args.logdir, 'pruned_cfg-%.2f.pkl' % args.pruned_ratio)) 91 | misc.dump_pickle(masks, os.path.join(args.logdir, 'masks-%.2f.pkl' % args.pruned_ratio)) 92 | -------------------------------------------------------------------------------- /script/prune_model_imagenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import misc 3 | import argparse 4 | import os 5 | import flop_counter 6 | import models 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--gpu', default='0', type=str) 10 | parser.add_argument('--arch', '-a', default='vgg16_bn', type=str) 11 | parser.add_argument('--sparsity_level', '-s', default=0.5, type=float) 12 | parser.add_argument('--pruned_ratio', '-p', default=0.5, type=float) 13 | parser.add_argument('--max_iter', default=10, type=int) 14 | parser.add_argument('--expanded_inchannel', '-e', default=40, type=int) 15 | parser.add_argument('--multiplier', '-m', default=1.0, type=float) 16 | 17 | args = parser.parse_args() 18 | 19 | args.device = 'cuda' 20 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 21 | 22 | args.eps = 0.001 23 | args.num_classes = 1000 24 | 25 | args.logdir = 'logs/imagenet-%s/channel-%d-sparsity-%.2f' % ( 26 | args.arch, args.expanded_inchannel, args.sparsity_level 27 | ) 28 | 29 | gates_params = misc.load_pickle(os.path.join(args.logdir, 'channel_gates.pkl')) 30 | if args.arch == 'mobilenet_v2': 31 | last_param = gates_params[1].clone() 32 | gates_params.pop(1) 33 | gates_params.append(last_param) 34 | 35 | def calculate_flops(model, input_size=(1, 3, 224, 224)): 36 | model = flop_counter.add_flops_counting_methods(model) 37 | model.eval().start_flops_count() 38 | inp = torch.randn(*input_size) 39 | out = model(inp) 40 | flops = model.compute_average_flops_cost() 41 | return flops 42 | 43 | print('==> Initializing full model...') 44 | model = models.__dict__[args.arch](args.num_classes) 45 | 46 | full_flops = calculate_flops(model) 47 | print('Full model FLOPS = %.4f (M)' % (full_flops / 1e6)) 48 | 49 | all_gates = torch.cat(gates_params) 50 | gates_lens = [len(p) for p in gates_params] 51 | 52 | start_pruned_ratio = 0 53 | end_pruned_ratio = 1 54 | 55 | pruned_cfg = models.expanded_cfg(args.expanded_inchannel)[args.arch] 56 | 57 | for j in range(args.max_iter): 58 | cur_pruned_ratio = (start_pruned_ratio + end_pruned_ratio) / 2 59 | reserved_channel_num = round(len(all_gates) * (1 - cur_pruned_ratio)) 60 | reserved_index = all_gates.topk(reserved_channel_num)[1] 61 | mask = torch.zeros(len(all_gates)) 62 | mask[reserved_index] = 1 63 | masks = torch.split_with_sizes(mask, gates_lens) 64 | 65 | counter = 0 66 | for i in range(len(pruned_cfg)): 67 | pruned_cfg[i] = masks[counter].sum().long().item() 68 | counter += 1 69 | 70 | model = models.__dict__[args.arch](args.num_classes, args.expanded_inchannel, args.multiplier, pruned_cfg) 71 | 72 | pruned_flops = calculate_flops(model) 73 | actual_pruned_ratio = 1 - pruned_flops / full_flops 74 | print('Iter %d, start %.2f, end %.2f, pruned ratio = %.4f' % ( 75 | j, start_pruned_ratio, end_pruned_ratio, actual_pruned_ratio 76 | )) 77 | 78 | if abs(actual_pruned_ratio - args.pruned_ratio) / args.pruned_ratio <= args.eps: 79 | print('Successfully reach the target pruned ratio with FLOPS = %.4f (M)' % ( 80 | pruned_flops / 1e6 81 | )) 82 | break 83 | 84 | if actual_pruned_ratio > args.pruned_ratio: 85 | end_pruned_ratio = cur_pruned_ratio 86 | else: 87 | start_pruned_ratio = cur_pruned_ratio 88 | 89 | misc.dump_pickle(pruned_cfg, os.path.join(args.logdir, 'pruned_cfg-%.2f.pkl' % args.pruned_ratio)) 90 | misc.dump_pickle(masks, os.path.join(args.logdir, 'masks-%.2f.pkl' % args.pruned_ratio)) 91 | -------------------------------------------------------------------------------- /script/train_pruned.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | import datasets 3 | import torch.nn.functional as F 4 | import torch 5 | import argparse 6 | import os 7 | 8 | import models 9 | import misc 10 | 11 | print = misc.logger.info 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--gpu', default='0', type=str) 15 | parser.add_argument('--dataset', default='cifar10', type=str) 16 | parser.add_argument('--arch', '-a', default='vgg16_bn', type=str) 17 | parser.add_argument('--lr', default=0.1, type=float) 18 | parser.add_argument('--mm', default=0.9, type=float) 19 | parser.add_argument('--wd', default=1e-4, type=float) 20 | parser.add_argument('--epochs', default=160, type=int) 21 | parser.add_argument('--log_interval', default=100, type=int) 22 | parser.add_argument('--train_batch_size', default=128, type=int) 23 | parser.add_argument('--sparsity_level', '-s', default=0.2, type=float) 24 | parser.add_argument('--pruned_ratio', '-p', default=0.7, type=float) 25 | parser.add_argument('--expanded_inchannel', '-e', default=80, type=int) 26 | parser.add_argument('--seed', default=None, type=int) 27 | parser.add_argument('--budget_train', action='store_true') 28 | 29 | args = parser.parse_args() 30 | args.seed = misc.set_seed(args.seed) 31 | 32 | if args.budget_train: 33 | args.epochs = int(1 / (1 - args.pruned_ratio) * args.epochs) 34 | 35 | args.num_classes = 10 36 | 37 | args.device = 'cuda' 38 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 39 | 40 | args.logdir = 'seed-%d/%s-%s/channel-%d-pruned-%.2f' % ( 41 | args.seed, args.dataset, args.arch, args.expanded_inchannel, args.pruned_ratio 42 | ) 43 | 44 | if args.budget_train: 45 | args.logdir += '-B' 46 | 47 | misc.prepare_logging(args) 48 | 49 | print('==> Preparing data..') 50 | 51 | transform_train = transforms.Compose([ 52 | transforms.RandomCrop(32, padding=4), 53 | transforms.RandomHorizontalFlip(), 54 | transforms.ToTensor(), 55 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 56 | ]) 57 | 58 | transform_val = transforms.Compose([ 59 | transforms.ToTensor(), 60 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 61 | ]) 62 | 63 | trainset = datasets.CIFAR10(root='./data/cifar10', type='train+val', transform=transform_train) 64 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.train_batch_size, shuffle=True, num_workers=2) 65 | 66 | testset = datasets.CIFAR10(root='./data/cifar10', type='test', transform=transform_val) 67 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) 68 | 69 | print('==> Initializing model...') 70 | pruned_cfg = misc.load_pickle('logs/seed-%d/%s-%s/channel-%d-sparsity-%.2f/pruned_cfg-%.2f.pkl' % ( 71 | args.seed, args.dataset, args.arch, args.expanded_inchannel, args.sparsity_level, args.pruned_ratio 72 | )) 73 | 74 | model = models.__dict__[args.arch](args.num_classes, args.expanded_inchannel, pruned_cfg) 75 | 76 | model = model.to(args.device) 77 | 78 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.mm, weight_decay=args.wd) 79 | scheduler = torch.optim.lr_scheduler.MultiStepLR( 80 | optimizer, milestones=[int(args.epochs * 0.5), int(args.epochs * 0.75)], gamma=0.1 81 | ) 82 | 83 | 84 | def train(epoch): 85 | model.train() 86 | for i, (data, target) in enumerate(trainloader): 87 | data = data.to(args.device) 88 | target = target.to(args.device) 89 | 90 | optimizer.zero_grad() 91 | output = model(data) 92 | loss = F.cross_entropy(output, target) 93 | loss.backward() 94 | optimizer.step() 95 | pred = output.max(1)[1] 96 | acc = (pred == target).float().mean() 97 | 98 | if i % args.log_interval == 0: 99 | print('Train Epoch: {} [{}/{}]\tLoss: {:.6f}, Accuracy: {:.4f}'.format( 100 | epoch, i, len(trainloader), loss.item(), acc.item() 101 | )) 102 | 103 | def evaluate(loader): 104 | model.eval() 105 | test_loss = 0 106 | correct = 0 107 | with torch.no_grad(): 108 | for data, target in loader: 109 | data, target = data.to(args.device), target.to(args.device) 110 | output = model(data) 111 | test_loss += F.cross_entropy(output, target, reduction='sum').item() 112 | pred = output.max(1)[1] 113 | correct += (pred == target).float().sum().item() 114 | 115 | test_loss /= len(loader.dataset) 116 | acc = correct / len(loader.dataset) 117 | print('Val set: Average loss: {:.4f}, Accuracy: {:.4f}\n'.format( 118 | test_loss, acc 119 | )) 120 | return acc 121 | 122 | for epoch in range(args.epochs): 123 | scheduler.step() 124 | train(epoch) 125 | evaluate(testloader) 126 | 127 | torch.save(model.state_dict(), os.path.join(args.logdir, 'best_checkpoint.pth')) 128 | test_acc = evaluate(testloader) 129 | print('Final saved model test accuracy = %.4f' % test_acc) 130 | 131 | -------------------------------------------------------------------------------- /script/train_pruned_imagenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import shutil 4 | import os 5 | import torch.nn.parallel 6 | import torch.distributed as dist 7 | import torch.nn as nn 8 | 9 | import models 10 | import datasets 11 | import misc 12 | import math 13 | from apex.parallel import DistributedDataParallel as DDP 14 | import warnings 15 | warnings.filterwarnings("ignore") 16 | 17 | def print(msg): 18 | if args.local_rank == 0: 19 | misc.logger.info(msg) 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--data', default='data/imagenet', type=str) 23 | parser.add_argument('--arch', '-a', default='mobilenet_v1', type=str) 24 | parser.add_argument('--lr_gamma', default=0.975, type=float) 25 | parser.add_argument('--lr_scheduler', default='cos', type=str) 26 | parser.add_argument('--lr', default=0.05, type=float) 27 | parser.add_argument('--mm', default=0.9, type=float) 28 | parser.add_argument('--wd', default=4e-5, type=float) 29 | parser.add_argument('--epochs', default=150, type=int) 30 | parser.add_argument('--log_interval', default=50, type=int) 31 | parser.add_argument('-b', '--batch_size', default=256, type=int, 32 | metavar='N', help='mini-batch size per process (default: 256)') 33 | parser.add_argument("--local_rank", default=0, type=int) 34 | parser.add_argument('--sparsity_level', '-s', default=0.5, type=float) 35 | parser.add_argument('--pruned_ratio', '-p', default=0.5, type=float) 36 | parser.add_argument('--expanded_inchannel', '-e', default=40, type=int) 37 | parser.add_argument('--multiplier', '-m', default=1.0, type=float) 38 | parser.add_argument('--budget_train', action='store_true') 39 | parser.add_argument('--label_smooth', action='store_true') 40 | 41 | args = parser.parse_args() 42 | if args.budget_train: 43 | args.epochs = 200 if args.arch == 'resnet50' else 300 44 | 45 | args.logdir = 'imagenet-%s/channel-%d-pruned-%.2f' % ( 46 | args.arch, args.expanded_inchannel, args.pruned_ratio 47 | ) 48 | 49 | if args.budget_train: 50 | args.logdir += '-B' 51 | if args.label_smooth: 52 | args.logdir += '-smooth' 53 | 54 | torch.backends.cudnn.benchmark = True 55 | args.distributed = False 56 | if 'WORLD_SIZE' in os.environ: 57 | args.distributed = int(os.environ['WORLD_SIZE']) > 1 58 | 59 | args.gpu = 0 60 | args.world_size = 1 61 | 62 | if args.distributed: 63 | args.gpu = args.local_rank % torch.cuda.device_count() 64 | torch.cuda.set_device(args.gpu) 65 | torch.distributed.init_process_group(backend='nccl', 66 | init_method='env://') 67 | args.world_size = torch.distributed.get_world_size() 68 | 69 | if args.local_rank == 0: 70 | misc.prepare_logging(args) 71 | 72 | print("=> Using model {}".format(args.arch)) 73 | pruned_cfg = misc.load_pickle('logs/imagenet-%s/channel-%d-sparsity-%.2f/pruned_cfg-%.2f.pkl' % ( 74 | args.arch, args.expanded_inchannel, args.sparsity_level, args.pruned_ratio 75 | )) 76 | 77 | model = models.__dict__[args.arch](1000, args.expanded_inchannel, args.multiplier, pruned_cfg) 78 | model = model.cuda() 79 | model = DDP(model, delay_allreduce=True) 80 | # define loss function (criterion) and optimizer 81 | criterion = nn.CrossEntropyLoss().cuda() 82 | if args.label_smooth: 83 | class CrossEntropyLabelSmooth(nn.Module): 84 | def __init__(self, num_classes, epsilon): 85 | super(CrossEntropyLabelSmooth, self).__init__() 86 | self.num_classes = num_classes 87 | self.epsilon = epsilon 88 | self.logsoftmax = nn.LogSoftmax(dim=1) 89 | 90 | def forward(self, inputs, targets): 91 | log_probs = self.logsoftmax(inputs) 92 | targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1) 93 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 94 | loss = (-targets * log_probs).mean(0).sum() 95 | return loss 96 | criterion = CrossEntropyLabelSmooth(num_classes=1000, epsilon=0.1).cuda() 97 | 98 | print('==> Preparing data..') 99 | train_loader, train_sampler = datasets.get_imagenet_loader( 100 | os.path.join(args.data, 'train'), args.batch_size, type='train', mobile_setting=(not args.arch == 'resnet50') 101 | ) 102 | test_loader = datasets.get_imagenet_loader( 103 | os.path.join(args.data, 'val'), 100, type='test', mobile_setting=(not args.arch == 'resnet50') 104 | ) 105 | 106 | def get_lr_scheduler(optimizer): 107 | """get learning rate""" 108 | 109 | if args.lr_scheduler == 'multistep': 110 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( 111 | optimizer, milestones=[int(0.3*args.epochs), int(0.6*args.epochs), int(0.9*args.epochs)], 112 | gamma=0.1) 113 | 114 | elif args.lr_scheduler == 'cos': 115 | lr_dict = {} 116 | for i in range(args.epochs): 117 | lr_dict[i] = 0.5 * (1 + math.cos(math.pi * i / args.epochs)) 118 | lr_lambda = lambda epoch: lr_dict[epoch] 119 | lr_scheduler = torch.optim.lr_scheduler.LambdaLR( 120 | optimizer, lr_lambda=lr_lambda) 121 | return lr_scheduler 122 | 123 | def reduce_tensor(tensor): 124 | rt = tensor.clone() 125 | dist.all_reduce(rt, op=dist.reduce_op.SUM) 126 | rt /= args.world_size 127 | return rt 128 | 129 | def to_python_float(t): 130 | if hasattr(t, 'item'): 131 | return t.item() 132 | else: 133 | return t[0] 134 | 135 | args.lr = args.lr*float(args.batch_size*args.world_size)/256. 136 | # all depthwise convolution (N, 1, x, x) has no weight decay 137 | # weight decay only on normal conv and fc 138 | model_params = [] 139 | for params in model.parameters(): 140 | ps = list(params.size()) 141 | if len(ps) == 4 and ps[1] != 1: 142 | weight_decay = args.wd 143 | elif len(ps) == 2: 144 | weight_decay = args.wd 145 | else: 146 | weight_decay = 0 147 | item = {'params': params, 'weight_decay': weight_decay, 148 | 'lr': args.lr, 'momentum': args.mm, 149 | 'nesterov': True} 150 | model_params.append(item) 151 | 152 | optimizer = torch.optim.SGD(model_params) 153 | lr_scheduler = get_lr_scheduler(optimizer) 154 | 155 | def train(train_loader, model, criterion, optimizer, epoch): 156 | losses = misc.AverageMeter() 157 | top1 = misc.AverageMeter() 158 | top5 = misc.AverageMeter() 159 | 160 | # switch to train mode 161 | prefetcher = datasets.DataPrefetcher(train_loader) 162 | model.train() 163 | 164 | input, target = prefetcher.next() 165 | i = -1 166 | while input is not None: 167 | i += 1 168 | 169 | output = model(input) 170 | loss = criterion(output, target) 171 | 172 | # compute gradient and do SGD step 173 | optimizer.zero_grad() 174 | loss.backward() 175 | optimizer.step() 176 | 177 | if i % args.log_interval == 0: 178 | prec1, prec5 = misc.accuracy(output.data, target, topk=(1, 5)) 179 | 180 | # Average loss and accuracy across processes for logging 181 | reduced_loss = reduce_tensor(loss.data) 182 | prec1 = reduce_tensor(prec1) 183 | prec5 = reduce_tensor(prec5) 184 | 185 | # to_python_float incurs a host<->device sync 186 | losses.update(to_python_float(reduced_loss), input.size(0)) 187 | top1.update(to_python_float(prec1), input.size(0)) 188 | top5.update(to_python_float(prec5), input.size(0)) 189 | 190 | torch.cuda.synchronize() 191 | 192 | print('Epoch: [{0}][{1}/{2}]\t' 193 | 'Loss {loss.val:.10f} ({loss.avg:.4f})\t' 194 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 195 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 196 | epoch, i, len(train_loader), loss=losses, top1=top1, top5=top5)) 197 | 198 | input, target = prefetcher.next() 199 | 200 | 201 | def validate(val_loader, model, criterion, epoch): 202 | losses = misc.AverageMeter() 203 | top1 = misc.AverageMeter() 204 | top5 = misc.AverageMeter() 205 | 206 | # switch to evaluate mode 207 | prefetcher = datasets.DataPrefetcher(val_loader) 208 | model.eval() 209 | 210 | input, target = prefetcher.next() 211 | i = -1 212 | while input is not None: 213 | i += 1 214 | with torch.no_grad(): 215 | output = model(input) 216 | loss = criterion(output, target) 217 | 218 | # measure accuracy and record loss 219 | prec1, prec5 = misc.accuracy(output.data, target, topk=(1, 5)) 220 | 221 | reduced_loss = reduce_tensor(loss.data) 222 | prec1 = reduce_tensor(prec1) 223 | prec5 = reduce_tensor(prec5) 224 | 225 | losses.update(to_python_float(reduced_loss), input.size(0)) 226 | top1.update(to_python_float(prec1), input.size(0)) 227 | top5.update(to_python_float(prec5), input.size(0)) 228 | 229 | input, target = prefetcher.next() 230 | 231 | print(' * Test Epoch {0}, Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}\n' 232 | .format(epoch, top1=top1, top5=top5)) 233 | 234 | return top1.avg 235 | 236 | 237 | # main 238 | best_prec1 = 0 239 | for epoch in range(args.epochs): 240 | if args.distributed: 241 | train_sampler.set_epoch(epoch) 242 | 243 | lr_scheduler.step() 244 | # train for one epoch 245 | train(train_loader, model, criterion, optimizer, epoch) 246 | 247 | # evaluate on validation set 248 | prec1 = validate(test_loader, model, criterion, epoch) 249 | 250 | # remember best prec@1 and save checkpoint 251 | if args.local_rank == 0: 252 | is_best = prec1 > best_prec1 253 | best_prec1 = max(prec1, best_prec1) 254 | torch.save({ 255 | 'epoch': epoch + 1, 256 | 'arch': args.arch, 257 | 'state_dict': model.state_dict(), 258 | 'best_prec1': best_prec1, 259 | 'optimizer' : optimizer.state_dict(), 260 | }, os.path.join(args.logdir, 'checkpoint.pth.tar')) 261 | 262 | if is_best: 263 | shutil.copyfile(os.path.join(args.logdir, 'checkpoint.pth.tar'), 264 | os.path.join(args.logdir, 'model_best.pth.tar')) 265 | print(' * Save best model @ Epoch {}\n'.format(epoch)) 266 | --------------------------------------------------------------------------------