├── models ├── __init__.py ├── search_space_mbv2.py ├── search_space_res.py ├── model_derived.py ├── operations.py ├── dropped_model.py └── search_space_base.py ├── tools ├── __init__.py ├── collections.py ├── io.py ├── lr_scheduler.py ├── utils.py ├── config_yaml.py └── multadds_count.py ├── dataset ├── __init__.py ├── mk_img_list.py ├── mk_split_img_list.py ├── torchvision_extension.py ├── img2lmdb.py ├── prefetch_data.py ├── lmdb_dataset.py └── imagenet_data.py ├── run_apis ├── __init__.py ├── derive_arch_res.py ├── derive_arch_mbv2.py ├── derive_arch.py ├── latency_measure.py ├── validation.py ├── optimizer.py ├── retrain.py ├── trainer.py └── search.py ├── configs ├── __init__.py ├── imagenet_val_cfg.py ├── imagenet_search_cfg_resnet.yaml ├── search_config.py ├── imagenet_search_cfg_mbv2.yaml └── imagenet_train_cfg.py ├── imgs ├── archs.png ├── res_comp.png ├── mbv2_comp.png ├── mbv2_results.png ├── res_results.png └── search_space.png ├── .gitignore ├── README.md ├── latency_list ├── lat_list_resv2_32_reg10 ├── lat_list_resv2_32 └── lat_list_densenas_mbv2_xp32 └── LICENSE /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /run_apis/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("..") 3 | -------------------------------------------------------------------------------- /imgs/archs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaminFong/DenseNAS/HEAD/imgs/archs.png -------------------------------------------------------------------------------- /imgs/res_comp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaminFong/DenseNAS/HEAD/imgs/res_comp.png -------------------------------------------------------------------------------- /imgs/mbv2_comp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaminFong/DenseNAS/HEAD/imgs/mbv2_comp.png -------------------------------------------------------------------------------- /imgs/mbv2_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaminFong/DenseNAS/HEAD/imgs/mbv2_results.png -------------------------------------------------------------------------------- /imgs/res_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaminFong/DenseNAS/HEAD/imgs/res_results.png -------------------------------------------------------------------------------- /imgs/search_space.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaminFong/DenseNAS/HEAD/imgs/search_space.png -------------------------------------------------------------------------------- /configs/imagenet_val_cfg.py: -------------------------------------------------------------------------------- 1 | from tools.collections import AttrDict 2 | 3 | __C = AttrDict() 4 | 5 | cfg = __C 6 | 7 | __C.net_config="" 8 | 9 | __C.data=AttrDict() 10 | __C.data.num_workers=16 11 | __C.data.batch_size=1024 12 | __C.data.dataset='imagenet' 13 | __C.data.train_data_type='lmdb' 14 | __C.data.val_data_type='lmdb' 15 | __C.data.patch_dataset=False 16 | __C.data.num_examples=1281167 17 | __C.data.input_size=(3,224,224) -------------------------------------------------------------------------------- /dataset/mk_img_list.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | def get_list(data_path, output_path): 5 | for split in os.listdir(data_path): 6 | split_path = os.path.join(data_path, split) 7 | if not os.path.isdir(split_path): 8 | continue 9 | f = open(os.path.join(output_path, split + '_datalist'), 'a+') 10 | for sub in os.listdir(split_path): 11 | sub_path = os.path.join(split_path, sub) 12 | if not os.path.isdir(sub_path): 13 | continue 14 | for image in os.listdir(sub_path): 15 | image_name = sub + '/' + image 16 | f.writelines(image_name + '\n') 17 | f.close() 18 | 19 | 20 | if __name__ == '__main__': 21 | parser = argparse.ArgumentParser("Params") 22 | parser.add_argument('--image_path', type=str, default='', help='the path of the images') 23 | parser.add_argument('--output_path', type=str, default='', help='the output path of the lmdb file') 24 | args = parser.parse_args() 25 | 26 | get_list(args.image_path, args.output_path) 27 | -------------------------------------------------------------------------------- /dataset/mk_split_img_list.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | 5 | def get_list(data_path, output_path): 6 | for split in os.listdir(data_path): 7 | if split == 'train': 8 | split_path = os.path.join(data_path, split) 9 | if not os.path.isdir(split_path): 10 | continue 11 | f_train = open(os.path.join(output_path, split + '_datalist'), 'w') 12 | f_val = open(os.path.join(output_path, 'val' + '_datalist'), 'w') 13 | class_list = os.listdir(split_path) 14 | for sub in class_list[:100]: 15 | sub_path = os.path.join(split_path, sub) 16 | if not os.path.isdir(sub_path): 17 | continue 18 | img_list = os.listdir(sub_path) 19 | train_len = int(0.8*len(img_list)) 20 | for image in img_list[:train_len]: 21 | image_name = os.path.join(sub, image) 22 | f_train.writelines(image_name + '\n') 23 | for image in img_list[train_len:]: 24 | image_name = os.path.join(sub, image) 25 | f_val.writelines(image_name + '\n') 26 | f_train.close() 27 | f_val.close() 28 | 29 | if __name__ == '__main__': 30 | parser = argparse.ArgumentParser("Params") 31 | parser.add_argument('--image_path', type=str, default='', help='the path of the images') 32 | parser.add_argument('--output_path', type=str, default='.', help='the output path of the list file') 33 | args = parser.parse_args() 34 | 35 | get_list(args.image_path, args.output_path) 36 | -------------------------------------------------------------------------------- /models/search_space_mbv2.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .operations import OPS 4 | from .search_space_base import Conv1_1_Block, Block 5 | from .search_space_base import Network as BaseSearchSpace 6 | 7 | class Network(BaseSearchSpace): 8 | def __init__(self, init_ch, dataset, config): 9 | super(Network, self).__init__(init_ch, dataset, config) 10 | 11 | self.input_block = nn.Sequential( 12 | nn.Conv2d(in_channels=3, out_channels=self._C_input, kernel_size=3, 13 | stride=2, padding=1, bias=False), 14 | nn.BatchNorm2d(self._C_input, affine=False, track_running_stats=True), 15 | nn.ReLU6(inplace=True) 16 | ) 17 | 18 | self.head_block = OPS['mbconv_k3_t1'](self._C_input, self._head_dim, 1, affine=False, track_running_stats=True) 19 | 20 | self.blocks = nn.ModuleList() 21 | 22 | for i in range(self.num_blocks): 23 | input_config = self.input_configs[i] 24 | self.blocks.append(Block( 25 | input_config['in_chs'], 26 | input_config['ch'], 27 | input_config['strides'], 28 | input_config['num_stack_layers'], 29 | self.config 30 | )) 31 | 32 | self.conv1_1_block = Conv1_1_Block(self.input_configs[-1]['in_chs'], 33 | self.config.optim.last_dim) 34 | 35 | self.global_pooling = nn.AdaptiveAvgPool2d(1) 36 | self.classifier = nn.Linear(self.config.optim.last_dim, self._num_classes) 37 | 38 | self.init_model(model_init=config.optim.init_mode) 39 | self.set_bn_param(self.config.optim.bn_momentum, self.config.optim.bn_eps) 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /dataset/torchvision_extension.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torchvision.transforms as transforms 3 | from torchvision.transforms import functional as F 4 | 5 | #In this file some more transformations (apart from the ones defined in torchvision.transform) 6 | #are added. Particularly helpful to train imagenet, and in the style of the transforms 7 | #used by fb.resnet https://github.com/facebook/fb.resnet.torch/blob/master/datasets/imagenet.lua 8 | 9 | #This file is taken from a proposed pull request on the torchvision github project. 10 | #At the moment this pull request has not been accepted yet, that is why I report it here. 11 | #Link to the pull request: https://github.com/pytorch/vision/pull/27/files 12 | 13 | class Lighting(object): 14 | 15 | """Lighting noise(AlexNet - style PCA - based noise)""" 16 | 17 | def __init__(self, alphastd, eigval, eigvec): 18 | self.alphastd = alphastd 19 | self.eigval = eigval 20 | self.eigvec = eigvec 21 | 22 | def __call__(self, img): 23 | # img is supposed go be a torch tensor 24 | 25 | if self.alphastd == 0: 26 | return img 27 | 28 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 29 | rgb = self.eigvec.type_as(img).clone()\ 30 | .mul(alpha.view(1, 3).expand(3, 3))\ 31 | .mul(self.eigval.view(1, 3).expand(3, 3))\ 32 | .sum(1).squeeze() 33 | 34 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 35 | 36 | 37 | class RandomScale(object): 38 | 39 | """ResNet style data augmentation""" 40 | 41 | def __init__(self, minSize, maxSize): 42 | self.minSize = minSize 43 | self.maxSize = maxSize 44 | 45 | def __call__(self, img): 46 | 47 | targetSz = int(round(random.uniform(self.minSize, self.maxSize))) 48 | 49 | return F.resize(img, targetSz) 50 | 51 | -------------------------------------------------------------------------------- /run_apis/derive_arch_res.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from .derive_arch import BaseArchGenerate 3 | 4 | class ArchGenerate(BaseArchGenerate): 5 | def __init__(self, super_network, config): 6 | super(ArchGenerate, self).__init__(super_network, config) 7 | 8 | def derive_archs(self, betas, head_alphas, stack_alphas, if_display=True): 9 | 10 | self.update_arch_params(betas, head_alphas, stack_alphas) 11 | 12 | # [[ch, head_op, [stack_op], num_layers, stride], ..., [...]] 13 | derived_archs = [] 14 | ch_path, derived_chs = self.derive_chs() 15 | 16 | layer_count = 0 17 | for i, (ch_idx, ch) in enumerate(zip(ch_path, derived_chs)): 18 | if ch_idx == 0 or i == len(derived_chs)-1: 19 | continue 20 | 21 | block_idx = ch_idx - 1 22 | input_config = self.input_configs[block_idx] 23 | 24 | head_id = input_config['in_block_idx'].index(ch_path[i-1]) 25 | head_alpha = self.head_alphas[block_idx][head_id] 26 | head_op = self.derive_ops(head_alpha, 'head') 27 | 28 | stride = input_config['strides'][input_config['in_block_idx'].index(ch_path[i-1])] 29 | 30 | stack_ops = [] 31 | for stack_alpha in self.stack_alphas[block_idx]: 32 | stack_op = self.derive_ops(stack_alpha, 'stack') 33 | if stack_op != 'skip_connect': 34 | stack_ops.append(stack_op) 35 | layer_count += 1 36 | 37 | derived_archs.append( 38 | [[derived_chs[i-1], ch], head_op, stack_ops, len(stack_ops), stride] 39 | ) 40 | 41 | layer_count += len(derived_archs) 42 | if if_display: 43 | logging.info('Derived arch: \n' + '|\n'.join(map(str, derived_archs))) 44 | logging.info('Total {} layers.'.format(layer_count)) 45 | 46 | return derived_archs -------------------------------------------------------------------------------- /run_apis/derive_arch_mbv2.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from .derive_arch import BaseArchGenerate 3 | 4 | class ArchGenerate(BaseArchGenerate): 5 | def __init__(self, super_network, config): 6 | super(ArchGenerate, self).__init__(super_network, config) 7 | 8 | def derive_archs(self, betas, head_alphas, stack_alphas, if_display=True): 9 | 10 | self.update_arch_params(betas, head_alphas, stack_alphas) 11 | 12 | # [[ch, head_op, [stack_op], num_layers, stride], ..., [...]] 13 | derived_archs = [[[self.config.optim.init_dim, self.config.optim.head_dim], 'mbconv_k3_t1', [], 0, 1],] 14 | ch_path, derived_chs = self.derive_chs() 15 | 16 | layer_count = 0 17 | for i, (ch_idx, ch) in enumerate(zip(ch_path, derived_chs)): 18 | if ch_idx == 0 or i == len(derived_chs)-1: 19 | continue 20 | 21 | block_idx = ch_idx - 1 22 | input_config = self.input_configs[block_idx] 23 | 24 | head_id = input_config['in_block_idx'].index(ch_path[i-1]) 25 | head_alpha = self.head_alphas[block_idx][head_id] 26 | head_op = self.derive_ops(head_alpha, 'head') 27 | 28 | stride = input_config['strides'][input_config['in_block_idx'].index(ch_path[i-1])] 29 | 30 | stack_ops = [] 31 | for stack_alpha in self.stack_alphas[block_idx]: 32 | stack_op = self.derive_ops(stack_alpha, 'stack') 33 | if stack_op != 'skip_connect': 34 | stack_ops.append(stack_op) 35 | layer_count += 1 36 | 37 | derived_archs.append( 38 | [[derived_chs[i-1], ch], head_op, stack_ops, len(stack_ops), stride] 39 | ) 40 | derived_archs.append([[derived_chs[-2], self.config.optim.last_dim], 'conv1_1']) 41 | 42 | layer_count += len(derived_archs) 43 | if if_display: 44 | logging.info('Derived arch: \n' + '|\n'.join(map(str, derived_archs))) 45 | logging.info('Total {} layers.'.format(layer_count)) 46 | 47 | return derived_archs -------------------------------------------------------------------------------- /configs/imagenet_search_cfg_resnet.yaml: -------------------------------------------------------------------------------- 1 | net_type: res 2 | train_params: 3 | epochs: 70 4 | use_seed: True 5 | seed: 2 6 | 7 | search_params: 8 | val_start_epoch: 50 9 | arch_update_epoch: 10 10 | sample_policy: prob # prob uniform 11 | weight_sample_num: 1 12 | softmax_temp: 0.9 13 | 14 | PRIMITIVES_stack: ['basic_block', 15 | 'skip_connect',] 16 | PRIMITIVES_head: ['basic_block', 17 | ] 18 | 19 | adjoin_connect_nums: [10, 10, 10, 10, 10, 10, 10] 20 | net_scale: 21 | chs: [32, 22 | 48, 56, 64, 23 | 72, 96, 112, 24 | 128, 144, 160, 176, 192, 208, 224, 25 | 240, 256, 272, 288, 480, 496, 512] 26 | fm_sizes: [112, 27 | 56, 56, 56, 28 | 28, 28, 28, 29 | 14, 14, 14, 14, 14, 14, 14, 30 | 7, 7, 7, 7, 7, 7, 7] 31 | stage: [0, 32 | 1, 1, 1, 33 | 2, 2, 2, 34 | 3, 3, 3, 3, 4, 4, 4, 35 | 5, 5, 5, 5, 6, 6, 6, 36 | 7] 37 | num_layers: [0, 38 | 0, 0, 0, 39 | 5, 5, 5, 40 | 15, 15, 15, 15, 5, 5, 5, 41 | 5, 5, 5, 5, 1, 1, 1] 42 | 43 | optim: 44 | last_dim: 512 45 | init_dim: 32 46 | bn_momentum: 0.1 47 | bn_eps: 0.001 48 | weight: 49 | init_lr: 0.2 50 | min_lr: 0.0001 51 | lr_decay_type: cosine 52 | momentum: 0.9 53 | weight_decay: 0.00004 54 | arch: 55 | alpha_lr: 0.0003 56 | beta_lr: 0.0003 57 | weight_decay: 0.001 58 | 59 | if_sub_obj: True 60 | sub_obj: 61 | type: flops 62 | skip_reg: True 63 | log_base: 3500. 64 | sub_loss_factor: 0.2 65 | 66 | if_resume: False 67 | resume: 68 | load_path: '' 69 | load_epoch: 9 70 | 71 | data: 72 | batch_size: 512 73 | num_workers: 16 74 | dataset: imagenet 75 | train_data_type: lmdb 76 | val_data_type: lmdb 77 | patch_dataset: False 78 | input_size: (3,224,224) 79 | type_of_data_aug: random_sized 80 | color: False 81 | random_sized: 82 | min_scale: 0.08 83 | 84 | -------------------------------------------------------------------------------- /configs/search_config.py: -------------------------------------------------------------------------------- 1 | from tools.collections import AttrDict 2 | 3 | __C = AttrDict() 4 | search_cfg = __C 5 | 6 | 7 | __C.search_params=AttrDict() 8 | __C.search_params.arch_update_epoch=10 9 | __C.search_params.val_start_epoch=120 10 | __C.search_params.sample_policy='prob' # prob uniform 11 | __C.search_params.weight_sample_num=1 12 | __C.search_params.softmax_temp=1. 13 | 14 | __C.search_params.adjoin_connect_nums = [] 15 | __C.search_params.net_scale = AttrDict() 16 | __C.search_params.net_scale.chs = [] 17 | __C.search_params.net_scale.fm_sizes = [] 18 | __C.search_params.net_scale.stage = [] 19 | __C.search_params.net_scale.num_layers = [] 20 | 21 | __C.search_params.PRIMITIVES_stack = [ 22 | 'mbconv_k3_t3', 23 | 'mbconv_k3_t6', 24 | 'mbconv_k5_t3', 25 | 'mbconv_k5_t6', 26 | 'mbconv_k7_t3', 27 | 'mbconv_k7_t6', 28 | 'skip_connect', 29 | ] 30 | __C.search_params.PRIMITIVES_head = [ 31 | 'mbconv_k3_t3', 32 | 'mbconv_k3_t6', 33 | 'mbconv_k5_t3', 34 | 'mbconv_k5_t6', 35 | 'mbconv_k7_t3', 36 | 'mbconv_k7_t6', 37 | ] 38 | 39 | __C.optim=AttrDict() 40 | __C.optim.init_dim=16 41 | __C.optim.head_dim=16 42 | __C.optim.last_dim=1984 43 | __C.optim.weight=AttrDict() 44 | __C.optim.weight.init_lr=0.1 45 | __C.optim.weight.min_lr=1e-4 46 | __C.optim.weight.lr_decay_type='cosine' 47 | __C.optim.weight.momentum=0.9 48 | __C.optim.weight.weight_decay=4e-5 49 | 50 | __C.optim.arch=AttrDict() 51 | __C.optim.arch.alpha_lr=3e-4 52 | __C.optim.arch.beta_lr=3e-4 53 | __C.optim.arch.weight_decay=1e-3 54 | 55 | __C.optim.if_sub_obj=True 56 | __C.optim.sub_obj=AttrDict() 57 | __C.optim.sub_obj.type='latency' # latency / flops 58 | __C.optim.sub_obj.skip_reg=False 59 | __C.optim.sub_obj.log_base=15.0 60 | __C.optim.sub_obj.sub_loss_factor=0.1 61 | __C.optim.sub_obj.latency_list_path='' 62 | 63 | -------------------------------------------------------------------------------- /run_apis/derive_arch.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | class BaseArchGenerate(object): 5 | def __init__(self, super_network, config): 6 | self.config = config 7 | self.num_blocks = len(super_network.block_chs) # including the input 32 and 1280 block 8 | self.super_chs = super_network.block_chs 9 | self.input_configs = super_network.input_configs 10 | 11 | 12 | def update_arch_params(self, betas, head_alphas, stack_alphas): 13 | self.betas = betas 14 | self.head_alphas = head_alphas 15 | self.stack_alphas = stack_alphas 16 | 17 | 18 | def derive_chs(self): 19 | """ 20 | using viterbi algorithm to choose the best path of the super net 21 | """ 22 | path_p_max = [] # [[max_last_state_id, trans_prob], ...] 23 | 24 | path_p_max.append([0, 1]) 25 | 26 | for input_config in self.input_configs: 27 | block_path_prob_max = [None, 0] 28 | for in_block_id, beta_id in zip(input_config['in_block_idx'], input_config['beta_idx']): 29 | path_prob = path_p_max[in_block_id][1]*self.betas[in_block_id][beta_id] 30 | if path_prob > block_path_prob_max[1]: 31 | block_path_prob_max = [in_block_id ,path_prob] 32 | 33 | path_p_max.append(block_path_prob_max) 34 | 35 | ch_idx = len(path_p_max) - 1 36 | ch_path = [] 37 | ch_path.append(ch_idx) 38 | 39 | while 1: 40 | ch_idx = path_p_max[ch_idx][0] 41 | ch_path.append(ch_idx) 42 | if ch_idx == 0: 43 | break 44 | 45 | derived_chs = [self.super_chs[ch_id] for ch_id in ch_path] 46 | 47 | ch_path = ch_path[::-1] 48 | derived_chs = derived_chs[::-1] 49 | 50 | return ch_path, derived_chs 51 | 52 | 53 | def derive_ops(self, alpha, alpha_type): 54 | assert alpha_type in ['head', 'stack'] 55 | 56 | if alpha_type == 'head': 57 | op_type = self.config.search_params.PRIMITIVES_head[alpha.index(max(alpha))] 58 | elif alpha_type == 'stack': 59 | op_type = self.config.search_params.PRIMITIVES_stack[alpha.index(max(alpha))] 60 | 61 | return op_type 62 | 63 | 64 | def derive_archs(self, betas, head_alphas, stack_alphas, if_display=True): 65 | raise NotImplementedError -------------------------------------------------------------------------------- /configs/imagenet_search_cfg_mbv2.yaml: -------------------------------------------------------------------------------- 1 | net_type: mbv2 2 | train_params: 3 | epochs: 150 4 | use_seed: True 5 | seed: 2 6 | 7 | search_params: 8 | arch_update_epoch: 50 9 | val_start_epoch: 120 10 | sample_policy: prob # prob uniform 11 | weight_sample_num: 1 12 | softmax_temp: 1. 13 | 14 | PRIMITIVES_stack: ['mbconv_k3_t3', 15 | 'mbconv_k3_t6', 16 | 'mbconv_k5_t3', 17 | 'mbconv_k5_t6', 18 | 'mbconv_k7_t3', 19 | 'mbconv_k7_t6', 20 | 'skip_connect',] 21 | PRIMITIVES_head: ['mbconv_k3_t3', 22 | 'mbconv_k3_t6', 23 | 'mbconv_k5_t3', 24 | 'mbconv_k5_t6', 25 | 'mbconv_k7_t3', 26 | 'mbconv_k7_t6', 27 | ] 28 | # search space 29 | adjoin_connect_nums: [4, 4, 4, 4, 4, 4, 4] 30 | net_scale: 31 | chs: [16, 24, 32, 40, 48, 56, 64, 72, 96, 112, 128, 160, 176, 192, 320, 352, 384] 32 | fm_sizes: [112, 56, 28, 28, 28, 14, 14, 14, 14, 14, 14, 7, 7, 7, 7, 7, 7] 33 | stage: [0, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7] 34 | num_layers: [0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 0, 0] 35 | 36 | optim: 37 | if_sub_obj: True 38 | sub_obj: 39 | type: latency 40 | skip_reg: False 41 | log_base: 15.0 42 | sub_loss_factor: 0.15 43 | latency_list_path: lat_list_densenas_mbv2_xp32 44 | 45 | if_resume: False 46 | resume: 47 | load_path: '' 48 | load_epoch: 49 49 | 50 | init_dim: 16 51 | head_dim: 16 52 | last_dim: 1984 53 | bn_momentum: 0.1 54 | bn_eps: 0.001 55 | weight: 56 | init_lr: 0.2 57 | min_lr: 0.0001 58 | lr_decay_type: cosine 59 | momentum: 0.9 60 | weight_decay: 0.00004 61 | arch: 62 | alpha_lr: 0.0003 63 | beta_lr: 0.0003 64 | weight_decay: 0.001 65 | 66 | data: 67 | dataset: imagenet 68 | batch_size: 352 69 | num_workers: 16 70 | train_data_type: lmdb 71 | val_data_type: lmdb 72 | input_size: (3,224,224) 73 | type_of_data_aug: random_sized # random_sized / rand_scale 74 | color: False 75 | random_sized: 76 | min_scale: 0.08 77 | -------------------------------------------------------------------------------- /models/search_space_res.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .operations import OPS 4 | from .search_space_base import Conv1_1_Block, Block 5 | from .search_space_base import Network as BaseSearchSpace 6 | 7 | class Network(BaseSearchSpace): 8 | def __init__(self, init_ch, dataset, config, groups=1, base_width=64, dilation=1, norm_layer=None): 9 | super(Network, self).__init__(init_ch, dataset, config) 10 | 11 | if norm_layer is None: 12 | norm_layer = nn.BatchNorm2d 13 | if groups != 1 or base_width != 64: 14 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 15 | if dilation > 1: 16 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 17 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 18 | self.input_block = nn.Sequential( 19 | nn.Conv2d(3, self._C_input, kernel_size=3, stride=2, padding=1, bias=False), 20 | norm_layer(self._C_input), 21 | nn.ReLU(inplace=True), 22 | ) 23 | 24 | self.blocks = nn.ModuleList() 25 | 26 | for i in range(self.num_blocks): 27 | input_config = self.input_configs[i] 28 | self.blocks.append(Block( 29 | input_config['in_chs'], 30 | input_config['ch'], 31 | input_config['strides'], 32 | input_config['num_stack_layers'], 33 | self.config 34 | )) 35 | 36 | if 'bottle_neck' in self.config.search_params.PRIMITIVES_stack: 37 | conv1_1_input_dim = [ch * 4 for ch in self.input_configs[-1]['in_chs']] 38 | last_dim = self.config.optim.last_dim * 4 39 | else: 40 | conv1_1_input_dim = self.input_configs[-1]['in_chs'] 41 | last_dim = self.config.optim.last_dim 42 | self.conv1_1_block = Conv1_1_Block(conv1_1_input_dim, last_dim) 43 | self.global_pooling = nn.AdaptiveAvgPool2d((1, 1)) 44 | self.classifier = nn.Linear(last_dim, self._num_classes) 45 | 46 | for m in self.modules(): 47 | if isinstance(m, nn.Conv2d): 48 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 49 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 50 | if m.affine==True: 51 | nn.init.constant_(m.weight, 1) 52 | nn.init.constant_(m.bias, 0) 53 | 54 | -------------------------------------------------------------------------------- /configs/imagenet_train_cfg.py: -------------------------------------------------------------------------------- 1 | from tools.collections import AttrDict 2 | 3 | __C = AttrDict() 4 | 5 | cfg = __C 6 | 7 | __C.net_type='mbv2' # mbv2 / res 8 | __C.net_config="""[[16, 16], 'mbconv_k3_t1', [], 0, 1]| 9 | [[16, 24], 'mbconv_k3_t6', [], 0, 2]| 10 | [[24, 48], 'mbconv_k7_t6', ['mbconv_k3_t3'], 1, 2]| 11 | [[48, 72], 'mbconv_k5_t6', ['mbconv_k3_t6', 'mbconv_k3_t3'], 2, 2]| 12 | [[72, 128], 'mbconv_k3_t6', ['mbconv_k3_t3', 'mbconv_k3_t3'], 2, 1]| 13 | [[128, 160], 'mbconv_k3_t6', ['mbconv_k7_t3', 'mbconv_k5_t6', 'mbconv_k7_t3'], 3, 2]| 14 | [[160, 176], 'mbconv_k3_t3', ['mbconv_k3_t6', 'mbconv_k3_t6', 'mbconv_k3_t6'], 3, 1]| 15 | [[176, 384], 'mbconv_k7_t6', [], 0, 1]| 16 | [[384, 1984], 'conv1_1']""" 17 | 18 | __C.train_params=AttrDict() 19 | __C.train_params.epochs=240 20 | __C.train_params.use_seed=True 21 | __C.train_params.seed=0 22 | 23 | __C.optim=AttrDict() 24 | __C.optim.init_lr=0.5 25 | __C.optim.min_lr=1e-5 26 | __C.optim.lr_schedule='cosine' # cosine poly 27 | __C.optim.momentum=0.9 28 | __C.optim.weight_decay=4e-5 29 | __C.optim.use_grad_clip=False 30 | __C.optim.grad_clip=10 31 | __C.optim.label_smooth=True 32 | __C.optim.smooth_alpha=0.1 33 | __C.optim.init_mode='he_fout' 34 | 35 | __C.optim.if_resume=False 36 | __C.optim.resume=AttrDict() 37 | __C.optim.resume.load_path='' 38 | __C.optim.resume.load_epoch=0 39 | 40 | __C.optim.use_warm_up=False 41 | __C.optim.warm_up=AttrDict() 42 | __C.optim.warm_up.epoch=5 43 | __C.optim.warm_up.init_lr=0.0001 44 | __C.optim.warm_up.target_lr=0.1 45 | 46 | __C.optim.use_multi_stage=False 47 | __C.optim.multi_stage=AttrDict() 48 | __C.optim.multi_stage.stage_epochs=330 49 | 50 | __C.optim.cosine=AttrDict() 51 | __C.optim.cosine.use_restart=False 52 | __C.optim.cosine.restart=AttrDict() 53 | __C.optim.cosine.restart.lr_period=[10, 20, 40, 80, 160, 320] 54 | __C.optim.cosine.restart.lr_step=[0, 10, 30, 70, 150, 310] 55 | 56 | __C.optim.bn_momentum=0.1 57 | __C.optim.bn_eps=0.001 58 | 59 | __C.data=AttrDict() 60 | __C.data.num_workers=16 61 | __C.data.batch_size=1024 62 | __C.data.dataset='imagenet' #imagenet 63 | __C.data.train_data_type='lmdb' 64 | __C.data.val_data_type='img' 65 | __C.data.patch_dataset=False 66 | # __C.data.num_examples=1281167 67 | __C.data.input_size=(3,224,224) 68 | __C.data.type_of_data_aug='random_sized' # random_sized / rand_scale 69 | __C.data.random_sized=AttrDict() 70 | __C.data.random_sized.min_scale=0.08 71 | __C.data.mean=[0.485, 0.456, 0.406] 72 | __C.data.std=[0.229, 0.224, 0.225] 73 | __C.data.color=False -------------------------------------------------------------------------------- /tools/collections.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | ############################################################################## 15 | 16 | """A simple attribute dictionary used for representing configuration options.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | from __future__ import unicode_literals 22 | 23 | 24 | class AttrDict(dict): 25 | 26 | IMMUTABLE = '__immutable__' 27 | 28 | def __init__(self, *args, **kwargs): 29 | super(AttrDict, self).__init__(*args, **kwargs) 30 | self.__dict__[AttrDict.IMMUTABLE] = False 31 | 32 | def __getattr__(self, name): 33 | if name in self.__dict__: 34 | return self.__dict__[name] 35 | elif name in self: 36 | return self[name] 37 | else: 38 | raise AttributeError(name) 39 | 40 | def __setattr__(self, name, value): 41 | if not self.__dict__[AttrDict.IMMUTABLE]: 42 | if name in self.__dict__: 43 | self.__dict__[name] = value 44 | else: 45 | self[name] = value 46 | else: 47 | raise AttributeError( 48 | 'Attempted to set "{}" to "{}", but AttrDict is immutable'. 49 | format(name, value) 50 | ) 51 | 52 | def immutable(self, is_immutable): 53 | """Set immutability to is_immutable and recursively apply the setting 54 | to all nested AttrDicts. 55 | """ 56 | self.__dict__[AttrDict.IMMUTABLE] = is_immutable 57 | # Recursively set immutable state 58 | for v in self.__dict__.values(): 59 | if isinstance(v, AttrDict): 60 | v.immutable(is_immutable) 61 | for v in self.values(): 62 | if isinstance(v, AttrDict): 63 | v.immutable(is_immutable) 64 | 65 | def is_immutable(self): 66 | return self.__dict__[AttrDict.IMMUTABLE] 67 | -------------------------------------------------------------------------------- /run_apis/latency_measure.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import importlib 3 | import logging 4 | import os 5 | import sys 6 | 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | 10 | from configs.imagenet_train_cfg import cfg 11 | from configs.search_config import search_cfg 12 | from tools import utils 13 | from tools.config_yaml import merge_cfg_from_file, update_cfg_from_cfg 14 | 15 | if __name__ == '__main__': 16 | 17 | parser = argparse.ArgumentParser("Params") 18 | parser.add_argument('--save', type=str, default='./', help='experiment name') 19 | parser.add_argument('--input_size', type=str, default='[32, 3, 224, 224]', help='data input size') 20 | parser.add_argument('--meas_times', type=int, default=5000, help='measure times') 21 | parser.add_argument('--list_name', type=str, default='', help='output list name') 22 | parser.add_argument('--device', choices=['gpu', 'cpu']) 23 | parser.add_argument('-c', '--config', metavar='C', default=None, help='The Configuration file') 24 | 25 | args = parser.parse_args() 26 | 27 | update_cfg_from_cfg(search_cfg, cfg) 28 | if args.config is not None: 29 | merge_cfg_from_file(args.config, cfg) 30 | config = cfg 31 | 32 | args.save = os.path.join(args.save, 'output') 33 | utils.create_exp_dir(args.save) 34 | 35 | args.input_size = eval(args.input_size) 36 | if len(args.input_size) != 4: 37 | raise ValueError('The batch size should be specified.') 38 | 39 | log_format = '%(asctime)s %(message)s' 40 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, 41 | format=log_format, datefmt='%m/%d %I:%M:%S %p') 42 | fh = logging.FileHandler(os.path.join(args.save, 'log.txt')) 43 | fh.setFormatter(logging.Formatter(log_format)) 44 | logging.getLogger().addHandler(fh) 45 | 46 | if not torch.cuda.is_available(): 47 | logging.info('no gpu device available') 48 | sys.exit(1) 49 | 50 | cudnn.benchmark = True 51 | cudnn.enabled = True 52 | 53 | SearchSpace = importlib.import_module('models.search_space_'+config.net_type).Network 54 | super_model = SearchSpace(config.optim.init_dim, config.data.dataset, config) 55 | 56 | super_model.eval() 57 | logging.info("Params = %fMB" % utils.count_parameters_in_MB(super_model)) 58 | 59 | if args.device == 'gpu': 60 | super_model = super_model.cuda() 61 | 62 | latency_list, total_latency = super_model.get_cost_list( 63 | args.input_size, cost_type='latency', 64 | use_gpu = (args.device == 'gpu'), 65 | meas_times = args.meas_times) 66 | 67 | logging.info('latency_list:\n' + str(latency_list)) 68 | logging.info('total latency: ' + str(total_latency) + 'ms') 69 | 70 | with open(os.path.join(args.save, args.list_name), 'w') as f: 71 | f.write(str(latency_list)) 72 | -------------------------------------------------------------------------------- /run_apis/validation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import pprint 5 | import sys 6 | import time 7 | 8 | import torch 9 | import torch.backends.cudnn as cudnn 10 | import torch.nn as nn 11 | 12 | from configs.imagenet_val_cfg import cfg 13 | from dataset import imagenet_data 14 | from models import model_derived 15 | from tools import utils 16 | from tools.multadds_count import comp_multadds 17 | 18 | from .trainer import Trainer 19 | 20 | if __name__ == '__main__': 21 | 22 | parser = argparse.ArgumentParser("Params") 23 | parser.add_argument('--report_freq', type=float, default=50, help='report frequency') 24 | parser.add_argument('--data_path', type=str, default='../data', help='location of the dataset') 25 | parser.add_argument('--load_path', type=str, default='./model_path', help='model loading path') 26 | parser.add_argument('--save', type=str, default='./', help='the path of output') 27 | 28 | args = parser.parse_args() 29 | config = cfg 30 | 31 | args.save = os.path.join(args.save, 'output') 32 | utils.create_exp_dir(args.save) 33 | 34 | log_format = '%(asctime)s %(message)s' 35 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, 36 | format=log_format, datefmt='%m/%d %I:%M:%S %p') 37 | fh = logging.FileHandler(os.path.join(args.save, 'log.txt')) 38 | fh.setFormatter(logging.Formatter(log_format)) 39 | logging.getLogger().addHandler(fh) 40 | 41 | if not torch.cuda.is_available(): 42 | logging.info('no gpu device available') 43 | sys.exit(1) 44 | 45 | cudnn.benchmark = True 46 | cudnn.enabled = True 47 | 48 | logging.info("args = %s", args) 49 | logging.info('Training with config:') 50 | logging.info(pprint.pformat(config)) 51 | 52 | config.net_config, net_type = utils.load_net_config(os.path.join(args.load_path, 'net_config')) 53 | 54 | derivedNetwork = getattr(model_derived, '%s_Net' % net_type.upper()) 55 | model = derivedNetwork(config.net_config, config=config) 56 | 57 | logging.info("Network Structure: \n" + '\n'.join(map(str, model.net_config))) 58 | logging.info("Params = %.2fMB" % utils.count_parameters_in_MB(model)) 59 | logging.info("Mult-Adds = %.2fMB" % comp_multadds(model, input_size=config.data.input_size)) 60 | 61 | model = model.cuda() 62 | model = nn.DataParallel(model) 63 | utils.load_model(model, os.path.join(args.load_path, 'weights.pt')) 64 | 65 | imagenet = imagenet_data.ImageNet12(trainFolder=os.path.join(args.data_path, 'train'), 66 | testFolder=os.path.join(args.data_path, 'val'), 67 | num_workers=config.data.num_workers, 68 | data_config=config.data) 69 | valid_queue = imagenet.getTestLoader(config.data.batch_size) 70 | trainer = Trainer(None, valid_queue, None, None, 71 | None, config, args.report_freq) 72 | 73 | with torch.no_grad(): 74 | val_acc_top1, val_acc_top5, valid_obj, batch_time = trainer.infer(model) 75 | -------------------------------------------------------------------------------- /dataset/img2lmdb.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | import sys 5 | 6 | import cv2 7 | import lmdb 8 | import msgpack 9 | import numpy as np 10 | from PIL import Image 11 | from tqdm import tqdm 12 | 13 | image_size=None 14 | 15 | class Datum(object): 16 | def __init__(self, shape=None, image=None, label=None): 17 | self.shape = shape 18 | self.image = image 19 | self.label = label 20 | 21 | def SerializeToString(self, img=None): 22 | image_data = self.image.astype(np.uint8).tobytes() 23 | label_data = np.uint16(self.label).tobytes() 24 | return msgpack.packb(image_data+label_data, use_bin_type=True) 25 | 26 | def ParseFromString(self, raw_data, orig_img): 27 | raw_data = msgpack.unpackb(raw_data, raw=False) 28 | raw_img_data = raw_data[:-2] 29 | # share the memory of data while fromstring copy one 30 | image_data = np.frombuffer(raw_img_data, dtype=np.uint8) 31 | self.image = cv2.imdecode(image_data, cv2.IMREAD_COLOR) 32 | 33 | raw_label_data = raw_data[-2:] 34 | self.label = np.frombuffer(raw_label_data, dtype=np.uint16) 35 | 36 | 37 | def create_dataset(output_path, image_folder, image_list, image_size): 38 | image_name_list = [i.strip() for i in open(image_list)] 39 | n_samples = len(image_name_list) 40 | env = lmdb.open(output_path, map_size=1099511627776, meminit=False, map_async=True) # 1TB 41 | 42 | txn = env.begin(write=True) 43 | classes = [d for d in os.listdir(image_folder) if os.path.isdir(os.path.join(image_folder, d))] 44 | for idx, image_name in enumerate(tqdm(image_name_list)): 45 | image_path = os.path.join(image_folder, image_name) 46 | label_name = image_name.split('/')[0] 47 | label = classes.index(label_name) 48 | if not os.path.isfile(image_path): 49 | raise RuntimeError('%s does not exist' % image_path) 50 | 51 | img = cv2.imread(image_path) 52 | img_orig = img 53 | 54 | if image_size: 55 | resize_ratio = float(image_size)/min(img.shape[0:2]) 56 | new_size = (int(img.shape[1]*resize_ratio), int(img.shape[0]*resize_ratio)) #inverse order for cv2 57 | img = cv2.resize(src=img, dsize=new_size) 58 | img = cv2.imencode('.JPEG', img)[1] 59 | 60 | image = np.asarray(img) 61 | datum = Datum(image.shape, image, label) 62 | txn.put(image_name.encode('ascii'), datum.SerializeToString()) 63 | 64 | if (idx + 1) % 1000 == 0: 65 | txn.commit() 66 | txn = env.begin(write=True) 67 | txn.commit() 68 | env.sync() 69 | env.close() 70 | 71 | print(f'Created dataset with {n_samples:d} samples') 72 | 73 | 74 | if __name__ == '__main__': 75 | parser = argparse.ArgumentParser("Params") 76 | parser.add_argument('--image_size', type=int, default=None, help='the size of the image u want to pack') 77 | parser.add_argument('--image_path', type=str, default='', help='the path of the images') 78 | parser.add_argument('--list_path', type=str, default='', help='the path of the image list') 79 | parser.add_argument('--output_path', type=str, default='', help='the output path of the lmdb file') 80 | parser.add_argument('--split', type=str, default='', help='the split path of the images: train / val') 81 | args = parser.parse_args() 82 | 83 | image_size = args.image_size if args.image_size else image_size 84 | image_folder = osp.join(args.image_path, args.split) 85 | image_list = osp.join(args.list_path, '{}_datalist'.format(args.split)) 86 | 87 | if image_size: 88 | output_path = osp.join(args.output_path, '{}_{}'.format(args.split, image_size)) 89 | else: 90 | output_path = osp.join(args.output_path, args.split) 91 | 92 | create_dataset(output_path, image_folder, image_list, image_size) 93 | -------------------------------------------------------------------------------- /dataset/prefetch_data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import time 4 | 5 | class data_prefetcher(): 6 | def __init__(self, loader, mean=None, std=None, is_cutout=False, cutout_length=16): 7 | self.loader = iter(loader) 8 | self.stream = torch.cuda.Stream() 9 | if mean is None: 10 | self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1) 11 | else: 12 | self.mean = torch.tensor([m * 255 for m in mean]).cuda().view(1,3,1,1) 13 | if std is None: 14 | self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1) 15 | else: 16 | self.std = torch.tensor([s * 255 for s in std]).cuda().view(1,3,1,1) 17 | self.is_cutout = is_cutout 18 | self.cutout_length = cutout_length 19 | self.preload() 20 | 21 | def normalize(self, data): 22 | data = data.float() 23 | data = data.sub_(self.mean).div_(self.std) 24 | return data 25 | 26 | def cutout(self, data): 27 | batch_size, h, w = data.shape[0], data.shape[2], data.shape[3] 28 | mask = torch.ones(batch_size, h, w).cuda() 29 | y = torch.randint(low=0, high=h, size=(batch_size,)) 30 | x = torch.randint(low=0, high=w, size=(batch_size,)) 31 | 32 | y1 = torch.clamp(y - self.cutout_length // 2, 0, h) 33 | y2 = torch.clamp(y + self.cutout_length // 2, 0, h) 34 | x1 = torch.clamp(x - self.cutout_length // 2, 0, w) 35 | x2 = torch.clamp(x + self.cutout_length // 2, 0, w) 36 | for i in range(batch_size): 37 | mask[i][y1[i]: y2[i], x1[i]: x2[i]] = 0. 38 | mask = mask.expand_as(data.transpose(0,1)).transpose(0,1) 39 | data *= mask 40 | return data 41 | 42 | def preload(self): 43 | try: 44 | self.next_input, self.next_target = next(self.loader) 45 | except StopIteration: 46 | self.next_input = None 47 | self.next_target = None 48 | return 49 | # if record_stream() doesn't work, another option is to make sure device inputs are created 50 | # on the main stream. 51 | # self.next_input_gpu = torch.empty_like(self.next_input, device='cuda') 52 | # self.next_target_gpu = torch.empty_like(self.next_target, device='cuda') 53 | # Need to make sure the memory allocated for next_* is not still in use by the main stream 54 | # at the time we start copying to next_*: 55 | # self.stream.wait_stream(torch.cuda.current_stream()) 56 | with torch.cuda.stream(self.stream): 57 | self.next_input = self.next_input.cuda(non_blocking=True) 58 | self.next_target = self.next_target.cuda(non_blocking=True) 59 | self.next_input = self.normalize(self.next_input) 60 | if self.is_cutout: 61 | self.next_input = self.cutout(self.next_input) 62 | 63 | def next(self): 64 | torch.cuda.current_stream().wait_stream(self.stream) 65 | input = self.next_input 66 | target = self.next_target 67 | if input is not None: 68 | input.record_stream(torch.cuda.current_stream()) 69 | if target is not None: 70 | target.record_stream(torch.cuda.current_stream()) 71 | self.preload() 72 | return input, target 73 | 74 | 75 | def fast_collate(batch): 76 | imgs = [img[0] for img in batch] 77 | targets = torch.tensor([target[1] for target in batch], dtype=torch.int64) 78 | w = imgs[0].size[0] 79 | h = imgs[0].size[1] 80 | tensor = torch.zeros((len(imgs), 3, h, w), dtype=torch.uint8) 81 | for i, img in enumerate(imgs): 82 | nump_array = np.asarray(img, dtype=np.uint8) 83 | if(nump_array.ndim < 3): 84 | nump_array = np.expand_dims(nump_array, axis=-1) 85 | nump_array = np.rollaxis(nump_array, 2) 86 | 87 | tensor[i] += torch.from_numpy(nump_array) 88 | 89 | return tensor, targets 90 | -------------------------------------------------------------------------------- /tools/io.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | ############################################################################## 15 | 16 | """IO utilities.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | from __future__ import unicode_literals 22 | 23 | import pickle 24 | import hashlib 25 | import logging 26 | import os 27 | import re 28 | import sys 29 | import urllib 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | _DETECTRON_S3_BASE_URL = 'https://s3-us-west-2.amazonaws.com/detectron' 34 | 35 | 36 | def save_object(obj, file_name): 37 | """Save a Python object by pickling it.""" 38 | file_name = os.path.abspath(file_name) 39 | with open(file_name, 'wb') as f: 40 | pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) 41 | 42 | 43 | def cache_url(url_or_file, cache_dir): 44 | """Download the file specified by the URL to the cache_dir and return the 45 | path to the cached file. If the argument is not a URL, simply return it as 46 | is. 47 | """ 48 | is_url = re.match(r'^(?:http)s?://', url_or_file, re.IGNORECASE) is not None 49 | 50 | if not is_url: 51 | return url_or_file 52 | 53 | url = url_or_file 54 | assert url.startswith(_DETECTRON_S3_BASE_URL), \ 55 | ('Detectron only automatically caches URLs in the Detectron S3 ' 56 | 'bucket: {}').format(_DETECTRON_S3_BASE_URL) 57 | 58 | cache_file_path = url.replace(_DETECTRON_S3_BASE_URL, cache_dir) 59 | if os.path.exists(cache_file_path): 60 | assert_cache_file_is_ok(url, cache_file_path) 61 | return cache_file_path 62 | 63 | cache_file_dir = os.path.dirname(cache_file_path) 64 | if not os.path.exists(cache_file_dir): 65 | os.makedirs(cache_file_dir) 66 | 67 | logger.info('Downloading remote file {} to {}'.format(url, cache_file_path)) 68 | download_url(url, cache_file_path) 69 | assert_cache_file_is_ok(url, cache_file_path) 70 | return cache_file_path 71 | 72 | 73 | def assert_cache_file_is_ok(url, file_path): 74 | """Check that cache file has the correct hash.""" 75 | # File is already in the cache, verify that the md5sum matches and 76 | # return local path 77 | cache_file_md5sum = _get_file_md5sum(file_path) 78 | ref_md5sum = _get_reference_md5sum(url) 79 | assert cache_file_md5sum == ref_md5sum, \ 80 | ('Target URL {} appears to be downloaded to the local cache file ' 81 | '{}, but the md5 hash of the local file does not match the ' 82 | 'reference (actual: {} vs. expected: {}). You may wish to delete ' 83 | 'the cached file and try again to trigger automatic ' 84 | 'download.').format(url, file_path, cache_file_md5sum, ref_md5sum) 85 | 86 | 87 | def _progress_bar(count, total): 88 | """Report download progress. 89 | Credit: 90 | https://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console/27871113 91 | """ 92 | bar_len = 60 93 | filled_len = int(round(bar_len * count / float(total))) 94 | 95 | percents = round(100.0 * count / float(total), 1) 96 | bar = '=' * filled_len + '-' * (bar_len - filled_len) 97 | 98 | sys.stdout.write( 99 | ' [{}] {}% of {:.1f}MB file \r'. 100 | format(bar, percents, total / 1024 / 1024) 101 | ) 102 | sys.stdout.flush() 103 | if count >= total: 104 | sys.stdout.write('\n') 105 | 106 | 107 | def download_url( 108 | url, dst_file_path, chunk_size=8192, progress_hook=_progress_bar 109 | ): 110 | """Download url and write it to dst_file_path. 111 | Credit: 112 | https://stackoverflow.com/questions/2028517/python-urllib2-progress-hook 113 | """ 114 | response = urllib.request.urlopen(url) 115 | total_size = response.info().getheader('Content-Length').strip() 116 | total_size = int(total_size) 117 | bytes_so_far = 0 118 | 119 | with open(dst_file_path, 'wb') as f: 120 | while 1: 121 | chunk = response.read(chunk_size) 122 | bytes_so_far += len(chunk) 123 | if not chunk: 124 | break 125 | if progress_hook: 126 | progress_hook(bytes_so_far, total_size) 127 | f.write(chunk) 128 | 129 | return bytes_so_far 130 | 131 | 132 | def _get_file_md5sum(file_name): 133 | """Compute the md5 hash of a file.""" 134 | hash_obj = hashlib.md5() 135 | with open(file_name, 'r') as f: 136 | hash_obj.update(f.read()) 137 | return hash_obj.hexdigest() 138 | 139 | 140 | def _get_reference_md5sum(url): 141 | """By convention the md5 hash for url is stored in url + '.md5sum'.""" 142 | url_md5sum = url + '.md5sum' 143 | md5sum = urllib.request.urlopen(url_md5sum).read().strip() 144 | return md5sum 145 | -------------------------------------------------------------------------------- /dataset/lmdb_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import msgpack 5 | import numpy as np 6 | import torch.utils.data as data 7 | from PIL import Image 8 | 9 | import lmdb 10 | 11 | 12 | class Datum(object): 13 | def __init__(self, shape=None, image=None, label=None): 14 | self.shape = shape 15 | self.image = image 16 | self.label = label 17 | 18 | def SerializeToString(self): 19 | image_data = self.image.astype(np.uint8).tobytes() 20 | label_data = np.uint16(self.label).tobytes() 21 | return msgpack.packb(image_data+label_data, use_bin_type=True) 22 | 23 | def ParseFromString(self, raw_data): 24 | raw_data = msgpack.unpackb(raw_data, raw=False) 25 | raw_img_data = raw_data[:-2] 26 | image_data = np.frombuffer(raw_img_data, dtype=np.uint8) #share the memory of data while from string copy one 27 | self.image = cv2.imdecode(image_data, cv2.IMREAD_COLOR) 28 | 29 | raw_label_data = raw_data[-2:] 30 | self.label = np.frombuffer(raw_label_data, dtype=np.uint16) 31 | 32 | 33 | class DatasetFolder(data.Dataset): 34 | """ 35 | Args: 36 | root (string): Root directory path. 37 | transform (callable, optional): A function/transform that takes in 38 | a sample and returns a transformed version. 39 | E.g, ``transforms.RandomCrop`` for images. 40 | target_transform (callable, optional): A function/transform that takes 41 | in the target and transforms it. 42 | 43 | Attributes: 44 | samples (list): List of (sample path, class_index) tuples 45 | """ 46 | 47 | def __init__(self, root, list_path, transform=None, target_transform=None, patch_dataset=False): 48 | self.root = root 49 | self.patch_dataset = patch_dataset 50 | 51 | if patch_dataset: 52 | self.txn = [] 53 | for path in os.listdir(root): 54 | lmdb_path = os.path.join(root, path) 55 | if os.path.isdir(lmdb_path): 56 | env = lmdb.open(lmdb_path, 57 | readonly=True, 58 | lock=False, 59 | readahead=False, 60 | meminit=False) 61 | txn = env.begin(write=False) 62 | self.txn.append(txn) 63 | 64 | else: 65 | self.env = lmdb.open(root, 66 | readonly=True, 67 | lock=False, 68 | readahead=False, 69 | meminit=False) 70 | self.txn = self.env.begin(write=False) 71 | 72 | self.list_path = list_path 73 | self.samples = [image_name.strip() for image_name in open(list_path)] 74 | 75 | if len(self.samples) == 0: 76 | raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n")) 77 | 78 | self.transform = transform 79 | self.target_transform = target_transform 80 | 81 | def __getitem__(self, index): 82 | """ 83 | Args: 84 | index (int): Index 85 | 86 | Returns: 87 | tuple: (sample, target) where target is class_index of the target class. 88 | """ 89 | 90 | img_name = self.samples[index] 91 | 92 | if self.patch_dataset: 93 | txn_index = index // (len(self.samples) // 10) 94 | if txn_index==10: 95 | txn_index = 9 96 | txn = self.txn[txn_index] 97 | else: 98 | txn = self.txn 99 | 100 | datum = Datum() 101 | data_bin = txn.get(img_name.encode('ascii')) 102 | if data_bin is None: 103 | raise RuntimeError(f'Key {img_name} not found') 104 | datum.ParseFromString(data_bin) 105 | 106 | sample = Image.fromarray(cv2.cvtColor(datum.image, cv2.COLOR_BGR2RGB)) 107 | target = np.int(datum.label) 108 | 109 | if self.transform is not None: 110 | sample = self.transform(sample) 111 | if self.target_transform is not None: 112 | target = self.target_transform(target) 113 | 114 | return sample, target 115 | 116 | def __len__(self): 117 | return len(self.samples) 118 | 119 | def __repr__(self): 120 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 121 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 122 | fmt_str += ' Root Location: {}\n'.format(self.root) 123 | tmp = ' Transforms (if any): ' 124 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 125 | tmp = ' Target Transforms (if any): ' 126 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 127 | return fmt_str 128 | 129 | 130 | class ImageFolder(DatasetFolder): 131 | def __init__(self, root, list_path, transform=None, target_transform=None, patch_dataset=False): 132 | super(ImageFolder, self).__init__(root, list_path, 133 | transform=transform, 134 | target_transform=target_transform, 135 | patch_dataset=patch_dataset) 136 | self.imgs = self.samples 137 | -------------------------------------------------------------------------------- /tools/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.lr_scheduler import CosineAnnealingLR 3 | from torch.optim.optimizer import Optimizer 4 | import math 5 | 6 | class CosineRestartAnnealingLR(object): 7 | # decay as step 8 | # T_max refers to the max update step 9 | 10 | def __init__(self, optimizer, T_max, lr_period, lr_step, eta_min=0, last_step=-1, 11 | use_warmup=False, warmup_mode='linear', warmup_steps=0, warmup_startlr=0, 12 | warmup_targetlr=0, use_restart=False): 13 | 14 | self.use_warmup = use_warmup 15 | self.warmup_mode = warmup_mode 16 | self.warmup_steps = warmup_steps 17 | self.warmup_startlr = warmup_startlr 18 | self.warmup_targetlr = warmup_targetlr 19 | self.use_restart = use_restart 20 | self.T_max = T_max 21 | self.eta_min = eta_min 22 | 23 | if self.use_restart == False: 24 | self.lr_period = [self.T_max - self.warmup_steps] 25 | self.lr_step = [self.warmup_steps] 26 | else: 27 | self.lr_period = lr_period 28 | self.lr_step = lr_step 29 | 30 | self.last_step = last_step 31 | self.cycle_length = self.lr_period[0] 32 | self.cur = 0 33 | 34 | if not isinstance(optimizer, Optimizer): 35 | raise TypeError('{} is not an Optimizer'.format( 36 | type(optimizer).__name__)) 37 | self.optimizer = optimizer 38 | if last_step == -1: 39 | for group in optimizer.param_groups: 40 | group.setdefault('initial_lr', group['lr']) 41 | else: 42 | for i, group in enumerate(optimizer.param_groups): 43 | if 'initial_lr' not in group: 44 | raise KeyError("param 'initial_lr' is not specified " 45 | "in param_groups[{}] when resuming an optimizer".format(i)) 46 | self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups)) 47 | 48 | 49 | def step(self, step=None): 50 | 51 | if step is not None: 52 | self.last_step = step 53 | else: 54 | self.last_step += 1 55 | 56 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 57 | param_group['lr'] = lr 58 | 59 | 60 | def get_lr(self): 61 | 62 | lrs = [] 63 | for base_lr in self.base_lrs: 64 | if self.use_warmup and self.last_step < self.warmup_steps: 65 | if self.warmup_mode == 'constant': 66 | lrs.append(self.warmup_startlr) 67 | elif self.warmup_mode =='linear': 68 | cur_lr = self.warmup_startlr + \ 69 | float(self.warmup_targetlr-self.warmup_startlr)/self.warmup_steps*self.last_step 70 | lrs.append(cur_lr) 71 | else: 72 | raise NotImplementedError 73 | 74 | else: 75 | if (self.last_step) in self.lr_step: 76 | self.cycle_length = self.lr_period[self.lr_step.index(self.last_step)] 77 | self.cur = self.last_step 78 | 79 | peri_iter = self.last_step-self.cur 80 | 81 | if peri_iter <= self.cycle_length: 82 | unit_cycle = (1 + math.cos(peri_iter * math.pi / self.cycle_length)) / 2 83 | adjusted_cycle = unit_cycle * (base_lr - self.eta_min) + self.eta_min 84 | lrs.append(adjusted_cycle) 85 | else: 86 | lrs.append(self.eta_min) 87 | 88 | return lrs 89 | 90 | 91 | def display_lr_curve(self, total_steps): 92 | lrs = [] 93 | for _ in range(total_steps): 94 | self.step() 95 | lrs.append(self.get_lr()[0]) 96 | import matplotlib.pyplot as plt 97 | plt.plot(lrs) 98 | plt.show() 99 | 100 | 101 | def get_lr_scheduler(config, optimizer, num_examples=None): 102 | 103 | if num_examples is None: 104 | num_examples = config.data.num_examples 105 | epoch_steps = num_examples // config.data.batch_size + 1 106 | 107 | if config.optim.use_multi_stage: 108 | max_steps = epoch_steps * config.optim.multi_stage.stage_epochs 109 | else: 110 | max_steps = epoch_steps * config.train_params.epochs 111 | 112 | period_steps = [epoch_steps * x for x in config.optim.cosine.restart.lr_period] 113 | step_steps = [epoch_steps * x for x in config.optim.cosine.restart.lr_step] 114 | 115 | init_lr = config.optim.init_lr 116 | 117 | use_warmup = config.optim.use_warm_up 118 | if use_warmup: 119 | warmup_steps = config.optim.warm_up.epoch * epoch_steps 120 | warmup_startlr = config.optim.warm_up.init_lr 121 | warmup_targetlr = config.optim.warm_up.target_lr 122 | else: 123 | warmup_steps = 0 124 | warmup_startlr = init_lr 125 | warmup_targetlr = init_lr 126 | 127 | if config.optim.lr_schedule == 'cosine': 128 | scheduler = CosineRestartAnnealingLR(optimizer, 129 | float(max_steps), 130 | period_steps, 131 | step_steps, 132 | eta_min=config.optim.min_lr, 133 | use_warmup=use_warmup, 134 | warmup_steps=warmup_steps, 135 | warmup_startlr=warmup_startlr, 136 | warmup_targetlr=warmup_targetlr, 137 | use_restart=config.optim.cosine.use_restart) 138 | # scheduler = CosineAnnealingLR(optimizer, config.train_params.epochs, config.optim.min_lr) 139 | elif config.optim.lr_schedule == 'poly': 140 | raise NotImplementedError 141 | else: 142 | raise NotImplementedError 143 | 144 | return scheduler 145 | 146 | -------------------------------------------------------------------------------- /run_apis/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from models.dropped_model import Dropped_Network 5 | 6 | 7 | class Optimizer(object): 8 | 9 | def __init__(self, model, criterion, config): 10 | self.config = config 11 | self.weight_sample_num = self.config.search_params.weight_sample_num 12 | self.criterion = criterion 13 | self.Dropped_Network = lambda model: Dropped_Network( 14 | model, softmax_temp=config.search_params.softmax_temp) 15 | 16 | arch_params_id = list(map(id, model.module.arch_parameters)) 17 | weight_params = filter(lambda p: id(p) not in arch_params_id, model.parameters()) 18 | 19 | self.weight_optimizer = torch.optim.SGD( 20 | weight_params, 21 | config.optim.weight.init_lr, 22 | momentum=config.optim.weight.momentum, 23 | weight_decay=config.optim.weight.weight_decay) 24 | 25 | self.arch_optimizer = torch.optim.Adam( 26 | [{'params': model.module.arch_alpha_params, 'lr': config.optim.arch.alpha_lr}, 27 | {'params': model.module.arch_beta_params, 'lr': config.optim.arch.beta_lr}], 28 | betas=(0.5, 0.999), 29 | weight_decay=config.optim.arch.weight_decay) 30 | 31 | 32 | def arch_step(self, input_valid, target_valid, model, search_stage): 33 | head_sampled_w_old, alpha_head_index = \ 34 | model.module.sample_branch('head', 2, search_stage= search_stage) 35 | stack_sampled_w_old, alpha_stack_index = \ 36 | model.module.sample_branch('stack', 2, search_stage= search_stage) 37 | self.arch_optimizer.zero_grad() 38 | 39 | dropped_model = nn.DataParallel(self.Dropped_Network(model)) 40 | logits, sub_obj = dropped_model(input_valid) 41 | sub_obj = torch.mean(sub_obj) 42 | loss = self.criterion(logits, target_valid) 43 | if self.config.optim.if_sub_obj: 44 | loss_sub_obj = torch.log(sub_obj) / torch.log(torch.tensor(self.config.optim.sub_obj.log_base)) 45 | sub_loss_factor = self.config.optim.sub_obj.sub_loss_factor 46 | loss += loss_sub_obj * sub_loss_factor 47 | loss.backward() 48 | self.arch_optimizer.step() 49 | 50 | self.rescale_arch_params(head_sampled_w_old, 51 | stack_sampled_w_old, 52 | alpha_head_index, 53 | alpha_stack_index, 54 | model) 55 | return logits.detach(), loss.item(), sub_obj.item() 56 | 57 | 58 | def weight_step(self, *args, **kwargs): 59 | return self.weight_step_(*args, **kwargs) 60 | 61 | 62 | def weight_step_(self, input_train, target_train, model, search_stage): 63 | _, _ = model.module.sample_branch('head', self.weight_sample_num, search_stage=search_stage) 64 | _, _ = model.module.sample_branch('stack', self.weight_sample_num, search_stage=search_stage) 65 | 66 | self.weight_optimizer.zero_grad() 67 | dropped_model = nn.DataParallel(self.Dropped_Network(model)) 68 | logits, sub_obj = dropped_model(input_train) 69 | sub_obj = torch.mean(sub_obj) 70 | loss = self.criterion(logits, target_train) 71 | loss.backward() 72 | self.weight_optimizer.step() 73 | 74 | return logits.detach(), loss.item(), sub_obj.item() 75 | 76 | 77 | def valid_step(self, input_valid, target_valid, model): 78 | _, _ = model.module.sample_branch('head', 1, training=False) 79 | _, _ = model.module.sample_branch('stack', 1, training=False) 80 | 81 | dropped_model = nn.DataParallel(self.Dropped_Network(model)) 82 | logits, sub_obj = dropped_model(input_valid) 83 | sub_obj = torch.mean(sub_obj) 84 | loss = self.criterion(logits, target_valid) 85 | 86 | return logits, loss.item(), sub_obj.item() 87 | 88 | 89 | def rescale_arch_params(self, alpha_head_weights_drop, 90 | alpha_stack_weights_drop, 91 | alpha_head_index, 92 | alpha_stack_index, 93 | model): 94 | 95 | def comp_rescale_value(old_weights, new_weights, index): 96 | old_exp_sum = old_weights.exp().sum() 97 | new_drop_arch_params = torch.gather(new_weights, dim=-1, index=index) 98 | new_exp_sum = new_drop_arch_params.exp().sum() 99 | rescale_value = torch.log(old_exp_sum / new_exp_sum) 100 | rescale_mat = torch.zeros_like(new_weights).scatter_(0, index, rescale_value) 101 | return rescale_value, rescale_mat 102 | 103 | def rescale_params(old_weights, new_weights, indices): 104 | for i, (old_weights_block, indices_block) in enumerate(zip(old_weights, indices)): 105 | for j, (old_weights_branch, indices_branch) in enumerate( 106 | zip(old_weights_block, indices_block)): 107 | rescale_value, rescale_mat = comp_rescale_value(old_weights_branch, 108 | new_weights[i][j], 109 | indices_branch) 110 | new_weights[i][j].data.add_(rescale_mat) 111 | 112 | # rescale the arch params for head layers 113 | rescale_params(alpha_head_weights_drop, model.module.alpha_head_weights, alpha_head_index) 114 | # rescale the arch params for stack layers 115 | rescale_params(alpha_stack_weights_drop, model.module.alpha_stack_weights, alpha_stack_index) 116 | 117 | 118 | def set_param_grad_state(self, stage): 119 | def set_grad_state(params, state): 120 | for group in params: 121 | for param in group['params']: 122 | param.requires_grad_(state) 123 | if stage == 'Arch': 124 | state_list = [True, False] # [arch, weight] 125 | elif stage == 'Weights': 126 | state_list = [False, True] 127 | else: 128 | state_list = [False, False] 129 | set_grad_state(self.arch_optimizer.param_groups, state_list[0]) 130 | set_grad_state(self.weight_optimizer.param_groups, state_list[1]) 131 | -------------------------------------------------------------------------------- /tools/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import shutil 4 | import sys 5 | import time 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | class AverageMeter(object): 13 | def __init__(self): 14 | self.reset() 15 | 16 | def reset(self): 17 | self.avg = 0 18 | self.sum = 0 19 | self.cnt = 0 20 | 21 | def update(self, val, n=1): 22 | self.cur = val 23 | self.sum += val * n 24 | self.cnt += n 25 | self.avg = self.sum / self.cnt 26 | 27 | 28 | def accuracy(output, target, topk=(1, 5)): 29 | maxk = max(topk) 30 | batch_size = target.size(0) 31 | 32 | _, pred = output.topk(maxk, 1, True, True) 33 | pred = pred.t() 34 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 35 | 36 | res = [] 37 | for k in topk: 38 | correct_k = correct[:k].view(-1).float().sum(0) 39 | res.append(correct_k.mul_(100.0/batch_size)) 40 | return res 41 | 42 | 43 | def count_parameters_in_MB(model): 44 | return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "aux" not in name)/1e6 45 | 46 | 47 | def save_checkpoint(state, is_best, save): 48 | filename = os.path.join(save, 'checkpoint.pth.tar') 49 | torch.save(state, filename) 50 | if is_best: 51 | best_filename = os.path.join(save, 'model_best.pth.tar') 52 | shutil.copyfile(filename, best_filename) 53 | 54 | 55 | def save(model, model_path): 56 | torch.save(model.state_dict(), model_path) 57 | 58 | 59 | def load_net_config(path): 60 | with open(path, 'r') as f: 61 | net_config = '' 62 | while True: 63 | line = f.readline().strip() 64 | if 'net_type' in line: 65 | net_type = line.split(': ')[-1] 66 | break 67 | else: 68 | net_config += line 69 | return net_config, net_type 70 | 71 | 72 | def load_model(model, model_path): 73 | logging.info('Start loading the model from ' + model_path) 74 | if 'http' in model_path: 75 | model_addr = model_path 76 | model_path = model_path.split('/')[-1] 77 | if os.path.isfile(model_path): 78 | os.system('rm ' + model_path) 79 | os.system('wget -q ' + model_addr) 80 | model.load_state_dict(torch.load(model_path)) 81 | logging.info('Loading the model finished!') 82 | 83 | 84 | def create_exp_dir(path): 85 | if not os.path.exists(path): 86 | os.mkdir(path) 87 | print('Experiment dir : {}'.format(path)) 88 | 89 | 90 | def cross_entropy_with_label_smoothing(pred, target, label_smoothing=0.): 91 | """ 92 | Label smoothing implementation. 93 | This function is taken from https://github.com/MIT-HAN-LAB/ProxylessNAS/blob/master/proxyless_nas/utils.py 94 | """ 95 | 96 | logsoftmax = nn.LogSoftmax().cuda() 97 | n_classes = pred.size(1) 98 | # convert to one-hot 99 | target = torch.unsqueeze(target, 1) 100 | soft_target = torch.zeros_like(pred) 101 | soft_target.scatter_(1, target, 1) 102 | # label smoothing 103 | soft_target = soft_target * (1 - label_smoothing) + label_smoothing / n_classes 104 | return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1)) 105 | 106 | 107 | def parse_net_config(net_config): 108 | str_configs = net_config.split('|') 109 | return [eval(str_config) for str_config in str_configs] 110 | 111 | 112 | def set_seed(seed): 113 | np.random.seed(seed) 114 | torch.manual_seed(seed) 115 | torch.cuda.manual_seed(seed) 116 | 117 | 118 | def set_logging(save_path, log_name='log.txt'): 119 | log_format = '%(asctime)s %(message)s' 120 | date_format = '%m/%d %H:%M:%S' 121 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, 122 | format=log_format, datefmt=date_format) 123 | fh = logging.FileHandler(os.path.join(save_path, log_name)) 124 | fh.setFormatter(logging.Formatter(log_format, date_format)) 125 | logging.getLogger().addHandler(fh) 126 | 127 | 128 | def create_save_dir(save_path, job_name): 129 | if job_name != '': 130 | job_name = time.strftime("%Y%m%d-%H%M%S-") + job_name 131 | save_path = os.path.join(save_path, job_name) 132 | create_exp_dir(save_path) 133 | os.system('cp -r ./* '+save_path) 134 | save_path = os.path.join(save_path, 'output') 135 | create_exp_dir(save_path) 136 | else: 137 | save_path = os.path.join(save_path, 'output') 138 | create_exp_dir(save_path) 139 | return save_path, job_name 140 | 141 | 142 | def latency_measure(module, input_size, batch_size, meas_times, mode='gpu'): 143 | assert mode in ['gpu', 'cpu'] 144 | 145 | latency = [] 146 | module.eval() 147 | input_size = (batch_size,) + tuple(input_size) 148 | input_data = torch.randn(input_size) 149 | if mode=='gpu': 150 | input_data = input_data.cuda() 151 | module.cuda() 152 | 153 | for i in range(meas_times): 154 | with torch.no_grad(): 155 | start = time.time() 156 | _ = module(input_data) 157 | torch.cuda.synchronize() 158 | if i >= 100: 159 | latency.append(time.time() - start) 160 | print(np.mean(latency) * 1e3, 'ms') 161 | return np.mean(latency) * 1e3 162 | 163 | 164 | def latency_measure_fw(module, input_data, meas_times): 165 | latency = [] 166 | module.eval() 167 | 168 | for i in range(meas_times): 169 | with torch.no_grad(): 170 | start = time.time() 171 | output_data = module(input_data) 172 | torch.cuda.synchronize() 173 | if i >= 100: 174 | latency.append(time.time() - start) 175 | print(np.mean(latency) * 1e3, 'ms') 176 | return np.mean(latency) * 1e3, output_data 177 | 178 | 179 | def record_topk(k, rec_list, data, comp_attr, check_attr): 180 | def get_insert_idx(orig_list, data, comp_attr): 181 | start = 0 182 | end = len(orig_list) 183 | while start < end: 184 | mid = (start + end) // 2 185 | if data[comp_attr] < orig_list[mid][comp_attr]: 186 | start = mid + 1 187 | else: 188 | end = mid 189 | return start 190 | 191 | if_insert = False 192 | insert_idx = get_insert_idx(rec_list, data, comp_attr) 193 | if insert_idx < k: 194 | rec_list.insert(insert_idx, data) 195 | if_insert = True 196 | while len(rec_list) > k: 197 | rec_list.pop() 198 | return if_insert 199 | -------------------------------------------------------------------------------- /run_apis/retrain.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import ast 3 | import logging 4 | import os 5 | import pprint 6 | import sys 7 | import time 8 | 9 | import numpy as np 10 | import torch 11 | import torch.backends.cudnn as cudnn 12 | import torch.nn as nn 13 | from tensorboardX import SummaryWriter 14 | 15 | from configs.imagenet_train_cfg import cfg as config 16 | from dataset import imagenet_data 17 | from models import model_derived 18 | from tools import utils 19 | from tools.lr_scheduler import get_lr_scheduler 20 | from tools.multadds_count import comp_multadds 21 | 22 | from .trainer import Trainer 23 | 24 | if __name__ == '__main__': 25 | parser = argparse.ArgumentParser("Train_Params") 26 | parser.add_argument('--report_freq', type=float, default=500, help='report frequency') 27 | parser.add_argument('--data_path', type=str, default='../data', help='location of the data corpus') 28 | parser.add_argument('--load_path', type=str, default='./model_path', help='model loading path') 29 | parser.add_argument('--save', type=str, default='../', help='experiment name') 30 | parser.add_argument('--tb_path', type=str, default='', help='tensorboard output path') 31 | parser.add_argument('--meas_lat', type=ast.literal_eval, default='False', help='whether to measure the latency of the model') 32 | parser.add_argument('--job_name', type=str, default='', help='job_name') 33 | args = parser.parse_args() 34 | 35 | if args.job_name != '': 36 | args.job_name = time.strftime("%Y%m%d-%H%M%S-") + args.job_name 37 | args.save = os.path.join(args.save, args.job_name) 38 | utils.create_exp_dir(args.save) 39 | os.system('cp -r ./* '+args.save) 40 | else: 41 | args.save = os.path.join(args.save, 'output') 42 | utils.create_exp_dir(args.save) 43 | 44 | if args.tb_path == '': 45 | args.tb_path = args.save 46 | writer = SummaryWriter(args.tb_path) 47 | 48 | utils.set_logging(args.save) 49 | 50 | if not torch.cuda.is_available(): 51 | logging.info('no gpu device available') 52 | sys.exit(1) 53 | 54 | cudnn.benchmark = True 55 | cudnn.enabled = True 56 | 57 | if config.train_params.use_seed: 58 | utils.set_seed(config.train_params.seed) 59 | 60 | logging.info("args = %s", args) 61 | logging.info('Training with config:') 62 | logging.info(pprint.pformat(config)) 63 | 64 | if os.path.isfile(os.path.join(args.load_path, 'net_config')): 65 | config.net_config, config.net_type = utils.load_net_config( 66 | os.path.join(args.load_path, 'net_config')) 67 | derivedNetwork = getattr(model_derived, '%s_Net' % config.net_type.upper()) 68 | model = derivedNetwork(config.net_config, config=config) 69 | 70 | model.eval() 71 | if hasattr(model, 'net_config'): 72 | logging.info("Network Structure: \n" + '|\n'.join(map(str, model.net_config))) 73 | if args.meas_lat: 74 | latency_cpu = utils.latency_measure(model, (3, 224, 224), 1, 2000, mode='cpu') 75 | logging.info('latency_cpu (batch 1): %.2fms' % latency_cpu) 76 | latency_gpu = utils.latency_measure(model, (3, 224, 224), 32, 5000, mode='gpu') 77 | logging.info('latency_gpu (batch 32): %.2fms' % latency_gpu) 78 | params = utils.count_parameters_in_MB(model) 79 | logging.info("Params = %.2fMB" % params) 80 | mult_adds = comp_multadds(model, input_size=config.data.input_size) 81 | logging.info("Mult-Adds = %.2fMB" % mult_adds) 82 | 83 | model = nn.DataParallel(model) 84 | 85 | # whether to resume from a checkpoint 86 | if config.optim.if_resume: 87 | utils.load_model(model, config.optim.resume.load_path) 88 | start_epoch = config.optim.resume.load_epoch + 1 89 | else: 90 | start_epoch = 0 91 | 92 | model = model.cuda() 93 | 94 | if config.optim.label_smooth: 95 | criterion = utils.cross_entropy_with_label_smoothing 96 | else: 97 | criterion = nn.CrossEntropyLoss() 98 | criterion = criterion.cuda() 99 | 100 | optimizer = torch.optim.SGD( 101 | model.parameters(), 102 | config.optim.init_lr, 103 | momentum=config.optim.momentum, 104 | weight_decay=config.optim.weight_decay 105 | ) 106 | 107 | imagenet = imagenet_data.ImageNet12(trainFolder=os.path.join(args.data_path, 'train'), 108 | testFolder=os.path.join(args.data_path, 'val'), 109 | num_workers=config.data.num_workers, 110 | type_of_data_augmentation=config.data.type_of_data_aug, 111 | data_config=config.data) 112 | 113 | if config.optim.use_multi_stage: 114 | (train_queue, week_train_queue), valid_queue = imagenet.getSetTrainTestLoader(config.data.batch_size) 115 | else: 116 | train_queue, valid_queue = imagenet.getTrainTestLoader(config.data.batch_size) 117 | 118 | scheduler = get_lr_scheduler(config, optimizer, train_queue.dataset.__len__()) 119 | scheduler.last_step = start_epoch * (train_queue.dataset.__len__() // config.data.batch_size + 1)-1 120 | 121 | trainer = Trainer(train_queue, valid_queue, optimizer, criterion, scheduler, config, args.report_freq) 122 | 123 | best_epoch = [0, 0, 0] # [epoch, acc_top1, acc_top5] 124 | for epoch in range(start_epoch, config.train_params.epochs): 125 | 126 | if config.optim.use_multi_stage and epoch>=config.optim.multi_stage.stage_epochs: 127 | train_data = week_train_queue 128 | else: 129 | train_data = train_queue 130 | 131 | train_acc_top1, train_acc_top5, train_obj, batch_time, data_time = trainer.train(model, epoch) 132 | 133 | with torch.no_grad(): 134 | val_acc_top1, val_acc_top5, batch_time, data_time = trainer.infer(model, epoch) 135 | 136 | if val_acc_top1 > best_epoch[1]: 137 | best_epoch = [epoch, val_acc_top1, val_acc_top5] 138 | utils.save(model, os.path.join(args.save, 'weights.pt')) 139 | logging.info('BEST EPOCH %d val_top1 %.2f val_top5 %.2f', best_epoch[0], best_epoch[1], best_epoch[2]) 140 | 141 | writer.add_scalar('train_acc_top1', train_acc_top1, epoch) 142 | writer.add_scalar('train_loss', train_obj, epoch) 143 | writer.add_scalar('val_acc_top1', val_acc_top1, epoch) 144 | 145 | if hasattr(model.module, 'net_config'): 146 | logging.info("Network Structure: \n" + '|\n'.join(map(str, model.module.net_config))) 147 | logging.info("Params = %.2fMB" % params) 148 | logging.info("Mult-Adds = %.2fMB" % mult_adds) 149 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DenseNAS 2 | 3 | The code of the CVPR2020 paper [Densely Connected Search Space for More Flexible Neural Architecture Search](https://arxiv.org/abs/1906.09607). 4 | 5 | Neural architecture search (NAS) has dramatically advanced the development of neural network design. We revisit the search space design in most previous NAS methods and find the number of blocks and the widths of blocks are set manually. However, block counts and block widths determine the network scale (depth and width) and make a great influence on both the accuracy and the model cost (FLOPs/latency). 6 | 7 | We propose to search block counts and block widths by designing a densely connected search space, i.e., DenseNAS. The new search space is represented as a dense super network, which is built upon our designed routing blocks. In the super network, routing blocks are densely connected and we search for the best path between them to derive the final architecture. We further propose a chained cost estimation algorithm to approximate the model cost during the search. Both the accuracy and model cost are optimized in DenseNAS. 8 | ![search_space](./imgs/search_space.png) 9 | 10 | ## Updates 11 | * 2020.6 The search code is released, including both MobileNetV2- and ResNet- based search space. 12 | 13 | ## Requirements 14 | 15 | * pytorch >= 1.0.1 16 | * python >= 3.6 17 | 18 | ## Search 19 | 20 | 1. Prepare the image set for search which contains 100 classes of the original ImageNet dataset. And 20% images are used as the validation set and 80% are used as the training set. 21 | 22 | 1). Generate the split list of the image data.
23 | `python dataset/mk_split_img_list.py --image_path 'the path of your ImageNet data' --output_path 'the path to output the list file'` 24 | 25 | 2). Use the image list obtained above to make the lmdb file.
26 | `python dataset/img2lmdb.py --image_path 'the path of your ImageNet data' --list_path 'the path of your image list generated above' --output_path 'the path to output the lmdb file' --split 'split folder (train/val)'` 27 | 28 | 2. Build the latency lookup table (lut) of the search space using the following script or directly use the ones provided in `./latency_list/`.
29 | `python -m run_apis.latency_measure --save 'output path' --input_size 'the input image size' --meas_times 'the times of op measurement' --list_name 'the name of the output lut' --device 'gpu or cpu' --config 'the path of the yaml config'` 30 | 31 | 3. Search for the architectures. (We perform the search process on 4 32G V100 GPUs.)
32 | For MobileNetV2 search:
33 | `python -m run_apis.search --data_path 'the path of the split dataset' --config configs/imagenet_search_cfg_mbv2.yaml`
34 | For ResNet search:
35 | `python -m run_apis.search --data_path 'the path of the split dataset' --config configs/imagenet_search_cfg_resnet.yaml` 36 | 37 | ## Train 38 | 39 | 1. (Optional) We pack the ImageNet data as the lmdb file for faster IO. The lmdb files can be made as follows. If you don't want to use lmdb data, just set `__C.data.train_data_type='img'` in the training config file `imagenet_train_cfg.py`. 40 | 41 | 1). Generate the list of the image data.
42 | `python dataset/mk_img_list.py --image_path 'the path of your image data' --output_path 'the path to output the list file'` 43 | 44 | 2). Use the image list obtained above to make the lmdb file.
45 | `python dataset/img2lmdb.py --image_path 'the path of your image data' --list_path 'the path of your image list' --output_path 'the path to output the lmdb file' --split 'split folder (train/val)'` 46 | 47 | 2. Train the searched model with the following script by assigning `__C.net_config` with the architecture obtained in the above search process. You can also train your customized model by redefine the variable `model` in `retrain.py`.
48 | `python -m run_apis.retrain --data_path 'The path of ImageNet data' --load_path 'The path you put the net_config of the model'` 49 | 50 | ## Evaluate 51 | 52 | 1. Download the related files of the pretrained model and put `net_config` and `weights.pt` into the `model_path` 53 | 2. `python -m run_apis.validation --data_path 'The path of ImageNet data' --load_path 'The path you put the pre-trained model'` 54 | 55 | ## Results 56 | 57 | For experiments on the MobileNetV2-based search space, DenseNAS achieves 75.3\% top-1 accuracy on ImageNet with only 361MB FLOPs and 17.9ms latency on a single TITAN-XP. The larger model searched by DenseNAS achieves 76.1\% accuracy with only 479M FLOPs. DenseNAS further promotes the ImageNet classification accuracies of ResNet-18, -34 and -50-B by 1.5\%, 0.5\% and 0.3\% with 200M, 600M and 680M FLOPs reduction respectively. 58 | 59 | The comparison of model performance on ImageNet under the MobileNetV2-based search spaces. 60 | 61 |

62 | 63 | 64 |

65 | 66 | The comparison of model performance on ImageNet under the ResNet-based search spaces. 67 | 68 |

69 | 70 | 71 |

72 | 73 | Our pre-trained models can be downloaded in the following links. The complete list of the models can be found in [DenseNAS_modelzoo](https://drive.google.com/open?id=183oIMF6IowZrj81kenVBkQoIMlip9kLo). 74 | 75 | | Model | FLOPs | Latency | Top-1(%)| 76 | |----------------------|-------|---------|---------| 77 | | [DenseNAS-Large](https://drive.google.com/open?id=14Zgc-IlxjaRtGyDHJSdMpLHVvOd0Km1u) | 479M | 28.9ms | 76.1 | 78 | | [DenseNAS-A](https://drive.google.com/open?id=1ZdephrAY4GVRqv9SvOXoJDUmO-kWhhml) | 251M | 13.6ms | 73.1 | 79 | | [DenseNAS-B](https://drive.google.com/open?id=1djhL5P1vsWVqWuT5lR7UCxEhw4cET__7) | 314M | 15.4ms | 74.6 | 80 | | [DenseNAS-C](https://drive.google.com/open?id=1L2mqir89b1UiBkePmrtjG6QLi9MqzRdQ) | 361M | 17.9ms | 75.3 | 81 | | [DenseNAS-R1](https://drive.google.com/open?id=1YaMWb1LKpgSS5mgBcB3CthTGtTIOtxWw) | 1.61B | 12.0ms | 73.5 | 82 | | [DenseNAS-R2](https://drive.google.com/open?id=1Qawst3E2hqdam2TiTFo2BhBXS-M6AWdh) | 3.06B | 22.2ms | 75.8 | 83 | | [DenseNAS-R3](https://drive.google.com/open?id=14RwIGWsurNvevhxL9AcnlngU0KR8WeX-) | 3.41B | 41.7ms | 78.0 | 84 | 85 | ![archs](imgs/archs.png) 86 | 87 | ## Citation 88 | If you find this repository/work helpful in your research, welcome to cite it. 89 | ``` 90 | @inproceedings{fang2019densely, 91 | title={Densely connected search space for more flexible neural architecture search}, 92 | author={Fang, Jiemin and Sun, Yuzhu and Zhang, Qian and Li, Yuan and Liu, Wenyu and Wang, Xinggang}, 93 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 94 | year={2020} 95 | } 96 | ``` -------------------------------------------------------------------------------- /latency_list/lat_list_resv2_32_reg10: -------------------------------------------------------------------------------- 1 | [0.7470053614992084, [[[[1.7158220271871547]], []], [[[1.6621649144875883], [1.7133765027980612]], []], [[[1.9134202629628807], [1.8592351855653704], [1.8933377362260915]], []], [[[1.0193982991305264], [0.9558140388642898], [0.8666880443842724]], [[0.6160835545472425, 0.06160835545472425], [0.6116041270169345, 0.061160412701693444], [0.6122468938731184, 0.061224689387311834], [0.6109231168573552, 0.06109231168573552], [0.6104896766970856, 0.06104896766970856]]], [[[0.877161074166346], [1.1209101147121854], [1.0645518158421372], [1.0146694231514979]], [[0.8019653474441683, 0.08019653474441682], [0.7993539415224634, 0.07993539415224635], [0.8006674593145197, 0.08006674593145197], [0.799418796192516, 0.0799418796192516], [0.8003230046744298, 0.08003230046744299]]], [[[1.2716174125671387], [1.1164091813443888], [1.2395282466002184], [1.195189832436918], [1.1547757399202596]], [[1.094416416052616, 0.1094416416052616], [1.0937702535378813, 0.10937702535378813], [1.0924938712457213, 0.10924938712457213], [1.0943670465488626, 0.10943670465488627], [1.092695009828818, 0.1092695009828818]]], [[[0.5315265029367775], [0.49857038440126356], [0.4235670301649306]], [[0.41330046123928493, 0.041330046123928495], [0.41206465827094185, 0.04120646582709418], [0.4113578796386719, 0.04113578796386719], [0.4114802196772412, 0.041148021967724116], [0.41177434150618736, 0.04117743415061874], [0.41203164091013894, 0.04120316409101389], [0.4116343488596907, 0.04116343488596907], [0.4124539307873658, 0.04124539307873658], [0.41144332500419234, 0.04114433250041923], [0.41133061803952614, 0.04113306180395261], [0.41104555130004883, 0.04110455513000488], [0.41221972667809686, 0.041221972667809685], [0.4108423897714326, 0.041084238977143264], [0.4113530149363508, 0.041135301493635076], [0.41111775118895255, 0.04111177511889526]]], [[[0.5795735301393451], [0.7294338158886843], [0.6747800412804189], [0.5880859163072375]], [[0.4475170915777033, 0.04475170915777033], [0.4473568453933253, 0.044735684539332535], [0.44746613261675594, 0.044746613261675595], [0.4469970982484143, 0.044699709824841435], [0.4433615520747021, 0.04433615520747021], [0.4466982321305708, 0.04466982321305708], [0.44393339542427446, 0.04439333954242745], [0.44414787581472687, 0.04441478758147269], [0.44315029876400724, 0.04431502987640072], [0.44436098349214803, 0.0444360983492148], [0.4429972292196871, 0.04429972292196871], [0.4465308093061351, 0.04465308093061351], [0.4408782660359084, 0.04408782660359084], [0.44157577283454663, 0.044157577283454666], [0.44270563607264046, 0.044270563607264043]]], [[[0.5950747355066165], [0.5492691319398205], [0.7678800881511033], [0.7344900237189399], [0.6177222849142672]], [[0.4978300826718109, 0.049783008267181086], [0.4936368296844791, 0.04936368296844791], [0.4946338287507645, 0.04946338287507645], [0.4952423259465381, 0.049524232594653814], [0.4930364483534688, 0.04930364483534688], [0.49611074755890205, 0.0496110747558902], [0.4933389991220802, 0.04933389991220802], [0.4946485673538362, 0.049464856735383624], [0.4925500985347864, 0.04925500985347864], [0.4923339082737162, 0.04923339082737162], [0.4922379869403261, 0.04922379869403261], [0.4916638798183865, 0.04916638798183865], [0.4915951960014574, 0.04915951960014574], [0.49248155921396586, 0.049248155921396586], [0.4925919060755257, 0.04925919060755257]]], [[[0.7117494188173853], [0.6895356467275908], [0.6656108480511289], [0.8351398236823805], [0.7782968366988982], [0.6874867400737724]], [[0.6256991685038865, 0.06256991685038865], [0.6235957145690918, 0.06235957145690918], [0.6226729383372297, 0.06226729383372297], [0.6236689981788095, 0.06236689981788095], [0.6236180873832318, 0.06236180873832318], [0.6223808153711184, 0.06223808153711184], [0.6238544589341288, 0.06238544589341288], [0.621514826109915, 0.0621514826109915], [0.6216128185541943, 0.06216128185541943], [0.6210656358738138, 0.06210656358738138], [0.6469456595603865, 0.06469456595603865], [0.6190218106664792, 0.061902181066647924], [0.61852570736047, 0.061852570736047], [0.6181354715366557, 0.06181354715366557], [0.621467214642149, 0.062146721464214905]]], [[[0.7768191954102179], [0.7580275005764431], [0.7367050045668477], [0.7167703455144709]], [[0.6815665900105178, 0.06815665900105178], [0.6803986038824524, 0.06803986038824525], [0.6794075291566174, 0.06794075291566173], [0.6797389309815686, 0.06797389309815685], [0.6779417847142075, 0.06779417847142075]]], [[[0.9306495839899237], [0.8969544160245645], [0.8677374955379602], [0.8321288378551753], [0.8342863814999358]], [[0.8310986046839242, 0.08310986046839242], [0.8355484345946649, 0.08355484345946648], [0.8296608443212028, 0.08296608443212028], [0.8263708605910793, 0.08263708605910794], [0.8277203338314788, 0.08277203338314788]]], [[[1.0617327449297664], [0.9705225867454452], [0.9498220983177724], [0.9233483882865521], [0.8980551392141014], [0.8740888701544868]], [[0.8907880927577163, 0.08907880927577164], [0.8912024112662884, 0.08912024112662884], [0.8915229999657833, 0.08915229999657834], [0.8909804170781916, 0.08909804170781917], [0.890022504209268, 0.0890022504209268]]], [[[0.5085483223500877], [0.49751910296353424], [0.4761004688763859]], [[0.41085341964105165, 0.04108534196410517], [0.4084869827887025, 0.04084869827887025], [0.40876537862450185, 0.040876537862450185], [0.40820923718539154, 0.040820923718539154], [0.4086827509330981, 0.04086827509330981]]], [[[0.5120689218694514], [0.5295760944636182], [0.5133399337229102], [0.4979153353758532]], [[0.44370985994435325, 0.04437098599443533], [0.4397995785029248, 0.043979957850292475], [0.43866104549831814, 0.043866104549831815], [0.44069872962103945, 0.04406987296210395], [0.4398404468189586, 0.04398404468189586]]], [[[0.5302517823498658], [0.5341541646706938], [0.6771185903838186], [0.6394316692544957], [0.6240752008226182]], [[0.4810813701514042, 0.04810813701514042], [0.4783633261015921, 0.04783633261015921], [0.47798166371355155, 0.04779816637135516], [0.47773036089810456, 0.047773036089810456], [0.4787054447212605, 0.04787054447212605]]], [[[0.5921873420175879], [0.5748705671291159], [0.5476833593965781], [0.6850687662760417], [0.6735895137594203], [0.6462683340515754]], [[0.51674510493423, 0.051674510493423], [0.5135976425325027, 0.05135976425325027], [0.5137101086703213, 0.05137101086703213], [0.5135496457417805, 0.051354964574178055], [0.5139973669341117, 0.05139973669341117]]], [[[1.1904039045776984], [1.171813565071183], [1.13944036792023], [1.1230145078716856]], [[1.400846206780636, 0.1400846206780636]]], [[[1.5440915810941447], [1.2339090578483813], [1.1897495780328307], [1.1610502666897244], [1.1645580060554275]], [[1.4041198865331783, 0.14041198865331783]]], [[[1.448338104016853], [1.4705032050007523], [1.1230347614095668], [1.1128196330985638], [1.0730700059370561], [1.0587836516023885]], [[1.2187801226220951, 0.12187801226220951]]]], [0.13975803298179548, 0.1369504976754237, 0.13584091205789584], 0.19902790435636888] -------------------------------------------------------------------------------- /latency_list/lat_list_resv2_32: -------------------------------------------------------------------------------- 1 | [0.7470053614992084, [[[[1.7158220271871547]], []], [[[1.6621649144875883], [1.7133765027980612]], []], [[[1.9134202629628807], [1.8592351855653704], [1.8933377362260915]], []], [[[1.0193982991305264], [0.9558140388642898], [0.8666880443842724]], [[0.6160835545472425, 0.005083878835042318], [0.6116041270169345, 0.004423748363148083], [0.6122468938731184, 0.004710428642504143], [0.6109231168573552, 0.005227267140089863], [0.6104896766970856, 0.004845990075005426]]], [[[0.877161074166346], [1.1209101147121854], [1.0645518158421372], [1.0146694231514979]], [[0.8019653474441683, 0.005188060529304274], [0.7993539415224634, 0.004809480724912701], [0.8006674593145197, 0.005086792839898003], [0.799418796192516, 0.0048404992228806626], [0.8003230046744298, 0.005279695144807449]]], [[[1.2716174125671387], [1.1164091813443888], [1.2395282466002184], [1.195189832436918], [1.1547757399202596]], [[1.094416416052616, 0.004327369458747632], [1.0937702535378813, 0.0049243069658375755], [1.0924938712457213, 0.00507874922318892], [1.0943670465488626, 0.005123567099523063], [1.092695009828818, 0.005048718115296026]]], [[[0.5315265029367775], [0.49857038440126356], [0.4235670301649306]], [[0.41330046123928493, 0.0046339902010830965], [0.41206465827094185, 0.004920863141917219], [0.4113578796386719, 0.005258743209068222], [0.4114802196772412, 0.0052337212996049366], [0.41177434150618736, 0.004909568362765842], [0.41203164091013894, 0.004904535081651476], [0.4116343488596907, 0.00530339250660906], [0.4124539307873658, 0.004672931902336352], [0.41144332500419234, 0.00468335970483645], [0.41133061803952614, 0.005149263324159564], [0.41104555130004883, 0.004897117614746094], [0.41221972667809686, 0.004751080214375197], [0.4108423897714326, 0.005267075817994397], [0.4113530149363508, 0.004888375600179036], [0.41111775118895255, 0.0051637851830684785]]], [[[0.5795735301393451], [0.7294338158886843], [0.6747800412804189], [0.5880859163072375]], [[0.4475170915777033, 0.004560899252843376], [0.4473568453933253, 0.005116848030475655], [0.44746613261675594, 0.0050102580677379265], [0.4469970982484143, 0.004821449819237295], [0.4433615520747021, 0.004868603715992937], [0.4466982321305708, 0.005760216953778508], [0.44393339542427446, 0.004536623906607579], [0.44414787581472687, 0.004877586557407572], [0.44315029876400724, 0.004928280608822601], [0.44436098349214803, 0.004839006096425683], [0.4429972292196871, 0.004484942465117484], [0.4465308093061351, 0.005215683368721394], [0.4408782660359084, 0.004838741186893348], [0.44157577283454663, 0.004812563308561691], [0.44270563607264046, 0.005174333398992365]]], [[[0.5950747355066165], [0.5492691319398205], [0.7678800881511033], [0.7344900237189399], [0.6177222849142672]], [[0.4978300826718109, 0.004748093961465239], [0.4936368296844791, 0.004589581730389836], [0.4946338287507645, 0.005070175787415167], [0.4952423259465381, 0.004672811488912563], [0.4930364483534688, 0.004822509457366635], [0.49611074755890205, 0.0048711564805772566], [0.4933389991220802, 0.004833105838660038], [0.4946485673538362, 0.004776294785316544], [0.4925500985347864, 0.004769792460431956], [0.4923339082737162, 0.005009198429608586], [0.4922379869403261, 0.00521529804576527], [0.4916638798183865, 0.004769985121910018], [0.4915951960014574, 0.004877273482505721], [0.49248155921396586, 0.005180137326018979], [0.4925919060755257, 0.005257298247982757]]], [[[0.7117494188173853], [0.6895356467275908], [0.6656108480511289], [0.8351398236823805], [0.7782968366988982], [0.6874867400737724]], [[0.6256991685038865, 0.0046522930414989744], [0.6235957145690918, 0.005195285334731593], [0.6226729383372297, 0.00455521573924055], [0.6236689981788095, 0.0046892840452868526], [0.6236180873832318, 0.004846182736483488], [0.6223808153711184, 0.005256985173080907], [0.6238544589341288, 0.004499079001070274], [0.621514826109915, 0.00439465647996074], [0.6216128185541943, 0.005834873276527482], [0.6210656358738138, 0.005216092774362275], [0.6469456595603865, 0.0049324469132856884], [0.6190218106664792, 0.0048540096090297505], [0.61852570736047, 0.005814354829113893], [0.6181354715366557, 0.004709561665852864], [0.621467214642149, 0.004667681877059166]]], [[[0.7768191954102179], [0.7580275005764431], [0.7367050045668477], [0.7167703455144709]], [[0.6815665900105178, 0.00501138995392154], [0.6803986038824524, 0.004988993057096848], [0.6794075291566174, 0.005249977111816406], [0.6797389309815686, 0.004805507081927675], [0.6779417847142075, 0.004855791727701823]]], [[[0.9306495839899237], [0.8969544160245645], [0.8677374955379602], [0.8321288378551753], [0.8342863814999358]], [[0.8310986046839242, 0.004771863571321121], [0.8355484345946649, 0.004956047944348268], [0.8296608443212028, 0.004708622441147313], [0.8263708605910793, 0.004478680967080473], [0.8277203338314788, 0.004992942617397115]]], [[[1.0617327449297664], [0.9705225867454452], [0.9498220983177724], [0.9233483882865521], [0.8980551392141014], [0.8740888701544868]], [[0.8907880927577163, 0.0046964125199751424], [0.8912024112662884, 0.004646922602798], [0.8915229999657833, 0.004888905419243706], [0.8909804170781916, 0.0051665065264461015], [0.890022504209268, 0.004725648899271031]]], [[[0.5085483223500877], [0.49751910296353424], [0.4761004688763859]], [[0.41085341964105165, 0.004508663909603851], [0.4084869827887025, 0.004639240226360282], [0.40876537862450185, 0.005059145917796126], [0.40820923718539154, 0.005027236360492128], [0.4086827509330981, 0.004860897256870463]]], [[[0.5120689218694514], [0.5295760944636182], [0.5133399337229102], [0.4979153353758532]], [[0.44370985994435325, 0.00485039720631609], [0.4397995785029248, 0.0051300935070924085], [0.43866104549831814, 0.00495421766030668], [0.44069872962103945, 0.00507662994693024], [0.4398404468189586, 0.0051092860674617265]]], [[[0.5302517823498658], [0.5341541646706938], [0.6771185903838186], [0.6394316692544957], [0.6240752008226182]], [[0.4810813701514042, 0.005186904560435902], [0.4783633261015921, 0.005180474483605587], [0.47798166371355155, 0.005193695877537583], [0.47773036089810456, 0.004894829759694109], [0.4787054447212605, 0.004925438852021189]]], [[[0.5921873420175879], [0.5748705671291159], [0.5476833593965781], [0.6850687662760417], [0.6735895137594203], [0.6462683340515754]], [[0.51674510493423, 0.004966836987119733], [0.5135976425325027, 0.005099749324297664], [0.5137101086703213, 0.004888158856016217], [0.5135496457417805, 0.005136788493455059], [0.5139973669341117, 0.004399376686173255]]], [[[1.1904039045776984], [1.171813565071183], [1.13944036792023], [1.1230145078716856]], [[1.400846206780636, 0.005112127824263139]]], [[[1.5440915810941447], [1.2339090578483813], [1.1897495780328307], [1.1610502666897244], [1.1645580060554275]], [[1.4041198865331783, 0.004999589438390251]]], [[[1.448338104016853], [1.4705032050007523], [1.1230347614095668], [1.1128196330985638], [1.0730700059370561], [1.0587836516023885]], [[1.2187801226220951, 0.005037182509297072]]]], [0.13975803298179548, 0.1369504976754237, 0.13584091205789584], 0.19902790435636888] -------------------------------------------------------------------------------- /models/model_derived.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .operations import OPS 7 | from tools.utils import parse_net_config 8 | 9 | 10 | class Block(nn.Module): 11 | 12 | def __init__(self, in_ch, block_ch, head_op, stack_ops, stride): 13 | super(Block, self).__init__() 14 | self.head_layer = OPS[head_op](in_ch, block_ch, stride, 15 | affine=True, track_running_stats=True) 16 | 17 | modules = [] 18 | for stack_op in stack_ops: 19 | modules.append(OPS[stack_op](block_ch, block_ch, 1, 20 | affine=True, track_running_stats=True)) 21 | self.stack_layers = nn.Sequential(*modules) 22 | 23 | def forward(self, x): 24 | x = self.head_layer(x) 25 | x = self.stack_layers(x) 26 | return x 27 | 28 | 29 | class Conv1_1_Block(nn.Module): 30 | 31 | def __init__(self, in_ch, block_ch): 32 | super(Conv1_1_Block, self).__init__() 33 | self.conv1_1 = nn.Sequential( 34 | nn.Conv2d(in_channels=in_ch, out_channels=block_ch, 35 | kernel_size=1, stride=1, padding=0, bias=False), 36 | nn.BatchNorm2d(block_ch), 37 | nn.ReLU6(inplace=True) 38 | ) 39 | 40 | def forward(self, x): 41 | return self.conv1_1(x) 42 | 43 | 44 | class MBV2_Net(nn.Module): 45 | def __init__(self, net_config, config=None): 46 | """ 47 | net_config=[[in_ch, out_ch], head_op, [stack_ops], num_stack_layers, stride] 48 | """ 49 | super(MBV2_Net, self).__init__() 50 | self.config = config 51 | self.net_config = parse_net_config(net_config) 52 | self.in_chs = self.net_config[0][0][0] 53 | self._num_classes = 1000 54 | 55 | self.input_block = nn.Sequential( 56 | nn.Conv2d(in_channels=3, out_channels=self.in_chs, kernel_size=3, 57 | stride=2, padding=1, bias=False), 58 | nn.BatchNorm2d(self.in_chs), 59 | nn.ReLU6(inplace=True) 60 | ) 61 | self.blocks = nn.ModuleList() 62 | for config in self.net_config: 63 | if config[1] == 'conv1_1': 64 | continue 65 | self.blocks.append(Block(config[0][0], config[0][1], 66 | config[1], config[2], config[-1])) 67 | 68 | if self.net_config[-1][1] == 'conv1_1': 69 | block_last_dim = self.net_config[-1][0][0] 70 | last_dim = self.net_config[-1][0][1] 71 | else: 72 | block_last_dim = self.net_config[-1][0][1] 73 | self.conv1_1_block = Conv1_1_Block(block_last_dim, last_dim) 74 | 75 | self.global_pooling = nn.AdaptiveAvgPool2d(1) 76 | self.classifier = nn.Linear(last_dim, self._num_classes) 77 | 78 | self.init_model() 79 | self.set_bn_param(0.1, 0.001) 80 | 81 | 82 | def forward(self,x): 83 | block_data = self.input_block(x) 84 | for i, block in enumerate(self.blocks): 85 | block_data = block(block_data) 86 | block_data = self.conv1_1_block(block_data) 87 | 88 | out = self.global_pooling(block_data) 89 | logits = self.classifier(out.view(out.size(0),-1)) 90 | 91 | return logits 92 | 93 | def init_model(self, model_init='he_fout', init_div_groups=True): 94 | for m in self.modules(): 95 | if isinstance(m, nn.Conv2d): 96 | if model_init == 'he_fout': 97 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 98 | if init_div_groups: 99 | n /= m.groups 100 | m.weight.data.normal_(0, math.sqrt(2. / n)) 101 | elif model_init == 'he_fin': 102 | n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels 103 | if init_div_groups: 104 | n /= m.groups 105 | m.weight.data.normal_(0, math.sqrt(2. / n)) 106 | else: 107 | raise NotImplementedError 108 | elif isinstance(m, nn.BatchNorm2d): 109 | m.weight.data.fill_(1) 110 | m.bias.data.zero_() 111 | elif isinstance(m, nn.Linear): 112 | if m.bias is not None: 113 | m.bias.data.zero_() 114 | elif isinstance(m, nn.BatchNorm1d): 115 | m.weight.data.fill_(1) 116 | m.bias.data.zero_() 117 | 118 | def set_bn_param(self, bn_momentum, bn_eps): 119 | for m in self.modules(): 120 | if isinstance(m, nn.BatchNorm2d): 121 | m.momentum = bn_momentum 122 | m.eps = bn_eps 123 | return 124 | 125 | 126 | class RES_Net(nn.Module): 127 | def __init__(self, net_config, config=None): 128 | """ 129 | net_config=[[in_ch, out_ch], head_op, [stack_ops], num_stack_layers, stride] 130 | """ 131 | super(RES_Net, self).__init__() 132 | self.config = config 133 | self.net_config = parse_net_config(net_config) 134 | self.in_chs = self.net_config[0][0][0] 135 | self._num_classes = 1000 136 | 137 | self.input_block = nn.Sequential( 138 | nn.Conv2d(in_channels=3, out_channels=self.in_chs, kernel_size=3, 139 | stride=2, padding=1, bias=False), 140 | nn.BatchNorm2d(self.in_chs), 141 | nn.ReLU6(inplace=True), 142 | ) 143 | self.blocks = nn.ModuleList() 144 | for config in self.net_config: 145 | self.blocks.append(Block(config[0][0], config[0][1], 146 | config[1], config[2], config[-1])) 147 | 148 | self.global_pooling = nn.AdaptiveAvgPool2d((1, 1)) 149 | if self.net_config[-1][1] == 'bottle_neck': 150 | last_dim = self.net_config[-1][0][-1] * 4 151 | else: 152 | last_dim = self.net_config[-1][0][1] 153 | self.classifier = nn.Linear(last_dim, self._num_classes) 154 | 155 | for m in self.modules(): 156 | if isinstance(m, nn.Conv2d): 157 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 158 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 159 | if m.affine==True: 160 | nn.init.constant_(m.weight, 1) 161 | nn.init.constant_(m.bias, 0) 162 | 163 | def forward(self, x): 164 | block_data = self.input_block(x) 165 | for i, block in enumerate(self.blocks): 166 | block_data = block(block_data) 167 | 168 | out = self.global_pooling(block_data) 169 | out = torch.flatten(out, 1) 170 | logits = self.classifier(out) 171 | return logits -------------------------------------------------------------------------------- /dataset/imagenet_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | 7 | from . import lmdb_dataset 8 | from . import torchvision_extension as transforms_extension 9 | from .prefetch_data import fast_collate 10 | 11 | class ImageNet12(object): 12 | 13 | def __init__(self, trainFolder, testFolder, num_workers=8, pin_memory=True, 14 | size_images=224, scaled_size=256, type_of_data_augmentation='rand_scale', 15 | data_config=None): 16 | 17 | self.data_config = data_config 18 | self.trainFolder = trainFolder 19 | self.testFolder = testFolder 20 | self.num_workers = num_workers 21 | self.pin_memory = pin_memory 22 | self.patch_dataset = self.data_config.patch_dataset 23 | 24 | #images will be rescaled to match this size 25 | if not isinstance(size_images, int): 26 | raise ValueError('size_images must be an int. It will be scaled to a square image') 27 | self.size_images = size_images 28 | self.scaled_size = scaled_size 29 | 30 | type_of_data_augmentation = type_of_data_augmentation.lower() 31 | if type_of_data_augmentation not in ('rand_scale', 'random_sized'): 32 | raise ValueError('type_of_data_augmentation must be either rand-scale or random-sized') 33 | self.type_of_data_augmentation = type_of_data_augmentation 34 | 35 | 36 | def _getTransformList(self, aug_type): 37 | 38 | assert aug_type in ['rand_scale', 'random_sized', 'week_train', 'validation'] 39 | list_of_transforms = [] 40 | 41 | if aug_type == 'validation': 42 | list_of_transforms.append(transforms.Resize(self.scaled_size)) 43 | list_of_transforms.append(transforms.CenterCrop(self.size_images)) 44 | 45 | elif aug_type == 'week_train': 46 | list_of_transforms.append(transforms.Resize(256)) 47 | list_of_transforms.append(transforms.RandomCrop(self.size_images)) 48 | list_of_transforms.append(transforms.RandomHorizontalFlip()) 49 | 50 | else: 51 | if aug_type == 'rand_scale': 52 | list_of_transforms.append(transforms_extension.RandomScale(256, 480)) 53 | list_of_transforms.append(transforms.RandomCrop(self.size_images)) 54 | list_of_transforms.append(transforms.RandomHorizontalFlip()) 55 | 56 | elif aug_type == 'random_sized': 57 | list_of_transforms.append(transforms.RandomResizedCrop(self.size_images, 58 | scale=(self.data_config.random_sized.min_scale, 1.0))) 59 | list_of_transforms.append(transforms.RandomHorizontalFlip()) 60 | 61 | if self.data_config.color: 62 | list_of_transforms.append(transforms.ColorJitter(brightness=0.4, 63 | contrast=0.4, 64 | saturation=0.4)) 65 | return transforms.Compose(list_of_transforms) 66 | 67 | 68 | def _getTrainSet(self): 69 | 70 | train_transform = self._getTransformList(self.type_of_data_augmentation) 71 | 72 | if self.data_config.train_data_type == 'img': 73 | train_set = torchvision.datasets.ImageFolder(self.trainFolder, train_transform) 74 | elif self.data_config.train_data_type == 'lmdb': 75 | train_set = lmdb_dataset.ImageFolder(self.trainFolder, 76 | os.path.join(self.trainFolder, '..', 'train_datalist'), 77 | train_transform, 78 | patch_dataset=self.patch_dataset) 79 | self.train_num_examples = train_set.__len__() 80 | 81 | return train_set 82 | 83 | 84 | def _getWeekTrainSet(self): 85 | 86 | train_transform = self._getTransformList('week_train') 87 | if self.data_config.train_data_type == 'img': 88 | train_set = torchvision.datasets.ImageFolder(self.trainFolder, train_transform) 89 | elif self.data_config.train_data_type == 'lmdb': 90 | train_set = lmdb_dataset.ImageFolder(self.trainFolder, 91 | os.path.join(self.trainFolder, '..', 'train_datalist'), 92 | train_transform, 93 | patch_dataset=self.patch_dataset) 94 | self.train_num_examples = train_set.__len__() 95 | return train_set 96 | 97 | 98 | def _getTestSet(self): 99 | 100 | test_transform = self._getTransformList('validation') 101 | if self.data_config.val_data_type == 'img': 102 | test_set = torchvision.datasets.ImageFolder(self.testFolder, test_transform) 103 | elif self.data_config.val_data_type == 'lmdb': 104 | test_set = lmdb_dataset.ImageFolder(self.testFolder, 105 | os.path.join(self.testFolder, '..', 'val_datalist'), 106 | test_transform) 107 | self.test_num_examples = test_set.__len__() 108 | return test_set 109 | 110 | 111 | def getTrainLoader(self, batch_size, shuffle=True): 112 | 113 | train_set = self._getTrainSet() 114 | train_loader = torch.utils.data.DataLoader( 115 | train_set, batch_size=batch_size, shuffle=shuffle, 116 | num_workers=self.num_workers, pin_memory=self.pin_memory, 117 | sampler=None, collate_fn=fast_collate) 118 | return train_loader 119 | 120 | 121 | def getWeekTrainLoader(self, batch_size, shuffle=True): 122 | 123 | train_set = self._getWeekTrainSet() 124 | train_loader = torch.utils.data.DataLoader(train_set, 125 | batch_size=batch_size, 126 | shuffle=shuffle, 127 | num_workers=self.num_workers, 128 | pin_memory=self.pin_memory, 129 | collate_fn=fast_collate) 130 | return train_loader 131 | 132 | 133 | def getTestLoader(self, batch_size, shuffle=False): 134 | 135 | test_set = self._getTestSet() 136 | 137 | test_loader = torch.utils.data.DataLoader( 138 | test_set, batch_size=batch_size, shuffle=shuffle, 139 | num_workers=self.num_workers, pin_memory=self.pin_memory, sampler=None, 140 | collate_fn=fast_collate) 141 | return test_loader 142 | 143 | 144 | def getTrainTestLoader(self, batch_size, train_shuffle=True, val_shuffle=False): 145 | 146 | train_loader = self.getTrainLoader(batch_size, train_shuffle) 147 | test_loader = self.getTestLoader(batch_size, val_shuffle) 148 | return train_loader, test_loader 149 | 150 | 151 | def getSetTrainTestLoader(self, batch_size, train_shuffle=True, val_shuffle=False): 152 | 153 | train_loader = self.getTrainLoader(batch_size, train_shuffle) 154 | week_train_loader = self.getWeekTrainLoader(batch_size, train_shuffle) 155 | test_loader = self.getTestLoader(batch_size, val_shuffle) 156 | return (train_loader, week_train_loader), test_loader 157 | -------------------------------------------------------------------------------- /models/operations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | OPS = { 5 | 'mbconv_k3_t1': lambda C_in, C_out, stride, affine, track_running_stats: MBConv(C_in, C_out, 3, stride, 1, t=1, affine=affine, track_running_stats=track_running_stats), 6 | 'mbconv_k3_t3': lambda C_in, C_out, stride, affine, track_running_stats: MBConv(C_in, C_out, 3, stride, 1, t=3, affine=affine, track_running_stats=track_running_stats), 7 | 'mbconv_k3_t6': lambda C_in, C_out, stride, affine, track_running_stats: MBConv(C_in, C_out, 3, stride, 1, t=6, affine=affine, track_running_stats=track_running_stats), 8 | 'mbconv_k5_t3': lambda C_in, C_out, stride, affine, track_running_stats: MBConv(C_in, C_out, 5, stride, 2, t=3, affine=affine, track_running_stats=track_running_stats), 9 | 'mbconv_k5_t6': lambda C_in, C_out, stride, affine, track_running_stats: MBConv(C_in, C_out, 5, stride, 2, t=6, affine=affine, track_running_stats=track_running_stats), 10 | 'mbconv_k7_t3': lambda C_in, C_out, stride, affine, track_running_stats: MBConv(C_in, C_out, 7, stride, 3, t=3, affine=affine, track_running_stats=track_running_stats), 11 | 'mbconv_k7_t6': lambda C_in, C_out, stride, affine, track_running_stats: MBConv(C_in, C_out, 7, stride, 3, t=6, affine=affine, track_running_stats=track_running_stats), 12 | 'basic_block': lambda C_in, C_out, stride, affine, track_running_stats: BasicBlock(C_in, C_out, stride, affine=affine, track_running_stats=track_running_stats), 13 | 'bottle_neck': lambda C_in, C_out, stride, affine, track_running_stats: Bottleneck(C_in, C_out, stride, affine=affine, track_running_stats=track_running_stats), 14 | 'skip_connect': lambda C_in, C_out, stride, affine, track_running_stats: Skip(C_in, C_out, 1, affine=affine, track_running_stats=track_running_stats), 15 | } 16 | 17 | 18 | class MBConv(nn.Module): 19 | def __init__(self, C_in, C_out, kernel_size, stride, padding, t=3, affine=True, 20 | track_running_stats=True, use_se=False): 21 | super(MBConv, self).__init__() 22 | self.t = t 23 | if self.t > 1: 24 | self._expand_conv = nn.Sequential( 25 | nn.Conv2d(C_in, C_in*self.t, kernel_size=1, stride=1, padding=0, groups=1, bias=False), 26 | nn.BatchNorm2d(C_in*self.t, affine=affine, track_running_stats=track_running_stats), 27 | nn.ReLU6(inplace=True)) 28 | 29 | self._depthwise_conv = nn.Sequential( 30 | nn.Conv2d(C_in*self.t, C_in*self.t, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in*self.t, bias=False), 31 | nn.BatchNorm2d(C_in*self.t, affine=affine, track_running_stats=track_running_stats), 32 | nn.ReLU6(inplace=True)) 33 | 34 | self._project_conv = nn.Sequential( 35 | nn.Conv2d(C_in*self.t, C_out, kernel_size=1, stride=1, padding=0, groups=1, bias=False), 36 | nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats)) 37 | else: 38 | self._expand_conv = None 39 | 40 | self._depthwise_conv = nn.Sequential( 41 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False), 42 | nn.BatchNorm2d(C_in, affine=affine, track_running_stats=track_running_stats), 43 | nn.ReLU6(inplace=True)) 44 | 45 | self._project_conv = nn.Sequential( 46 | nn.Conv2d(C_in, C_out, 1, 1, 0, bias=False), 47 | nn.BatchNorm2d(C_out)) 48 | 49 | def forward(self, x): 50 | input_data = x 51 | if self._expand_conv is not None: 52 | x = self._expand_conv(x) 53 | x = self._depthwise_conv(x) 54 | out_data = self._project_conv(x) 55 | 56 | if out_data.shape == input_data.shape: 57 | return out_data + input_data 58 | else: 59 | return out_data 60 | 61 | 62 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 63 | """3x3 convolution with padding""" 64 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 65 | padding=dilation, groups=groups, bias=False, dilation=dilation) 66 | 67 | def conv1x1(in_planes, out_planes, stride=1): 68 | """1x1 convolution""" 69 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 70 | 71 | class BasicBlock(nn.Module): 72 | def __init__(self, inplanes, planes, stride=1, groups=1, 73 | base_width=64, dilation=1, norm_layer=None, 74 | affine=True, track_running_stats=True): 75 | super(BasicBlock, self).__init__() 76 | if norm_layer is None: 77 | norm_layer = nn.BatchNorm2d 78 | if groups != 1 or base_width != 64: 79 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 80 | if dilation > 1: 81 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 82 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 83 | self.conv1 = conv3x3(inplanes, planes, stride) 84 | self.bn1 = norm_layer(planes, affine=affine, track_running_stats=track_running_stats) 85 | self.relu = nn.ReLU(inplace=True) 86 | self.conv2 = conv3x3(planes, planes) 87 | self.bn2 = norm_layer(planes, affine=affine, track_running_stats=track_running_stats) 88 | self.downsample = None 89 | if stride != 1 or inplanes != planes: 90 | self.downsample = nn.Sequential( 91 | conv1x1(inplanes, planes, stride), 92 | norm_layer(planes, affine=affine, track_running_stats=track_running_stats), 93 | ) 94 | 95 | def forward(self, x): 96 | identity = x 97 | out = self.conv1(x) 98 | out = self.bn1(out) 99 | out = self.relu(out) 100 | out = self.conv2(out) 101 | out = self.bn2(out) 102 | 103 | if self.downsample is not None: 104 | identity = self.downsample(x) 105 | out += identity 106 | out = self.relu(out) 107 | 108 | return out 109 | 110 | 111 | class Bottleneck(nn.Module): 112 | def __init__(self, inplanes, planes, stride=1, affine=True, track_running_stats=True): 113 | super(Bottleneck, self).__init__() 114 | if inplanes != 32: 115 | inplanes *= 4 116 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 117 | self.bn1 = nn.BatchNorm2d(planes) 118 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 119 | padding=1, bias=False) 120 | self.bn2 = nn.BatchNorm2d(planes) 121 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 122 | self.bn3 = nn.BatchNorm2d(planes * 4) 123 | self.relu = nn.ReLU(inplace=True) 124 | self.stride = stride 125 | self.downsample = None 126 | if stride != 1 or inplanes != planes*4: 127 | self.downsample = nn.Sequential( 128 | conv1x1(inplanes, planes * 4, stride), 129 | nn.BatchNorm2d(planes * 4, affine=affine, track_running_stats=track_running_stats), 130 | ) 131 | 132 | def forward(self, x): 133 | residual = x 134 | 135 | out = self.conv1(x) 136 | out = self.bn1(out) 137 | out = self.relu(out) 138 | 139 | out = self.conv2(out) 140 | out = self.bn2(out) 141 | out = self.relu(out) 142 | 143 | out = self.conv3(out) 144 | out = self.bn3(out) 145 | 146 | if self.downsample is not None: 147 | residual = self.downsample(x) 148 | 149 | out += residual 150 | out = self.relu(out) 151 | 152 | return out 153 | 154 | 155 | class Skip(nn.Module): 156 | def __init__(self, C_in, C_out, stride, affine=True, track_running_stats=True): 157 | super(Skip, self).__init__() 158 | if C_in!=C_out: 159 | skip_conv = nn.Sequential( 160 | nn.Conv2d(C_in, C_out, kernel_size=1, stride=stride, padding=0, groups=1, bias=False), 161 | nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats)) 162 | stride = 1 163 | self.op=Identity(stride) 164 | 165 | if C_in!=C_out: 166 | self.op=nn.Sequential(skip_conv, self.op) 167 | 168 | def forward(self,x): 169 | return self.op(x) 170 | 171 | class Identity(nn.Module): 172 | def __init__(self, stride): 173 | super(Identity, self).__init__() 174 | self.stride = stride 175 | 176 | def forward(self, x): 177 | if self.stride == 1: 178 | return x 179 | else: 180 | return x[:, :, ::self.stride, ::self.stride] -------------------------------------------------------------------------------- /run_apis/trainer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | 4 | import torch.nn as nn 5 | 6 | from dataset.prefetch_data import data_prefetcher 7 | from tools import utils 8 | 9 | 10 | class Trainer(object): 11 | def __init__(self, train_data, val_data, optimizer=None, criterion=None, 12 | scheduler=None, config=None, report_freq=None): 13 | self.train_data = train_data 14 | self.val_data = val_data 15 | self.optimizer = optimizer 16 | self.criterion = criterion 17 | self.scheduler = scheduler 18 | self.config = config 19 | self.report_freq = report_freq 20 | 21 | def train(self, model, epoch): 22 | objs = utils.AverageMeter() 23 | top1 = utils.AverageMeter() 24 | top5 = utils.AverageMeter() 25 | data_time = utils.AverageMeter() 26 | batch_time = utils.AverageMeter() 27 | model.train() 28 | start = time.time() 29 | 30 | prefetcher = data_prefetcher(self.train_data) 31 | input, target = prefetcher.next() 32 | step = 0 33 | while input is not None: 34 | data_t = time.time() - start 35 | self.scheduler.step() 36 | n = input.size(0) 37 | if step==0: 38 | logging.info('epoch %d lr %e', epoch, self.optimizer.param_groups[0]['lr']) 39 | self.optimizer.zero_grad() 40 | 41 | logits= model(input) 42 | if self.config.optim.label_smooth: 43 | loss = self.criterion(logits, target, self.config.optim.smooth_alpha) 44 | else: 45 | loss = self.criterion(logits, target) 46 | 47 | loss.backward() 48 | if self.config.optim.use_grad_clip: 49 | nn.utils.clip_grad_norm_(model.parameters(), self.config.optim.grad_clip) 50 | self.optimizer.step() 51 | 52 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) 53 | 54 | batch_t = time.time() - start 55 | start = time.time() 56 | 57 | objs.update(loss.item(), n) 58 | top1.update(prec1.item(), n) 59 | top5.update(prec5.item(), n) 60 | data_time.update(data_t) 61 | batch_time.update(batch_t) 62 | if step!=0 and step % self.report_freq == 0: 63 | logging.info( 64 | 'Train epoch %03d step %03d | loss %.4f top1_acc %.2f top5_acc %.2f | batch_time %.3f data_time %.3f', 65 | epoch, step, objs.avg, top1.avg, top5.avg, batch_time.avg, data_time.avg) 66 | input, target = prefetcher.next() 67 | step += 1 68 | logging.info('EPOCH%d Train_acc top1 %.2f top5 %.2f batch_time %.3f data_time %.3f', 69 | epoch, top1.avg, top5.avg, batch_time.avg, data_time.avg) 70 | 71 | return top1.avg, top5.avg, objs.avg, batch_time.avg, data_time.avg 72 | 73 | 74 | def infer(self, model, epoch=0): 75 | top1 = utils.AverageMeter() 76 | top5 = utils.AverageMeter() 77 | data_time = utils.AverageMeter() 78 | batch_time = utils.AverageMeter() 79 | model.eval() 80 | 81 | start = time.time() 82 | prefetcher = data_prefetcher(self.val_data) 83 | input, target = prefetcher.next() 84 | step = 0 85 | while input is not None: 86 | step += 1 87 | data_t = time.time() - start 88 | n = input.size(0) 89 | 90 | logits = model(input) 91 | 92 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) 93 | 94 | batch_t = time.time() - start 95 | top1.update(prec1.item(), n) 96 | top5.update(prec5.item(), n) 97 | data_time.update(data_t) 98 | batch_time.update(batch_t) 99 | 100 | if step % self.report_freq == 0: 101 | logging.info( 102 | 'Val epoch %03d step %03d | top1_acc %.2f top5_acc %.2f | batch_time %.3f data_time %.3f', 103 | epoch, step, top1.avg, top5.avg, batch_time.avg, data_time.avg) 104 | start = time.time() 105 | input, target = prefetcher.next() 106 | 107 | logging.info('EPOCH%d Valid_acc top1 %.2f top5 %.2f batch_time %.3f data_time %.3f', 108 | epoch, top1.avg, top5.avg, batch_time.avg, data_time.avg) 109 | return top1.avg, top5.avg, batch_time.avg, data_time.avg 110 | 111 | 112 | class SearchTrainer(object): 113 | def __init__(self, train_data, val_data, search_optim, criterion, scheduler, config, args): 114 | self.train_data = train_data 115 | self.val_data = val_data 116 | self.search_optim = search_optim 117 | self.criterion = criterion 118 | self.scheduler = scheduler 119 | self.sub_obj_type = config.optim.sub_obj.type 120 | self.args = args 121 | 122 | 123 | def train(self, model, epoch, optim_obj='Weights', search_stage=0): 124 | assert optim_obj in ['Weights', 'Arch'] 125 | objs = utils.AverageMeter() 126 | top1 = utils.AverageMeter() 127 | top5 = utils.AverageMeter() 128 | sub_obj_avg = utils.AverageMeter() 129 | data_time = utils.AverageMeter() 130 | batch_time = utils.AverageMeter() 131 | model.train() 132 | 133 | start = time.time() 134 | if optim_obj == 'Weights': 135 | prefetcher = data_prefetcher(self.train_data) 136 | elif optim_obj == 'Arch': 137 | prefetcher = data_prefetcher(self.val_data) 138 | 139 | input, target = prefetcher.next() 140 | step = 0 141 | while input is not None: 142 | input, target = input.cuda(), target.cuda() 143 | data_t = time.time() - start 144 | n = input.size(0) 145 | if optim_obj == 'Weights': 146 | self.scheduler.step() 147 | if step==0: 148 | logging.info('epoch %d weight_lr %e', epoch, self.search_optim.weight_optimizer.param_groups[0]['lr']) 149 | logits, loss, sub_obj = self.search_optim.weight_step(input, target, model, search_stage) 150 | elif optim_obj == 'Arch': 151 | if step==0: 152 | logging.info('epoch %d arch_lr %e', epoch, self.search_optim.arch_optimizer.param_groups[0]['lr']) 153 | logits, loss, sub_obj = self.search_optim.arch_step(input, target, model, search_stage) 154 | 155 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) 156 | del logits, input, target 157 | 158 | batch_t = time.time() - start 159 | objs.update(loss, n) 160 | top1.update(prec1.item(), n) 161 | top5.update(prec5.item(), n) 162 | sub_obj_avg.update(sub_obj) 163 | data_time.update(data_t) 164 | batch_time.update(batch_t) 165 | 166 | if step!=0 and step % self.args.report_freq == 0: 167 | logging.info( 168 | 'Train%s epoch %03d step %03d | loss %.4f %s %.2f top1_acc %.2f top5_acc %.2f | batch_time %.3f data_time %.3f', 169 | optim_obj ,epoch, step, objs.avg, self.sub_obj_type, sub_obj_avg.avg, 170 | top1.avg, top5.avg, batch_time.avg, data_time.avg) 171 | start = time.time() 172 | step += 1 173 | input, target = prefetcher.next() 174 | return top1.avg, top5.avg, objs.avg, sub_obj_avg.avg, batch_time.avg 175 | 176 | 177 | def infer(self, model, epoch): 178 | objs = utils.AverageMeter() 179 | top1 = utils.AverageMeter() 180 | top5 = utils.AverageMeter() 181 | sub_obj_avg = utils.AverageMeter() 182 | data_time = utils.AverageMeter() 183 | batch_time = utils.AverageMeter() 184 | 185 | model.train() # don't use running_mean and running_var during search 186 | start = time.time() 187 | prefetcher = data_prefetcher(self.val_data) 188 | input, target = prefetcher.next() 189 | step = 0 190 | while input is not None: 191 | step += 1 192 | data_t = time.time() - start 193 | n = input.size(0) 194 | 195 | logits, loss, sub_obj = self.search_optim.valid_step(input, target, model) 196 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) 197 | 198 | batch_t = time.time() - start 199 | objs.update(loss, n) 200 | top1.update(prec1.item(), n) 201 | top5.update(prec5.item(), n) 202 | sub_obj_avg.update(sub_obj) 203 | data_time.update(data_t) 204 | batch_time.update(batch_t) 205 | 206 | if step % self.args.report_freq == 0: 207 | logging.info( 208 | 'Val epoch %03d step %03d | loss %.4f %s %.2f top1_acc %.2f top5_acc %.2f | batch_time %.3f data_time %.3f', 209 | epoch, step, objs.avg, self.sub_obj_type, sub_obj_avg.avg, top1.avg, top5.avg, 210 | batch_time.avg, data_time.avg) 211 | start = time.time() 212 | input, target = prefetcher.next() 213 | 214 | return top1.avg, top5.avg, objs.avg, sub_obj_avg.avg, batch_time.avg 215 | -------------------------------------------------------------------------------- /tools/config_yaml.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | ############################################################################## 15 | # 16 | # Based on: 17 | # -------------------------------------------------------- 18 | # Fast R-CNN 19 | # Copyright (c) 2015 Microsoft 20 | # Licensed under The MIT License [see LICENSE for details] 21 | # Written by Ross Girshick 22 | # -------------------------------------------------------- 23 | 24 | from __future__ import absolute_import 25 | from __future__ import division 26 | from __future__ import print_function 27 | from __future__ import unicode_literals 28 | 29 | from ast import literal_eval 30 | from io import IOBase 31 | #from future.utils import iteritems 32 | import copy 33 | import logging 34 | import numpy as np 35 | import os 36 | import os.path as osp 37 | import yaml 38 | 39 | from .collections import AttrDict 40 | from .io import cache_url 41 | 42 | logger = logging.getLogger(__name__) 43 | 44 | def load_cfg(cfg_to_load): 45 | """Wrapper around yaml.load used for maintaining backward compatibility""" 46 | if isinstance(cfg_to_load, IOBase): 47 | cfg_to_load = ''.join(cfg_to_load.readlines()) 48 | return yaml.load(cfg_to_load) 49 | 50 | 51 | def load_cfg_to_dict(cfg_filename): 52 | with open(cfg_filename, 'r') as f: 53 | # yaml_cfg = AttrDict(load_cfg(f)) 54 | yaml_cfg = load_cfg(f) 55 | return yaml_cfg 56 | 57 | 58 | def merge_cfg_from_file(cfg_filename, global_config): 59 | """Load a yaml config file and merge it into the global config.""" 60 | with open(cfg_filename, 'r') as f: 61 | yaml_cfg = AttrDict(load_cfg(f)) 62 | _merge_a_into_b(yaml_cfg, global_config) 63 | 64 | 65 | def merge_cfg_from_cfg(cfg_other, global_config): 66 | """Merge `cfg_other` into the global config.""" 67 | _merge_a_into_b(cfg_other, global_config) 68 | 69 | def update_cfg_from_file(cfg_filename, global_config): 70 | with open(cfg_filename, 'r') as f: 71 | yaml_cfg = AttrDict(load_cfg(f)) 72 | update_cfg_from_cfg(yaml_cfg, global_config) 73 | 74 | def update_cfg_from_cfg(cfg_other, global_config, stack=None): 75 | assert isinstance(cfg_other, AttrDict), \ 76 | '`a` (cur type {}) must be an instance of {}'.format(type(a), AttrDict) 77 | assert isinstance(global_config, AttrDict), \ 78 | '`b` (cur type {}) must be an instance of {}'.format(type(b), AttrDict) 79 | 80 | for k, v_ in cfg_other.items(): 81 | full_key = '.'.join(stack) + '.' + k if stack is not None else k 82 | 83 | v = copy.deepcopy(v_) 84 | v = _decode_cfg_value(v) 85 | 86 | if k not in global_config: 87 | global_config[k] = v 88 | if isinstance(v, AttrDict): 89 | try: 90 | stack_push = [k] if stack is None else stack + [k] 91 | update_cfg_from_cfg(v, global_config[k], stack=stack_push) 92 | except BaseException: 93 | raise 94 | 95 | else: 96 | # Recursively merge dicts 97 | v = _check_and_coerce_cfg_value_type(v, global_config[k], k, full_key) 98 | if isinstance(v, AttrDict): 99 | try: 100 | stack_push = [k] if stack is None else stack + [k] 101 | update_cfg_from_cfg(v, global_config[k], stack=stack_push) 102 | except BaseException: 103 | raise 104 | else: 105 | global_config[k] = v 106 | 107 | 108 | def merge_cfg_from_list(cfg_list, global_config): 109 | """Merge config keys, values in a list (e.g., from command line) into the 110 | global config. For example, `cfg_list = ['TEST.NMS', 0.5]`. 111 | """ 112 | assert len(cfg_list) % 2 == 0 113 | for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]): 114 | if _key_is_deprecated(full_key): 115 | continue 116 | if _key_is_renamed(full_key): 117 | _raise_key_rename_error(full_key) 118 | key_list = full_key.split('.') 119 | d = global_config 120 | for subkey in key_list[:-1]: 121 | assert subkey in d, 'Non-existent key: {}'.format(full_key) 122 | d = d[subkey] 123 | subkey = key_list[-1] 124 | assert subkey in d, 'Non-existent key: {}'.format(full_key) 125 | value = _decode_cfg_value(v) 126 | value = _check_and_coerce_cfg_value_type( 127 | value, d[subkey], subkey, full_key 128 | ) 129 | d[subkey] = value 130 | 131 | 132 | def _merge_a_into_b(a, b, stack=None): 133 | """Merge config dictionary a into config dictionary b, clobbering the 134 | options in b whenever they are also specified in a. 135 | """ 136 | assert isinstance(a, AttrDict), \ 137 | '`a` (cur type {}) must be an instance of {}'.format(type(a), AttrDict) 138 | assert isinstance(b, AttrDict), \ 139 | '`b` (cur type {}) must be an instance of {}'.format(type(b), AttrDict) 140 | 141 | for k, v_ in a.items(): 142 | full_key = '.'.join(stack) + '.' + k if stack is not None else k 143 | # a must specify keys that are in b 144 | if k not in b: 145 | raise KeyError('Non-existent config key: {}'.format(full_key)) 146 | 147 | v = copy.deepcopy(v_) 148 | v = _decode_cfg_value(v) 149 | v = _check_and_coerce_cfg_value_type(v, b[k], k, full_key) 150 | 151 | # Recursively merge dicts 152 | if isinstance(v, AttrDict): 153 | try: 154 | stack_push = [k] if stack is None else stack + [k] 155 | _merge_a_into_b(v, b[k], stack=stack_push) 156 | except BaseException: 157 | raise 158 | else: 159 | b[k] = v 160 | 161 | def _key_is_deprecated(full_key): 162 | if full_key in _DEPRECATED_KEYS: 163 | logger.warn( 164 | 'Deprecated config key (ignoring): {}'.format(full_key) 165 | ) 166 | return True 167 | return False 168 | 169 | 170 | def _key_is_renamed(full_key): 171 | return full_key in _RENAMED_KEYS 172 | 173 | 174 | def _raise_key_rename_error(full_key): 175 | new_key = _RENAMED_KEYS[full_key] 176 | if isinstance(new_key, tuple): 177 | msg = ' Note: ' + new_key[1] 178 | new_key = new_key[0] 179 | else: 180 | msg = '' 181 | raise KeyError( 182 | 'Key {} was renamed to {}; please update your config.{}'. 183 | format(full_key, new_key, msg) 184 | ) 185 | 186 | 187 | def _decode_cfg_value(v): 188 | """Decodes a raw config value (e.g., from a yaml config files or command 189 | line argument) into a Python object. 190 | """ 191 | # Configs parsed from raw yaml will contain dictionary keys that need to be 192 | # converted to AttrDict objects 193 | if isinstance(v, dict): 194 | return AttrDict(v) 195 | # All remaining processing is only applied to strings 196 | # Try to interpret `v` as a: 197 | # string, number, tuple, list, dict, boolean, or None 198 | try: 199 | v = literal_eval(v) 200 | # The following two excepts allow v to pass through when it represents a 201 | # string. 202 | # 203 | # Longer explanation: 204 | # The type of v is always a string (before calling literal_eval), but 205 | # sometimes it *represents* a string and other times a data structure, like 206 | # a list. In the case that v represents a string, what we got back from the 207 | # yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is 208 | # ok with '"foo"', but will raise a ValueError if given 'foo'. In other 209 | # cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval 210 | # will raise a SyntaxError. 211 | except ValueError: 212 | pass 213 | except SyntaxError: 214 | pass 215 | return v 216 | 217 | 218 | def _check_and_coerce_cfg_value_type(value_a, value_b, key, full_key): 219 | """Checks that `value_a`, which is intended to replace `value_b` is of the 220 | right type. The type is correct if it matches exactly or is one of a few 221 | cases in which the type can be easily coerced. 222 | """ 223 | # The types must match (with some exceptions) 224 | type_b = type(value_b) 225 | type_a = type(value_a) 226 | if type_a is type_b: 227 | return value_a 228 | 229 | # Exceptions: numpy arrays, strings, tuple<->list 230 | if isinstance(value_b, np.ndarray): 231 | value_a = np.array(value_a, dtype=value_b.dtype) 232 | elif isinstance(value_a, tuple) and isinstance(value_b, list): 233 | value_a = list(value_a) 234 | elif isinstance(value_a, list) and isinstance(value_b, tuple): 235 | value_a = tuple(value_a) 236 | else: 237 | raise ValueError( 238 | 'Type mismatch ({} vs. {}) with values ({} vs. {}) for config ' 239 | 'key: {}'.format(type_b, type_a, value_b, value_a, full_key) 240 | ) 241 | return value_a 242 | -------------------------------------------------------------------------------- /models/dropped_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class MixedOp(nn.Module): 7 | def __init__(self, dropped_mixed_ops, softmax_temp=1.): 8 | super(MixedOp, self).__init__() 9 | self.softmax_temp = softmax_temp 10 | self._ops = nn.ModuleList() 11 | for op in dropped_mixed_ops: 12 | self._ops.append(op) 13 | 14 | def forward(self, x, alphas, branch_indices, mixed_sub_obj): 15 | op_weights = torch.stack([alphas[branch_index] for branch_index in branch_indices]) 16 | op_weights = F.softmax(op_weights / self.softmax_temp, dim=-1) 17 | return sum(op_weight * op(x) for op_weight, op in zip(op_weights, self._ops)), \ 18 | sum(op_weight * mixed_sub_obj[branch_index] for op_weight, branch_index in zip( 19 | op_weights, branch_indices)) 20 | 21 | 22 | class HeadLayer(nn.Module): 23 | def __init__(self, dropped_mixed_ops, softmax_temp=1.): 24 | super(HeadLayer, self).__init__() 25 | self.head_branches = nn.ModuleList() 26 | for mixed_ops in dropped_mixed_ops: 27 | self.head_branches.append(MixedOp(mixed_ops, softmax_temp)) 28 | 29 | def forward(self, inputs, betas, alphas, head_index, head_sub_obj): 30 | head_data = [] 31 | count_sub_obj = [] 32 | for input_data, head_branch, alpha, head_idx, branch_sub_obj in zip( 33 | inputs, self.head_branches, alphas, head_index, head_sub_obj): 34 | data, sub_obj = head_branch(input_data, alpha, head_idx, branch_sub_obj) 35 | head_data.append(data) 36 | count_sub_obj.append(sub_obj) 37 | 38 | return sum(branch_weight * data for branch_weight, data in zip(betas, head_data)), \ 39 | count_sub_obj 40 | 41 | 42 | class StackLayers(nn.Module): 43 | def __init__(self, num_block_layers, dropped_mixed_ops, softmax_temp=1.): 44 | super(StackLayers, self).__init__() 45 | 46 | if num_block_layers != 0: 47 | self.stack_layers = nn.ModuleList() 48 | for i in range(num_block_layers): 49 | self.stack_layers.append(MixedOp(dropped_mixed_ops[i], softmax_temp)) 50 | else: 51 | self.stack_layers = None 52 | 53 | def forward(self, x, alphas, stack_index, stack_sub_obj): 54 | 55 | if self.stack_layers is not None: 56 | count_sub_obj = 0 57 | for stack_layer, alpha, stack_idx, layer_sub_obj in zip(self.stack_layers, alphas, stack_index, stack_sub_obj): 58 | x, sub_obj = stack_layer(x, alpha, stack_idx, layer_sub_obj) 59 | count_sub_obj += sub_obj 60 | return x, count_sub_obj 61 | 62 | else: 63 | return x, 0 64 | 65 | 66 | class Block(nn.Module): 67 | def __init__(self, num_block_layers, dropped_mixed_ops, softmax_temp=1.): 68 | super(Block, self).__init__() 69 | self.head_layer = HeadLayer(dropped_mixed_ops[0], softmax_temp) 70 | self.stack_layers = StackLayers(num_block_layers, dropped_mixed_ops[1], softmax_temp) 71 | 72 | def forward(self, inputs, betas, head_alphas, stack_alphas, head_index, stack_index, block_sub_obj): 73 | x, head_sub_obj = self.head_layer(inputs, betas, head_alphas, head_index, block_sub_obj[0]) 74 | x, stack_sub_obj = self.stack_layers(x, stack_alphas, stack_index, block_sub_obj[1]) 75 | 76 | return x, [head_sub_obj, stack_sub_obj] 77 | 78 | 79 | class Dropped_Network(nn.Module): 80 | def __init__(self, super_model, alpha_head_index=None, alpha_stack_index=None, softmax_temp=1.): 81 | super(Dropped_Network, self).__init__() 82 | 83 | self.softmax_temp = softmax_temp 84 | # static modules loading 85 | self.input_block = super_model.module.input_block 86 | if hasattr(super_model.module, 'head_block'): 87 | self.head_block = super_model.module.head_block 88 | self.conv1_1_block = super_model.module.conv1_1_block 89 | self.global_pooling = super_model.module.global_pooling 90 | self.classifier = super_model.module.classifier 91 | 92 | # architecture parameters loading 93 | self.alpha_head_weights = super_model.module.alpha_head_weights 94 | self.alpha_stack_weights = super_model.module.alpha_stack_weights 95 | self.beta_weights = super_model.module.beta_weights 96 | self.alpha_head_index = alpha_head_index if alpha_head_index is not None else \ 97 | super_model.module.alpha_head_index 98 | self.alpha_stack_index = alpha_stack_index if alpha_stack_index is not None else \ 99 | super_model.module.alpha_stack_index 100 | 101 | # config loading 102 | self.config = super_model.module.config 103 | self.input_configs = super_model.module.input_configs 104 | self.output_configs = super_model.module.output_configs 105 | self.sub_obj_list = super_model.module.sub_obj_list 106 | 107 | # dynamic blocks loading 108 | self.blocks = nn.ModuleList() 109 | 110 | for i, block in enumerate(super_model.module.blocks): 111 | input_config = self.input_configs[i] 112 | 113 | dropped_mixed_ops = [] 114 | # for the head layers 115 | head_mixed_ops = [] 116 | for j, head_index in enumerate(self.alpha_head_index[i]): 117 | head_mixed_ops.append([block.head_layer.head_branches[j]._ops[k] for k in head_index]) 118 | dropped_mixed_ops.append(head_mixed_ops) 119 | 120 | stack_mixed_ops = [] 121 | for j, stack_index in enumerate(self.alpha_stack_index[i]): 122 | stack_mixed_ops.append([block.stack_layers.stack_layers[j]._ops[k] for k in stack_index]) 123 | dropped_mixed_ops.append(stack_mixed_ops) 124 | 125 | self.blocks.append(Block( 126 | input_config['num_stack_layers'], 127 | dropped_mixed_ops 128 | )) 129 | 130 | def forward(self, x): 131 | ''' 132 | To approximate the the total sub_obj(latency/flops), we firstly create the obj list for blocks 133 | as follows: 134 | [[[head_flops_1, head_flops_2, ...], stack_flops], ...] 135 | Then we compute the whole obj approximation from the end to the beginning. For block b, 136 | flops'_b = sum(beta_{bi} * (head_flops_{bi} + stack_flops_{i}) for i in out_idx[b]) 137 | The total flops equals flops'_0 138 | ''' 139 | sub_obj_list = [] 140 | block_datas = [] 141 | branch_weights = [] 142 | for betas in self.beta_weights: 143 | branch_weights.append(F.softmax(betas / self.softmax_temp, dim=-1)) 144 | 145 | block_data = self.input_block(x) 146 | if hasattr(self, 'head_block'): 147 | block_data = self.head_block(block_data) 148 | 149 | block_datas.append(block_data) 150 | sub_obj_list.append([[],torch.tensor(self.sub_obj_list[0]).cuda()]) 151 | 152 | for i in range(len(self.blocks)+1): 153 | config = self.input_configs[i] 154 | inputs = [block_datas[i] for i in config['in_block_idx']] 155 | betas = [branch_weights[block_id][beta_id] 156 | for block_id, beta_id in zip(config['in_block_idx'], config['beta_idx'])] 157 | 158 | if i == len(self.blocks): 159 | block_data, block_sub_obj = self.conv1_1_block(inputs, betas, self.sub_obj_list[2]) 160 | 161 | else: 162 | block_data, block_sub_obj = self.blocks[i](inputs, 163 | betas, 164 | self.alpha_head_weights[i], 165 | self.alpha_stack_weights[i], 166 | self.alpha_head_index[i], 167 | self.alpha_stack_index[i], 168 | self.sub_obj_list[1][i]) 169 | block_datas.append(block_data) 170 | sub_obj_list.append(block_sub_obj) 171 | 172 | out = self.global_pooling(block_datas[-1]) 173 | logits = self.classifier(out.view(out.size(0),-1)) 174 | 175 | # chained cost estimation 176 | for i, out_config in enumerate(self.output_configs[::-1]): 177 | block_id = len(self.output_configs)-i-1 178 | sum_obj = [] 179 | for j, out_id in enumerate(out_config['out_id']): 180 | head_id = self.input_configs[out_id-1]['in_block_idx'].index(block_id) 181 | head_obj = sub_obj_list[out_id][0][head_id] 182 | stack_obj = sub_obj_list[out_id][1] 183 | sub_obj_j = branch_weights[block_id][j] * (head_obj + stack_obj) 184 | sum_obj.append(sub_obj_j) 185 | sub_obj_list[-i-2][1] += sum(sum_obj) 186 | 187 | net_sub_obj = torch.tensor(self.sub_obj_list[-1]).cuda() + sub_obj_list[0][1] 188 | return logits, net_sub_obj.expand(1) 189 | 190 | @property 191 | def arch_parameters(self): 192 | arch_params = nn.ParameterList() 193 | arch_params.extend(self.beta_weights) 194 | arch_params.extend(self.alpha_head_weights) 195 | arch_params.extend(self.alpha_stack_weights) 196 | return arch_params 197 | 198 | @property 199 | def arch_alpha_params(self): 200 | alpha_params = nn.ParameterList() 201 | alpha_params.extend(self.alpha_head_weights) 202 | alpha_params.extend(self.alpha_stack_weights) 203 | return alpha_params 204 | -------------------------------------------------------------------------------- /tools/multadds_count.py: -------------------------------------------------------------------------------- 1 | import torch 2 | # Original implementation: 3 | # https://github.com/warmspringwinds/pytorch-segmentation-detection/blob/master/pytorch_segmentation_detection/utils/flops_benchmark.py 4 | 5 | # ---- Public functions 6 | 7 | def comp_multadds(model, input_size=(3,224,224)): 8 | input_size = (1,) + tuple(input_size) 9 | model = model.cuda() 10 | input_data = torch.randn(input_size).cuda() 11 | model = add_flops_counting_methods(model) 12 | model.start_flops_count() 13 | with torch.no_grad(): 14 | _ = model(input_data) 15 | 16 | mult_adds = model.compute_average_flops_cost() / 1e6 17 | return mult_adds 18 | 19 | 20 | def comp_multadds_fw(model, input_data, use_gpu=True): 21 | model = add_flops_counting_methods(model) 22 | if use_gpu: 23 | model = model.cuda() 24 | model.start_flops_count() 25 | with torch.no_grad(): 26 | output_data = model(input_data) 27 | 28 | mult_adds = model.compute_average_flops_cost() / 1e6 29 | return mult_adds, output_data 30 | 31 | 32 | def add_flops_counting_methods(net_main_module): 33 | """Adds flops counting functions to an existing model. After that 34 | the flops count should be activated and the model should be run on an input 35 | image. 36 | Example: 37 | fcn = add_flops_counting_methods(fcn) 38 | fcn = fcn.cuda().train() 39 | fcn.start_flops_count() 40 | _ = fcn(batch) 41 | fcn.compute_average_flops_cost() / 1e9 / 2 # Result in GFLOPs per image in batch 42 | Important: dividing by 2 only works for resnet models -- see below for the details 43 | of flops computation. 44 | Attention: we are counting multiply-add as two flops in this work, because in 45 | most resnet models convolutions are bias-free (BN layers act as bias there) 46 | and it makes sense to count muliply and add as separate flops therefore. 47 | This is why in the above example we divide by 2 in order to be consistent with 48 | most modern benchmarks. For example in "Spatially Adaptive Computatin Time for Residual 49 | Networks" by Figurnov et al multiply-add was counted as two flops. 50 | This module computes the average flops which is necessary for dynamic networks which 51 | have different number of executed layers. For static networks it is enough to run the network 52 | once and get statistics (above example). 53 | Implementation: 54 | The module works by adding batch_count to the main module which tracks the sum 55 | of all batch sizes that were run through the network. 56 | Also each convolutional layer of the network tracks the overall number of flops 57 | performed. 58 | The parameters are updated with the help of registered hook-functions which 59 | are being called each time the respective layer is executed. 60 | Parameters 61 | ---------- 62 | net_main_module : torch.nn.Module 63 | Main module containing network 64 | Returns 65 | ------- 66 | net_main_module : torch.nn.Module 67 | Updated main module with new methods/attributes that are used 68 | to compute flops. 69 | """ 70 | 71 | # adding additional methods to the existing module object, 72 | # this is done this way so that each function has access to self object 73 | net_main_module.start_flops_count = start_flops_count.__get__(net_main_module) 74 | net_main_module.stop_flops_count = stop_flops_count.__get__(net_main_module) 75 | net_main_module.reset_flops_count = reset_flops_count.__get__(net_main_module) 76 | net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__(net_main_module) 77 | 78 | net_main_module.reset_flops_count() 79 | 80 | # Adding varialbles necessary for masked flops computation 81 | net_main_module.apply(add_flops_mask_variable_or_reset) 82 | 83 | return net_main_module 84 | 85 | 86 | def compute_average_flops_cost(self): 87 | """ 88 | A method that will be available after add_flops_counting_methods() is called 89 | on a desired net object. 90 | Returns current mean flops consumption per image. 91 | """ 92 | 93 | batches_count = self.__batch_counter__ 94 | 95 | flops_sum = 0 96 | 97 | for module in self.modules(): 98 | 99 | if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear): 100 | flops_sum += module.__flops__ 101 | 102 | 103 | return flops_sum / batches_count 104 | 105 | 106 | def start_flops_count(self): 107 | """ 108 | A method that will be available after add_flops_counting_methods() is called 109 | on a desired net object. 110 | Activates the computation of mean flops consumption per image. 111 | Call it before you run the network. 112 | """ 113 | 114 | add_batch_counter_hook_function(self) 115 | 116 | self.apply(add_flops_counter_hook_function) 117 | 118 | 119 | def stop_flops_count(self): 120 | """ 121 | A method that will be available after add_flops_counting_methods() is called 122 | on a desired net object. 123 | Stops computing the mean flops consumption per image. 124 | Call whenever you want to pause the computation. 125 | """ 126 | 127 | remove_batch_counter_hook_function(self) 128 | 129 | self.apply(remove_flops_counter_hook_function) 130 | 131 | 132 | def reset_flops_count(self): 133 | """ 134 | A method that will be available after add_flops_counting_methods() is called 135 | on a desired net object. 136 | Resets statistics computed so far. 137 | """ 138 | 139 | add_batch_counter_variables_or_reset(self) 140 | 141 | self.apply(add_flops_counter_variable_or_reset) 142 | 143 | 144 | def add_flops_mask(module, mask): 145 | def add_flops_mask_func(module): 146 | if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear): 147 | module.__mask__ = mask 148 | 149 | module.apply(add_flops_mask_func) 150 | 151 | 152 | def remove_flops_mask(module): 153 | module.apply(add_flops_mask_variable_or_reset) 154 | 155 | 156 | # ---- Internal functions 157 | 158 | 159 | def conv_flops_counter_hook(conv_module, input, output): 160 | # Can have multiple inputs, getting the first one 161 | input = input[0] 162 | 163 | batch_size = input.shape[0] 164 | output_height, output_width = output.shape[2:] 165 | 166 | kernel_height, kernel_width = conv_module.kernel_size 167 | in_channels = conv_module.in_channels 168 | out_channels = conv_module.out_channels 169 | 170 | conv_per_position_flops = (kernel_height * kernel_width * in_channels * out_channels) / conv_module.groups 171 | 172 | active_elements_count = batch_size * output_height * output_width 173 | 174 | if conv_module.__mask__ is not None: 175 | # (b, 1, h, w) 176 | flops_mask = conv_module.__mask__.expand(batch_size, 1, output_height, output_width) 177 | active_elements_count = flops_mask.sum() 178 | 179 | overall_conv_flops = conv_per_position_flops * active_elements_count 180 | 181 | bias_flops = 0 182 | 183 | if conv_module.bias is not None: 184 | bias_flops = out_channels * active_elements_count 185 | 186 | overall_flops = overall_conv_flops + bias_flops 187 | 188 | conv_module.__flops__ += overall_flops 189 | 190 | 191 | def linear_flops_counter_hook(linear_module, input, output): 192 | 193 | input = input[0] 194 | batch_size = input.shape[0] 195 | overall_flops = linear_module.in_features * linear_module.out_features * batch_size 196 | 197 | # bias_flops = 0 198 | 199 | # if conv_module.bias is not None: 200 | # bias_flops = out_channels * active_elements_count 201 | 202 | # overall_flops = overall_conv_flops + bias_flops 203 | 204 | linear_module.__flops__ += overall_flops 205 | 206 | 207 | def batch_counter_hook(module, input, output): 208 | # Can have multiple inputs, getting the first one 209 | input = input[0] 210 | 211 | batch_size = input.shape[0] 212 | 213 | module.__batch_counter__ += batch_size 214 | 215 | 216 | def add_batch_counter_variables_or_reset(module): 217 | module.__batch_counter__ = 0 218 | 219 | 220 | def add_batch_counter_hook_function(module): 221 | if hasattr(module, '__batch_counter_handle__'): 222 | return 223 | 224 | handle = module.register_forward_hook(batch_counter_hook) 225 | module.__batch_counter_handle__ = handle 226 | 227 | 228 | def remove_batch_counter_hook_function(module): 229 | if hasattr(module, '__batch_counter_handle__'): 230 | module.__batch_counter_handle__.remove() 231 | 232 | del module.__batch_counter_handle__ 233 | 234 | 235 | def add_flops_counter_variable_or_reset(module): 236 | if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear): 237 | module.__flops__ = 0 238 | 239 | 240 | def add_flops_counter_hook_function(module): 241 | if isinstance(module, torch.nn.Conv2d): 242 | if hasattr(module, '__flops_handle__'): 243 | return 244 | 245 | handle = module.register_forward_hook(conv_flops_counter_hook) 246 | module.__flops_handle__ = handle 247 | elif isinstance(module, torch.nn.Linear): 248 | 249 | if hasattr(module, '__flops_handle__'): 250 | return 251 | 252 | handle = module.register_forward_hook(linear_flops_counter_hook) 253 | module.__flops_handle__ = handle 254 | 255 | 256 | def remove_flops_counter_hook_function(module): 257 | if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear): 258 | 259 | if hasattr(module, '__flops_handle__'): 260 | module.__flops_handle__.remove() 261 | 262 | del module.__flops_handle__ 263 | 264 | 265 | # --- Masked flops counting 266 | 267 | 268 | # Also being run in the initialization 269 | def add_flops_mask_variable_or_reset(module): 270 | if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear): 271 | module.__mask__ = None 272 | 273 | -------------------------------------------------------------------------------- /run_apis/search.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import ast 3 | import importlib 4 | import logging 5 | import os 6 | import pprint 7 | import sys 8 | import time 9 | 10 | import numpy as np 11 | import torch 12 | import torch.backends.cudnn as cudnn 13 | import torch.nn as nn 14 | from tensorboardX import SummaryWriter 15 | 16 | from configs.search_config import search_cfg 17 | from configs.imagenet_train_cfg import cfg 18 | from dataset import imagenet_data 19 | from models import model_derived 20 | from tools import utils 21 | from tools.config_yaml import merge_cfg_from_file, update_cfg_from_cfg 22 | from tools.lr_scheduler import get_lr_scheduler 23 | from tools.multadds_count import comp_multadds 24 | 25 | from .optimizer import Optimizer 26 | from .trainer import SearchTrainer 27 | 28 | if __name__ == '__main__': 29 | 30 | parser = argparse.ArgumentParser("Search_Configs") 31 | parser.add_argument('--report_freq', type=float, default=100, help='report frequency') 32 | parser.add_argument('--data_path', type=str, default='../data', help='location of the data corpus') 33 | parser.add_argument('--save', type=str, default='../', help='experiment name') 34 | parser.add_argument('--tb_path', type=str, default='', help='tensorboard output path') 35 | parser.add_argument('--job_name', type=str, default='', help='job_name') 36 | parser.add_argument('-c', '--config', metavar='C', default=None, help='The Configuration file') 37 | 38 | args = parser.parse_args() 39 | 40 | update_cfg_from_cfg(search_cfg, cfg) 41 | if args.config is not None: 42 | merge_cfg_from_file(args.config, cfg) 43 | config = cfg 44 | 45 | if args.job_name != '': 46 | args.job_name = time.strftime("%Y%m%d-%H%M%S-") + args.job_name 47 | args.save = os.path.join(args.save, args.job_name) 48 | utils.create_exp_dir(args.save) 49 | os.system('cp -r ./* '+args.save) 50 | args.save = os.path.join(args.save, 'output') 51 | utils.create_exp_dir(args.save) 52 | else: 53 | args.save = os.path.join(args.save, 'output') 54 | utils.create_exp_dir(args.save) 55 | 56 | if args.tb_path == '': 57 | args.tb_path = args.save 58 | 59 | log_format = '%(asctime)s %(message)s' 60 | date_format = '%m/%d %H:%M:%S' 61 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, 62 | format=log_format, datefmt=date_format) 63 | fh = logging.FileHandler(os.path.join(args.save, 'log.txt')) 64 | fh.setFormatter(logging.Formatter(log_format, date_format)) 65 | logging.getLogger().addHandler(fh) 66 | 67 | if not torch.cuda.is_available(): 68 | logging.info('No gpu device available') 69 | sys.exit(1) 70 | cudnn.benchmark = True 71 | cudnn.enabled = True 72 | 73 | if config.train_params.use_seed: 74 | np.random.seed(config.train_params.seed) 75 | torch.manual_seed(config.train_params.seed) 76 | torch.cuda.manual_seed(config.train_params.seed) 77 | 78 | logging.info("args = %s", args) 79 | logging.info('Training with config:') 80 | logging.info(pprint.pformat(config)) 81 | job_name = os.popen('cd %s && pwd -P && cd -' % args.save).readline().split('/')[-2] 82 | 83 | writer = SummaryWriter(args.tb_path) 84 | 85 | criterion = nn.CrossEntropyLoss() 86 | criterion = criterion.cuda() 87 | 88 | SearchSpace = importlib.import_module('models.search_space_'+config.net_type).Network 89 | ArchGenerater = importlib.import_module('.derive_arch_'+config.net_type, __package__).ArchGenerate 90 | derivedNetwork = getattr(model_derived, '%s_Net' % config.net_type.upper()) 91 | 92 | super_model = SearchSpace(config.optim.init_dim, config.data.dataset, config) 93 | arch_gener = ArchGenerater(super_model, config) 94 | der_Net = lambda net_config: derivedNetwork(net_config, 95 | config=config) 96 | super_model = nn.DataParallel(super_model) 97 | 98 | # whether to resume from a checkpoint 99 | if config.optim.if_resume: 100 | utils.load_model(super_model, config.optim.resume.load_path) 101 | start_epoch = config.optim.resume.load_epoch + 1 102 | else: 103 | start_epoch = 0 104 | 105 | super_model = super_model.cuda() 106 | 107 | if config.optim.sub_obj.type=='flops': 108 | flops_list, total_flops = super_model.module.get_cost_list( 109 | config.data.input_size, cost_type='flops') 110 | super_model.module.sub_obj_list = flops_list 111 | logging.info("Super Network flops (M) list: \n") 112 | logging.info(str(flops_list)) 113 | logging.info("Total flops: " + str(total_flops)) 114 | elif config.optim.sub_obj.type=='latency': 115 | with open(os.path.join('latency_list', config.optim.sub_obj.latency_list_path), 'r') as f: 116 | latency_list = eval(f.readline()) 117 | super_model.module.sub_obj_list = latency_list 118 | logging.info("Super Network latency (ms) list: \n") 119 | logging.info(str(latency_list)) 120 | else: 121 | raise NotImplementedError 122 | logging.info("Num Params = %.2fMB", utils.count_parameters_in_MB(super_model)) 123 | 124 | if config.data.dataset == 'imagenet': 125 | imagenet = imagenet_data.ImageNet12(trainFolder=os.path.join(args.data_path, 'train'), 126 | testFolder=os.path.join(args.data_path, 'val'), 127 | num_workers=config.data.num_workers, 128 | type_of_data_augmentation=config.data.type_of_data_aug, 129 | data_config=config.data) 130 | train_queue, valid_queue = imagenet.getTrainTestLoader(config.data.batch_size, 131 | train_shuffle=True, 132 | val_shuffle=True) 133 | else: 134 | raise NotImplementedError 135 | 136 | search_optim = Optimizer(super_model, criterion, config) 137 | 138 | scheduler = get_lr_scheduler(config, search_optim.weight_optimizer, imagenet.train_num_examples) 139 | scheduler.last_step = start_epoch * (imagenet.train_num_examples // config.data.batch_size + 1) 140 | 141 | search_trainer = SearchTrainer(train_queue, valid_queue, search_optim, criterion, scheduler, config, args) 142 | 143 | betas, head_alphas, stack_alphas = super_model.module.display_arch_params() 144 | derived_archs = arch_gener.derive_archs(betas, head_alphas, stack_alphas) 145 | derived_model = der_Net('|'.join(map(str, derived_archs))) 146 | logging.info("Derived Model Mult-Adds = %.2fMB" % comp_multadds(derived_model, 147 | input_size=config.data.input_size)) 148 | logging.info("Derived Model Num Params = %.2fMB", utils.count_parameters_in_MB(derived_model)) 149 | 150 | best_epoch = [0, 0, 0] # [epoch, acc_top1, acc_top5] 151 | rec_list = [] 152 | for epoch in range(start_epoch, config.train_params.epochs): 153 | # training part1: update the architecture parameters 154 | if epoch >= config.search_params.arch_update_epoch: 155 | search_stage = 1 156 | search_optim.set_param_grad_state('Arch') 157 | train_acc_top1, train_acc_top5, train_obj, sub_obj, batch_time = search_trainer.train( 158 | super_model, epoch, 'Arch', search_stage) 159 | logging.info('EPOCH%d Arch Train_acc top1 %.2f top5 %.2f loss %.4f %s %.2f batch_time %.3f', 160 | epoch, train_acc_top1, train_acc_top5, train_obj, config.optim.sub_obj.type, sub_obj, batch_time) 161 | writer.add_scalar('arch_train_acc_top1', train_acc_top1, epoch) 162 | writer.add_scalar('arch_train_loss', train_obj, epoch) 163 | else: 164 | search_stage = 0 165 | 166 | # training part2: update the operator parameters 167 | search_optim.set_param_grad_state('Weights') 168 | train_acc_top1, train_acc_top5, train_obj, sub_obj, batch_time = search_trainer.train( 169 | super_model, epoch, 'Weights', search_stage) 170 | logging.info('EPOCH%d Weights Train_acc top1 %.2f top5 %.2f loss %.4f %s %.2f | batch_time %.3f', 171 | epoch, train_acc_top1, train_acc_top5, train_obj, config.optim.sub_obj.type, sub_obj, batch_time) 172 | writer.add_scalar('weight_train_acc_top1', train_acc_top1, epoch) 173 | writer.add_scalar('weight_train_loss', train_obj, epoch) 174 | 175 | # validation 176 | if epoch >= config.search_params.val_start_epoch: 177 | with torch.no_grad(): 178 | val_acc_top1, val_acc_top5, valid_obj, sub_obj, batch_time = search_trainer.infer(super_model, epoch) 179 | logging.info('EPOCH%d Valid_acc top1 %.2f top5 %.2f %s %.2f batch_time %.3f', 180 | epoch, val_acc_top1, val_acc_top5, config.optim.sub_obj.type, sub_obj, batch_time) 181 | writer.add_scalar('arch_val_acc', val_acc_top1, epoch) 182 | writer.add_scalar('arch_whole_{}'.format(config.optim.sub_obj.type), sub_obj, epoch) 183 | 184 | if val_acc_top1 > best_epoch[1]: 185 | best_epoch = [epoch, val_acc_top1, val_acc_top5] 186 | utils.save(super_model, os.path.join(args.save, 'weights_best.pt')) 187 | logging.info('BEST EPOCH %d val_top1 %.2f val_top5 %.2f', best_epoch[0], best_epoch[1], best_epoch[2]) 188 | else: 189 | utils.save(super_model, os.path.join(args.save, 'weights_best.pt')) 190 | 191 | betas, head_alphas, stack_alphas = super_model.module.display_arch_params() 192 | derived_arch = arch_gener.derive_archs(betas, head_alphas, stack_alphas) 193 | derived_arch_str = '|\n'.join(map(str, derived_arch)) 194 | derived_model = der_Net(derived_arch_str) 195 | derived_flops = comp_multadds(derived_model, input_size=config.data.input_size) 196 | derived_params = utils.count_parameters_in_MB(derived_model) 197 | logging.info("Derived Model Mult-Adds = %.2fMB" % derived_flops) 198 | logging.info("Derived Model Num Params = %.2fMB" % derived_params) 199 | writer.add_scalar('derived_flops', derived_flops, epoch) 200 | 201 | if (epoch+1)==config.search_params.arch_update_epoch: 202 | utils.save(super_model, os.path.join(args.save, 'weights_{}.pt'.format(epoch))) 203 | 204 | if epoch >= config.search_params.val_start_epoch: 205 | epoch_rec = {'top1_acc': val_acc_top1, 206 | 'epoch': epoch, 207 | 'multadds': derived_flops, 208 | 'params': derived_params, 209 | 'arch': derived_arch_str} 210 | if_update = utils.record_topk(2, rec_list, epoch_rec, 'top1_acc', 'arch') 211 | if if_update: 212 | with open(os.path.join(args.save, 'top_results'), 'w') as f: 213 | f.write(str(rec_list) + '\n') 214 | f.write(job_name) 215 | with open(os.path.join(args.save, 'excel_record'), 'w') as f: 216 | for record in rec_list: 217 | f.write(',,,{:.2f}MB,{:.2f}MB,,,,{},{}\n'.format( 218 | record['multadds'], record['params'], 219 | job_name, record['epoch'])) 220 | f.write(record['arch']+'\n') 221 | 222 | logging.info('\nTop2 arch records for Excel: ') 223 | for record in rec_list: 224 | logging.info('\n,,,{:.2f}MB,{:.2f}MB,,,,{},{}'.format( 225 | record['multadds'], record['params'], job_name, record['epoch'])) 226 | logging.info('\n'+record['arch']) 227 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /latency_list/lat_list_densenas_mbv2_xp32: -------------------------------------------------------------------------------- 1 | [2.130230027015763, [[[[2.0424278577168784, 3.9147443482370092, 2.826663460394349, 5.422255848393296, 3.750695026282108, 7.1249667803446455]], [[1.6153342073613948, 3.0373611353864574, 2.3987865447998047, 4.654994853819259, 3.261269535681214, 6.318014583202324, 0.004162114075940065], [1.6315834690826108, 3.0268227933633205, 2.4445987951875936, 4.614120661610305, 3.1950888730058766, 6.186397436893348, 0.004321204291449653], [1.6260789139102203, 3.0519497274148346, 2.389646034048061, 4.737014577846335, 3.23591497209337, 6.18449805962919, 0.004463557041052616]]], [[[0.7256638883340238, 1.3191255415328826, 0.9336039755079482, 1.7637216201936357, 1.2678955540512546, 2.410705258147885]], [[0.4845246160873259, 0.871481124800865, 0.6886357490462486, 1.2964955002370506, 0.9334284849841186, 1.7795484475415162, 0.00426448956884519], [0.4868063782200669, 0.8848115892121287, 0.6880599079710064, 1.2923753863633283, 0.9581135017703278, 1.7851776787728975, 0.0040594736735026045], [0.4841990422720861, 0.8834096398016419, 0.676188324436997, 1.2965060725356594, 0.9642027122805816, 1.8043483628167047, 0.005086961418691307]]], [[[0.47651384816025244, 0.8765797904043486, 0.6815321276886294, 1.2988936780679106, 0.9366233902748184, 1.7686583538248082], [0.7290506122088192, 1.3409103046764026, 0.9578582493945805, 1.766649472593057, 1.2713261324949938, 2.4222339764989984]], [[0.714959664778276, 1.1088877494889076, 0.9645739709488069, 1.6069072665590227, 1.283542893149636, 2.241900617426092, 0.004182656606038411], [0.7255952285997795, 1.1247719899572508, 0.9729856192463576, 1.6400691234704219, 1.2840147933574637, 2.223358057966136, 0.004817090853296145], [0.7210639992145577, 1.118303236335215, 0.9568171067671342, 1.6137872320232969, 1.2853490463410964, 2.220366771775063, 0.0044962613269536175]]], [[[0.7221531867980957, 1.099681348511667, 0.9537572571725557, 1.6170255583946151, 1.258736475549563, 2.228106127844917], [0.5150983550331809, 0.8821482610220861, 0.7235027563692343, 1.285614919180822, 0.9779198241956306, 1.7995608214176062], [0.7342606120639378, 1.3387289432564167, 0.9407694171173404, 1.778323794856216, 1.2683106913711084, 2.3811431364579634]], [[0.7432978803461249, 1.3567979648859814, 1.0377877890461622, 1.9703037811048103, 1.43352544668949, 2.7093362085746997, 0.004813309871789181], [0.7404989184755267, 1.3694960179955067, 1.0372354526712437, 1.943115730478306, 1.4028442989696157, 2.726303037970957, 0.004229617841315992], [0.7309939885380292, 1.3608129819234211, 1.0305414536986688, 1.9552293449941307, 1.401155861941251, 2.7030082664104422, 0.00416184916640773]]], [[[0.3884150042678371, 0.7135931891624373, 0.5162300244726317, 0.954577417084665, 0.6934213397478817, 1.298581903631037], [0.4412888998937125, 0.59423191378815, 0.5425221751434635, 0.789706056768244, 0.6770676795882408, 1.076465303247625], [0.32727501609108667, 0.45259521465108854, 0.3506885634528266, 0.6308338377210828, 0.4525994291209211, 0.8446315081432612]], [[0.3733000129160255, 0.40592995556918055, 0.37138982252641156, 0.6014351411299271, 0.43380426638054126, 0.8142890593018195, 0.00443302019677981], [0.3568933467672329, 0.4109610210765492, 0.3373644087049696, 0.5994894769456651, 0.43327100349195075, 0.8130553515270503, 0.004622550925823173], [0.34756833856756036, 0.4080691000427863, 0.33895901959351815, 0.6000231251572117, 0.42908610719622986, 0.8126938945115214, 0.004636446634928385]]], [[[0.31285572533655653, 0.41354148074834035, 0.3234596685929732, 0.5967205702656447, 0.42944281992286143, 0.8081769702410456], [0.38233056212916516, 0.7180399124068444, 0.506961490168716, 0.958221440363412, 0.6713989286711722, 1.2972089497729986], [0.4516297879845205, 0.584752680075289, 0.5289154341726592, 0.7857127623124556, 0.6657846287043407, 1.0508054434651077], [0.3216169338033657, 0.4714334613145, 0.341750852989428, 0.6193477216393056, 0.4556712237271396, 0.8496491595952198]], [[0.3577318336024429, 0.46187034761062773, 0.36865022447374135, 0.6880141508699668, 0.490064789550473, 0.9283094213466452, 0.004251918407401653], [0.35744508107503253, 0.4682911526073109, 0.3670787811279297, 0.6878670297487818, 0.4916710323757596, 0.9247834995539501, 0.004101233048872515], [0.34304789822511, 0.46506821507155294, 0.3702619340684679, 0.6890249252319336, 0.48857850257796476, 0.9265595493894635, 0.004756330239652383]]], [[[0.32409022552798494, 0.4537436938045001, 0.3600639044636428, 0.6765010621812608, 0.4840817355146312, 0.919375443699384], [0.3252748287085331, 0.39878895788481744, 0.32009264435430973, 0.5979235967000326, 0.4323175218370226, 0.8091085125701596], [0.3775752433622726, 0.7120554856579713, 0.507318853127836, 0.9536376866427335, 0.6691536277231545, 1.2858475819982664], [0.4203124961467704, 0.5663527382744684, 0.5275549310626406, 0.7793375940033884, 0.6773896169180822, 1.0505064328511555]], [[0.3410032301238089, 0.5246755089422669, 0.4104647251090618, 0.7753114748482752, 0.5516608796938501, 1.0461353774022573, 0.004761459851505781], [0.3467666500746602, 0.5258405810654765, 0.41413808109784367, 0.7768700580404262, 0.5466712364042648, 1.0479942475906527, 0.004405325109308416], [0.3509985316883434, 0.52610416605015, 0.4151724324081883, 0.7738488611548838, 0.5471140928942747, 1.0415510697798296, 0.004083628606314611]]], [[[0.3314778780696368, 0.5213186716792559, 0.4044876194963552, 0.7711166564864342, 0.5433830107101287, 1.0415975734440968], [0.31709625263406777, 0.4603457450866699, 0.3640068661082875, 0.6819324059919878, 0.49100427916555694, 0.9225822940017239], [0.30831818628792806, 0.403487417432997, 0.3342929030909683, 0.5993133121066623, 0.43180643910109395, 0.8127121973519373]], [[0.39097884688714535, 0.714596136651858, 0.5550253030025598, 1.0415667476076067, 0.734896683933759, 1.4084796953682948, 0.00444214753430299], [0.39187493950429586, 0.7178884323197182, 0.5559180962918985, 1.0453465731457028, 0.7339226597487325, 1.4049145669648142, 0.004976397812968553], [0.3937760748044409, 0.7125301072091768, 0.5555001894632975, 1.0429809551046352, 0.7313448010068951, 1.4042837932856396, 0.004143618574046126]]], [[[0.3865939679771963, 0.7147598989082105, 0.5532094444891419, 1.0452751920680807, 0.7332464661261049, 1.4074049092302419], [0.3078347263914166, 0.5324024865121553, 0.413363076219655, 0.7835227070432721, 0.5587832373802109, 1.0567584904757412], [0.3132787136116413, 0.46954424694331004, 0.37312808662954006, 0.6916724310980903, 0.5062123019285876, 0.9339909120039507], [0.3129829541601316, 0.41637312282215466, 0.33040607818449386, 0.6098940637376573, 0.447241272589173, 0.8206562562422318]], [[0.4596819781293773, 0.8472787250172008, 0.6554982638118243, 1.2365547334304965, 0.8659759916440405, 1.6601162968259868, 0.004585535839350537], [0.4650660717126095, 0.8608341217041016, 0.6551676567154701, 1.2411773566043738, 0.8663152685069074, 1.657198920394435, 0.0049291475854738795], [0.46498640619143095, 0.8596966001722547, 0.6567297078142262, 1.2392559918490322, 0.8636613566466051, 1.6618470952968405, 0.004425771308667732]]], [[[0.4479286887428977, 0.8394645440458047, 0.6429801083574391, 1.2274719729568018, 0.8534268658570568, 1.6510065155799942], [0.38627870155103283, 0.7220117251078287, 0.5556040821653424, 1.0479230832571935, 0.7360771689752135, 1.4107511741946441], [0.31525987567323627, 0.5351946811483365, 0.41544396467883177, 0.7865141618131387, 0.5511918934908779, 1.0561841425269543], [0.32904622530696365, 0.4712159465057681, 0.37493383041535966, 0.6946381655606356, 0.497800147894657, 0.9343978612109869]], [[0.5242839244881061, 0.9681367151664966, 0.7472212386853767, 1.4109373333478215, 0.9878720659198184, 1.9006856041725235, 0.004231592621466126], [0.5298834858518657, 0.9775502513153386, 0.7481799703655821, 1.4128783014085557, 0.9842759431010545, 1.889745057231248, 0.004668765597873264], [0.5280137543726449, 0.9755541339065089, 0.7452876399261783, 1.408308395231613, 0.9832095618199821, 1.8874812848640212, 0.005188975671325067]]], [[[0.33862624505553585, 0.6294144765295164, 0.4290954512779159, 0.810597280059198, 0.5483022121467976, 1.0510839356316461], [0.31762077350809115, 0.559200927464649, 0.37940246890289614, 0.7181714038656216, 0.481732493699199, 0.9220814223241324], [0.31610031320591164, 0.47471125920613605, 0.3283519937534525, 0.6014666171989055, 0.4142167833116319, 0.7825871188231188]], [[0.35007797106347904, 0.4412787610834295, 0.3423054049713443, 0.5852765025514545, 0.40640749112524166, 0.7427115873856978, 0.004194794279156309], [0.33576348815301454, 0.4444064275182859, 0.33754336713540434, 0.5864705461444277, 0.4062691601839932, 0.741335888101597, 0.0042497991311429735], [0.347997010356248, 0.4433897047331839, 0.34060092887493093, 0.5851883599252411, 0.40573748675259674, 0.7406246300899622, 0.00453204819650361]]], [[[0.3260928693443838, 0.440192800579649, 0.3165682879361239, 0.5835583715727835, 0.38191544889199613, 0.7402725894041736], [0.3473482709942442, 0.636373818522752, 0.4305076839947941, 0.8120978721464523, 0.5494835400822187, 1.0491823668431755], [0.32828465856686984, 0.5603574261520848, 0.400460705612645, 0.7195479942090584, 0.5029239558210277, 0.9257183893762454], [0.3226006633103496, 0.4737337671145044, 0.32909169341578626, 0.606823545513731, 0.41159104819249626, 0.7856156368448277]], [[0.34432806149877687, 0.48627826902601456, 0.35802836369986485, 0.6468107483603737, 0.43973238781245066, 0.8166859125850177, 0.004473575437911833], [0.34155412153764203, 0.4882547590467665, 0.3579827510949337, 0.6446643790813408, 0.4383250197978935, 0.813951227400038, 0.0045983478276416506], [0.3423182892076897, 0.49038797917992183, 0.3526753849453397, 0.6453903275306778, 0.4380230951790858, 0.814479264346036, 0.00436650382147895]]], [[[0.31611798989652384, 0.5243616152291346, 0.34948548885306924, 0.6844820879926585, 0.4360723254656551, 0.8539520851289384], [0.3155406797775115, 0.441239843464861, 0.3129776318868001, 0.5864422971552069, 0.3826266346555767, 0.7415137387285329], [0.341466942218819, 0.6324205013236613, 0.43117410004741014, 0.813953322593612, 0.549957872641207, 1.0534283849928114], [0.32031817869706586, 0.5669151412116156, 0.3825205985945885, 0.7180473298737498, 0.4875175399009628, 0.9277093530905367]], [[0.33806482950846356, 0.5306571180170233, 0.38618906579836454, 0.7049235189803923, 0.4787800769613247, 0.8875630841110692, 0.004487085824060922], [0.3448331958115703, 0.5330200869627674, 0.3988961739973588, 0.7003687848948469, 0.47722561190826723, 0.8862311912305427, 0.004113322556620896], [0.3474274789444124, 0.5329236117276278, 0.3846857282850477, 0.7030708139592952, 0.4763580813552394, 0.8857826993922995, 0.0040633269030638415]]], [[[0.32171962237117263, 0.5723524575281626, 0.40756129255198464, 0.7432415991118461, 0.4988236619968607, 0.927507034455887], [0.3153900185016671, 0.5259105412646977, 0.37463693907766626, 0.6845260388923413, 0.4606634679466787, 0.8534405448219993], [0.3186584241462476, 0.48466046651204425, 0.3262169433362556, 0.6259828625303326, 0.4041823473843661, 0.7803911873788545]], [[0.5497075572158351, 1.0372076612530332, 0.6927506851427483, 1.3213739250645493, 0.8462854587670529, 1.630257669121328, 0.004380182786421342], [0.5530304860587072, 1.0435495232090806, 0.693162354555997, 1.321483693941675, 0.8439054392805003, 1.6276801475370772, 0.004326261655248777], [0.5519224176503191, 1.042814471504905, 0.6919608694134336, 1.3225178766732264, 0.8431200065998116, 1.6249998410542805, 0.004217167093296244]]], [[[0.5501972545276989, 1.058707164995598, 0.694795594070897, 1.3447087220471314, 0.8489726047323208, 1.6523896082483156], [0.3268535691078263, 0.5884998013274838, 0.4162517460909757, 0.7617164380622633, 0.508760129562532, 0.9446223576863607], [0.31738830335212476, 0.5434732726126006, 0.383788262954866, 0.7018130716651377, 0.46994878788187044, 0.8708676665720314], [0.31652166385843294, 0.49711326155999697, 0.3372591192072088, 0.640807007298325, 0.41489528887199634, 0.7945932282341851]], [[0.6180085317052976, 1.1803230372342197, 0.7760107878482703, 1.4968219670382412, 0.9470603923604946, 1.8356977809559216, 0.00415854983859592], [0.6226346468684649, 1.1885866733512493, 0.7775900339839434, 1.4975251573504824, 0.944289607231063, 1.833727817342739, 0.005115378986705434], [0.6210861784039121, 1.187923966032086, 0.776205351858428, 1.4959932096076733, 0.9441822225397283, 1.8313705559932825, 0.0047388221278335104]]], [[[0.6076182259453667, 1.1761999130249023, 0.7682572229944095, 1.4921375236125907, 0.9383929377854472, 1.827945179409451], [0.5541018524555245, 1.0647979408803612, 0.6995696973319006, 1.348778743936558, 0.8533838782647644, 1.6592941621337274], [0.3311891025967068, 0.5921483280682804, 0.4187226054644344, 0.7656114269988705, 0.5120331109172166, 0.9497349671643189], [0.31480916822799526, 0.5449382223264136, 0.38729677296648124, 0.7054818278611309, 0.4725636857928652, 0.8741613590356074]], [[0.674338870578342, 1.2978931870123354, 0.8485467024523803, 1.6414815970141476, 1.0351639564591224, 2.0140134445344557, 0.004651305651423907], [0.6819091662011966, 1.305962861186326, 0.8498003747728136, 1.645497866351195, 1.032505324392608, 2.0091032259392017, 0.00428734403668028], [0.6797817259123831, 1.3053699936529604, 0.8498629415878142, 1.642856477486967, 1.032006981396916, 2.0077070322903716, 0.004319518503516611]]]], [0.3001111926454486, 0.28059167091292564, 0.26766254444314974], 0.4041198769001046] -------------------------------------------------------------------------------- /models/search_space_base.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from tools.multadds_count import comp_multadds_fw 10 | from tools.utils import latency_measure_fw 11 | from . import operations 12 | from .operations import OPS 13 | 14 | 15 | class MixedOp(nn.Module): 16 | def __init__(self, C_in, C_out, stride, primitives): 17 | super(MixedOp, self).__init__() 18 | self._ops = nn.ModuleList() 19 | for primitive in primitives: 20 | op = OPS[primitive](C_in, C_out, stride, affine=False, track_running_stats=True) 21 | self._ops.append(op) 22 | 23 | 24 | class HeadLayer(nn.Module): 25 | def __init__(self, in_chs, ch, strides, config): 26 | super(HeadLayer, self).__init__() 27 | self.head_branches = nn.ModuleList() 28 | for in_ch, stride in zip(in_chs, strides): 29 | self.head_branches.append( 30 | MixedOp(in_ch, ch, stride, 31 | config.search_params.PRIMITIVES_head) 32 | ) 33 | 34 | 35 | class StackLayers(nn.Module): 36 | def __init__(self, ch, num_block_layers, config, primitives): 37 | super(StackLayers, self).__init__() 38 | 39 | if num_block_layers != 0: 40 | self.stack_layers = nn.ModuleList() 41 | for i in range(num_block_layers): 42 | self.stack_layers.append(MixedOp(ch, ch, 1, primitives)) 43 | else: 44 | self.stack_layers = None 45 | 46 | 47 | class Block(nn.Module): 48 | def __init__(self, in_chs, block_ch, strides, num_block_layers, config): 49 | super(Block, self).__init__() 50 | assert len(in_chs) == len(strides) 51 | self.head_layer = HeadLayer(in_chs, block_ch, strides, config) 52 | self.stack_layers = StackLayers(block_ch, num_block_layers, config, config.search_params.PRIMITIVES_stack) 53 | 54 | 55 | class Conv1_1_Branch(nn.Module): 56 | def __init__(self, in_ch, block_ch): 57 | super(Conv1_1_Branch, self).__init__() 58 | self.conv1_1 = nn.Sequential( 59 | nn.Conv2d(in_channels=in_ch, out_channels=block_ch, 60 | kernel_size=1, stride=1, padding=0, bias=False), 61 | nn.BatchNorm2d(block_ch, affine=False, track_running_stats=True), 62 | nn.ReLU6(inplace=True) 63 | ) 64 | 65 | def forward(self, x): 66 | return self.conv1_1(x) 67 | 68 | 69 | class Conv1_1_Block(nn.Module): 70 | def __init__(self, in_chs, block_ch): 71 | super(Conv1_1_Block, self).__init__() 72 | self.conv1_1_branches = nn.ModuleList() 73 | for in_ch in in_chs: 74 | self.conv1_1_branches.append(Conv1_1_Branch(in_ch, block_ch)) 75 | 76 | def forward(self, inputs, betas, block_sub_obj): 77 | branch_weights = F.softmax(torch.stack(betas), dim=-1) 78 | return sum(branch_weight * branch(input_data) for input_data, branch, branch_weight in zip( 79 | inputs, self.conv1_1_branches, branch_weights)), \ 80 | [block_sub_obj, 0] 81 | 82 | 83 | class Network(nn.Module): 84 | def __init__(self, init_ch, dataset, config): 85 | super(Network, self).__init__() 86 | self.config = config 87 | self._C_input = init_ch 88 | self._head_dim = self.config.optim.head_dim 89 | self._dataset = dataset 90 | # use 100-class sub dataset for search 91 | self._num_classes = 100 92 | 93 | self.initialize() 94 | 95 | 96 | def initialize(self): 97 | self._init_block_config() 98 | self._create_output_list() 99 | self._create_input_list() 100 | self._init_betas() 101 | self._init_alphas() 102 | self._init_sample_branch() 103 | 104 | 105 | def init_model(self, model_init='he_fout', init_div_groups=True): 106 | for m in self.modules(): 107 | if isinstance(m, nn.Conv2d): 108 | if model_init == 'he_fout': 109 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 110 | if init_div_groups: 111 | n /= m.groups 112 | m.weight.data.normal_(0, math.sqrt(2. / n)) 113 | elif model_init == 'he_fin': 114 | n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels 115 | if init_div_groups: 116 | n /= m.groups 117 | m.weight.data.normal_(0, math.sqrt(2. / n)) 118 | else: 119 | raise NotImplementedError 120 | elif isinstance(m, nn.BatchNorm2d): 121 | if m.affine==True: 122 | m.weight.data.fill_(1) 123 | m.bias.data.zero_() 124 | elif isinstance(m, nn.Linear): 125 | if m.bias is not None: 126 | m.bias.data.zero_() 127 | elif isinstance(m, nn.BatchNorm1d): 128 | if m.affine==True: 129 | m.weight.data.fill_(1) 130 | m.bias.data.zero_() 131 | 132 | 133 | def set_bn_param(self, bn_momentum, bn_eps): 134 | for m in self.modules(): 135 | if isinstance(m, nn.BatchNorm2d): 136 | m.momentum = bn_momentum 137 | m.eps = bn_eps 138 | return 139 | 140 | 141 | def _init_betas(self): 142 | r""" 143 | beta weights for the output ch choices in the head layer of the block 144 | """ 145 | self.beta_weights = nn.ParameterList() 146 | for block in self.output_configs: 147 | num_betas = len(block['out_chs']) 148 | self.beta_weights.append( 149 | nn.Parameter(1e-3 * torch.randn(num_betas)) 150 | ) 151 | 152 | 153 | def _init_alphas(self): 154 | r""" 155 | alpha weights for the op type in the block 156 | """ 157 | self.alpha_head_weights = nn.ParameterList() 158 | self.alpha_stack_weights = nn.ParameterList() 159 | 160 | for block in self.input_configs[:-1]: 161 | num_head_alpha = len(block['in_block_idx']) 162 | self.alpha_head_weights.append(nn.Parameter( 163 | 1e-3*torch.randn(num_head_alpha, len(self.config.search_params.PRIMITIVES_head))) 164 | ) 165 | 166 | num_layers = block['num_stack_layers'] 167 | self.alpha_stack_weights.append(nn.Parameter( 168 | 1e-3*torch.randn(num_layers, len(self.config.search_params.PRIMITIVES_stack))) 169 | ) 170 | 171 | @property 172 | def arch_parameters(self): 173 | arch_params = nn.ParameterList() 174 | arch_params.extend(self.beta_weights) 175 | arch_params.extend(self.alpha_head_weights) 176 | arch_params.extend(self.alpha_stack_weights) 177 | return arch_params 178 | 179 | @property 180 | def arch_beta_params(self): 181 | return self.beta_weights 182 | 183 | @property 184 | def arch_alpha_params(self): 185 | alpha_params = nn.ParameterList() 186 | alpha_params.extend(self.alpha_head_weights) 187 | alpha_params.extend(self.alpha_stack_weights) 188 | return alpha_params 189 | 190 | 191 | def display_arch_params(self, display=True): 192 | branch_weights = [] 193 | head_op_weights = [] 194 | stack_op_weights = [] 195 | for betas in self.beta_weights: 196 | branch_weights.append(F.softmax(betas, dim=-1)) 197 | for head_alpha in self.alpha_head_weights: 198 | head_op_weights.append(F.softmax(head_alpha, dim=-1)) 199 | for stack_alpha in self.alpha_stack_weights: 200 | stack_op_weights.append(F.softmax(stack_alpha, dim=-1)) 201 | 202 | if display: 203 | logging.info('branch_weights \n' + '\n'.join(map(str, branch_weights))) 204 | if len(self.config.search_params.PRIMITIVES_head) > 1: 205 | logging.info('head_op_weights \n' + '\n'.join(map(str, head_op_weights))) 206 | logging.info('stack_op_weights \n' + '\n'.join(map(str, stack_op_weights))) 207 | 208 | return [x.tolist() for x in branch_weights], \ 209 | [x.tolist() for x in head_op_weights], \ 210 | [x.tolist() for x in stack_op_weights] 211 | 212 | 213 | def _init_sample_branch(self): 214 | _, _ = self.sample_branch('head', 1, training=False) 215 | _, _ = self.sample_branch('stack', 1, training=False) 216 | 217 | 218 | def sample_branch(self, params_type, sample_num, training=True, search_stage=1, if_sort=True): 219 | r""" 220 | the sampling computing is based on torch 221 | input: params_type 222 | output: sampled params 223 | """ 224 | 225 | def sample(param, weight, sample_num, sample_policy='prob', if_sort=True): 226 | if sample_num >= weight.shape[-1]: 227 | sample_policy = 'all' 228 | assert param.shape == weight.shape 229 | assert sample_policy in ['prob', 'uniform', 'all'] 230 | if param.shape[0] == 0: 231 | return [], [] 232 | if sample_policy == 'prob': 233 | sampled_index = torch.multinomial(weight, num_samples=sample_num, replacement=False) 234 | elif sample_policy == 'uniform': 235 | weight = torch.ones_like(weight) 236 | sampled_index = torch.multinomial(weight, num_samples=sample_num, replacement=False) 237 | else: 238 | sampled_index = torch.arange(start=0, end=weight.shape[-1], step=1, device=weight.device 239 | ).repeat(param.shape[0], 1) 240 | if if_sort: 241 | sampled_index, _ = torch.sort(sampled_index, descending=False) 242 | sampled_param_old = torch.gather(param, dim=-1, index=sampled_index) 243 | return sampled_param_old, sampled_index 244 | 245 | if params_type=='head': 246 | params = self.alpha_head_weights 247 | elif params_type=='stack': 248 | params = self.alpha_stack_weights 249 | else: 250 | raise TypeError 251 | 252 | weights = [] 253 | sampled_params_old = [] 254 | sampled_indices = [] 255 | if training: 256 | sample_policy = self.config.search_params.sample_policy if search_stage==1 else 'uniform' 257 | else: 258 | sample_policy = 'all' 259 | 260 | for param in params: 261 | weights.append(F.softmax(param, dim=-1)) 262 | 263 | for param, weight in zip(params, weights): #list dim 264 | sampled_param_old, sampled_index = sample( 265 | param, weight, sample_num, sample_policy, if_sort) 266 | sampled_params_old.append(sampled_param_old) 267 | sampled_indices.append(sampled_index) 268 | 269 | if params_type=='head': 270 | self.alpha_head_index = sampled_indices 271 | elif params_type=='stack': 272 | self.alpha_stack_index = sampled_indices 273 | return sampled_params_old, sampled_indices 274 | 275 | 276 | def _init_block_config(self): 277 | self.block_chs = self.config.search_params.net_scale.chs 278 | self.block_fm_sizes = self.config.search_params.net_scale.fm_sizes 279 | self.num_blocks = len(self.block_chs) - 1 # not include the head and tail 280 | self.num_block_layers = self.config.search_params.net_scale.num_layers 281 | if hasattr(self.config.search_params.net_scale, 'stage'): 282 | self.block_stage = self.config.search_params.net_scale.stage 283 | 284 | self.block_chs.append(self.config.optim.last_dim) 285 | self.block_fm_sizes.append(self.block_fm_sizes[-1]) 286 | self.num_block_layers.append(0) 287 | 288 | 289 | def _create_output_list(self): 290 | r""" 291 | Generate the output config of each block, which contains: 292 | 'ch': the channel number of the block 293 | 'out_chs': the possible output channel numbers 294 | 'strides': the corresponding stride 295 | """ 296 | 297 | self.output_configs = [] 298 | for i in range(len(self.block_chs)-1): 299 | if hasattr(self, 'block_stage'): 300 | stage = self.block_stage[i] 301 | output_config = {'ch': self.block_chs[i], 302 | 'fm_size': self.block_fm_sizes[i], 303 | 'out_chs': [], 304 | 'out_fms': [], 305 | 'strides': [], 306 | 'out_id': [], 307 | 'num_stack_layers': self.num_block_layers[i]} 308 | for j in range(self.config.search_params.adjoin_connect_nums[stage]): 309 | out_index = i + j + 1 310 | if out_index >= len(self.block_chs): 311 | break 312 | if hasattr(self, 'block_stage'): 313 | block_stage = getattr(self, 'block_stage') 314 | if block_stage[out_index]-block_stage[i] > 1: 315 | break 316 | fm_size_ratio = self.block_fm_sizes[i] / self.block_fm_sizes[out_index] 317 | if fm_size_ratio == 2: 318 | output_config['strides'].append(2) 319 | elif fm_size_ratio == 1: 320 | output_config['strides'] .append(1) 321 | else: 322 | break # only connet to the block whose fm size expansion ratio is 1 or 2 323 | output_config['out_chs'].append(self.block_chs[out_index]) 324 | output_config['out_fms'].append(self.block_fm_sizes[out_index]) 325 | output_config['out_id'].append(out_index) 326 | 327 | self.output_configs.append(output_config) 328 | 329 | logging.info('Network output configs: \n' + '\n'.join(map(str, self.output_configs))) 330 | 331 | 332 | def _create_input_list(self): 333 | r""" 334 | Generate the input config of each block for constructing the whole network. 335 | Each config dict contains: 336 | 'ch': the channel number of the block 337 | 'in_chs': all the possible input channel numbers 338 | 'strides': the corresponding stride 339 | 'in_block_idx': the index of the input block 340 | 'beta_idx': the corresponding beta weight index. 341 | """ 342 | 343 | self.input_configs = [] 344 | for i in range(1, len(self.block_chs)): 345 | input_config = {'ch': self.block_chs[i], 346 | 'fm_size': self.block_fm_sizes[i], 347 | 'in_chs': [], 348 | 'in_fms': [], 349 | 'strides': [], 350 | 'in_block_idx': [], 351 | 'beta_idx': [], 352 | 'num_stack_layers': self.num_block_layers[i]} 353 | for j in range(i): 354 | in_index = i - j - 1 355 | if in_index < 0: 356 | break 357 | output_config = self.output_configs[in_index] 358 | if i in output_config['out_id']: 359 | beta_idx = output_config['out_id'].index(i) 360 | input_config['in_block_idx'].append(in_index) 361 | input_config['in_chs'].append(output_config['ch']) 362 | input_config['in_fms'].append(output_config['fm_size']) 363 | input_config['beta_idx'].append(beta_idx) 364 | input_config['strides'].append(output_config['strides'][beta_idx]) 365 | else: 366 | continue 367 | 368 | self.input_configs.append(input_config) 369 | 370 | logging.info('Network input configs: \n' + '\n'.join(map(str, self.input_configs))) 371 | 372 | 373 | def get_cost_list(self, data_shape, cost_type='flops', 374 | use_gpu=True, meas_times=1000): 375 | cost_list = [] 376 | block_datas = [] 377 | total_cost = 0 378 | if cost_type == 'flops': 379 | cost_func = lambda module, data: comp_multadds_fw( 380 | module, data, use_gpu) 381 | elif cost_type == 'latency': 382 | cost_func = lambda module, data: latency_measure_fw( 383 | module, data, meas_times) 384 | else: 385 | raise NotImplementedError 386 | 387 | if len(data_shape) == 3: 388 | input_data = torch.randn((1,) + tuple(data_shape)) 389 | else: 390 | input_data = torch.randn(tuple(data_shape)) 391 | if use_gpu: 392 | input_data = input_data.cuda() 393 | 394 | cost, block_data = cost_func(self.input_block, input_data) 395 | cost_list.append(cost) 396 | block_datas.append(block_data) 397 | total_cost += cost 398 | if hasattr(self, 'head_block'): 399 | cost, block_data = cost_func(self.head_block, block_data) 400 | cost_list[0] += cost 401 | block_datas[0] = block_data 402 | 403 | block_flops = [] 404 | for block_id, block in enumerate(self.blocks): 405 | input_config = self.input_configs[block_id] 406 | inputs = [block_datas[i] for i in input_config['in_block_idx']] 407 | 408 | head_branch_flops = [] 409 | for branch_id, head_branch in enumerate(block.head_layer.head_branches): 410 | op_flops = [] 411 | for op in head_branch._ops: 412 | cost, block_data = cost_func(op, inputs[branch_id]) 413 | op_flops.append(cost) 414 | total_cost += cost 415 | 416 | head_branch_flops.append(op_flops) 417 | 418 | stack_layer_flops = [] 419 | if block.stack_layers.stack_layers is not None: 420 | for stack_layer in block.stack_layers.stack_layers: 421 | op_flops = [] 422 | for op in stack_layer._ops: 423 | cost, block_data = cost_func(op, block_data) 424 | if isinstance(op, operations.Skip) and \ 425 | self.config.optim.sub_obj.skip_reg: 426 | # skip_reg is used for regularization as the cost of skip is too small 427 | cost = op_flops[0] / 10. 428 | op_flops.append(cost) 429 | total_cost += cost 430 | stack_layer_flops.append(op_flops) 431 | block_flops.append([head_branch_flops, stack_layer_flops]) 432 | block_datas.append(block_data) 433 | 434 | cost_list.append(block_flops) 435 | 436 | conv1_1_flops = [] 437 | input_config = self.input_configs[-1] 438 | inputs = [block_datas[i] for i in input_config['in_block_idx']] 439 | for branch_id, branch in enumerate(self.conv1_1_block.conv1_1_branches): 440 | cost, block_data = cost_func(branch, inputs[branch_id]) 441 | conv1_1_flops.append(cost) 442 | total_cost += cost 443 | block_datas.append(block_data) 444 | 445 | cost_list.append(conv1_1_flops) 446 | out = block_datas[-1] 447 | out = self.global_pooling(out) 448 | 449 | cost, out = cost_func(self.classifier, out.view(out.size(0), -1)) 450 | cost_list.append(cost) 451 | total_cost += cost 452 | 453 | return cost_list, total_cost 454 | --------------------------------------------------------------------------------