├── models
├── __init__.py
├── search_space_mbv2.py
├── search_space_res.py
├── model_derived.py
├── operations.py
├── dropped_model.py
└── search_space_base.py
├── tools
├── __init__.py
├── collections.py
├── io.py
├── lr_scheduler.py
├── utils.py
├── config_yaml.py
└── multadds_count.py
├── dataset
├── __init__.py
├── mk_img_list.py
├── mk_split_img_list.py
├── torchvision_extension.py
├── img2lmdb.py
├── prefetch_data.py
├── lmdb_dataset.py
└── imagenet_data.py
├── run_apis
├── __init__.py
├── derive_arch_res.py
├── derive_arch_mbv2.py
├── derive_arch.py
├── latency_measure.py
├── validation.py
├── optimizer.py
├── retrain.py
├── trainer.py
└── search.py
├── configs
├── __init__.py
├── imagenet_val_cfg.py
├── imagenet_search_cfg_resnet.yaml
├── search_config.py
├── imagenet_search_cfg_mbv2.yaml
└── imagenet_train_cfg.py
├── imgs
├── archs.png
├── res_comp.png
├── mbv2_comp.png
├── mbv2_results.png
├── res_results.png
└── search_space.png
├── .gitignore
├── README.md
├── latency_list
├── lat_list_resv2_32_reg10
├── lat_list_resv2_32
└── lat_list_densenas_mbv2_xp32
└── LICENSE
/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/run_apis/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/configs/__init__.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("..")
3 |
--------------------------------------------------------------------------------
/imgs/archs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JaminFong/DenseNAS/HEAD/imgs/archs.png
--------------------------------------------------------------------------------
/imgs/res_comp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JaminFong/DenseNAS/HEAD/imgs/res_comp.png
--------------------------------------------------------------------------------
/imgs/mbv2_comp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JaminFong/DenseNAS/HEAD/imgs/mbv2_comp.png
--------------------------------------------------------------------------------
/imgs/mbv2_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JaminFong/DenseNAS/HEAD/imgs/mbv2_results.png
--------------------------------------------------------------------------------
/imgs/res_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JaminFong/DenseNAS/HEAD/imgs/res_results.png
--------------------------------------------------------------------------------
/imgs/search_space.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JaminFong/DenseNAS/HEAD/imgs/search_space.png
--------------------------------------------------------------------------------
/configs/imagenet_val_cfg.py:
--------------------------------------------------------------------------------
1 | from tools.collections import AttrDict
2 |
3 | __C = AttrDict()
4 |
5 | cfg = __C
6 |
7 | __C.net_config=""
8 |
9 | __C.data=AttrDict()
10 | __C.data.num_workers=16
11 | __C.data.batch_size=1024
12 | __C.data.dataset='imagenet'
13 | __C.data.train_data_type='lmdb'
14 | __C.data.val_data_type='lmdb'
15 | __C.data.patch_dataset=False
16 | __C.data.num_examples=1281167
17 | __C.data.input_size=(3,224,224)
--------------------------------------------------------------------------------
/dataset/mk_img_list.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | def get_list(data_path, output_path):
5 | for split in os.listdir(data_path):
6 | split_path = os.path.join(data_path, split)
7 | if not os.path.isdir(split_path):
8 | continue
9 | f = open(os.path.join(output_path, split + '_datalist'), 'a+')
10 | for sub in os.listdir(split_path):
11 | sub_path = os.path.join(split_path, sub)
12 | if not os.path.isdir(sub_path):
13 | continue
14 | for image in os.listdir(sub_path):
15 | image_name = sub + '/' + image
16 | f.writelines(image_name + '\n')
17 | f.close()
18 |
19 |
20 | if __name__ == '__main__':
21 | parser = argparse.ArgumentParser("Params")
22 | parser.add_argument('--image_path', type=str, default='', help='the path of the images')
23 | parser.add_argument('--output_path', type=str, default='', help='the output path of the lmdb file')
24 | args = parser.parse_args()
25 |
26 | get_list(args.image_path, args.output_path)
27 |
--------------------------------------------------------------------------------
/dataset/mk_split_img_list.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 |
5 | def get_list(data_path, output_path):
6 | for split in os.listdir(data_path):
7 | if split == 'train':
8 | split_path = os.path.join(data_path, split)
9 | if not os.path.isdir(split_path):
10 | continue
11 | f_train = open(os.path.join(output_path, split + '_datalist'), 'w')
12 | f_val = open(os.path.join(output_path, 'val' + '_datalist'), 'w')
13 | class_list = os.listdir(split_path)
14 | for sub in class_list[:100]:
15 | sub_path = os.path.join(split_path, sub)
16 | if not os.path.isdir(sub_path):
17 | continue
18 | img_list = os.listdir(sub_path)
19 | train_len = int(0.8*len(img_list))
20 | for image in img_list[:train_len]:
21 | image_name = os.path.join(sub, image)
22 | f_train.writelines(image_name + '\n')
23 | for image in img_list[train_len:]:
24 | image_name = os.path.join(sub, image)
25 | f_val.writelines(image_name + '\n')
26 | f_train.close()
27 | f_val.close()
28 |
29 | if __name__ == '__main__':
30 | parser = argparse.ArgumentParser("Params")
31 | parser.add_argument('--image_path', type=str, default='', help='the path of the images')
32 | parser.add_argument('--output_path', type=str, default='.', help='the output path of the list file')
33 | args = parser.parse_args()
34 |
35 | get_list(args.image_path, args.output_path)
36 |
--------------------------------------------------------------------------------
/models/search_space_mbv2.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from .operations import OPS
4 | from .search_space_base import Conv1_1_Block, Block
5 | from .search_space_base import Network as BaseSearchSpace
6 |
7 | class Network(BaseSearchSpace):
8 | def __init__(self, init_ch, dataset, config):
9 | super(Network, self).__init__(init_ch, dataset, config)
10 |
11 | self.input_block = nn.Sequential(
12 | nn.Conv2d(in_channels=3, out_channels=self._C_input, kernel_size=3,
13 | stride=2, padding=1, bias=False),
14 | nn.BatchNorm2d(self._C_input, affine=False, track_running_stats=True),
15 | nn.ReLU6(inplace=True)
16 | )
17 |
18 | self.head_block = OPS['mbconv_k3_t1'](self._C_input, self._head_dim, 1, affine=False, track_running_stats=True)
19 |
20 | self.blocks = nn.ModuleList()
21 |
22 | for i in range(self.num_blocks):
23 | input_config = self.input_configs[i]
24 | self.blocks.append(Block(
25 | input_config['in_chs'],
26 | input_config['ch'],
27 | input_config['strides'],
28 | input_config['num_stack_layers'],
29 | self.config
30 | ))
31 |
32 | self.conv1_1_block = Conv1_1_Block(self.input_configs[-1]['in_chs'],
33 | self.config.optim.last_dim)
34 |
35 | self.global_pooling = nn.AdaptiveAvgPool2d(1)
36 | self.classifier = nn.Linear(self.config.optim.last_dim, self._num_classes)
37 |
38 | self.init_model(model_init=config.optim.init_mode)
39 | self.set_bn_param(self.config.optim.bn_momentum, self.config.optim.bn_eps)
40 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/dataset/torchvision_extension.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torchvision.transforms as transforms
3 | from torchvision.transforms import functional as F
4 |
5 | #In this file some more transformations (apart from the ones defined in torchvision.transform)
6 | #are added. Particularly helpful to train imagenet, and in the style of the transforms
7 | #used by fb.resnet https://github.com/facebook/fb.resnet.torch/blob/master/datasets/imagenet.lua
8 |
9 | #This file is taken from a proposed pull request on the torchvision github project.
10 | #At the moment this pull request has not been accepted yet, that is why I report it here.
11 | #Link to the pull request: https://github.com/pytorch/vision/pull/27/files
12 |
13 | class Lighting(object):
14 |
15 | """Lighting noise(AlexNet - style PCA - based noise)"""
16 |
17 | def __init__(self, alphastd, eigval, eigvec):
18 | self.alphastd = alphastd
19 | self.eigval = eigval
20 | self.eigvec = eigvec
21 |
22 | def __call__(self, img):
23 | # img is supposed go be a torch tensor
24 |
25 | if self.alphastd == 0:
26 | return img
27 |
28 | alpha = img.new().resize_(3).normal_(0, self.alphastd)
29 | rgb = self.eigvec.type_as(img).clone()\
30 | .mul(alpha.view(1, 3).expand(3, 3))\
31 | .mul(self.eigval.view(1, 3).expand(3, 3))\
32 | .sum(1).squeeze()
33 |
34 | return img.add(rgb.view(3, 1, 1).expand_as(img))
35 |
36 |
37 | class RandomScale(object):
38 |
39 | """ResNet style data augmentation"""
40 |
41 | def __init__(self, minSize, maxSize):
42 | self.minSize = minSize
43 | self.maxSize = maxSize
44 |
45 | def __call__(self, img):
46 |
47 | targetSz = int(round(random.uniform(self.minSize, self.maxSize)))
48 |
49 | return F.resize(img, targetSz)
50 |
51 |
--------------------------------------------------------------------------------
/run_apis/derive_arch_res.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from .derive_arch import BaseArchGenerate
3 |
4 | class ArchGenerate(BaseArchGenerate):
5 | def __init__(self, super_network, config):
6 | super(ArchGenerate, self).__init__(super_network, config)
7 |
8 | def derive_archs(self, betas, head_alphas, stack_alphas, if_display=True):
9 |
10 | self.update_arch_params(betas, head_alphas, stack_alphas)
11 |
12 | # [[ch, head_op, [stack_op], num_layers, stride], ..., [...]]
13 | derived_archs = []
14 | ch_path, derived_chs = self.derive_chs()
15 |
16 | layer_count = 0
17 | for i, (ch_idx, ch) in enumerate(zip(ch_path, derived_chs)):
18 | if ch_idx == 0 or i == len(derived_chs)-1:
19 | continue
20 |
21 | block_idx = ch_idx - 1
22 | input_config = self.input_configs[block_idx]
23 |
24 | head_id = input_config['in_block_idx'].index(ch_path[i-1])
25 | head_alpha = self.head_alphas[block_idx][head_id]
26 | head_op = self.derive_ops(head_alpha, 'head')
27 |
28 | stride = input_config['strides'][input_config['in_block_idx'].index(ch_path[i-1])]
29 |
30 | stack_ops = []
31 | for stack_alpha in self.stack_alphas[block_idx]:
32 | stack_op = self.derive_ops(stack_alpha, 'stack')
33 | if stack_op != 'skip_connect':
34 | stack_ops.append(stack_op)
35 | layer_count += 1
36 |
37 | derived_archs.append(
38 | [[derived_chs[i-1], ch], head_op, stack_ops, len(stack_ops), stride]
39 | )
40 |
41 | layer_count += len(derived_archs)
42 | if if_display:
43 | logging.info('Derived arch: \n' + '|\n'.join(map(str, derived_archs)))
44 | logging.info('Total {} layers.'.format(layer_count))
45 |
46 | return derived_archs
--------------------------------------------------------------------------------
/run_apis/derive_arch_mbv2.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from .derive_arch import BaseArchGenerate
3 |
4 | class ArchGenerate(BaseArchGenerate):
5 | def __init__(self, super_network, config):
6 | super(ArchGenerate, self).__init__(super_network, config)
7 |
8 | def derive_archs(self, betas, head_alphas, stack_alphas, if_display=True):
9 |
10 | self.update_arch_params(betas, head_alphas, stack_alphas)
11 |
12 | # [[ch, head_op, [stack_op], num_layers, stride], ..., [...]]
13 | derived_archs = [[[self.config.optim.init_dim, self.config.optim.head_dim], 'mbconv_k3_t1', [], 0, 1],]
14 | ch_path, derived_chs = self.derive_chs()
15 |
16 | layer_count = 0
17 | for i, (ch_idx, ch) in enumerate(zip(ch_path, derived_chs)):
18 | if ch_idx == 0 or i == len(derived_chs)-1:
19 | continue
20 |
21 | block_idx = ch_idx - 1
22 | input_config = self.input_configs[block_idx]
23 |
24 | head_id = input_config['in_block_idx'].index(ch_path[i-1])
25 | head_alpha = self.head_alphas[block_idx][head_id]
26 | head_op = self.derive_ops(head_alpha, 'head')
27 |
28 | stride = input_config['strides'][input_config['in_block_idx'].index(ch_path[i-1])]
29 |
30 | stack_ops = []
31 | for stack_alpha in self.stack_alphas[block_idx]:
32 | stack_op = self.derive_ops(stack_alpha, 'stack')
33 | if stack_op != 'skip_connect':
34 | stack_ops.append(stack_op)
35 | layer_count += 1
36 |
37 | derived_archs.append(
38 | [[derived_chs[i-1], ch], head_op, stack_ops, len(stack_ops), stride]
39 | )
40 | derived_archs.append([[derived_chs[-2], self.config.optim.last_dim], 'conv1_1'])
41 |
42 | layer_count += len(derived_archs)
43 | if if_display:
44 | logging.info('Derived arch: \n' + '|\n'.join(map(str, derived_archs)))
45 | logging.info('Total {} layers.'.format(layer_count))
46 |
47 | return derived_archs
--------------------------------------------------------------------------------
/configs/imagenet_search_cfg_resnet.yaml:
--------------------------------------------------------------------------------
1 | net_type: res
2 | train_params:
3 | epochs: 70
4 | use_seed: True
5 | seed: 2
6 |
7 | search_params:
8 | val_start_epoch: 50
9 | arch_update_epoch: 10
10 | sample_policy: prob # prob uniform
11 | weight_sample_num: 1
12 | softmax_temp: 0.9
13 |
14 | PRIMITIVES_stack: ['basic_block',
15 | 'skip_connect',]
16 | PRIMITIVES_head: ['basic_block',
17 | ]
18 |
19 | adjoin_connect_nums: [10, 10, 10, 10, 10, 10, 10]
20 | net_scale:
21 | chs: [32,
22 | 48, 56, 64,
23 | 72, 96, 112,
24 | 128, 144, 160, 176, 192, 208, 224,
25 | 240, 256, 272, 288, 480, 496, 512]
26 | fm_sizes: [112,
27 | 56, 56, 56,
28 | 28, 28, 28,
29 | 14, 14, 14, 14, 14, 14, 14,
30 | 7, 7, 7, 7, 7, 7, 7]
31 | stage: [0,
32 | 1, 1, 1,
33 | 2, 2, 2,
34 | 3, 3, 3, 3, 4, 4, 4,
35 | 5, 5, 5, 5, 6, 6, 6,
36 | 7]
37 | num_layers: [0,
38 | 0, 0, 0,
39 | 5, 5, 5,
40 | 15, 15, 15, 15, 5, 5, 5,
41 | 5, 5, 5, 5, 1, 1, 1]
42 |
43 | optim:
44 | last_dim: 512
45 | init_dim: 32
46 | bn_momentum: 0.1
47 | bn_eps: 0.001
48 | weight:
49 | init_lr: 0.2
50 | min_lr: 0.0001
51 | lr_decay_type: cosine
52 | momentum: 0.9
53 | weight_decay: 0.00004
54 | arch:
55 | alpha_lr: 0.0003
56 | beta_lr: 0.0003
57 | weight_decay: 0.001
58 |
59 | if_sub_obj: True
60 | sub_obj:
61 | type: flops
62 | skip_reg: True
63 | log_base: 3500.
64 | sub_loss_factor: 0.2
65 |
66 | if_resume: False
67 | resume:
68 | load_path: ''
69 | load_epoch: 9
70 |
71 | data:
72 | batch_size: 512
73 | num_workers: 16
74 | dataset: imagenet
75 | train_data_type: lmdb
76 | val_data_type: lmdb
77 | patch_dataset: False
78 | input_size: (3,224,224)
79 | type_of_data_aug: random_sized
80 | color: False
81 | random_sized:
82 | min_scale: 0.08
83 |
84 |
--------------------------------------------------------------------------------
/configs/search_config.py:
--------------------------------------------------------------------------------
1 | from tools.collections import AttrDict
2 |
3 | __C = AttrDict()
4 | search_cfg = __C
5 |
6 |
7 | __C.search_params=AttrDict()
8 | __C.search_params.arch_update_epoch=10
9 | __C.search_params.val_start_epoch=120
10 | __C.search_params.sample_policy='prob' # prob uniform
11 | __C.search_params.weight_sample_num=1
12 | __C.search_params.softmax_temp=1.
13 |
14 | __C.search_params.adjoin_connect_nums = []
15 | __C.search_params.net_scale = AttrDict()
16 | __C.search_params.net_scale.chs = []
17 | __C.search_params.net_scale.fm_sizes = []
18 | __C.search_params.net_scale.stage = []
19 | __C.search_params.net_scale.num_layers = []
20 |
21 | __C.search_params.PRIMITIVES_stack = [
22 | 'mbconv_k3_t3',
23 | 'mbconv_k3_t6',
24 | 'mbconv_k5_t3',
25 | 'mbconv_k5_t6',
26 | 'mbconv_k7_t3',
27 | 'mbconv_k7_t6',
28 | 'skip_connect',
29 | ]
30 | __C.search_params.PRIMITIVES_head = [
31 | 'mbconv_k3_t3',
32 | 'mbconv_k3_t6',
33 | 'mbconv_k5_t3',
34 | 'mbconv_k5_t6',
35 | 'mbconv_k7_t3',
36 | 'mbconv_k7_t6',
37 | ]
38 |
39 | __C.optim=AttrDict()
40 | __C.optim.init_dim=16
41 | __C.optim.head_dim=16
42 | __C.optim.last_dim=1984
43 | __C.optim.weight=AttrDict()
44 | __C.optim.weight.init_lr=0.1
45 | __C.optim.weight.min_lr=1e-4
46 | __C.optim.weight.lr_decay_type='cosine'
47 | __C.optim.weight.momentum=0.9
48 | __C.optim.weight.weight_decay=4e-5
49 |
50 | __C.optim.arch=AttrDict()
51 | __C.optim.arch.alpha_lr=3e-4
52 | __C.optim.arch.beta_lr=3e-4
53 | __C.optim.arch.weight_decay=1e-3
54 |
55 | __C.optim.if_sub_obj=True
56 | __C.optim.sub_obj=AttrDict()
57 | __C.optim.sub_obj.type='latency' # latency / flops
58 | __C.optim.sub_obj.skip_reg=False
59 | __C.optim.sub_obj.log_base=15.0
60 | __C.optim.sub_obj.sub_loss_factor=0.1
61 | __C.optim.sub_obj.latency_list_path=''
62 |
63 |
--------------------------------------------------------------------------------
/run_apis/derive_arch.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 |
4 | class BaseArchGenerate(object):
5 | def __init__(self, super_network, config):
6 | self.config = config
7 | self.num_blocks = len(super_network.block_chs) # including the input 32 and 1280 block
8 | self.super_chs = super_network.block_chs
9 | self.input_configs = super_network.input_configs
10 |
11 |
12 | def update_arch_params(self, betas, head_alphas, stack_alphas):
13 | self.betas = betas
14 | self.head_alphas = head_alphas
15 | self.stack_alphas = stack_alphas
16 |
17 |
18 | def derive_chs(self):
19 | """
20 | using viterbi algorithm to choose the best path of the super net
21 | """
22 | path_p_max = [] # [[max_last_state_id, trans_prob], ...]
23 |
24 | path_p_max.append([0, 1])
25 |
26 | for input_config in self.input_configs:
27 | block_path_prob_max = [None, 0]
28 | for in_block_id, beta_id in zip(input_config['in_block_idx'], input_config['beta_idx']):
29 | path_prob = path_p_max[in_block_id][1]*self.betas[in_block_id][beta_id]
30 | if path_prob > block_path_prob_max[1]:
31 | block_path_prob_max = [in_block_id ,path_prob]
32 |
33 | path_p_max.append(block_path_prob_max)
34 |
35 | ch_idx = len(path_p_max) - 1
36 | ch_path = []
37 | ch_path.append(ch_idx)
38 |
39 | while 1:
40 | ch_idx = path_p_max[ch_idx][0]
41 | ch_path.append(ch_idx)
42 | if ch_idx == 0:
43 | break
44 |
45 | derived_chs = [self.super_chs[ch_id] for ch_id in ch_path]
46 |
47 | ch_path = ch_path[::-1]
48 | derived_chs = derived_chs[::-1]
49 |
50 | return ch_path, derived_chs
51 |
52 |
53 | def derive_ops(self, alpha, alpha_type):
54 | assert alpha_type in ['head', 'stack']
55 |
56 | if alpha_type == 'head':
57 | op_type = self.config.search_params.PRIMITIVES_head[alpha.index(max(alpha))]
58 | elif alpha_type == 'stack':
59 | op_type = self.config.search_params.PRIMITIVES_stack[alpha.index(max(alpha))]
60 |
61 | return op_type
62 |
63 |
64 | def derive_archs(self, betas, head_alphas, stack_alphas, if_display=True):
65 | raise NotImplementedError
--------------------------------------------------------------------------------
/configs/imagenet_search_cfg_mbv2.yaml:
--------------------------------------------------------------------------------
1 | net_type: mbv2
2 | train_params:
3 | epochs: 150
4 | use_seed: True
5 | seed: 2
6 |
7 | search_params:
8 | arch_update_epoch: 50
9 | val_start_epoch: 120
10 | sample_policy: prob # prob uniform
11 | weight_sample_num: 1
12 | softmax_temp: 1.
13 |
14 | PRIMITIVES_stack: ['mbconv_k3_t3',
15 | 'mbconv_k3_t6',
16 | 'mbconv_k5_t3',
17 | 'mbconv_k5_t6',
18 | 'mbconv_k7_t3',
19 | 'mbconv_k7_t6',
20 | 'skip_connect',]
21 | PRIMITIVES_head: ['mbconv_k3_t3',
22 | 'mbconv_k3_t6',
23 | 'mbconv_k5_t3',
24 | 'mbconv_k5_t6',
25 | 'mbconv_k7_t3',
26 | 'mbconv_k7_t6',
27 | ]
28 | # search space
29 | adjoin_connect_nums: [4, 4, 4, 4, 4, 4, 4]
30 | net_scale:
31 | chs: [16, 24, 32, 40, 48, 56, 64, 72, 96, 112, 128, 160, 176, 192, 320, 352, 384]
32 | fm_sizes: [112, 56, 28, 28, 28, 14, 14, 14, 14, 14, 14, 7, 7, 7, 7, 7, 7]
33 | stage: [0, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7]
34 | num_layers: [0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 0, 0]
35 |
36 | optim:
37 | if_sub_obj: True
38 | sub_obj:
39 | type: latency
40 | skip_reg: False
41 | log_base: 15.0
42 | sub_loss_factor: 0.15
43 | latency_list_path: lat_list_densenas_mbv2_xp32
44 |
45 | if_resume: False
46 | resume:
47 | load_path: ''
48 | load_epoch: 49
49 |
50 | init_dim: 16
51 | head_dim: 16
52 | last_dim: 1984
53 | bn_momentum: 0.1
54 | bn_eps: 0.001
55 | weight:
56 | init_lr: 0.2
57 | min_lr: 0.0001
58 | lr_decay_type: cosine
59 | momentum: 0.9
60 | weight_decay: 0.00004
61 | arch:
62 | alpha_lr: 0.0003
63 | beta_lr: 0.0003
64 | weight_decay: 0.001
65 |
66 | data:
67 | dataset: imagenet
68 | batch_size: 352
69 | num_workers: 16
70 | train_data_type: lmdb
71 | val_data_type: lmdb
72 | input_size: (3,224,224)
73 | type_of_data_aug: random_sized # random_sized / rand_scale
74 | color: False
75 | random_sized:
76 | min_scale: 0.08
77 |
--------------------------------------------------------------------------------
/models/search_space_res.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from .operations import OPS
4 | from .search_space_base import Conv1_1_Block, Block
5 | from .search_space_base import Network as BaseSearchSpace
6 |
7 | class Network(BaseSearchSpace):
8 | def __init__(self, init_ch, dataset, config, groups=1, base_width=64, dilation=1, norm_layer=None):
9 | super(Network, self).__init__(init_ch, dataset, config)
10 |
11 | if norm_layer is None:
12 | norm_layer = nn.BatchNorm2d
13 | if groups != 1 or base_width != 64:
14 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
15 | if dilation > 1:
16 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
17 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
18 | self.input_block = nn.Sequential(
19 | nn.Conv2d(3, self._C_input, kernel_size=3, stride=2, padding=1, bias=False),
20 | norm_layer(self._C_input),
21 | nn.ReLU(inplace=True),
22 | )
23 |
24 | self.blocks = nn.ModuleList()
25 |
26 | for i in range(self.num_blocks):
27 | input_config = self.input_configs[i]
28 | self.blocks.append(Block(
29 | input_config['in_chs'],
30 | input_config['ch'],
31 | input_config['strides'],
32 | input_config['num_stack_layers'],
33 | self.config
34 | ))
35 |
36 | if 'bottle_neck' in self.config.search_params.PRIMITIVES_stack:
37 | conv1_1_input_dim = [ch * 4 for ch in self.input_configs[-1]['in_chs']]
38 | last_dim = self.config.optim.last_dim * 4
39 | else:
40 | conv1_1_input_dim = self.input_configs[-1]['in_chs']
41 | last_dim = self.config.optim.last_dim
42 | self.conv1_1_block = Conv1_1_Block(conv1_1_input_dim, last_dim)
43 | self.global_pooling = nn.AdaptiveAvgPool2d((1, 1))
44 | self.classifier = nn.Linear(last_dim, self._num_classes)
45 |
46 | for m in self.modules():
47 | if isinstance(m, nn.Conv2d):
48 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
49 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
50 | if m.affine==True:
51 | nn.init.constant_(m.weight, 1)
52 | nn.init.constant_(m.bias, 0)
53 |
54 |
--------------------------------------------------------------------------------
/configs/imagenet_train_cfg.py:
--------------------------------------------------------------------------------
1 | from tools.collections import AttrDict
2 |
3 | __C = AttrDict()
4 |
5 | cfg = __C
6 |
7 | __C.net_type='mbv2' # mbv2 / res
8 | __C.net_config="""[[16, 16], 'mbconv_k3_t1', [], 0, 1]|
9 | [[16, 24], 'mbconv_k3_t6', [], 0, 2]|
10 | [[24, 48], 'mbconv_k7_t6', ['mbconv_k3_t3'], 1, 2]|
11 | [[48, 72], 'mbconv_k5_t6', ['mbconv_k3_t6', 'mbconv_k3_t3'], 2, 2]|
12 | [[72, 128], 'mbconv_k3_t6', ['mbconv_k3_t3', 'mbconv_k3_t3'], 2, 1]|
13 | [[128, 160], 'mbconv_k3_t6', ['mbconv_k7_t3', 'mbconv_k5_t6', 'mbconv_k7_t3'], 3, 2]|
14 | [[160, 176], 'mbconv_k3_t3', ['mbconv_k3_t6', 'mbconv_k3_t6', 'mbconv_k3_t6'], 3, 1]|
15 | [[176, 384], 'mbconv_k7_t6', [], 0, 1]|
16 | [[384, 1984], 'conv1_1']"""
17 |
18 | __C.train_params=AttrDict()
19 | __C.train_params.epochs=240
20 | __C.train_params.use_seed=True
21 | __C.train_params.seed=0
22 |
23 | __C.optim=AttrDict()
24 | __C.optim.init_lr=0.5
25 | __C.optim.min_lr=1e-5
26 | __C.optim.lr_schedule='cosine' # cosine poly
27 | __C.optim.momentum=0.9
28 | __C.optim.weight_decay=4e-5
29 | __C.optim.use_grad_clip=False
30 | __C.optim.grad_clip=10
31 | __C.optim.label_smooth=True
32 | __C.optim.smooth_alpha=0.1
33 | __C.optim.init_mode='he_fout'
34 |
35 | __C.optim.if_resume=False
36 | __C.optim.resume=AttrDict()
37 | __C.optim.resume.load_path=''
38 | __C.optim.resume.load_epoch=0
39 |
40 | __C.optim.use_warm_up=False
41 | __C.optim.warm_up=AttrDict()
42 | __C.optim.warm_up.epoch=5
43 | __C.optim.warm_up.init_lr=0.0001
44 | __C.optim.warm_up.target_lr=0.1
45 |
46 | __C.optim.use_multi_stage=False
47 | __C.optim.multi_stage=AttrDict()
48 | __C.optim.multi_stage.stage_epochs=330
49 |
50 | __C.optim.cosine=AttrDict()
51 | __C.optim.cosine.use_restart=False
52 | __C.optim.cosine.restart=AttrDict()
53 | __C.optim.cosine.restart.lr_period=[10, 20, 40, 80, 160, 320]
54 | __C.optim.cosine.restart.lr_step=[0, 10, 30, 70, 150, 310]
55 |
56 | __C.optim.bn_momentum=0.1
57 | __C.optim.bn_eps=0.001
58 |
59 | __C.data=AttrDict()
60 | __C.data.num_workers=16
61 | __C.data.batch_size=1024
62 | __C.data.dataset='imagenet' #imagenet
63 | __C.data.train_data_type='lmdb'
64 | __C.data.val_data_type='img'
65 | __C.data.patch_dataset=False
66 | # __C.data.num_examples=1281167
67 | __C.data.input_size=(3,224,224)
68 | __C.data.type_of_data_aug='random_sized' # random_sized / rand_scale
69 | __C.data.random_sized=AttrDict()
70 | __C.data.random_sized.min_scale=0.08
71 | __C.data.mean=[0.485, 0.456, 0.406]
72 | __C.data.std=[0.229, 0.224, 0.225]
73 | __C.data.color=False
--------------------------------------------------------------------------------
/tools/collections.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-present, Facebook, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | ##############################################################################
15 |
16 | """A simple attribute dictionary used for representing configuration options."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 | from __future__ import unicode_literals
22 |
23 |
24 | class AttrDict(dict):
25 |
26 | IMMUTABLE = '__immutable__'
27 |
28 | def __init__(self, *args, **kwargs):
29 | super(AttrDict, self).__init__(*args, **kwargs)
30 | self.__dict__[AttrDict.IMMUTABLE] = False
31 |
32 | def __getattr__(self, name):
33 | if name in self.__dict__:
34 | return self.__dict__[name]
35 | elif name in self:
36 | return self[name]
37 | else:
38 | raise AttributeError(name)
39 |
40 | def __setattr__(self, name, value):
41 | if not self.__dict__[AttrDict.IMMUTABLE]:
42 | if name in self.__dict__:
43 | self.__dict__[name] = value
44 | else:
45 | self[name] = value
46 | else:
47 | raise AttributeError(
48 | 'Attempted to set "{}" to "{}", but AttrDict is immutable'.
49 | format(name, value)
50 | )
51 |
52 | def immutable(self, is_immutable):
53 | """Set immutability to is_immutable and recursively apply the setting
54 | to all nested AttrDicts.
55 | """
56 | self.__dict__[AttrDict.IMMUTABLE] = is_immutable
57 | # Recursively set immutable state
58 | for v in self.__dict__.values():
59 | if isinstance(v, AttrDict):
60 | v.immutable(is_immutable)
61 | for v in self.values():
62 | if isinstance(v, AttrDict):
63 | v.immutable(is_immutable)
64 |
65 | def is_immutable(self):
66 | return self.__dict__[AttrDict.IMMUTABLE]
67 |
--------------------------------------------------------------------------------
/run_apis/latency_measure.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import importlib
3 | import logging
4 | import os
5 | import sys
6 |
7 | import torch
8 | import torch.backends.cudnn as cudnn
9 |
10 | from configs.imagenet_train_cfg import cfg
11 | from configs.search_config import search_cfg
12 | from tools import utils
13 | from tools.config_yaml import merge_cfg_from_file, update_cfg_from_cfg
14 |
15 | if __name__ == '__main__':
16 |
17 | parser = argparse.ArgumentParser("Params")
18 | parser.add_argument('--save', type=str, default='./', help='experiment name')
19 | parser.add_argument('--input_size', type=str, default='[32, 3, 224, 224]', help='data input size')
20 | parser.add_argument('--meas_times', type=int, default=5000, help='measure times')
21 | parser.add_argument('--list_name', type=str, default='', help='output list name')
22 | parser.add_argument('--device', choices=['gpu', 'cpu'])
23 | parser.add_argument('-c', '--config', metavar='C', default=None, help='The Configuration file')
24 |
25 | args = parser.parse_args()
26 |
27 | update_cfg_from_cfg(search_cfg, cfg)
28 | if args.config is not None:
29 | merge_cfg_from_file(args.config, cfg)
30 | config = cfg
31 |
32 | args.save = os.path.join(args.save, 'output')
33 | utils.create_exp_dir(args.save)
34 |
35 | args.input_size = eval(args.input_size)
36 | if len(args.input_size) != 4:
37 | raise ValueError('The batch size should be specified.')
38 |
39 | log_format = '%(asctime)s %(message)s'
40 | logging.basicConfig(stream=sys.stdout, level=logging.INFO,
41 | format=log_format, datefmt='%m/%d %I:%M:%S %p')
42 | fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
43 | fh.setFormatter(logging.Formatter(log_format))
44 | logging.getLogger().addHandler(fh)
45 |
46 | if not torch.cuda.is_available():
47 | logging.info('no gpu device available')
48 | sys.exit(1)
49 |
50 | cudnn.benchmark = True
51 | cudnn.enabled = True
52 |
53 | SearchSpace = importlib.import_module('models.search_space_'+config.net_type).Network
54 | super_model = SearchSpace(config.optim.init_dim, config.data.dataset, config)
55 |
56 | super_model.eval()
57 | logging.info("Params = %fMB" % utils.count_parameters_in_MB(super_model))
58 |
59 | if args.device == 'gpu':
60 | super_model = super_model.cuda()
61 |
62 | latency_list, total_latency = super_model.get_cost_list(
63 | args.input_size, cost_type='latency',
64 | use_gpu = (args.device == 'gpu'),
65 | meas_times = args.meas_times)
66 |
67 | logging.info('latency_list:\n' + str(latency_list))
68 | logging.info('total latency: ' + str(total_latency) + 'ms')
69 |
70 | with open(os.path.join(args.save, args.list_name), 'w') as f:
71 | f.write(str(latency_list))
72 |
--------------------------------------------------------------------------------
/run_apis/validation.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import os
4 | import pprint
5 | import sys
6 | import time
7 |
8 | import torch
9 | import torch.backends.cudnn as cudnn
10 | import torch.nn as nn
11 |
12 | from configs.imagenet_val_cfg import cfg
13 | from dataset import imagenet_data
14 | from models import model_derived
15 | from tools import utils
16 | from tools.multadds_count import comp_multadds
17 |
18 | from .trainer import Trainer
19 |
20 | if __name__ == '__main__':
21 |
22 | parser = argparse.ArgumentParser("Params")
23 | parser.add_argument('--report_freq', type=float, default=50, help='report frequency')
24 | parser.add_argument('--data_path', type=str, default='../data', help='location of the dataset')
25 | parser.add_argument('--load_path', type=str, default='./model_path', help='model loading path')
26 | parser.add_argument('--save', type=str, default='./', help='the path of output')
27 |
28 | args = parser.parse_args()
29 | config = cfg
30 |
31 | args.save = os.path.join(args.save, 'output')
32 | utils.create_exp_dir(args.save)
33 |
34 | log_format = '%(asctime)s %(message)s'
35 | logging.basicConfig(stream=sys.stdout, level=logging.INFO,
36 | format=log_format, datefmt='%m/%d %I:%M:%S %p')
37 | fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
38 | fh.setFormatter(logging.Formatter(log_format))
39 | logging.getLogger().addHandler(fh)
40 |
41 | if not torch.cuda.is_available():
42 | logging.info('no gpu device available')
43 | sys.exit(1)
44 |
45 | cudnn.benchmark = True
46 | cudnn.enabled = True
47 |
48 | logging.info("args = %s", args)
49 | logging.info('Training with config:')
50 | logging.info(pprint.pformat(config))
51 |
52 | config.net_config, net_type = utils.load_net_config(os.path.join(args.load_path, 'net_config'))
53 |
54 | derivedNetwork = getattr(model_derived, '%s_Net' % net_type.upper())
55 | model = derivedNetwork(config.net_config, config=config)
56 |
57 | logging.info("Network Structure: \n" + '\n'.join(map(str, model.net_config)))
58 | logging.info("Params = %.2fMB" % utils.count_parameters_in_MB(model))
59 | logging.info("Mult-Adds = %.2fMB" % comp_multadds(model, input_size=config.data.input_size))
60 |
61 | model = model.cuda()
62 | model = nn.DataParallel(model)
63 | utils.load_model(model, os.path.join(args.load_path, 'weights.pt'))
64 |
65 | imagenet = imagenet_data.ImageNet12(trainFolder=os.path.join(args.data_path, 'train'),
66 | testFolder=os.path.join(args.data_path, 'val'),
67 | num_workers=config.data.num_workers,
68 | data_config=config.data)
69 | valid_queue = imagenet.getTestLoader(config.data.batch_size)
70 | trainer = Trainer(None, valid_queue, None, None,
71 | None, config, args.report_freq)
72 |
73 | with torch.no_grad():
74 | val_acc_top1, val_acc_top5, valid_obj, batch_time = trainer.infer(model)
75 |
--------------------------------------------------------------------------------
/dataset/img2lmdb.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import os.path as osp
4 | import sys
5 |
6 | import cv2
7 | import lmdb
8 | import msgpack
9 | import numpy as np
10 | from PIL import Image
11 | from tqdm import tqdm
12 |
13 | image_size=None
14 |
15 | class Datum(object):
16 | def __init__(self, shape=None, image=None, label=None):
17 | self.shape = shape
18 | self.image = image
19 | self.label = label
20 |
21 | def SerializeToString(self, img=None):
22 | image_data = self.image.astype(np.uint8).tobytes()
23 | label_data = np.uint16(self.label).tobytes()
24 | return msgpack.packb(image_data+label_data, use_bin_type=True)
25 |
26 | def ParseFromString(self, raw_data, orig_img):
27 | raw_data = msgpack.unpackb(raw_data, raw=False)
28 | raw_img_data = raw_data[:-2]
29 | # share the memory of data while fromstring copy one
30 | image_data = np.frombuffer(raw_img_data, dtype=np.uint8)
31 | self.image = cv2.imdecode(image_data, cv2.IMREAD_COLOR)
32 |
33 | raw_label_data = raw_data[-2:]
34 | self.label = np.frombuffer(raw_label_data, dtype=np.uint16)
35 |
36 |
37 | def create_dataset(output_path, image_folder, image_list, image_size):
38 | image_name_list = [i.strip() for i in open(image_list)]
39 | n_samples = len(image_name_list)
40 | env = lmdb.open(output_path, map_size=1099511627776, meminit=False, map_async=True) # 1TB
41 |
42 | txn = env.begin(write=True)
43 | classes = [d for d in os.listdir(image_folder) if os.path.isdir(os.path.join(image_folder, d))]
44 | for idx, image_name in enumerate(tqdm(image_name_list)):
45 | image_path = os.path.join(image_folder, image_name)
46 | label_name = image_name.split('/')[0]
47 | label = classes.index(label_name)
48 | if not os.path.isfile(image_path):
49 | raise RuntimeError('%s does not exist' % image_path)
50 |
51 | img = cv2.imread(image_path)
52 | img_orig = img
53 |
54 | if image_size:
55 | resize_ratio = float(image_size)/min(img.shape[0:2])
56 | new_size = (int(img.shape[1]*resize_ratio), int(img.shape[0]*resize_ratio)) #inverse order for cv2
57 | img = cv2.resize(src=img, dsize=new_size)
58 | img = cv2.imencode('.JPEG', img)[1]
59 |
60 | image = np.asarray(img)
61 | datum = Datum(image.shape, image, label)
62 | txn.put(image_name.encode('ascii'), datum.SerializeToString())
63 |
64 | if (idx + 1) % 1000 == 0:
65 | txn.commit()
66 | txn = env.begin(write=True)
67 | txn.commit()
68 | env.sync()
69 | env.close()
70 |
71 | print(f'Created dataset with {n_samples:d} samples')
72 |
73 |
74 | if __name__ == '__main__':
75 | parser = argparse.ArgumentParser("Params")
76 | parser.add_argument('--image_size', type=int, default=None, help='the size of the image u want to pack')
77 | parser.add_argument('--image_path', type=str, default='', help='the path of the images')
78 | parser.add_argument('--list_path', type=str, default='', help='the path of the image list')
79 | parser.add_argument('--output_path', type=str, default='', help='the output path of the lmdb file')
80 | parser.add_argument('--split', type=str, default='', help='the split path of the images: train / val')
81 | args = parser.parse_args()
82 |
83 | image_size = args.image_size if args.image_size else image_size
84 | image_folder = osp.join(args.image_path, args.split)
85 | image_list = osp.join(args.list_path, '{}_datalist'.format(args.split))
86 |
87 | if image_size:
88 | output_path = osp.join(args.output_path, '{}_{}'.format(args.split, image_size))
89 | else:
90 | output_path = osp.join(args.output_path, args.split)
91 |
92 | create_dataset(output_path, image_folder, image_list, image_size)
93 |
--------------------------------------------------------------------------------
/dataset/prefetch_data.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import time
4 |
5 | class data_prefetcher():
6 | def __init__(self, loader, mean=None, std=None, is_cutout=False, cutout_length=16):
7 | self.loader = iter(loader)
8 | self.stream = torch.cuda.Stream()
9 | if mean is None:
10 | self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)
11 | else:
12 | self.mean = torch.tensor([m * 255 for m in mean]).cuda().view(1,3,1,1)
13 | if std is None:
14 | self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)
15 | else:
16 | self.std = torch.tensor([s * 255 for s in std]).cuda().view(1,3,1,1)
17 | self.is_cutout = is_cutout
18 | self.cutout_length = cutout_length
19 | self.preload()
20 |
21 | def normalize(self, data):
22 | data = data.float()
23 | data = data.sub_(self.mean).div_(self.std)
24 | return data
25 |
26 | def cutout(self, data):
27 | batch_size, h, w = data.shape[0], data.shape[2], data.shape[3]
28 | mask = torch.ones(batch_size, h, w).cuda()
29 | y = torch.randint(low=0, high=h, size=(batch_size,))
30 | x = torch.randint(low=0, high=w, size=(batch_size,))
31 |
32 | y1 = torch.clamp(y - self.cutout_length // 2, 0, h)
33 | y2 = torch.clamp(y + self.cutout_length // 2, 0, h)
34 | x1 = torch.clamp(x - self.cutout_length // 2, 0, w)
35 | x2 = torch.clamp(x + self.cutout_length // 2, 0, w)
36 | for i in range(batch_size):
37 | mask[i][y1[i]: y2[i], x1[i]: x2[i]] = 0.
38 | mask = mask.expand_as(data.transpose(0,1)).transpose(0,1)
39 | data *= mask
40 | return data
41 |
42 | def preload(self):
43 | try:
44 | self.next_input, self.next_target = next(self.loader)
45 | except StopIteration:
46 | self.next_input = None
47 | self.next_target = None
48 | return
49 | # if record_stream() doesn't work, another option is to make sure device inputs are created
50 | # on the main stream.
51 | # self.next_input_gpu = torch.empty_like(self.next_input, device='cuda')
52 | # self.next_target_gpu = torch.empty_like(self.next_target, device='cuda')
53 | # Need to make sure the memory allocated for next_* is not still in use by the main stream
54 | # at the time we start copying to next_*:
55 | # self.stream.wait_stream(torch.cuda.current_stream())
56 | with torch.cuda.stream(self.stream):
57 | self.next_input = self.next_input.cuda(non_blocking=True)
58 | self.next_target = self.next_target.cuda(non_blocking=True)
59 | self.next_input = self.normalize(self.next_input)
60 | if self.is_cutout:
61 | self.next_input = self.cutout(self.next_input)
62 |
63 | def next(self):
64 | torch.cuda.current_stream().wait_stream(self.stream)
65 | input = self.next_input
66 | target = self.next_target
67 | if input is not None:
68 | input.record_stream(torch.cuda.current_stream())
69 | if target is not None:
70 | target.record_stream(torch.cuda.current_stream())
71 | self.preload()
72 | return input, target
73 |
74 |
75 | def fast_collate(batch):
76 | imgs = [img[0] for img in batch]
77 | targets = torch.tensor([target[1] for target in batch], dtype=torch.int64)
78 | w = imgs[0].size[0]
79 | h = imgs[0].size[1]
80 | tensor = torch.zeros((len(imgs), 3, h, w), dtype=torch.uint8)
81 | for i, img in enumerate(imgs):
82 | nump_array = np.asarray(img, dtype=np.uint8)
83 | if(nump_array.ndim < 3):
84 | nump_array = np.expand_dims(nump_array, axis=-1)
85 | nump_array = np.rollaxis(nump_array, 2)
86 |
87 | tensor[i] += torch.from_numpy(nump_array)
88 |
89 | return tensor, targets
90 |
--------------------------------------------------------------------------------
/tools/io.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-present, Facebook, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | ##############################################################################
15 |
16 | """IO utilities."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 | from __future__ import unicode_literals
22 |
23 | import pickle
24 | import hashlib
25 | import logging
26 | import os
27 | import re
28 | import sys
29 | import urllib
30 |
31 | logger = logging.getLogger(__name__)
32 |
33 | _DETECTRON_S3_BASE_URL = 'https://s3-us-west-2.amazonaws.com/detectron'
34 |
35 |
36 | def save_object(obj, file_name):
37 | """Save a Python object by pickling it."""
38 | file_name = os.path.abspath(file_name)
39 | with open(file_name, 'wb') as f:
40 | pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
41 |
42 |
43 | def cache_url(url_or_file, cache_dir):
44 | """Download the file specified by the URL to the cache_dir and return the
45 | path to the cached file. If the argument is not a URL, simply return it as
46 | is.
47 | """
48 | is_url = re.match(r'^(?:http)s?://', url_or_file, re.IGNORECASE) is not None
49 |
50 | if not is_url:
51 | return url_or_file
52 |
53 | url = url_or_file
54 | assert url.startswith(_DETECTRON_S3_BASE_URL), \
55 | ('Detectron only automatically caches URLs in the Detectron S3 '
56 | 'bucket: {}').format(_DETECTRON_S3_BASE_URL)
57 |
58 | cache_file_path = url.replace(_DETECTRON_S3_BASE_URL, cache_dir)
59 | if os.path.exists(cache_file_path):
60 | assert_cache_file_is_ok(url, cache_file_path)
61 | return cache_file_path
62 |
63 | cache_file_dir = os.path.dirname(cache_file_path)
64 | if not os.path.exists(cache_file_dir):
65 | os.makedirs(cache_file_dir)
66 |
67 | logger.info('Downloading remote file {} to {}'.format(url, cache_file_path))
68 | download_url(url, cache_file_path)
69 | assert_cache_file_is_ok(url, cache_file_path)
70 | return cache_file_path
71 |
72 |
73 | def assert_cache_file_is_ok(url, file_path):
74 | """Check that cache file has the correct hash."""
75 | # File is already in the cache, verify that the md5sum matches and
76 | # return local path
77 | cache_file_md5sum = _get_file_md5sum(file_path)
78 | ref_md5sum = _get_reference_md5sum(url)
79 | assert cache_file_md5sum == ref_md5sum, \
80 | ('Target URL {} appears to be downloaded to the local cache file '
81 | '{}, but the md5 hash of the local file does not match the '
82 | 'reference (actual: {} vs. expected: {}). You may wish to delete '
83 | 'the cached file and try again to trigger automatic '
84 | 'download.').format(url, file_path, cache_file_md5sum, ref_md5sum)
85 |
86 |
87 | def _progress_bar(count, total):
88 | """Report download progress.
89 | Credit:
90 | https://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console/27871113
91 | """
92 | bar_len = 60
93 | filled_len = int(round(bar_len * count / float(total)))
94 |
95 | percents = round(100.0 * count / float(total), 1)
96 | bar = '=' * filled_len + '-' * (bar_len - filled_len)
97 |
98 | sys.stdout.write(
99 | ' [{}] {}% of {:.1f}MB file \r'.
100 | format(bar, percents, total / 1024 / 1024)
101 | )
102 | sys.stdout.flush()
103 | if count >= total:
104 | sys.stdout.write('\n')
105 |
106 |
107 | def download_url(
108 | url, dst_file_path, chunk_size=8192, progress_hook=_progress_bar
109 | ):
110 | """Download url and write it to dst_file_path.
111 | Credit:
112 | https://stackoverflow.com/questions/2028517/python-urllib2-progress-hook
113 | """
114 | response = urllib.request.urlopen(url)
115 | total_size = response.info().getheader('Content-Length').strip()
116 | total_size = int(total_size)
117 | bytes_so_far = 0
118 |
119 | with open(dst_file_path, 'wb') as f:
120 | while 1:
121 | chunk = response.read(chunk_size)
122 | bytes_so_far += len(chunk)
123 | if not chunk:
124 | break
125 | if progress_hook:
126 | progress_hook(bytes_so_far, total_size)
127 | f.write(chunk)
128 |
129 | return bytes_so_far
130 |
131 |
132 | def _get_file_md5sum(file_name):
133 | """Compute the md5 hash of a file."""
134 | hash_obj = hashlib.md5()
135 | with open(file_name, 'r') as f:
136 | hash_obj.update(f.read())
137 | return hash_obj.hexdigest()
138 |
139 |
140 | def _get_reference_md5sum(url):
141 | """By convention the md5 hash for url is stored in url + '.md5sum'."""
142 | url_md5sum = url + '.md5sum'
143 | md5sum = urllib.request.urlopen(url_md5sum).read().strip()
144 | return md5sum
145 |
--------------------------------------------------------------------------------
/dataset/lmdb_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import cv2
4 | import msgpack
5 | import numpy as np
6 | import torch.utils.data as data
7 | from PIL import Image
8 |
9 | import lmdb
10 |
11 |
12 | class Datum(object):
13 | def __init__(self, shape=None, image=None, label=None):
14 | self.shape = shape
15 | self.image = image
16 | self.label = label
17 |
18 | def SerializeToString(self):
19 | image_data = self.image.astype(np.uint8).tobytes()
20 | label_data = np.uint16(self.label).tobytes()
21 | return msgpack.packb(image_data+label_data, use_bin_type=True)
22 |
23 | def ParseFromString(self, raw_data):
24 | raw_data = msgpack.unpackb(raw_data, raw=False)
25 | raw_img_data = raw_data[:-2]
26 | image_data = np.frombuffer(raw_img_data, dtype=np.uint8) #share the memory of data while from string copy one
27 | self.image = cv2.imdecode(image_data, cv2.IMREAD_COLOR)
28 |
29 | raw_label_data = raw_data[-2:]
30 | self.label = np.frombuffer(raw_label_data, dtype=np.uint16)
31 |
32 |
33 | class DatasetFolder(data.Dataset):
34 | """
35 | Args:
36 | root (string): Root directory path.
37 | transform (callable, optional): A function/transform that takes in
38 | a sample and returns a transformed version.
39 | E.g, ``transforms.RandomCrop`` for images.
40 | target_transform (callable, optional): A function/transform that takes
41 | in the target and transforms it.
42 |
43 | Attributes:
44 | samples (list): List of (sample path, class_index) tuples
45 | """
46 |
47 | def __init__(self, root, list_path, transform=None, target_transform=None, patch_dataset=False):
48 | self.root = root
49 | self.patch_dataset = patch_dataset
50 |
51 | if patch_dataset:
52 | self.txn = []
53 | for path in os.listdir(root):
54 | lmdb_path = os.path.join(root, path)
55 | if os.path.isdir(lmdb_path):
56 | env = lmdb.open(lmdb_path,
57 | readonly=True,
58 | lock=False,
59 | readahead=False,
60 | meminit=False)
61 | txn = env.begin(write=False)
62 | self.txn.append(txn)
63 |
64 | else:
65 | self.env = lmdb.open(root,
66 | readonly=True,
67 | lock=False,
68 | readahead=False,
69 | meminit=False)
70 | self.txn = self.env.begin(write=False)
71 |
72 | self.list_path = list_path
73 | self.samples = [image_name.strip() for image_name in open(list_path)]
74 |
75 | if len(self.samples) == 0:
76 | raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n"))
77 |
78 | self.transform = transform
79 | self.target_transform = target_transform
80 |
81 | def __getitem__(self, index):
82 | """
83 | Args:
84 | index (int): Index
85 |
86 | Returns:
87 | tuple: (sample, target) where target is class_index of the target class.
88 | """
89 |
90 | img_name = self.samples[index]
91 |
92 | if self.patch_dataset:
93 | txn_index = index // (len(self.samples) // 10)
94 | if txn_index==10:
95 | txn_index = 9
96 | txn = self.txn[txn_index]
97 | else:
98 | txn = self.txn
99 |
100 | datum = Datum()
101 | data_bin = txn.get(img_name.encode('ascii'))
102 | if data_bin is None:
103 | raise RuntimeError(f'Key {img_name} not found')
104 | datum.ParseFromString(data_bin)
105 |
106 | sample = Image.fromarray(cv2.cvtColor(datum.image, cv2.COLOR_BGR2RGB))
107 | target = np.int(datum.label)
108 |
109 | if self.transform is not None:
110 | sample = self.transform(sample)
111 | if self.target_transform is not None:
112 | target = self.target_transform(target)
113 |
114 | return sample, target
115 |
116 | def __len__(self):
117 | return len(self.samples)
118 |
119 | def __repr__(self):
120 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
121 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
122 | fmt_str += ' Root Location: {}\n'.format(self.root)
123 | tmp = ' Transforms (if any): '
124 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
125 | tmp = ' Target Transforms (if any): '
126 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
127 | return fmt_str
128 |
129 |
130 | class ImageFolder(DatasetFolder):
131 | def __init__(self, root, list_path, transform=None, target_transform=None, patch_dataset=False):
132 | super(ImageFolder, self).__init__(root, list_path,
133 | transform=transform,
134 | target_transform=target_transform,
135 | patch_dataset=patch_dataset)
136 | self.imgs = self.samples
137 |
--------------------------------------------------------------------------------
/tools/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.optim.lr_scheduler import CosineAnnealingLR
3 | from torch.optim.optimizer import Optimizer
4 | import math
5 |
6 | class CosineRestartAnnealingLR(object):
7 | # decay as step
8 | # T_max refers to the max update step
9 |
10 | def __init__(self, optimizer, T_max, lr_period, lr_step, eta_min=0, last_step=-1,
11 | use_warmup=False, warmup_mode='linear', warmup_steps=0, warmup_startlr=0,
12 | warmup_targetlr=0, use_restart=False):
13 |
14 | self.use_warmup = use_warmup
15 | self.warmup_mode = warmup_mode
16 | self.warmup_steps = warmup_steps
17 | self.warmup_startlr = warmup_startlr
18 | self.warmup_targetlr = warmup_targetlr
19 | self.use_restart = use_restart
20 | self.T_max = T_max
21 | self.eta_min = eta_min
22 |
23 | if self.use_restart == False:
24 | self.lr_period = [self.T_max - self.warmup_steps]
25 | self.lr_step = [self.warmup_steps]
26 | else:
27 | self.lr_period = lr_period
28 | self.lr_step = lr_step
29 |
30 | self.last_step = last_step
31 | self.cycle_length = self.lr_period[0]
32 | self.cur = 0
33 |
34 | if not isinstance(optimizer, Optimizer):
35 | raise TypeError('{} is not an Optimizer'.format(
36 | type(optimizer).__name__))
37 | self.optimizer = optimizer
38 | if last_step == -1:
39 | for group in optimizer.param_groups:
40 | group.setdefault('initial_lr', group['lr'])
41 | else:
42 | for i, group in enumerate(optimizer.param_groups):
43 | if 'initial_lr' not in group:
44 | raise KeyError("param 'initial_lr' is not specified "
45 | "in param_groups[{}] when resuming an optimizer".format(i))
46 | self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
47 |
48 |
49 | def step(self, step=None):
50 |
51 | if step is not None:
52 | self.last_step = step
53 | else:
54 | self.last_step += 1
55 |
56 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
57 | param_group['lr'] = lr
58 |
59 |
60 | def get_lr(self):
61 |
62 | lrs = []
63 | for base_lr in self.base_lrs:
64 | if self.use_warmup and self.last_step < self.warmup_steps:
65 | if self.warmup_mode == 'constant':
66 | lrs.append(self.warmup_startlr)
67 | elif self.warmup_mode =='linear':
68 | cur_lr = self.warmup_startlr + \
69 | float(self.warmup_targetlr-self.warmup_startlr)/self.warmup_steps*self.last_step
70 | lrs.append(cur_lr)
71 | else:
72 | raise NotImplementedError
73 |
74 | else:
75 | if (self.last_step) in self.lr_step:
76 | self.cycle_length = self.lr_period[self.lr_step.index(self.last_step)]
77 | self.cur = self.last_step
78 |
79 | peri_iter = self.last_step-self.cur
80 |
81 | if peri_iter <= self.cycle_length:
82 | unit_cycle = (1 + math.cos(peri_iter * math.pi / self.cycle_length)) / 2
83 | adjusted_cycle = unit_cycle * (base_lr - self.eta_min) + self.eta_min
84 | lrs.append(adjusted_cycle)
85 | else:
86 | lrs.append(self.eta_min)
87 |
88 | return lrs
89 |
90 |
91 | def display_lr_curve(self, total_steps):
92 | lrs = []
93 | for _ in range(total_steps):
94 | self.step()
95 | lrs.append(self.get_lr()[0])
96 | import matplotlib.pyplot as plt
97 | plt.plot(lrs)
98 | plt.show()
99 |
100 |
101 | def get_lr_scheduler(config, optimizer, num_examples=None):
102 |
103 | if num_examples is None:
104 | num_examples = config.data.num_examples
105 | epoch_steps = num_examples // config.data.batch_size + 1
106 |
107 | if config.optim.use_multi_stage:
108 | max_steps = epoch_steps * config.optim.multi_stage.stage_epochs
109 | else:
110 | max_steps = epoch_steps * config.train_params.epochs
111 |
112 | period_steps = [epoch_steps * x for x in config.optim.cosine.restart.lr_period]
113 | step_steps = [epoch_steps * x for x in config.optim.cosine.restart.lr_step]
114 |
115 | init_lr = config.optim.init_lr
116 |
117 | use_warmup = config.optim.use_warm_up
118 | if use_warmup:
119 | warmup_steps = config.optim.warm_up.epoch * epoch_steps
120 | warmup_startlr = config.optim.warm_up.init_lr
121 | warmup_targetlr = config.optim.warm_up.target_lr
122 | else:
123 | warmup_steps = 0
124 | warmup_startlr = init_lr
125 | warmup_targetlr = init_lr
126 |
127 | if config.optim.lr_schedule == 'cosine':
128 | scheduler = CosineRestartAnnealingLR(optimizer,
129 | float(max_steps),
130 | period_steps,
131 | step_steps,
132 | eta_min=config.optim.min_lr,
133 | use_warmup=use_warmup,
134 | warmup_steps=warmup_steps,
135 | warmup_startlr=warmup_startlr,
136 | warmup_targetlr=warmup_targetlr,
137 | use_restart=config.optim.cosine.use_restart)
138 | # scheduler = CosineAnnealingLR(optimizer, config.train_params.epochs, config.optim.min_lr)
139 | elif config.optim.lr_schedule == 'poly':
140 | raise NotImplementedError
141 | else:
142 | raise NotImplementedError
143 |
144 | return scheduler
145 |
146 |
--------------------------------------------------------------------------------
/run_apis/optimizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from models.dropped_model import Dropped_Network
5 |
6 |
7 | class Optimizer(object):
8 |
9 | def __init__(self, model, criterion, config):
10 | self.config = config
11 | self.weight_sample_num = self.config.search_params.weight_sample_num
12 | self.criterion = criterion
13 | self.Dropped_Network = lambda model: Dropped_Network(
14 | model, softmax_temp=config.search_params.softmax_temp)
15 |
16 | arch_params_id = list(map(id, model.module.arch_parameters))
17 | weight_params = filter(lambda p: id(p) not in arch_params_id, model.parameters())
18 |
19 | self.weight_optimizer = torch.optim.SGD(
20 | weight_params,
21 | config.optim.weight.init_lr,
22 | momentum=config.optim.weight.momentum,
23 | weight_decay=config.optim.weight.weight_decay)
24 |
25 | self.arch_optimizer = torch.optim.Adam(
26 | [{'params': model.module.arch_alpha_params, 'lr': config.optim.arch.alpha_lr},
27 | {'params': model.module.arch_beta_params, 'lr': config.optim.arch.beta_lr}],
28 | betas=(0.5, 0.999),
29 | weight_decay=config.optim.arch.weight_decay)
30 |
31 |
32 | def arch_step(self, input_valid, target_valid, model, search_stage):
33 | head_sampled_w_old, alpha_head_index = \
34 | model.module.sample_branch('head', 2, search_stage= search_stage)
35 | stack_sampled_w_old, alpha_stack_index = \
36 | model.module.sample_branch('stack', 2, search_stage= search_stage)
37 | self.arch_optimizer.zero_grad()
38 |
39 | dropped_model = nn.DataParallel(self.Dropped_Network(model))
40 | logits, sub_obj = dropped_model(input_valid)
41 | sub_obj = torch.mean(sub_obj)
42 | loss = self.criterion(logits, target_valid)
43 | if self.config.optim.if_sub_obj:
44 | loss_sub_obj = torch.log(sub_obj) / torch.log(torch.tensor(self.config.optim.sub_obj.log_base))
45 | sub_loss_factor = self.config.optim.sub_obj.sub_loss_factor
46 | loss += loss_sub_obj * sub_loss_factor
47 | loss.backward()
48 | self.arch_optimizer.step()
49 |
50 | self.rescale_arch_params(head_sampled_w_old,
51 | stack_sampled_w_old,
52 | alpha_head_index,
53 | alpha_stack_index,
54 | model)
55 | return logits.detach(), loss.item(), sub_obj.item()
56 |
57 |
58 | def weight_step(self, *args, **kwargs):
59 | return self.weight_step_(*args, **kwargs)
60 |
61 |
62 | def weight_step_(self, input_train, target_train, model, search_stage):
63 | _, _ = model.module.sample_branch('head', self.weight_sample_num, search_stage=search_stage)
64 | _, _ = model.module.sample_branch('stack', self.weight_sample_num, search_stage=search_stage)
65 |
66 | self.weight_optimizer.zero_grad()
67 | dropped_model = nn.DataParallel(self.Dropped_Network(model))
68 | logits, sub_obj = dropped_model(input_train)
69 | sub_obj = torch.mean(sub_obj)
70 | loss = self.criterion(logits, target_train)
71 | loss.backward()
72 | self.weight_optimizer.step()
73 |
74 | return logits.detach(), loss.item(), sub_obj.item()
75 |
76 |
77 | def valid_step(self, input_valid, target_valid, model):
78 | _, _ = model.module.sample_branch('head', 1, training=False)
79 | _, _ = model.module.sample_branch('stack', 1, training=False)
80 |
81 | dropped_model = nn.DataParallel(self.Dropped_Network(model))
82 | logits, sub_obj = dropped_model(input_valid)
83 | sub_obj = torch.mean(sub_obj)
84 | loss = self.criterion(logits, target_valid)
85 |
86 | return logits, loss.item(), sub_obj.item()
87 |
88 |
89 | def rescale_arch_params(self, alpha_head_weights_drop,
90 | alpha_stack_weights_drop,
91 | alpha_head_index,
92 | alpha_stack_index,
93 | model):
94 |
95 | def comp_rescale_value(old_weights, new_weights, index):
96 | old_exp_sum = old_weights.exp().sum()
97 | new_drop_arch_params = torch.gather(new_weights, dim=-1, index=index)
98 | new_exp_sum = new_drop_arch_params.exp().sum()
99 | rescale_value = torch.log(old_exp_sum / new_exp_sum)
100 | rescale_mat = torch.zeros_like(new_weights).scatter_(0, index, rescale_value)
101 | return rescale_value, rescale_mat
102 |
103 | def rescale_params(old_weights, new_weights, indices):
104 | for i, (old_weights_block, indices_block) in enumerate(zip(old_weights, indices)):
105 | for j, (old_weights_branch, indices_branch) in enumerate(
106 | zip(old_weights_block, indices_block)):
107 | rescale_value, rescale_mat = comp_rescale_value(old_weights_branch,
108 | new_weights[i][j],
109 | indices_branch)
110 | new_weights[i][j].data.add_(rescale_mat)
111 |
112 | # rescale the arch params for head layers
113 | rescale_params(alpha_head_weights_drop, model.module.alpha_head_weights, alpha_head_index)
114 | # rescale the arch params for stack layers
115 | rescale_params(alpha_stack_weights_drop, model.module.alpha_stack_weights, alpha_stack_index)
116 |
117 |
118 | def set_param_grad_state(self, stage):
119 | def set_grad_state(params, state):
120 | for group in params:
121 | for param in group['params']:
122 | param.requires_grad_(state)
123 | if stage == 'Arch':
124 | state_list = [True, False] # [arch, weight]
125 | elif stage == 'Weights':
126 | state_list = [False, True]
127 | else:
128 | state_list = [False, False]
129 | set_grad_state(self.arch_optimizer.param_groups, state_list[0])
130 | set_grad_state(self.weight_optimizer.param_groups, state_list[1])
131 |
--------------------------------------------------------------------------------
/tools/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import shutil
4 | import sys
5 | import time
6 |
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 |
11 |
12 | class AverageMeter(object):
13 | def __init__(self):
14 | self.reset()
15 |
16 | def reset(self):
17 | self.avg = 0
18 | self.sum = 0
19 | self.cnt = 0
20 |
21 | def update(self, val, n=1):
22 | self.cur = val
23 | self.sum += val * n
24 | self.cnt += n
25 | self.avg = self.sum / self.cnt
26 |
27 |
28 | def accuracy(output, target, topk=(1, 5)):
29 | maxk = max(topk)
30 | batch_size = target.size(0)
31 |
32 | _, pred = output.topk(maxk, 1, True, True)
33 | pred = pred.t()
34 | correct = pred.eq(target.view(1, -1).expand_as(pred))
35 |
36 | res = []
37 | for k in topk:
38 | correct_k = correct[:k].view(-1).float().sum(0)
39 | res.append(correct_k.mul_(100.0/batch_size))
40 | return res
41 |
42 |
43 | def count_parameters_in_MB(model):
44 | return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "aux" not in name)/1e6
45 |
46 |
47 | def save_checkpoint(state, is_best, save):
48 | filename = os.path.join(save, 'checkpoint.pth.tar')
49 | torch.save(state, filename)
50 | if is_best:
51 | best_filename = os.path.join(save, 'model_best.pth.tar')
52 | shutil.copyfile(filename, best_filename)
53 |
54 |
55 | def save(model, model_path):
56 | torch.save(model.state_dict(), model_path)
57 |
58 |
59 | def load_net_config(path):
60 | with open(path, 'r') as f:
61 | net_config = ''
62 | while True:
63 | line = f.readline().strip()
64 | if 'net_type' in line:
65 | net_type = line.split(': ')[-1]
66 | break
67 | else:
68 | net_config += line
69 | return net_config, net_type
70 |
71 |
72 | def load_model(model, model_path):
73 | logging.info('Start loading the model from ' + model_path)
74 | if 'http' in model_path:
75 | model_addr = model_path
76 | model_path = model_path.split('/')[-1]
77 | if os.path.isfile(model_path):
78 | os.system('rm ' + model_path)
79 | os.system('wget -q ' + model_addr)
80 | model.load_state_dict(torch.load(model_path))
81 | logging.info('Loading the model finished!')
82 |
83 |
84 | def create_exp_dir(path):
85 | if not os.path.exists(path):
86 | os.mkdir(path)
87 | print('Experiment dir : {}'.format(path))
88 |
89 |
90 | def cross_entropy_with_label_smoothing(pred, target, label_smoothing=0.):
91 | """
92 | Label smoothing implementation.
93 | This function is taken from https://github.com/MIT-HAN-LAB/ProxylessNAS/blob/master/proxyless_nas/utils.py
94 | """
95 |
96 | logsoftmax = nn.LogSoftmax().cuda()
97 | n_classes = pred.size(1)
98 | # convert to one-hot
99 | target = torch.unsqueeze(target, 1)
100 | soft_target = torch.zeros_like(pred)
101 | soft_target.scatter_(1, target, 1)
102 | # label smoothing
103 | soft_target = soft_target * (1 - label_smoothing) + label_smoothing / n_classes
104 | return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1))
105 |
106 |
107 | def parse_net_config(net_config):
108 | str_configs = net_config.split('|')
109 | return [eval(str_config) for str_config in str_configs]
110 |
111 |
112 | def set_seed(seed):
113 | np.random.seed(seed)
114 | torch.manual_seed(seed)
115 | torch.cuda.manual_seed(seed)
116 |
117 |
118 | def set_logging(save_path, log_name='log.txt'):
119 | log_format = '%(asctime)s %(message)s'
120 | date_format = '%m/%d %H:%M:%S'
121 | logging.basicConfig(stream=sys.stdout, level=logging.INFO,
122 | format=log_format, datefmt=date_format)
123 | fh = logging.FileHandler(os.path.join(save_path, log_name))
124 | fh.setFormatter(logging.Formatter(log_format, date_format))
125 | logging.getLogger().addHandler(fh)
126 |
127 |
128 | def create_save_dir(save_path, job_name):
129 | if job_name != '':
130 | job_name = time.strftime("%Y%m%d-%H%M%S-") + job_name
131 | save_path = os.path.join(save_path, job_name)
132 | create_exp_dir(save_path)
133 | os.system('cp -r ./* '+save_path)
134 | save_path = os.path.join(save_path, 'output')
135 | create_exp_dir(save_path)
136 | else:
137 | save_path = os.path.join(save_path, 'output')
138 | create_exp_dir(save_path)
139 | return save_path, job_name
140 |
141 |
142 | def latency_measure(module, input_size, batch_size, meas_times, mode='gpu'):
143 | assert mode in ['gpu', 'cpu']
144 |
145 | latency = []
146 | module.eval()
147 | input_size = (batch_size,) + tuple(input_size)
148 | input_data = torch.randn(input_size)
149 | if mode=='gpu':
150 | input_data = input_data.cuda()
151 | module.cuda()
152 |
153 | for i in range(meas_times):
154 | with torch.no_grad():
155 | start = time.time()
156 | _ = module(input_data)
157 | torch.cuda.synchronize()
158 | if i >= 100:
159 | latency.append(time.time() - start)
160 | print(np.mean(latency) * 1e3, 'ms')
161 | return np.mean(latency) * 1e3
162 |
163 |
164 | def latency_measure_fw(module, input_data, meas_times):
165 | latency = []
166 | module.eval()
167 |
168 | for i in range(meas_times):
169 | with torch.no_grad():
170 | start = time.time()
171 | output_data = module(input_data)
172 | torch.cuda.synchronize()
173 | if i >= 100:
174 | latency.append(time.time() - start)
175 | print(np.mean(latency) * 1e3, 'ms')
176 | return np.mean(latency) * 1e3, output_data
177 |
178 |
179 | def record_topk(k, rec_list, data, comp_attr, check_attr):
180 | def get_insert_idx(orig_list, data, comp_attr):
181 | start = 0
182 | end = len(orig_list)
183 | while start < end:
184 | mid = (start + end) // 2
185 | if data[comp_attr] < orig_list[mid][comp_attr]:
186 | start = mid + 1
187 | else:
188 | end = mid
189 | return start
190 |
191 | if_insert = False
192 | insert_idx = get_insert_idx(rec_list, data, comp_attr)
193 | if insert_idx < k:
194 | rec_list.insert(insert_idx, data)
195 | if_insert = True
196 | while len(rec_list) > k:
197 | rec_list.pop()
198 | return if_insert
199 |
--------------------------------------------------------------------------------
/run_apis/retrain.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import ast
3 | import logging
4 | import os
5 | import pprint
6 | import sys
7 | import time
8 |
9 | import numpy as np
10 | import torch
11 | import torch.backends.cudnn as cudnn
12 | import torch.nn as nn
13 | from tensorboardX import SummaryWriter
14 |
15 | from configs.imagenet_train_cfg import cfg as config
16 | from dataset import imagenet_data
17 | from models import model_derived
18 | from tools import utils
19 | from tools.lr_scheduler import get_lr_scheduler
20 | from tools.multadds_count import comp_multadds
21 |
22 | from .trainer import Trainer
23 |
24 | if __name__ == '__main__':
25 | parser = argparse.ArgumentParser("Train_Params")
26 | parser.add_argument('--report_freq', type=float, default=500, help='report frequency')
27 | parser.add_argument('--data_path', type=str, default='../data', help='location of the data corpus')
28 | parser.add_argument('--load_path', type=str, default='./model_path', help='model loading path')
29 | parser.add_argument('--save', type=str, default='../', help='experiment name')
30 | parser.add_argument('--tb_path', type=str, default='', help='tensorboard output path')
31 | parser.add_argument('--meas_lat', type=ast.literal_eval, default='False', help='whether to measure the latency of the model')
32 | parser.add_argument('--job_name', type=str, default='', help='job_name')
33 | args = parser.parse_args()
34 |
35 | if args.job_name != '':
36 | args.job_name = time.strftime("%Y%m%d-%H%M%S-") + args.job_name
37 | args.save = os.path.join(args.save, args.job_name)
38 | utils.create_exp_dir(args.save)
39 | os.system('cp -r ./* '+args.save)
40 | else:
41 | args.save = os.path.join(args.save, 'output')
42 | utils.create_exp_dir(args.save)
43 |
44 | if args.tb_path == '':
45 | args.tb_path = args.save
46 | writer = SummaryWriter(args.tb_path)
47 |
48 | utils.set_logging(args.save)
49 |
50 | if not torch.cuda.is_available():
51 | logging.info('no gpu device available')
52 | sys.exit(1)
53 |
54 | cudnn.benchmark = True
55 | cudnn.enabled = True
56 |
57 | if config.train_params.use_seed:
58 | utils.set_seed(config.train_params.seed)
59 |
60 | logging.info("args = %s", args)
61 | logging.info('Training with config:')
62 | logging.info(pprint.pformat(config))
63 |
64 | if os.path.isfile(os.path.join(args.load_path, 'net_config')):
65 | config.net_config, config.net_type = utils.load_net_config(
66 | os.path.join(args.load_path, 'net_config'))
67 | derivedNetwork = getattr(model_derived, '%s_Net' % config.net_type.upper())
68 | model = derivedNetwork(config.net_config, config=config)
69 |
70 | model.eval()
71 | if hasattr(model, 'net_config'):
72 | logging.info("Network Structure: \n" + '|\n'.join(map(str, model.net_config)))
73 | if args.meas_lat:
74 | latency_cpu = utils.latency_measure(model, (3, 224, 224), 1, 2000, mode='cpu')
75 | logging.info('latency_cpu (batch 1): %.2fms' % latency_cpu)
76 | latency_gpu = utils.latency_measure(model, (3, 224, 224), 32, 5000, mode='gpu')
77 | logging.info('latency_gpu (batch 32): %.2fms' % latency_gpu)
78 | params = utils.count_parameters_in_MB(model)
79 | logging.info("Params = %.2fMB" % params)
80 | mult_adds = comp_multadds(model, input_size=config.data.input_size)
81 | logging.info("Mult-Adds = %.2fMB" % mult_adds)
82 |
83 | model = nn.DataParallel(model)
84 |
85 | # whether to resume from a checkpoint
86 | if config.optim.if_resume:
87 | utils.load_model(model, config.optim.resume.load_path)
88 | start_epoch = config.optim.resume.load_epoch + 1
89 | else:
90 | start_epoch = 0
91 |
92 | model = model.cuda()
93 |
94 | if config.optim.label_smooth:
95 | criterion = utils.cross_entropy_with_label_smoothing
96 | else:
97 | criterion = nn.CrossEntropyLoss()
98 | criterion = criterion.cuda()
99 |
100 | optimizer = torch.optim.SGD(
101 | model.parameters(),
102 | config.optim.init_lr,
103 | momentum=config.optim.momentum,
104 | weight_decay=config.optim.weight_decay
105 | )
106 |
107 | imagenet = imagenet_data.ImageNet12(trainFolder=os.path.join(args.data_path, 'train'),
108 | testFolder=os.path.join(args.data_path, 'val'),
109 | num_workers=config.data.num_workers,
110 | type_of_data_augmentation=config.data.type_of_data_aug,
111 | data_config=config.data)
112 |
113 | if config.optim.use_multi_stage:
114 | (train_queue, week_train_queue), valid_queue = imagenet.getSetTrainTestLoader(config.data.batch_size)
115 | else:
116 | train_queue, valid_queue = imagenet.getTrainTestLoader(config.data.batch_size)
117 |
118 | scheduler = get_lr_scheduler(config, optimizer, train_queue.dataset.__len__())
119 | scheduler.last_step = start_epoch * (train_queue.dataset.__len__() // config.data.batch_size + 1)-1
120 |
121 | trainer = Trainer(train_queue, valid_queue, optimizer, criterion, scheduler, config, args.report_freq)
122 |
123 | best_epoch = [0, 0, 0] # [epoch, acc_top1, acc_top5]
124 | for epoch in range(start_epoch, config.train_params.epochs):
125 |
126 | if config.optim.use_multi_stage and epoch>=config.optim.multi_stage.stage_epochs:
127 | train_data = week_train_queue
128 | else:
129 | train_data = train_queue
130 |
131 | train_acc_top1, train_acc_top5, train_obj, batch_time, data_time = trainer.train(model, epoch)
132 |
133 | with torch.no_grad():
134 | val_acc_top1, val_acc_top5, batch_time, data_time = trainer.infer(model, epoch)
135 |
136 | if val_acc_top1 > best_epoch[1]:
137 | best_epoch = [epoch, val_acc_top1, val_acc_top5]
138 | utils.save(model, os.path.join(args.save, 'weights.pt'))
139 | logging.info('BEST EPOCH %d val_top1 %.2f val_top5 %.2f', best_epoch[0], best_epoch[1], best_epoch[2])
140 |
141 | writer.add_scalar('train_acc_top1', train_acc_top1, epoch)
142 | writer.add_scalar('train_loss', train_obj, epoch)
143 | writer.add_scalar('val_acc_top1', val_acc_top1, epoch)
144 |
145 | if hasattr(model.module, 'net_config'):
146 | logging.info("Network Structure: \n" + '|\n'.join(map(str, model.module.net_config)))
147 | logging.info("Params = %.2fMB" % params)
148 | logging.info("Mult-Adds = %.2fMB" % mult_adds)
149 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DenseNAS
2 |
3 | The code of the CVPR2020 paper [Densely Connected Search Space for More Flexible Neural Architecture Search](https://arxiv.org/abs/1906.09607).
4 |
5 | Neural architecture search (NAS) has dramatically advanced the development of neural network design. We revisit the search space design in most previous NAS methods and find the number of blocks and the widths of blocks are set manually. However, block counts and block widths determine the network scale (depth and width) and make a great influence on both the accuracy and the model cost (FLOPs/latency).
6 |
7 | We propose to search block counts and block widths by designing a densely connected search space, i.e., DenseNAS. The new search space is represented as a dense super network, which is built upon our designed routing blocks. In the super network, routing blocks are densely connected and we search for the best path between them to derive the final architecture. We further propose a chained cost estimation algorithm to approximate the model cost during the search. Both the accuracy and model cost are optimized in DenseNAS.
8 | 
9 |
10 | ## Updates
11 | * 2020.6 The search code is released, including both MobileNetV2- and ResNet- based search space.
12 |
13 | ## Requirements
14 |
15 | * pytorch >= 1.0.1
16 | * python >= 3.6
17 |
18 | ## Search
19 |
20 | 1. Prepare the image set for search which contains 100 classes of the original ImageNet dataset. And 20% images are used as the validation set and 80% are used as the training set.
21 |
22 | 1). Generate the split list of the image data.
23 | `python dataset/mk_split_img_list.py --image_path 'the path of your ImageNet data' --output_path 'the path to output the list file'`
24 |
25 | 2). Use the image list obtained above to make the lmdb file.
26 | `python dataset/img2lmdb.py --image_path 'the path of your ImageNet data' --list_path 'the path of your image list generated above' --output_path 'the path to output the lmdb file' --split 'split folder (train/val)'`
27 |
28 | 2. Build the latency lookup table (lut) of the search space using the following script or directly use the ones provided in `./latency_list/`.
29 | `python -m run_apis.latency_measure --save 'output path' --input_size 'the input image size' --meas_times 'the times of op measurement' --list_name 'the name of the output lut' --device 'gpu or cpu' --config 'the path of the yaml config'`
30 |
31 | 3. Search for the architectures. (We perform the search process on 4 32G V100 GPUs.)
32 | For MobileNetV2 search:
33 | `python -m run_apis.search --data_path 'the path of the split dataset' --config configs/imagenet_search_cfg_mbv2.yaml`
34 | For ResNet search:
35 | `python -m run_apis.search --data_path 'the path of the split dataset' --config configs/imagenet_search_cfg_resnet.yaml`
36 |
37 | ## Train
38 |
39 | 1. (Optional) We pack the ImageNet data as the lmdb file for faster IO. The lmdb files can be made as follows. If you don't want to use lmdb data, just set `__C.data.train_data_type='img'` in the training config file `imagenet_train_cfg.py`.
40 |
41 | 1). Generate the list of the image data.
42 | `python dataset/mk_img_list.py --image_path 'the path of your image data' --output_path 'the path to output the list file'`
43 |
44 | 2). Use the image list obtained above to make the lmdb file.
45 | `python dataset/img2lmdb.py --image_path 'the path of your image data' --list_path 'the path of your image list' --output_path 'the path to output the lmdb file' --split 'split folder (train/val)'`
46 |
47 | 2. Train the searched model with the following script by assigning `__C.net_config` with the architecture obtained in the above search process. You can also train your customized model by redefine the variable `model` in `retrain.py`.
48 | `python -m run_apis.retrain --data_path 'The path of ImageNet data' --load_path 'The path you put the net_config of the model'`
49 |
50 | ## Evaluate
51 |
52 | 1. Download the related files of the pretrained model and put `net_config` and `weights.pt` into the `model_path`
53 | 2. `python -m run_apis.validation --data_path 'The path of ImageNet data' --load_path 'The path you put the pre-trained model'`
54 |
55 | ## Results
56 |
57 | For experiments on the MobileNetV2-based search space, DenseNAS achieves 75.3\% top-1 accuracy on ImageNet with only 361MB FLOPs and 17.9ms latency on a single TITAN-XP. The larger model searched by DenseNAS achieves 76.1\% accuracy with only 479M FLOPs. DenseNAS further promotes the ImageNet classification accuracies of ResNet-18, -34 and -50-B by 1.5\%, 0.5\% and 0.3\% with 200M, 600M and 680M FLOPs reduction respectively.
58 |
59 | The comparison of model performance on ImageNet under the MobileNetV2-based search spaces.
60 |
61 |
62 |
63 |
64 |
65 |
66 | The comparison of model performance on ImageNet under the ResNet-based search spaces.
67 |
68 |
69 |
70 |
71 |
72 |
73 | Our pre-trained models can be downloaded in the following links. The complete list of the models can be found in [DenseNAS_modelzoo](https://drive.google.com/open?id=183oIMF6IowZrj81kenVBkQoIMlip9kLo).
74 |
75 | | Model | FLOPs | Latency | Top-1(%)|
76 | |----------------------|-------|---------|---------|
77 | | [DenseNAS-Large](https://drive.google.com/open?id=14Zgc-IlxjaRtGyDHJSdMpLHVvOd0Km1u) | 479M | 28.9ms | 76.1 |
78 | | [DenseNAS-A](https://drive.google.com/open?id=1ZdephrAY4GVRqv9SvOXoJDUmO-kWhhml) | 251M | 13.6ms | 73.1 |
79 | | [DenseNAS-B](https://drive.google.com/open?id=1djhL5P1vsWVqWuT5lR7UCxEhw4cET__7) | 314M | 15.4ms | 74.6 |
80 | | [DenseNAS-C](https://drive.google.com/open?id=1L2mqir89b1UiBkePmrtjG6QLi9MqzRdQ) | 361M | 17.9ms | 75.3 |
81 | | [DenseNAS-R1](https://drive.google.com/open?id=1YaMWb1LKpgSS5mgBcB3CthTGtTIOtxWw) | 1.61B | 12.0ms | 73.5 |
82 | | [DenseNAS-R2](https://drive.google.com/open?id=1Qawst3E2hqdam2TiTFo2BhBXS-M6AWdh) | 3.06B | 22.2ms | 75.8 |
83 | | [DenseNAS-R3](https://drive.google.com/open?id=14RwIGWsurNvevhxL9AcnlngU0KR8WeX-) | 3.41B | 41.7ms | 78.0 |
84 |
85 | 
86 |
87 | ## Citation
88 | If you find this repository/work helpful in your research, welcome to cite it.
89 | ```
90 | @inproceedings{fang2019densely,
91 | title={Densely connected search space for more flexible neural architecture search},
92 | author={Fang, Jiemin and Sun, Yuzhu and Zhang, Qian and Li, Yuan and Liu, Wenyu and Wang, Xinggang},
93 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
94 | year={2020}
95 | }
96 | ```
--------------------------------------------------------------------------------
/latency_list/lat_list_resv2_32_reg10:
--------------------------------------------------------------------------------
1 | [0.7470053614992084, [[[[1.7158220271871547]], []], [[[1.6621649144875883], [1.7133765027980612]], []], [[[1.9134202629628807], [1.8592351855653704], [1.8933377362260915]], []], [[[1.0193982991305264], [0.9558140388642898], [0.8666880443842724]], [[0.6160835545472425, 0.06160835545472425], [0.6116041270169345, 0.061160412701693444], [0.6122468938731184, 0.061224689387311834], [0.6109231168573552, 0.06109231168573552], [0.6104896766970856, 0.06104896766970856]]], [[[0.877161074166346], [1.1209101147121854], [1.0645518158421372], [1.0146694231514979]], [[0.8019653474441683, 0.08019653474441682], [0.7993539415224634, 0.07993539415224635], [0.8006674593145197, 0.08006674593145197], [0.799418796192516, 0.0799418796192516], [0.8003230046744298, 0.08003230046744299]]], [[[1.2716174125671387], [1.1164091813443888], [1.2395282466002184], [1.195189832436918], [1.1547757399202596]], [[1.094416416052616, 0.1094416416052616], [1.0937702535378813, 0.10937702535378813], [1.0924938712457213, 0.10924938712457213], [1.0943670465488626, 0.10943670465488627], [1.092695009828818, 0.1092695009828818]]], [[[0.5315265029367775], [0.49857038440126356], [0.4235670301649306]], [[0.41330046123928493, 0.041330046123928495], [0.41206465827094185, 0.04120646582709418], [0.4113578796386719, 0.04113578796386719], [0.4114802196772412, 0.041148021967724116], [0.41177434150618736, 0.04117743415061874], [0.41203164091013894, 0.04120316409101389], [0.4116343488596907, 0.04116343488596907], [0.4124539307873658, 0.04124539307873658], [0.41144332500419234, 0.04114433250041923], [0.41133061803952614, 0.04113306180395261], [0.41104555130004883, 0.04110455513000488], [0.41221972667809686, 0.041221972667809685], [0.4108423897714326, 0.041084238977143264], [0.4113530149363508, 0.041135301493635076], [0.41111775118895255, 0.04111177511889526]]], [[[0.5795735301393451], [0.7294338158886843], [0.6747800412804189], [0.5880859163072375]], [[0.4475170915777033, 0.04475170915777033], [0.4473568453933253, 0.044735684539332535], [0.44746613261675594, 0.044746613261675595], [0.4469970982484143, 0.044699709824841435], [0.4433615520747021, 0.04433615520747021], [0.4466982321305708, 0.04466982321305708], [0.44393339542427446, 0.04439333954242745], [0.44414787581472687, 0.04441478758147269], [0.44315029876400724, 0.04431502987640072], [0.44436098349214803, 0.0444360983492148], [0.4429972292196871, 0.04429972292196871], [0.4465308093061351, 0.04465308093061351], [0.4408782660359084, 0.04408782660359084], [0.44157577283454663, 0.044157577283454666], [0.44270563607264046, 0.044270563607264043]]], [[[0.5950747355066165], [0.5492691319398205], [0.7678800881511033], [0.7344900237189399], [0.6177222849142672]], [[0.4978300826718109, 0.049783008267181086], [0.4936368296844791, 0.04936368296844791], [0.4946338287507645, 0.04946338287507645], [0.4952423259465381, 0.049524232594653814], [0.4930364483534688, 0.04930364483534688], [0.49611074755890205, 0.0496110747558902], [0.4933389991220802, 0.04933389991220802], [0.4946485673538362, 0.049464856735383624], [0.4925500985347864, 0.04925500985347864], [0.4923339082737162, 0.04923339082737162], [0.4922379869403261, 0.04922379869403261], [0.4916638798183865, 0.04916638798183865], [0.4915951960014574, 0.04915951960014574], [0.49248155921396586, 0.049248155921396586], [0.4925919060755257, 0.04925919060755257]]], [[[0.7117494188173853], [0.6895356467275908], [0.6656108480511289], [0.8351398236823805], [0.7782968366988982], [0.6874867400737724]], [[0.6256991685038865, 0.06256991685038865], [0.6235957145690918, 0.06235957145690918], [0.6226729383372297, 0.06226729383372297], [0.6236689981788095, 0.06236689981788095], [0.6236180873832318, 0.06236180873832318], [0.6223808153711184, 0.06223808153711184], [0.6238544589341288, 0.06238544589341288], [0.621514826109915, 0.0621514826109915], [0.6216128185541943, 0.06216128185541943], [0.6210656358738138, 0.06210656358738138], [0.6469456595603865, 0.06469456595603865], [0.6190218106664792, 0.061902181066647924], [0.61852570736047, 0.061852570736047], [0.6181354715366557, 0.06181354715366557], [0.621467214642149, 0.062146721464214905]]], [[[0.7768191954102179], [0.7580275005764431], [0.7367050045668477], [0.7167703455144709]], [[0.6815665900105178, 0.06815665900105178], [0.6803986038824524, 0.06803986038824525], [0.6794075291566174, 0.06794075291566173], [0.6797389309815686, 0.06797389309815685], [0.6779417847142075, 0.06779417847142075]]], [[[0.9306495839899237], [0.8969544160245645], [0.8677374955379602], [0.8321288378551753], [0.8342863814999358]], [[0.8310986046839242, 0.08310986046839242], [0.8355484345946649, 0.08355484345946648], [0.8296608443212028, 0.08296608443212028], [0.8263708605910793, 0.08263708605910794], [0.8277203338314788, 0.08277203338314788]]], [[[1.0617327449297664], [0.9705225867454452], [0.9498220983177724], [0.9233483882865521], [0.8980551392141014], [0.8740888701544868]], [[0.8907880927577163, 0.08907880927577164], [0.8912024112662884, 0.08912024112662884], [0.8915229999657833, 0.08915229999657834], [0.8909804170781916, 0.08909804170781917], [0.890022504209268, 0.0890022504209268]]], [[[0.5085483223500877], [0.49751910296353424], [0.4761004688763859]], [[0.41085341964105165, 0.04108534196410517], [0.4084869827887025, 0.04084869827887025], [0.40876537862450185, 0.040876537862450185], [0.40820923718539154, 0.040820923718539154], [0.4086827509330981, 0.04086827509330981]]], [[[0.5120689218694514], [0.5295760944636182], [0.5133399337229102], [0.4979153353758532]], [[0.44370985994435325, 0.04437098599443533], [0.4397995785029248, 0.043979957850292475], [0.43866104549831814, 0.043866104549831815], [0.44069872962103945, 0.04406987296210395], [0.4398404468189586, 0.04398404468189586]]], [[[0.5302517823498658], [0.5341541646706938], [0.6771185903838186], [0.6394316692544957], [0.6240752008226182]], [[0.4810813701514042, 0.04810813701514042], [0.4783633261015921, 0.04783633261015921], [0.47798166371355155, 0.04779816637135516], [0.47773036089810456, 0.047773036089810456], [0.4787054447212605, 0.04787054447212605]]], [[[0.5921873420175879], [0.5748705671291159], [0.5476833593965781], [0.6850687662760417], [0.6735895137594203], [0.6462683340515754]], [[0.51674510493423, 0.051674510493423], [0.5135976425325027, 0.05135976425325027], [0.5137101086703213, 0.05137101086703213], [0.5135496457417805, 0.051354964574178055], [0.5139973669341117, 0.05139973669341117]]], [[[1.1904039045776984], [1.171813565071183], [1.13944036792023], [1.1230145078716856]], [[1.400846206780636, 0.1400846206780636]]], [[[1.5440915810941447], [1.2339090578483813], [1.1897495780328307], [1.1610502666897244], [1.1645580060554275]], [[1.4041198865331783, 0.14041198865331783]]], [[[1.448338104016853], [1.4705032050007523], [1.1230347614095668], [1.1128196330985638], [1.0730700059370561], [1.0587836516023885]], [[1.2187801226220951, 0.12187801226220951]]]], [0.13975803298179548, 0.1369504976754237, 0.13584091205789584], 0.19902790435636888]
--------------------------------------------------------------------------------
/latency_list/lat_list_resv2_32:
--------------------------------------------------------------------------------
1 | [0.7470053614992084, [[[[1.7158220271871547]], []], [[[1.6621649144875883], [1.7133765027980612]], []], [[[1.9134202629628807], [1.8592351855653704], [1.8933377362260915]], []], [[[1.0193982991305264], [0.9558140388642898], [0.8666880443842724]], [[0.6160835545472425, 0.005083878835042318], [0.6116041270169345, 0.004423748363148083], [0.6122468938731184, 0.004710428642504143], [0.6109231168573552, 0.005227267140089863], [0.6104896766970856, 0.004845990075005426]]], [[[0.877161074166346], [1.1209101147121854], [1.0645518158421372], [1.0146694231514979]], [[0.8019653474441683, 0.005188060529304274], [0.7993539415224634, 0.004809480724912701], [0.8006674593145197, 0.005086792839898003], [0.799418796192516, 0.0048404992228806626], [0.8003230046744298, 0.005279695144807449]]], [[[1.2716174125671387], [1.1164091813443888], [1.2395282466002184], [1.195189832436918], [1.1547757399202596]], [[1.094416416052616, 0.004327369458747632], [1.0937702535378813, 0.0049243069658375755], [1.0924938712457213, 0.00507874922318892], [1.0943670465488626, 0.005123567099523063], [1.092695009828818, 0.005048718115296026]]], [[[0.5315265029367775], [0.49857038440126356], [0.4235670301649306]], [[0.41330046123928493, 0.0046339902010830965], [0.41206465827094185, 0.004920863141917219], [0.4113578796386719, 0.005258743209068222], [0.4114802196772412, 0.0052337212996049366], [0.41177434150618736, 0.004909568362765842], [0.41203164091013894, 0.004904535081651476], [0.4116343488596907, 0.00530339250660906], [0.4124539307873658, 0.004672931902336352], [0.41144332500419234, 0.00468335970483645], [0.41133061803952614, 0.005149263324159564], [0.41104555130004883, 0.004897117614746094], [0.41221972667809686, 0.004751080214375197], [0.4108423897714326, 0.005267075817994397], [0.4113530149363508, 0.004888375600179036], [0.41111775118895255, 0.0051637851830684785]]], [[[0.5795735301393451], [0.7294338158886843], [0.6747800412804189], [0.5880859163072375]], [[0.4475170915777033, 0.004560899252843376], [0.4473568453933253, 0.005116848030475655], [0.44746613261675594, 0.0050102580677379265], [0.4469970982484143, 0.004821449819237295], [0.4433615520747021, 0.004868603715992937], [0.4466982321305708, 0.005760216953778508], [0.44393339542427446, 0.004536623906607579], [0.44414787581472687, 0.004877586557407572], [0.44315029876400724, 0.004928280608822601], [0.44436098349214803, 0.004839006096425683], [0.4429972292196871, 0.004484942465117484], [0.4465308093061351, 0.005215683368721394], [0.4408782660359084, 0.004838741186893348], [0.44157577283454663, 0.004812563308561691], [0.44270563607264046, 0.005174333398992365]]], [[[0.5950747355066165], [0.5492691319398205], [0.7678800881511033], [0.7344900237189399], [0.6177222849142672]], [[0.4978300826718109, 0.004748093961465239], [0.4936368296844791, 0.004589581730389836], [0.4946338287507645, 0.005070175787415167], [0.4952423259465381, 0.004672811488912563], [0.4930364483534688, 0.004822509457366635], [0.49611074755890205, 0.0048711564805772566], [0.4933389991220802, 0.004833105838660038], [0.4946485673538362, 0.004776294785316544], [0.4925500985347864, 0.004769792460431956], [0.4923339082737162, 0.005009198429608586], [0.4922379869403261, 0.00521529804576527], [0.4916638798183865, 0.004769985121910018], [0.4915951960014574, 0.004877273482505721], [0.49248155921396586, 0.005180137326018979], [0.4925919060755257, 0.005257298247982757]]], [[[0.7117494188173853], [0.6895356467275908], [0.6656108480511289], [0.8351398236823805], [0.7782968366988982], [0.6874867400737724]], [[0.6256991685038865, 0.0046522930414989744], [0.6235957145690918, 0.005195285334731593], [0.6226729383372297, 0.00455521573924055], [0.6236689981788095, 0.0046892840452868526], [0.6236180873832318, 0.004846182736483488], [0.6223808153711184, 0.005256985173080907], [0.6238544589341288, 0.004499079001070274], [0.621514826109915, 0.00439465647996074], [0.6216128185541943, 0.005834873276527482], [0.6210656358738138, 0.005216092774362275], [0.6469456595603865, 0.0049324469132856884], [0.6190218106664792, 0.0048540096090297505], [0.61852570736047, 0.005814354829113893], [0.6181354715366557, 0.004709561665852864], [0.621467214642149, 0.004667681877059166]]], [[[0.7768191954102179], [0.7580275005764431], [0.7367050045668477], [0.7167703455144709]], [[0.6815665900105178, 0.00501138995392154], [0.6803986038824524, 0.004988993057096848], [0.6794075291566174, 0.005249977111816406], [0.6797389309815686, 0.004805507081927675], [0.6779417847142075, 0.004855791727701823]]], [[[0.9306495839899237], [0.8969544160245645], [0.8677374955379602], [0.8321288378551753], [0.8342863814999358]], [[0.8310986046839242, 0.004771863571321121], [0.8355484345946649, 0.004956047944348268], [0.8296608443212028, 0.004708622441147313], [0.8263708605910793, 0.004478680967080473], [0.8277203338314788, 0.004992942617397115]]], [[[1.0617327449297664], [0.9705225867454452], [0.9498220983177724], [0.9233483882865521], [0.8980551392141014], [0.8740888701544868]], [[0.8907880927577163, 0.0046964125199751424], [0.8912024112662884, 0.004646922602798], [0.8915229999657833, 0.004888905419243706], [0.8909804170781916, 0.0051665065264461015], [0.890022504209268, 0.004725648899271031]]], [[[0.5085483223500877], [0.49751910296353424], [0.4761004688763859]], [[0.41085341964105165, 0.004508663909603851], [0.4084869827887025, 0.004639240226360282], [0.40876537862450185, 0.005059145917796126], [0.40820923718539154, 0.005027236360492128], [0.4086827509330981, 0.004860897256870463]]], [[[0.5120689218694514], [0.5295760944636182], [0.5133399337229102], [0.4979153353758532]], [[0.44370985994435325, 0.00485039720631609], [0.4397995785029248, 0.0051300935070924085], [0.43866104549831814, 0.00495421766030668], [0.44069872962103945, 0.00507662994693024], [0.4398404468189586, 0.0051092860674617265]]], [[[0.5302517823498658], [0.5341541646706938], [0.6771185903838186], [0.6394316692544957], [0.6240752008226182]], [[0.4810813701514042, 0.005186904560435902], [0.4783633261015921, 0.005180474483605587], [0.47798166371355155, 0.005193695877537583], [0.47773036089810456, 0.004894829759694109], [0.4787054447212605, 0.004925438852021189]]], [[[0.5921873420175879], [0.5748705671291159], [0.5476833593965781], [0.6850687662760417], [0.6735895137594203], [0.6462683340515754]], [[0.51674510493423, 0.004966836987119733], [0.5135976425325027, 0.005099749324297664], [0.5137101086703213, 0.004888158856016217], [0.5135496457417805, 0.005136788493455059], [0.5139973669341117, 0.004399376686173255]]], [[[1.1904039045776984], [1.171813565071183], [1.13944036792023], [1.1230145078716856]], [[1.400846206780636, 0.005112127824263139]]], [[[1.5440915810941447], [1.2339090578483813], [1.1897495780328307], [1.1610502666897244], [1.1645580060554275]], [[1.4041198865331783, 0.004999589438390251]]], [[[1.448338104016853], [1.4705032050007523], [1.1230347614095668], [1.1128196330985638], [1.0730700059370561], [1.0587836516023885]], [[1.2187801226220951, 0.005037182509297072]]]], [0.13975803298179548, 0.1369504976754237, 0.13584091205789584], 0.19902790435636888]
--------------------------------------------------------------------------------
/models/model_derived.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from .operations import OPS
7 | from tools.utils import parse_net_config
8 |
9 |
10 | class Block(nn.Module):
11 |
12 | def __init__(self, in_ch, block_ch, head_op, stack_ops, stride):
13 | super(Block, self).__init__()
14 | self.head_layer = OPS[head_op](in_ch, block_ch, stride,
15 | affine=True, track_running_stats=True)
16 |
17 | modules = []
18 | for stack_op in stack_ops:
19 | modules.append(OPS[stack_op](block_ch, block_ch, 1,
20 | affine=True, track_running_stats=True))
21 | self.stack_layers = nn.Sequential(*modules)
22 |
23 | def forward(self, x):
24 | x = self.head_layer(x)
25 | x = self.stack_layers(x)
26 | return x
27 |
28 |
29 | class Conv1_1_Block(nn.Module):
30 |
31 | def __init__(self, in_ch, block_ch):
32 | super(Conv1_1_Block, self).__init__()
33 | self.conv1_1 = nn.Sequential(
34 | nn.Conv2d(in_channels=in_ch, out_channels=block_ch,
35 | kernel_size=1, stride=1, padding=0, bias=False),
36 | nn.BatchNorm2d(block_ch),
37 | nn.ReLU6(inplace=True)
38 | )
39 |
40 | def forward(self, x):
41 | return self.conv1_1(x)
42 |
43 |
44 | class MBV2_Net(nn.Module):
45 | def __init__(self, net_config, config=None):
46 | """
47 | net_config=[[in_ch, out_ch], head_op, [stack_ops], num_stack_layers, stride]
48 | """
49 | super(MBV2_Net, self).__init__()
50 | self.config = config
51 | self.net_config = parse_net_config(net_config)
52 | self.in_chs = self.net_config[0][0][0]
53 | self._num_classes = 1000
54 |
55 | self.input_block = nn.Sequential(
56 | nn.Conv2d(in_channels=3, out_channels=self.in_chs, kernel_size=3,
57 | stride=2, padding=1, bias=False),
58 | nn.BatchNorm2d(self.in_chs),
59 | nn.ReLU6(inplace=True)
60 | )
61 | self.blocks = nn.ModuleList()
62 | for config in self.net_config:
63 | if config[1] == 'conv1_1':
64 | continue
65 | self.blocks.append(Block(config[0][0], config[0][1],
66 | config[1], config[2], config[-1]))
67 |
68 | if self.net_config[-1][1] == 'conv1_1':
69 | block_last_dim = self.net_config[-1][0][0]
70 | last_dim = self.net_config[-1][0][1]
71 | else:
72 | block_last_dim = self.net_config[-1][0][1]
73 | self.conv1_1_block = Conv1_1_Block(block_last_dim, last_dim)
74 |
75 | self.global_pooling = nn.AdaptiveAvgPool2d(1)
76 | self.classifier = nn.Linear(last_dim, self._num_classes)
77 |
78 | self.init_model()
79 | self.set_bn_param(0.1, 0.001)
80 |
81 |
82 | def forward(self,x):
83 | block_data = self.input_block(x)
84 | for i, block in enumerate(self.blocks):
85 | block_data = block(block_data)
86 | block_data = self.conv1_1_block(block_data)
87 |
88 | out = self.global_pooling(block_data)
89 | logits = self.classifier(out.view(out.size(0),-1))
90 |
91 | return logits
92 |
93 | def init_model(self, model_init='he_fout', init_div_groups=True):
94 | for m in self.modules():
95 | if isinstance(m, nn.Conv2d):
96 | if model_init == 'he_fout':
97 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
98 | if init_div_groups:
99 | n /= m.groups
100 | m.weight.data.normal_(0, math.sqrt(2. / n))
101 | elif model_init == 'he_fin':
102 | n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
103 | if init_div_groups:
104 | n /= m.groups
105 | m.weight.data.normal_(0, math.sqrt(2. / n))
106 | else:
107 | raise NotImplementedError
108 | elif isinstance(m, nn.BatchNorm2d):
109 | m.weight.data.fill_(1)
110 | m.bias.data.zero_()
111 | elif isinstance(m, nn.Linear):
112 | if m.bias is not None:
113 | m.bias.data.zero_()
114 | elif isinstance(m, nn.BatchNorm1d):
115 | m.weight.data.fill_(1)
116 | m.bias.data.zero_()
117 |
118 | def set_bn_param(self, bn_momentum, bn_eps):
119 | for m in self.modules():
120 | if isinstance(m, nn.BatchNorm2d):
121 | m.momentum = bn_momentum
122 | m.eps = bn_eps
123 | return
124 |
125 |
126 | class RES_Net(nn.Module):
127 | def __init__(self, net_config, config=None):
128 | """
129 | net_config=[[in_ch, out_ch], head_op, [stack_ops], num_stack_layers, stride]
130 | """
131 | super(RES_Net, self).__init__()
132 | self.config = config
133 | self.net_config = parse_net_config(net_config)
134 | self.in_chs = self.net_config[0][0][0]
135 | self._num_classes = 1000
136 |
137 | self.input_block = nn.Sequential(
138 | nn.Conv2d(in_channels=3, out_channels=self.in_chs, kernel_size=3,
139 | stride=2, padding=1, bias=False),
140 | nn.BatchNorm2d(self.in_chs),
141 | nn.ReLU6(inplace=True),
142 | )
143 | self.blocks = nn.ModuleList()
144 | for config in self.net_config:
145 | self.blocks.append(Block(config[0][0], config[0][1],
146 | config[1], config[2], config[-1]))
147 |
148 | self.global_pooling = nn.AdaptiveAvgPool2d((1, 1))
149 | if self.net_config[-1][1] == 'bottle_neck':
150 | last_dim = self.net_config[-1][0][-1] * 4
151 | else:
152 | last_dim = self.net_config[-1][0][1]
153 | self.classifier = nn.Linear(last_dim, self._num_classes)
154 |
155 | for m in self.modules():
156 | if isinstance(m, nn.Conv2d):
157 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
158 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
159 | if m.affine==True:
160 | nn.init.constant_(m.weight, 1)
161 | nn.init.constant_(m.bias, 0)
162 |
163 | def forward(self, x):
164 | block_data = self.input_block(x)
165 | for i, block in enumerate(self.blocks):
166 | block_data = block(block_data)
167 |
168 | out = self.global_pooling(block_data)
169 | out = torch.flatten(out, 1)
170 | logits = self.classifier(out)
171 | return logits
--------------------------------------------------------------------------------
/dataset/imagenet_data.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | import torchvision
5 | import torchvision.transforms as transforms
6 |
7 | from . import lmdb_dataset
8 | from . import torchvision_extension as transforms_extension
9 | from .prefetch_data import fast_collate
10 |
11 | class ImageNet12(object):
12 |
13 | def __init__(self, trainFolder, testFolder, num_workers=8, pin_memory=True,
14 | size_images=224, scaled_size=256, type_of_data_augmentation='rand_scale',
15 | data_config=None):
16 |
17 | self.data_config = data_config
18 | self.trainFolder = trainFolder
19 | self.testFolder = testFolder
20 | self.num_workers = num_workers
21 | self.pin_memory = pin_memory
22 | self.patch_dataset = self.data_config.patch_dataset
23 |
24 | #images will be rescaled to match this size
25 | if not isinstance(size_images, int):
26 | raise ValueError('size_images must be an int. It will be scaled to a square image')
27 | self.size_images = size_images
28 | self.scaled_size = scaled_size
29 |
30 | type_of_data_augmentation = type_of_data_augmentation.lower()
31 | if type_of_data_augmentation not in ('rand_scale', 'random_sized'):
32 | raise ValueError('type_of_data_augmentation must be either rand-scale or random-sized')
33 | self.type_of_data_augmentation = type_of_data_augmentation
34 |
35 |
36 | def _getTransformList(self, aug_type):
37 |
38 | assert aug_type in ['rand_scale', 'random_sized', 'week_train', 'validation']
39 | list_of_transforms = []
40 |
41 | if aug_type == 'validation':
42 | list_of_transforms.append(transforms.Resize(self.scaled_size))
43 | list_of_transforms.append(transforms.CenterCrop(self.size_images))
44 |
45 | elif aug_type == 'week_train':
46 | list_of_transforms.append(transforms.Resize(256))
47 | list_of_transforms.append(transforms.RandomCrop(self.size_images))
48 | list_of_transforms.append(transforms.RandomHorizontalFlip())
49 |
50 | else:
51 | if aug_type == 'rand_scale':
52 | list_of_transforms.append(transforms_extension.RandomScale(256, 480))
53 | list_of_transforms.append(transforms.RandomCrop(self.size_images))
54 | list_of_transforms.append(transforms.RandomHorizontalFlip())
55 |
56 | elif aug_type == 'random_sized':
57 | list_of_transforms.append(transforms.RandomResizedCrop(self.size_images,
58 | scale=(self.data_config.random_sized.min_scale, 1.0)))
59 | list_of_transforms.append(transforms.RandomHorizontalFlip())
60 |
61 | if self.data_config.color:
62 | list_of_transforms.append(transforms.ColorJitter(brightness=0.4,
63 | contrast=0.4,
64 | saturation=0.4))
65 | return transforms.Compose(list_of_transforms)
66 |
67 |
68 | def _getTrainSet(self):
69 |
70 | train_transform = self._getTransformList(self.type_of_data_augmentation)
71 |
72 | if self.data_config.train_data_type == 'img':
73 | train_set = torchvision.datasets.ImageFolder(self.trainFolder, train_transform)
74 | elif self.data_config.train_data_type == 'lmdb':
75 | train_set = lmdb_dataset.ImageFolder(self.trainFolder,
76 | os.path.join(self.trainFolder, '..', 'train_datalist'),
77 | train_transform,
78 | patch_dataset=self.patch_dataset)
79 | self.train_num_examples = train_set.__len__()
80 |
81 | return train_set
82 |
83 |
84 | def _getWeekTrainSet(self):
85 |
86 | train_transform = self._getTransformList('week_train')
87 | if self.data_config.train_data_type == 'img':
88 | train_set = torchvision.datasets.ImageFolder(self.trainFolder, train_transform)
89 | elif self.data_config.train_data_type == 'lmdb':
90 | train_set = lmdb_dataset.ImageFolder(self.trainFolder,
91 | os.path.join(self.trainFolder, '..', 'train_datalist'),
92 | train_transform,
93 | patch_dataset=self.patch_dataset)
94 | self.train_num_examples = train_set.__len__()
95 | return train_set
96 |
97 |
98 | def _getTestSet(self):
99 |
100 | test_transform = self._getTransformList('validation')
101 | if self.data_config.val_data_type == 'img':
102 | test_set = torchvision.datasets.ImageFolder(self.testFolder, test_transform)
103 | elif self.data_config.val_data_type == 'lmdb':
104 | test_set = lmdb_dataset.ImageFolder(self.testFolder,
105 | os.path.join(self.testFolder, '..', 'val_datalist'),
106 | test_transform)
107 | self.test_num_examples = test_set.__len__()
108 | return test_set
109 |
110 |
111 | def getTrainLoader(self, batch_size, shuffle=True):
112 |
113 | train_set = self._getTrainSet()
114 | train_loader = torch.utils.data.DataLoader(
115 | train_set, batch_size=batch_size, shuffle=shuffle,
116 | num_workers=self.num_workers, pin_memory=self.pin_memory,
117 | sampler=None, collate_fn=fast_collate)
118 | return train_loader
119 |
120 |
121 | def getWeekTrainLoader(self, batch_size, shuffle=True):
122 |
123 | train_set = self._getWeekTrainSet()
124 | train_loader = torch.utils.data.DataLoader(train_set,
125 | batch_size=batch_size,
126 | shuffle=shuffle,
127 | num_workers=self.num_workers,
128 | pin_memory=self.pin_memory,
129 | collate_fn=fast_collate)
130 | return train_loader
131 |
132 |
133 | def getTestLoader(self, batch_size, shuffle=False):
134 |
135 | test_set = self._getTestSet()
136 |
137 | test_loader = torch.utils.data.DataLoader(
138 | test_set, batch_size=batch_size, shuffle=shuffle,
139 | num_workers=self.num_workers, pin_memory=self.pin_memory, sampler=None,
140 | collate_fn=fast_collate)
141 | return test_loader
142 |
143 |
144 | def getTrainTestLoader(self, batch_size, train_shuffle=True, val_shuffle=False):
145 |
146 | train_loader = self.getTrainLoader(batch_size, train_shuffle)
147 | test_loader = self.getTestLoader(batch_size, val_shuffle)
148 | return train_loader, test_loader
149 |
150 |
151 | def getSetTrainTestLoader(self, batch_size, train_shuffle=True, val_shuffle=False):
152 |
153 | train_loader = self.getTrainLoader(batch_size, train_shuffle)
154 | week_train_loader = self.getWeekTrainLoader(batch_size, train_shuffle)
155 | test_loader = self.getTestLoader(batch_size, val_shuffle)
156 | return (train_loader, week_train_loader), test_loader
157 |
--------------------------------------------------------------------------------
/models/operations.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | OPS = {
5 | 'mbconv_k3_t1': lambda C_in, C_out, stride, affine, track_running_stats: MBConv(C_in, C_out, 3, stride, 1, t=1, affine=affine, track_running_stats=track_running_stats),
6 | 'mbconv_k3_t3': lambda C_in, C_out, stride, affine, track_running_stats: MBConv(C_in, C_out, 3, stride, 1, t=3, affine=affine, track_running_stats=track_running_stats),
7 | 'mbconv_k3_t6': lambda C_in, C_out, stride, affine, track_running_stats: MBConv(C_in, C_out, 3, stride, 1, t=6, affine=affine, track_running_stats=track_running_stats),
8 | 'mbconv_k5_t3': lambda C_in, C_out, stride, affine, track_running_stats: MBConv(C_in, C_out, 5, stride, 2, t=3, affine=affine, track_running_stats=track_running_stats),
9 | 'mbconv_k5_t6': lambda C_in, C_out, stride, affine, track_running_stats: MBConv(C_in, C_out, 5, stride, 2, t=6, affine=affine, track_running_stats=track_running_stats),
10 | 'mbconv_k7_t3': lambda C_in, C_out, stride, affine, track_running_stats: MBConv(C_in, C_out, 7, stride, 3, t=3, affine=affine, track_running_stats=track_running_stats),
11 | 'mbconv_k7_t6': lambda C_in, C_out, stride, affine, track_running_stats: MBConv(C_in, C_out, 7, stride, 3, t=6, affine=affine, track_running_stats=track_running_stats),
12 | 'basic_block': lambda C_in, C_out, stride, affine, track_running_stats: BasicBlock(C_in, C_out, stride, affine=affine, track_running_stats=track_running_stats),
13 | 'bottle_neck': lambda C_in, C_out, stride, affine, track_running_stats: Bottleneck(C_in, C_out, stride, affine=affine, track_running_stats=track_running_stats),
14 | 'skip_connect': lambda C_in, C_out, stride, affine, track_running_stats: Skip(C_in, C_out, 1, affine=affine, track_running_stats=track_running_stats),
15 | }
16 |
17 |
18 | class MBConv(nn.Module):
19 | def __init__(self, C_in, C_out, kernel_size, stride, padding, t=3, affine=True,
20 | track_running_stats=True, use_se=False):
21 | super(MBConv, self).__init__()
22 | self.t = t
23 | if self.t > 1:
24 | self._expand_conv = nn.Sequential(
25 | nn.Conv2d(C_in, C_in*self.t, kernel_size=1, stride=1, padding=0, groups=1, bias=False),
26 | nn.BatchNorm2d(C_in*self.t, affine=affine, track_running_stats=track_running_stats),
27 | nn.ReLU6(inplace=True))
28 |
29 | self._depthwise_conv = nn.Sequential(
30 | nn.Conv2d(C_in*self.t, C_in*self.t, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in*self.t, bias=False),
31 | nn.BatchNorm2d(C_in*self.t, affine=affine, track_running_stats=track_running_stats),
32 | nn.ReLU6(inplace=True))
33 |
34 | self._project_conv = nn.Sequential(
35 | nn.Conv2d(C_in*self.t, C_out, kernel_size=1, stride=1, padding=0, groups=1, bias=False),
36 | nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats))
37 | else:
38 | self._expand_conv = None
39 |
40 | self._depthwise_conv = nn.Sequential(
41 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False),
42 | nn.BatchNorm2d(C_in, affine=affine, track_running_stats=track_running_stats),
43 | nn.ReLU6(inplace=True))
44 |
45 | self._project_conv = nn.Sequential(
46 | nn.Conv2d(C_in, C_out, 1, 1, 0, bias=False),
47 | nn.BatchNorm2d(C_out))
48 |
49 | def forward(self, x):
50 | input_data = x
51 | if self._expand_conv is not None:
52 | x = self._expand_conv(x)
53 | x = self._depthwise_conv(x)
54 | out_data = self._project_conv(x)
55 |
56 | if out_data.shape == input_data.shape:
57 | return out_data + input_data
58 | else:
59 | return out_data
60 |
61 |
62 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
63 | """3x3 convolution with padding"""
64 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
65 | padding=dilation, groups=groups, bias=False, dilation=dilation)
66 |
67 | def conv1x1(in_planes, out_planes, stride=1):
68 | """1x1 convolution"""
69 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
70 |
71 | class BasicBlock(nn.Module):
72 | def __init__(self, inplanes, planes, stride=1, groups=1,
73 | base_width=64, dilation=1, norm_layer=None,
74 | affine=True, track_running_stats=True):
75 | super(BasicBlock, self).__init__()
76 | if norm_layer is None:
77 | norm_layer = nn.BatchNorm2d
78 | if groups != 1 or base_width != 64:
79 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
80 | if dilation > 1:
81 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
82 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
83 | self.conv1 = conv3x3(inplanes, planes, stride)
84 | self.bn1 = norm_layer(planes, affine=affine, track_running_stats=track_running_stats)
85 | self.relu = nn.ReLU(inplace=True)
86 | self.conv2 = conv3x3(planes, planes)
87 | self.bn2 = norm_layer(planes, affine=affine, track_running_stats=track_running_stats)
88 | self.downsample = None
89 | if stride != 1 or inplanes != planes:
90 | self.downsample = nn.Sequential(
91 | conv1x1(inplanes, planes, stride),
92 | norm_layer(planes, affine=affine, track_running_stats=track_running_stats),
93 | )
94 |
95 | def forward(self, x):
96 | identity = x
97 | out = self.conv1(x)
98 | out = self.bn1(out)
99 | out = self.relu(out)
100 | out = self.conv2(out)
101 | out = self.bn2(out)
102 |
103 | if self.downsample is not None:
104 | identity = self.downsample(x)
105 | out += identity
106 | out = self.relu(out)
107 |
108 | return out
109 |
110 |
111 | class Bottleneck(nn.Module):
112 | def __init__(self, inplanes, planes, stride=1, affine=True, track_running_stats=True):
113 | super(Bottleneck, self).__init__()
114 | if inplanes != 32:
115 | inplanes *= 4
116 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
117 | self.bn1 = nn.BatchNorm2d(planes)
118 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
119 | padding=1, bias=False)
120 | self.bn2 = nn.BatchNorm2d(planes)
121 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
122 | self.bn3 = nn.BatchNorm2d(planes * 4)
123 | self.relu = nn.ReLU(inplace=True)
124 | self.stride = stride
125 | self.downsample = None
126 | if stride != 1 or inplanes != planes*4:
127 | self.downsample = nn.Sequential(
128 | conv1x1(inplanes, planes * 4, stride),
129 | nn.BatchNorm2d(planes * 4, affine=affine, track_running_stats=track_running_stats),
130 | )
131 |
132 | def forward(self, x):
133 | residual = x
134 |
135 | out = self.conv1(x)
136 | out = self.bn1(out)
137 | out = self.relu(out)
138 |
139 | out = self.conv2(out)
140 | out = self.bn2(out)
141 | out = self.relu(out)
142 |
143 | out = self.conv3(out)
144 | out = self.bn3(out)
145 |
146 | if self.downsample is not None:
147 | residual = self.downsample(x)
148 |
149 | out += residual
150 | out = self.relu(out)
151 |
152 | return out
153 |
154 |
155 | class Skip(nn.Module):
156 | def __init__(self, C_in, C_out, stride, affine=True, track_running_stats=True):
157 | super(Skip, self).__init__()
158 | if C_in!=C_out:
159 | skip_conv = nn.Sequential(
160 | nn.Conv2d(C_in, C_out, kernel_size=1, stride=stride, padding=0, groups=1, bias=False),
161 | nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats))
162 | stride = 1
163 | self.op=Identity(stride)
164 |
165 | if C_in!=C_out:
166 | self.op=nn.Sequential(skip_conv, self.op)
167 |
168 | def forward(self,x):
169 | return self.op(x)
170 |
171 | class Identity(nn.Module):
172 | def __init__(self, stride):
173 | super(Identity, self).__init__()
174 | self.stride = stride
175 |
176 | def forward(self, x):
177 | if self.stride == 1:
178 | return x
179 | else:
180 | return x[:, :, ::self.stride, ::self.stride]
--------------------------------------------------------------------------------
/run_apis/trainer.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import time
3 |
4 | import torch.nn as nn
5 |
6 | from dataset.prefetch_data import data_prefetcher
7 | from tools import utils
8 |
9 |
10 | class Trainer(object):
11 | def __init__(self, train_data, val_data, optimizer=None, criterion=None,
12 | scheduler=None, config=None, report_freq=None):
13 | self.train_data = train_data
14 | self.val_data = val_data
15 | self.optimizer = optimizer
16 | self.criterion = criterion
17 | self.scheduler = scheduler
18 | self.config = config
19 | self.report_freq = report_freq
20 |
21 | def train(self, model, epoch):
22 | objs = utils.AverageMeter()
23 | top1 = utils.AverageMeter()
24 | top5 = utils.AverageMeter()
25 | data_time = utils.AverageMeter()
26 | batch_time = utils.AverageMeter()
27 | model.train()
28 | start = time.time()
29 |
30 | prefetcher = data_prefetcher(self.train_data)
31 | input, target = prefetcher.next()
32 | step = 0
33 | while input is not None:
34 | data_t = time.time() - start
35 | self.scheduler.step()
36 | n = input.size(0)
37 | if step==0:
38 | logging.info('epoch %d lr %e', epoch, self.optimizer.param_groups[0]['lr'])
39 | self.optimizer.zero_grad()
40 |
41 | logits= model(input)
42 | if self.config.optim.label_smooth:
43 | loss = self.criterion(logits, target, self.config.optim.smooth_alpha)
44 | else:
45 | loss = self.criterion(logits, target)
46 |
47 | loss.backward()
48 | if self.config.optim.use_grad_clip:
49 | nn.utils.clip_grad_norm_(model.parameters(), self.config.optim.grad_clip)
50 | self.optimizer.step()
51 |
52 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
53 |
54 | batch_t = time.time() - start
55 | start = time.time()
56 |
57 | objs.update(loss.item(), n)
58 | top1.update(prec1.item(), n)
59 | top5.update(prec5.item(), n)
60 | data_time.update(data_t)
61 | batch_time.update(batch_t)
62 | if step!=0 and step % self.report_freq == 0:
63 | logging.info(
64 | 'Train epoch %03d step %03d | loss %.4f top1_acc %.2f top5_acc %.2f | batch_time %.3f data_time %.3f',
65 | epoch, step, objs.avg, top1.avg, top5.avg, batch_time.avg, data_time.avg)
66 | input, target = prefetcher.next()
67 | step += 1
68 | logging.info('EPOCH%d Train_acc top1 %.2f top5 %.2f batch_time %.3f data_time %.3f',
69 | epoch, top1.avg, top5.avg, batch_time.avg, data_time.avg)
70 |
71 | return top1.avg, top5.avg, objs.avg, batch_time.avg, data_time.avg
72 |
73 |
74 | def infer(self, model, epoch=0):
75 | top1 = utils.AverageMeter()
76 | top5 = utils.AverageMeter()
77 | data_time = utils.AverageMeter()
78 | batch_time = utils.AverageMeter()
79 | model.eval()
80 |
81 | start = time.time()
82 | prefetcher = data_prefetcher(self.val_data)
83 | input, target = prefetcher.next()
84 | step = 0
85 | while input is not None:
86 | step += 1
87 | data_t = time.time() - start
88 | n = input.size(0)
89 |
90 | logits = model(input)
91 |
92 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
93 |
94 | batch_t = time.time() - start
95 | top1.update(prec1.item(), n)
96 | top5.update(prec5.item(), n)
97 | data_time.update(data_t)
98 | batch_time.update(batch_t)
99 |
100 | if step % self.report_freq == 0:
101 | logging.info(
102 | 'Val epoch %03d step %03d | top1_acc %.2f top5_acc %.2f | batch_time %.3f data_time %.3f',
103 | epoch, step, top1.avg, top5.avg, batch_time.avg, data_time.avg)
104 | start = time.time()
105 | input, target = prefetcher.next()
106 |
107 | logging.info('EPOCH%d Valid_acc top1 %.2f top5 %.2f batch_time %.3f data_time %.3f',
108 | epoch, top1.avg, top5.avg, batch_time.avg, data_time.avg)
109 | return top1.avg, top5.avg, batch_time.avg, data_time.avg
110 |
111 |
112 | class SearchTrainer(object):
113 | def __init__(self, train_data, val_data, search_optim, criterion, scheduler, config, args):
114 | self.train_data = train_data
115 | self.val_data = val_data
116 | self.search_optim = search_optim
117 | self.criterion = criterion
118 | self.scheduler = scheduler
119 | self.sub_obj_type = config.optim.sub_obj.type
120 | self.args = args
121 |
122 |
123 | def train(self, model, epoch, optim_obj='Weights', search_stage=0):
124 | assert optim_obj in ['Weights', 'Arch']
125 | objs = utils.AverageMeter()
126 | top1 = utils.AverageMeter()
127 | top5 = utils.AverageMeter()
128 | sub_obj_avg = utils.AverageMeter()
129 | data_time = utils.AverageMeter()
130 | batch_time = utils.AverageMeter()
131 | model.train()
132 |
133 | start = time.time()
134 | if optim_obj == 'Weights':
135 | prefetcher = data_prefetcher(self.train_data)
136 | elif optim_obj == 'Arch':
137 | prefetcher = data_prefetcher(self.val_data)
138 |
139 | input, target = prefetcher.next()
140 | step = 0
141 | while input is not None:
142 | input, target = input.cuda(), target.cuda()
143 | data_t = time.time() - start
144 | n = input.size(0)
145 | if optim_obj == 'Weights':
146 | self.scheduler.step()
147 | if step==0:
148 | logging.info('epoch %d weight_lr %e', epoch, self.search_optim.weight_optimizer.param_groups[0]['lr'])
149 | logits, loss, sub_obj = self.search_optim.weight_step(input, target, model, search_stage)
150 | elif optim_obj == 'Arch':
151 | if step==0:
152 | logging.info('epoch %d arch_lr %e', epoch, self.search_optim.arch_optimizer.param_groups[0]['lr'])
153 | logits, loss, sub_obj = self.search_optim.arch_step(input, target, model, search_stage)
154 |
155 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
156 | del logits, input, target
157 |
158 | batch_t = time.time() - start
159 | objs.update(loss, n)
160 | top1.update(prec1.item(), n)
161 | top5.update(prec5.item(), n)
162 | sub_obj_avg.update(sub_obj)
163 | data_time.update(data_t)
164 | batch_time.update(batch_t)
165 |
166 | if step!=0 and step % self.args.report_freq == 0:
167 | logging.info(
168 | 'Train%s epoch %03d step %03d | loss %.4f %s %.2f top1_acc %.2f top5_acc %.2f | batch_time %.3f data_time %.3f',
169 | optim_obj ,epoch, step, objs.avg, self.sub_obj_type, sub_obj_avg.avg,
170 | top1.avg, top5.avg, batch_time.avg, data_time.avg)
171 | start = time.time()
172 | step += 1
173 | input, target = prefetcher.next()
174 | return top1.avg, top5.avg, objs.avg, sub_obj_avg.avg, batch_time.avg
175 |
176 |
177 | def infer(self, model, epoch):
178 | objs = utils.AverageMeter()
179 | top1 = utils.AverageMeter()
180 | top5 = utils.AverageMeter()
181 | sub_obj_avg = utils.AverageMeter()
182 | data_time = utils.AverageMeter()
183 | batch_time = utils.AverageMeter()
184 |
185 | model.train() # don't use running_mean and running_var during search
186 | start = time.time()
187 | prefetcher = data_prefetcher(self.val_data)
188 | input, target = prefetcher.next()
189 | step = 0
190 | while input is not None:
191 | step += 1
192 | data_t = time.time() - start
193 | n = input.size(0)
194 |
195 | logits, loss, sub_obj = self.search_optim.valid_step(input, target, model)
196 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
197 |
198 | batch_t = time.time() - start
199 | objs.update(loss, n)
200 | top1.update(prec1.item(), n)
201 | top5.update(prec5.item(), n)
202 | sub_obj_avg.update(sub_obj)
203 | data_time.update(data_t)
204 | batch_time.update(batch_t)
205 |
206 | if step % self.args.report_freq == 0:
207 | logging.info(
208 | 'Val epoch %03d step %03d | loss %.4f %s %.2f top1_acc %.2f top5_acc %.2f | batch_time %.3f data_time %.3f',
209 | epoch, step, objs.avg, self.sub_obj_type, sub_obj_avg.avg, top1.avg, top5.avg,
210 | batch_time.avg, data_time.avg)
211 | start = time.time()
212 | input, target = prefetcher.next()
213 |
214 | return top1.avg, top5.avg, objs.avg, sub_obj_avg.avg, batch_time.avg
215 |
--------------------------------------------------------------------------------
/tools/config_yaml.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-present, Facebook, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | ##############################################################################
15 | #
16 | # Based on:
17 | # --------------------------------------------------------
18 | # Fast R-CNN
19 | # Copyright (c) 2015 Microsoft
20 | # Licensed under The MIT License [see LICENSE for details]
21 | # Written by Ross Girshick
22 | # --------------------------------------------------------
23 |
24 | from __future__ import absolute_import
25 | from __future__ import division
26 | from __future__ import print_function
27 | from __future__ import unicode_literals
28 |
29 | from ast import literal_eval
30 | from io import IOBase
31 | #from future.utils import iteritems
32 | import copy
33 | import logging
34 | import numpy as np
35 | import os
36 | import os.path as osp
37 | import yaml
38 |
39 | from .collections import AttrDict
40 | from .io import cache_url
41 |
42 | logger = logging.getLogger(__name__)
43 |
44 | def load_cfg(cfg_to_load):
45 | """Wrapper around yaml.load used for maintaining backward compatibility"""
46 | if isinstance(cfg_to_load, IOBase):
47 | cfg_to_load = ''.join(cfg_to_load.readlines())
48 | return yaml.load(cfg_to_load)
49 |
50 |
51 | def load_cfg_to_dict(cfg_filename):
52 | with open(cfg_filename, 'r') as f:
53 | # yaml_cfg = AttrDict(load_cfg(f))
54 | yaml_cfg = load_cfg(f)
55 | return yaml_cfg
56 |
57 |
58 | def merge_cfg_from_file(cfg_filename, global_config):
59 | """Load a yaml config file and merge it into the global config."""
60 | with open(cfg_filename, 'r') as f:
61 | yaml_cfg = AttrDict(load_cfg(f))
62 | _merge_a_into_b(yaml_cfg, global_config)
63 |
64 |
65 | def merge_cfg_from_cfg(cfg_other, global_config):
66 | """Merge `cfg_other` into the global config."""
67 | _merge_a_into_b(cfg_other, global_config)
68 |
69 | def update_cfg_from_file(cfg_filename, global_config):
70 | with open(cfg_filename, 'r') as f:
71 | yaml_cfg = AttrDict(load_cfg(f))
72 | update_cfg_from_cfg(yaml_cfg, global_config)
73 |
74 | def update_cfg_from_cfg(cfg_other, global_config, stack=None):
75 | assert isinstance(cfg_other, AttrDict), \
76 | '`a` (cur type {}) must be an instance of {}'.format(type(a), AttrDict)
77 | assert isinstance(global_config, AttrDict), \
78 | '`b` (cur type {}) must be an instance of {}'.format(type(b), AttrDict)
79 |
80 | for k, v_ in cfg_other.items():
81 | full_key = '.'.join(stack) + '.' + k if stack is not None else k
82 |
83 | v = copy.deepcopy(v_)
84 | v = _decode_cfg_value(v)
85 |
86 | if k not in global_config:
87 | global_config[k] = v
88 | if isinstance(v, AttrDict):
89 | try:
90 | stack_push = [k] if stack is None else stack + [k]
91 | update_cfg_from_cfg(v, global_config[k], stack=stack_push)
92 | except BaseException:
93 | raise
94 |
95 | else:
96 | # Recursively merge dicts
97 | v = _check_and_coerce_cfg_value_type(v, global_config[k], k, full_key)
98 | if isinstance(v, AttrDict):
99 | try:
100 | stack_push = [k] if stack is None else stack + [k]
101 | update_cfg_from_cfg(v, global_config[k], stack=stack_push)
102 | except BaseException:
103 | raise
104 | else:
105 | global_config[k] = v
106 |
107 |
108 | def merge_cfg_from_list(cfg_list, global_config):
109 | """Merge config keys, values in a list (e.g., from command line) into the
110 | global config. For example, `cfg_list = ['TEST.NMS', 0.5]`.
111 | """
112 | assert len(cfg_list) % 2 == 0
113 | for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]):
114 | if _key_is_deprecated(full_key):
115 | continue
116 | if _key_is_renamed(full_key):
117 | _raise_key_rename_error(full_key)
118 | key_list = full_key.split('.')
119 | d = global_config
120 | for subkey in key_list[:-1]:
121 | assert subkey in d, 'Non-existent key: {}'.format(full_key)
122 | d = d[subkey]
123 | subkey = key_list[-1]
124 | assert subkey in d, 'Non-existent key: {}'.format(full_key)
125 | value = _decode_cfg_value(v)
126 | value = _check_and_coerce_cfg_value_type(
127 | value, d[subkey], subkey, full_key
128 | )
129 | d[subkey] = value
130 |
131 |
132 | def _merge_a_into_b(a, b, stack=None):
133 | """Merge config dictionary a into config dictionary b, clobbering the
134 | options in b whenever they are also specified in a.
135 | """
136 | assert isinstance(a, AttrDict), \
137 | '`a` (cur type {}) must be an instance of {}'.format(type(a), AttrDict)
138 | assert isinstance(b, AttrDict), \
139 | '`b` (cur type {}) must be an instance of {}'.format(type(b), AttrDict)
140 |
141 | for k, v_ in a.items():
142 | full_key = '.'.join(stack) + '.' + k if stack is not None else k
143 | # a must specify keys that are in b
144 | if k not in b:
145 | raise KeyError('Non-existent config key: {}'.format(full_key))
146 |
147 | v = copy.deepcopy(v_)
148 | v = _decode_cfg_value(v)
149 | v = _check_and_coerce_cfg_value_type(v, b[k], k, full_key)
150 |
151 | # Recursively merge dicts
152 | if isinstance(v, AttrDict):
153 | try:
154 | stack_push = [k] if stack is None else stack + [k]
155 | _merge_a_into_b(v, b[k], stack=stack_push)
156 | except BaseException:
157 | raise
158 | else:
159 | b[k] = v
160 |
161 | def _key_is_deprecated(full_key):
162 | if full_key in _DEPRECATED_KEYS:
163 | logger.warn(
164 | 'Deprecated config key (ignoring): {}'.format(full_key)
165 | )
166 | return True
167 | return False
168 |
169 |
170 | def _key_is_renamed(full_key):
171 | return full_key in _RENAMED_KEYS
172 |
173 |
174 | def _raise_key_rename_error(full_key):
175 | new_key = _RENAMED_KEYS[full_key]
176 | if isinstance(new_key, tuple):
177 | msg = ' Note: ' + new_key[1]
178 | new_key = new_key[0]
179 | else:
180 | msg = ''
181 | raise KeyError(
182 | 'Key {} was renamed to {}; please update your config.{}'.
183 | format(full_key, new_key, msg)
184 | )
185 |
186 |
187 | def _decode_cfg_value(v):
188 | """Decodes a raw config value (e.g., from a yaml config files or command
189 | line argument) into a Python object.
190 | """
191 | # Configs parsed from raw yaml will contain dictionary keys that need to be
192 | # converted to AttrDict objects
193 | if isinstance(v, dict):
194 | return AttrDict(v)
195 | # All remaining processing is only applied to strings
196 | # Try to interpret `v` as a:
197 | # string, number, tuple, list, dict, boolean, or None
198 | try:
199 | v = literal_eval(v)
200 | # The following two excepts allow v to pass through when it represents a
201 | # string.
202 | #
203 | # Longer explanation:
204 | # The type of v is always a string (before calling literal_eval), but
205 | # sometimes it *represents* a string and other times a data structure, like
206 | # a list. In the case that v represents a string, what we got back from the
207 | # yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is
208 | # ok with '"foo"', but will raise a ValueError if given 'foo'. In other
209 | # cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval
210 | # will raise a SyntaxError.
211 | except ValueError:
212 | pass
213 | except SyntaxError:
214 | pass
215 | return v
216 |
217 |
218 | def _check_and_coerce_cfg_value_type(value_a, value_b, key, full_key):
219 | """Checks that `value_a`, which is intended to replace `value_b` is of the
220 | right type. The type is correct if it matches exactly or is one of a few
221 | cases in which the type can be easily coerced.
222 | """
223 | # The types must match (with some exceptions)
224 | type_b = type(value_b)
225 | type_a = type(value_a)
226 | if type_a is type_b:
227 | return value_a
228 |
229 | # Exceptions: numpy arrays, strings, tuple<->list
230 | if isinstance(value_b, np.ndarray):
231 | value_a = np.array(value_a, dtype=value_b.dtype)
232 | elif isinstance(value_a, tuple) and isinstance(value_b, list):
233 | value_a = list(value_a)
234 | elif isinstance(value_a, list) and isinstance(value_b, tuple):
235 | value_a = tuple(value_a)
236 | else:
237 | raise ValueError(
238 | 'Type mismatch ({} vs. {}) with values ({} vs. {}) for config '
239 | 'key: {}'.format(type_b, type_a, value_b, value_a, full_key)
240 | )
241 | return value_a
242 |
--------------------------------------------------------------------------------
/models/dropped_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class MixedOp(nn.Module):
7 | def __init__(self, dropped_mixed_ops, softmax_temp=1.):
8 | super(MixedOp, self).__init__()
9 | self.softmax_temp = softmax_temp
10 | self._ops = nn.ModuleList()
11 | for op in dropped_mixed_ops:
12 | self._ops.append(op)
13 |
14 | def forward(self, x, alphas, branch_indices, mixed_sub_obj):
15 | op_weights = torch.stack([alphas[branch_index] for branch_index in branch_indices])
16 | op_weights = F.softmax(op_weights / self.softmax_temp, dim=-1)
17 | return sum(op_weight * op(x) for op_weight, op in zip(op_weights, self._ops)), \
18 | sum(op_weight * mixed_sub_obj[branch_index] for op_weight, branch_index in zip(
19 | op_weights, branch_indices))
20 |
21 |
22 | class HeadLayer(nn.Module):
23 | def __init__(self, dropped_mixed_ops, softmax_temp=1.):
24 | super(HeadLayer, self).__init__()
25 | self.head_branches = nn.ModuleList()
26 | for mixed_ops in dropped_mixed_ops:
27 | self.head_branches.append(MixedOp(mixed_ops, softmax_temp))
28 |
29 | def forward(self, inputs, betas, alphas, head_index, head_sub_obj):
30 | head_data = []
31 | count_sub_obj = []
32 | for input_data, head_branch, alpha, head_idx, branch_sub_obj in zip(
33 | inputs, self.head_branches, alphas, head_index, head_sub_obj):
34 | data, sub_obj = head_branch(input_data, alpha, head_idx, branch_sub_obj)
35 | head_data.append(data)
36 | count_sub_obj.append(sub_obj)
37 |
38 | return sum(branch_weight * data for branch_weight, data in zip(betas, head_data)), \
39 | count_sub_obj
40 |
41 |
42 | class StackLayers(nn.Module):
43 | def __init__(self, num_block_layers, dropped_mixed_ops, softmax_temp=1.):
44 | super(StackLayers, self).__init__()
45 |
46 | if num_block_layers != 0:
47 | self.stack_layers = nn.ModuleList()
48 | for i in range(num_block_layers):
49 | self.stack_layers.append(MixedOp(dropped_mixed_ops[i], softmax_temp))
50 | else:
51 | self.stack_layers = None
52 |
53 | def forward(self, x, alphas, stack_index, stack_sub_obj):
54 |
55 | if self.stack_layers is not None:
56 | count_sub_obj = 0
57 | for stack_layer, alpha, stack_idx, layer_sub_obj in zip(self.stack_layers, alphas, stack_index, stack_sub_obj):
58 | x, sub_obj = stack_layer(x, alpha, stack_idx, layer_sub_obj)
59 | count_sub_obj += sub_obj
60 | return x, count_sub_obj
61 |
62 | else:
63 | return x, 0
64 |
65 |
66 | class Block(nn.Module):
67 | def __init__(self, num_block_layers, dropped_mixed_ops, softmax_temp=1.):
68 | super(Block, self).__init__()
69 | self.head_layer = HeadLayer(dropped_mixed_ops[0], softmax_temp)
70 | self.stack_layers = StackLayers(num_block_layers, dropped_mixed_ops[1], softmax_temp)
71 |
72 | def forward(self, inputs, betas, head_alphas, stack_alphas, head_index, stack_index, block_sub_obj):
73 | x, head_sub_obj = self.head_layer(inputs, betas, head_alphas, head_index, block_sub_obj[0])
74 | x, stack_sub_obj = self.stack_layers(x, stack_alphas, stack_index, block_sub_obj[1])
75 |
76 | return x, [head_sub_obj, stack_sub_obj]
77 |
78 |
79 | class Dropped_Network(nn.Module):
80 | def __init__(self, super_model, alpha_head_index=None, alpha_stack_index=None, softmax_temp=1.):
81 | super(Dropped_Network, self).__init__()
82 |
83 | self.softmax_temp = softmax_temp
84 | # static modules loading
85 | self.input_block = super_model.module.input_block
86 | if hasattr(super_model.module, 'head_block'):
87 | self.head_block = super_model.module.head_block
88 | self.conv1_1_block = super_model.module.conv1_1_block
89 | self.global_pooling = super_model.module.global_pooling
90 | self.classifier = super_model.module.classifier
91 |
92 | # architecture parameters loading
93 | self.alpha_head_weights = super_model.module.alpha_head_weights
94 | self.alpha_stack_weights = super_model.module.alpha_stack_weights
95 | self.beta_weights = super_model.module.beta_weights
96 | self.alpha_head_index = alpha_head_index if alpha_head_index is not None else \
97 | super_model.module.alpha_head_index
98 | self.alpha_stack_index = alpha_stack_index if alpha_stack_index is not None else \
99 | super_model.module.alpha_stack_index
100 |
101 | # config loading
102 | self.config = super_model.module.config
103 | self.input_configs = super_model.module.input_configs
104 | self.output_configs = super_model.module.output_configs
105 | self.sub_obj_list = super_model.module.sub_obj_list
106 |
107 | # dynamic blocks loading
108 | self.blocks = nn.ModuleList()
109 |
110 | for i, block in enumerate(super_model.module.blocks):
111 | input_config = self.input_configs[i]
112 |
113 | dropped_mixed_ops = []
114 | # for the head layers
115 | head_mixed_ops = []
116 | for j, head_index in enumerate(self.alpha_head_index[i]):
117 | head_mixed_ops.append([block.head_layer.head_branches[j]._ops[k] for k in head_index])
118 | dropped_mixed_ops.append(head_mixed_ops)
119 |
120 | stack_mixed_ops = []
121 | for j, stack_index in enumerate(self.alpha_stack_index[i]):
122 | stack_mixed_ops.append([block.stack_layers.stack_layers[j]._ops[k] for k in stack_index])
123 | dropped_mixed_ops.append(stack_mixed_ops)
124 |
125 | self.blocks.append(Block(
126 | input_config['num_stack_layers'],
127 | dropped_mixed_ops
128 | ))
129 |
130 | def forward(self, x):
131 | '''
132 | To approximate the the total sub_obj(latency/flops), we firstly create the obj list for blocks
133 | as follows:
134 | [[[head_flops_1, head_flops_2, ...], stack_flops], ...]
135 | Then we compute the whole obj approximation from the end to the beginning. For block b,
136 | flops'_b = sum(beta_{bi} * (head_flops_{bi} + stack_flops_{i}) for i in out_idx[b])
137 | The total flops equals flops'_0
138 | '''
139 | sub_obj_list = []
140 | block_datas = []
141 | branch_weights = []
142 | for betas in self.beta_weights:
143 | branch_weights.append(F.softmax(betas / self.softmax_temp, dim=-1))
144 |
145 | block_data = self.input_block(x)
146 | if hasattr(self, 'head_block'):
147 | block_data = self.head_block(block_data)
148 |
149 | block_datas.append(block_data)
150 | sub_obj_list.append([[],torch.tensor(self.sub_obj_list[0]).cuda()])
151 |
152 | for i in range(len(self.blocks)+1):
153 | config = self.input_configs[i]
154 | inputs = [block_datas[i] for i in config['in_block_idx']]
155 | betas = [branch_weights[block_id][beta_id]
156 | for block_id, beta_id in zip(config['in_block_idx'], config['beta_idx'])]
157 |
158 | if i == len(self.blocks):
159 | block_data, block_sub_obj = self.conv1_1_block(inputs, betas, self.sub_obj_list[2])
160 |
161 | else:
162 | block_data, block_sub_obj = self.blocks[i](inputs,
163 | betas,
164 | self.alpha_head_weights[i],
165 | self.alpha_stack_weights[i],
166 | self.alpha_head_index[i],
167 | self.alpha_stack_index[i],
168 | self.sub_obj_list[1][i])
169 | block_datas.append(block_data)
170 | sub_obj_list.append(block_sub_obj)
171 |
172 | out = self.global_pooling(block_datas[-1])
173 | logits = self.classifier(out.view(out.size(0),-1))
174 |
175 | # chained cost estimation
176 | for i, out_config in enumerate(self.output_configs[::-1]):
177 | block_id = len(self.output_configs)-i-1
178 | sum_obj = []
179 | for j, out_id in enumerate(out_config['out_id']):
180 | head_id = self.input_configs[out_id-1]['in_block_idx'].index(block_id)
181 | head_obj = sub_obj_list[out_id][0][head_id]
182 | stack_obj = sub_obj_list[out_id][1]
183 | sub_obj_j = branch_weights[block_id][j] * (head_obj + stack_obj)
184 | sum_obj.append(sub_obj_j)
185 | sub_obj_list[-i-2][1] += sum(sum_obj)
186 |
187 | net_sub_obj = torch.tensor(self.sub_obj_list[-1]).cuda() + sub_obj_list[0][1]
188 | return logits, net_sub_obj.expand(1)
189 |
190 | @property
191 | def arch_parameters(self):
192 | arch_params = nn.ParameterList()
193 | arch_params.extend(self.beta_weights)
194 | arch_params.extend(self.alpha_head_weights)
195 | arch_params.extend(self.alpha_stack_weights)
196 | return arch_params
197 |
198 | @property
199 | def arch_alpha_params(self):
200 | alpha_params = nn.ParameterList()
201 | alpha_params.extend(self.alpha_head_weights)
202 | alpha_params.extend(self.alpha_stack_weights)
203 | return alpha_params
204 |
--------------------------------------------------------------------------------
/tools/multadds_count.py:
--------------------------------------------------------------------------------
1 | import torch
2 | # Original implementation:
3 | # https://github.com/warmspringwinds/pytorch-segmentation-detection/blob/master/pytorch_segmentation_detection/utils/flops_benchmark.py
4 |
5 | # ---- Public functions
6 |
7 | def comp_multadds(model, input_size=(3,224,224)):
8 | input_size = (1,) + tuple(input_size)
9 | model = model.cuda()
10 | input_data = torch.randn(input_size).cuda()
11 | model = add_flops_counting_methods(model)
12 | model.start_flops_count()
13 | with torch.no_grad():
14 | _ = model(input_data)
15 |
16 | mult_adds = model.compute_average_flops_cost() / 1e6
17 | return mult_adds
18 |
19 |
20 | def comp_multadds_fw(model, input_data, use_gpu=True):
21 | model = add_flops_counting_methods(model)
22 | if use_gpu:
23 | model = model.cuda()
24 | model.start_flops_count()
25 | with torch.no_grad():
26 | output_data = model(input_data)
27 |
28 | mult_adds = model.compute_average_flops_cost() / 1e6
29 | return mult_adds, output_data
30 |
31 |
32 | def add_flops_counting_methods(net_main_module):
33 | """Adds flops counting functions to an existing model. After that
34 | the flops count should be activated and the model should be run on an input
35 | image.
36 | Example:
37 | fcn = add_flops_counting_methods(fcn)
38 | fcn = fcn.cuda().train()
39 | fcn.start_flops_count()
40 | _ = fcn(batch)
41 | fcn.compute_average_flops_cost() / 1e9 / 2 # Result in GFLOPs per image in batch
42 | Important: dividing by 2 only works for resnet models -- see below for the details
43 | of flops computation.
44 | Attention: we are counting multiply-add as two flops in this work, because in
45 | most resnet models convolutions are bias-free (BN layers act as bias there)
46 | and it makes sense to count muliply and add as separate flops therefore.
47 | This is why in the above example we divide by 2 in order to be consistent with
48 | most modern benchmarks. For example in "Spatially Adaptive Computatin Time for Residual
49 | Networks" by Figurnov et al multiply-add was counted as two flops.
50 | This module computes the average flops which is necessary for dynamic networks which
51 | have different number of executed layers. For static networks it is enough to run the network
52 | once and get statistics (above example).
53 | Implementation:
54 | The module works by adding batch_count to the main module which tracks the sum
55 | of all batch sizes that were run through the network.
56 | Also each convolutional layer of the network tracks the overall number of flops
57 | performed.
58 | The parameters are updated with the help of registered hook-functions which
59 | are being called each time the respective layer is executed.
60 | Parameters
61 | ----------
62 | net_main_module : torch.nn.Module
63 | Main module containing network
64 | Returns
65 | -------
66 | net_main_module : torch.nn.Module
67 | Updated main module with new methods/attributes that are used
68 | to compute flops.
69 | """
70 |
71 | # adding additional methods to the existing module object,
72 | # this is done this way so that each function has access to self object
73 | net_main_module.start_flops_count = start_flops_count.__get__(net_main_module)
74 | net_main_module.stop_flops_count = stop_flops_count.__get__(net_main_module)
75 | net_main_module.reset_flops_count = reset_flops_count.__get__(net_main_module)
76 | net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__(net_main_module)
77 |
78 | net_main_module.reset_flops_count()
79 |
80 | # Adding varialbles necessary for masked flops computation
81 | net_main_module.apply(add_flops_mask_variable_or_reset)
82 |
83 | return net_main_module
84 |
85 |
86 | def compute_average_flops_cost(self):
87 | """
88 | A method that will be available after add_flops_counting_methods() is called
89 | on a desired net object.
90 | Returns current mean flops consumption per image.
91 | """
92 |
93 | batches_count = self.__batch_counter__
94 |
95 | flops_sum = 0
96 |
97 | for module in self.modules():
98 |
99 | if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
100 | flops_sum += module.__flops__
101 |
102 |
103 | return flops_sum / batches_count
104 |
105 |
106 | def start_flops_count(self):
107 | """
108 | A method that will be available after add_flops_counting_methods() is called
109 | on a desired net object.
110 | Activates the computation of mean flops consumption per image.
111 | Call it before you run the network.
112 | """
113 |
114 | add_batch_counter_hook_function(self)
115 |
116 | self.apply(add_flops_counter_hook_function)
117 |
118 |
119 | def stop_flops_count(self):
120 | """
121 | A method that will be available after add_flops_counting_methods() is called
122 | on a desired net object.
123 | Stops computing the mean flops consumption per image.
124 | Call whenever you want to pause the computation.
125 | """
126 |
127 | remove_batch_counter_hook_function(self)
128 |
129 | self.apply(remove_flops_counter_hook_function)
130 |
131 |
132 | def reset_flops_count(self):
133 | """
134 | A method that will be available after add_flops_counting_methods() is called
135 | on a desired net object.
136 | Resets statistics computed so far.
137 | """
138 |
139 | add_batch_counter_variables_or_reset(self)
140 |
141 | self.apply(add_flops_counter_variable_or_reset)
142 |
143 |
144 | def add_flops_mask(module, mask):
145 | def add_flops_mask_func(module):
146 | if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
147 | module.__mask__ = mask
148 |
149 | module.apply(add_flops_mask_func)
150 |
151 |
152 | def remove_flops_mask(module):
153 | module.apply(add_flops_mask_variable_or_reset)
154 |
155 |
156 | # ---- Internal functions
157 |
158 |
159 | def conv_flops_counter_hook(conv_module, input, output):
160 | # Can have multiple inputs, getting the first one
161 | input = input[0]
162 |
163 | batch_size = input.shape[0]
164 | output_height, output_width = output.shape[2:]
165 |
166 | kernel_height, kernel_width = conv_module.kernel_size
167 | in_channels = conv_module.in_channels
168 | out_channels = conv_module.out_channels
169 |
170 | conv_per_position_flops = (kernel_height * kernel_width * in_channels * out_channels) / conv_module.groups
171 |
172 | active_elements_count = batch_size * output_height * output_width
173 |
174 | if conv_module.__mask__ is not None:
175 | # (b, 1, h, w)
176 | flops_mask = conv_module.__mask__.expand(batch_size, 1, output_height, output_width)
177 | active_elements_count = flops_mask.sum()
178 |
179 | overall_conv_flops = conv_per_position_flops * active_elements_count
180 |
181 | bias_flops = 0
182 |
183 | if conv_module.bias is not None:
184 | bias_flops = out_channels * active_elements_count
185 |
186 | overall_flops = overall_conv_flops + bias_flops
187 |
188 | conv_module.__flops__ += overall_flops
189 |
190 |
191 | def linear_flops_counter_hook(linear_module, input, output):
192 |
193 | input = input[0]
194 | batch_size = input.shape[0]
195 | overall_flops = linear_module.in_features * linear_module.out_features * batch_size
196 |
197 | # bias_flops = 0
198 |
199 | # if conv_module.bias is not None:
200 | # bias_flops = out_channels * active_elements_count
201 |
202 | # overall_flops = overall_conv_flops + bias_flops
203 |
204 | linear_module.__flops__ += overall_flops
205 |
206 |
207 | def batch_counter_hook(module, input, output):
208 | # Can have multiple inputs, getting the first one
209 | input = input[0]
210 |
211 | batch_size = input.shape[0]
212 |
213 | module.__batch_counter__ += batch_size
214 |
215 |
216 | def add_batch_counter_variables_or_reset(module):
217 | module.__batch_counter__ = 0
218 |
219 |
220 | def add_batch_counter_hook_function(module):
221 | if hasattr(module, '__batch_counter_handle__'):
222 | return
223 |
224 | handle = module.register_forward_hook(batch_counter_hook)
225 | module.__batch_counter_handle__ = handle
226 |
227 |
228 | def remove_batch_counter_hook_function(module):
229 | if hasattr(module, '__batch_counter_handle__'):
230 | module.__batch_counter_handle__.remove()
231 |
232 | del module.__batch_counter_handle__
233 |
234 |
235 | def add_flops_counter_variable_or_reset(module):
236 | if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
237 | module.__flops__ = 0
238 |
239 |
240 | def add_flops_counter_hook_function(module):
241 | if isinstance(module, torch.nn.Conv2d):
242 | if hasattr(module, '__flops_handle__'):
243 | return
244 |
245 | handle = module.register_forward_hook(conv_flops_counter_hook)
246 | module.__flops_handle__ = handle
247 | elif isinstance(module, torch.nn.Linear):
248 |
249 | if hasattr(module, '__flops_handle__'):
250 | return
251 |
252 | handle = module.register_forward_hook(linear_flops_counter_hook)
253 | module.__flops_handle__ = handle
254 |
255 |
256 | def remove_flops_counter_hook_function(module):
257 | if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
258 |
259 | if hasattr(module, '__flops_handle__'):
260 | module.__flops_handle__.remove()
261 |
262 | del module.__flops_handle__
263 |
264 |
265 | # --- Masked flops counting
266 |
267 |
268 | # Also being run in the initialization
269 | def add_flops_mask_variable_or_reset(module):
270 | if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
271 | module.__mask__ = None
272 |
273 |
--------------------------------------------------------------------------------
/run_apis/search.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import ast
3 | import importlib
4 | import logging
5 | import os
6 | import pprint
7 | import sys
8 | import time
9 |
10 | import numpy as np
11 | import torch
12 | import torch.backends.cudnn as cudnn
13 | import torch.nn as nn
14 | from tensorboardX import SummaryWriter
15 |
16 | from configs.search_config import search_cfg
17 | from configs.imagenet_train_cfg import cfg
18 | from dataset import imagenet_data
19 | from models import model_derived
20 | from tools import utils
21 | from tools.config_yaml import merge_cfg_from_file, update_cfg_from_cfg
22 | from tools.lr_scheduler import get_lr_scheduler
23 | from tools.multadds_count import comp_multadds
24 |
25 | from .optimizer import Optimizer
26 | from .trainer import SearchTrainer
27 |
28 | if __name__ == '__main__':
29 |
30 | parser = argparse.ArgumentParser("Search_Configs")
31 | parser.add_argument('--report_freq', type=float, default=100, help='report frequency')
32 | parser.add_argument('--data_path', type=str, default='../data', help='location of the data corpus')
33 | parser.add_argument('--save', type=str, default='../', help='experiment name')
34 | parser.add_argument('--tb_path', type=str, default='', help='tensorboard output path')
35 | parser.add_argument('--job_name', type=str, default='', help='job_name')
36 | parser.add_argument('-c', '--config', metavar='C', default=None, help='The Configuration file')
37 |
38 | args = parser.parse_args()
39 |
40 | update_cfg_from_cfg(search_cfg, cfg)
41 | if args.config is not None:
42 | merge_cfg_from_file(args.config, cfg)
43 | config = cfg
44 |
45 | if args.job_name != '':
46 | args.job_name = time.strftime("%Y%m%d-%H%M%S-") + args.job_name
47 | args.save = os.path.join(args.save, args.job_name)
48 | utils.create_exp_dir(args.save)
49 | os.system('cp -r ./* '+args.save)
50 | args.save = os.path.join(args.save, 'output')
51 | utils.create_exp_dir(args.save)
52 | else:
53 | args.save = os.path.join(args.save, 'output')
54 | utils.create_exp_dir(args.save)
55 |
56 | if args.tb_path == '':
57 | args.tb_path = args.save
58 |
59 | log_format = '%(asctime)s %(message)s'
60 | date_format = '%m/%d %H:%M:%S'
61 | logging.basicConfig(stream=sys.stdout, level=logging.INFO,
62 | format=log_format, datefmt=date_format)
63 | fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
64 | fh.setFormatter(logging.Formatter(log_format, date_format))
65 | logging.getLogger().addHandler(fh)
66 |
67 | if not torch.cuda.is_available():
68 | logging.info('No gpu device available')
69 | sys.exit(1)
70 | cudnn.benchmark = True
71 | cudnn.enabled = True
72 |
73 | if config.train_params.use_seed:
74 | np.random.seed(config.train_params.seed)
75 | torch.manual_seed(config.train_params.seed)
76 | torch.cuda.manual_seed(config.train_params.seed)
77 |
78 | logging.info("args = %s", args)
79 | logging.info('Training with config:')
80 | logging.info(pprint.pformat(config))
81 | job_name = os.popen('cd %s && pwd -P && cd -' % args.save).readline().split('/')[-2]
82 |
83 | writer = SummaryWriter(args.tb_path)
84 |
85 | criterion = nn.CrossEntropyLoss()
86 | criterion = criterion.cuda()
87 |
88 | SearchSpace = importlib.import_module('models.search_space_'+config.net_type).Network
89 | ArchGenerater = importlib.import_module('.derive_arch_'+config.net_type, __package__).ArchGenerate
90 | derivedNetwork = getattr(model_derived, '%s_Net' % config.net_type.upper())
91 |
92 | super_model = SearchSpace(config.optim.init_dim, config.data.dataset, config)
93 | arch_gener = ArchGenerater(super_model, config)
94 | der_Net = lambda net_config: derivedNetwork(net_config,
95 | config=config)
96 | super_model = nn.DataParallel(super_model)
97 |
98 | # whether to resume from a checkpoint
99 | if config.optim.if_resume:
100 | utils.load_model(super_model, config.optim.resume.load_path)
101 | start_epoch = config.optim.resume.load_epoch + 1
102 | else:
103 | start_epoch = 0
104 |
105 | super_model = super_model.cuda()
106 |
107 | if config.optim.sub_obj.type=='flops':
108 | flops_list, total_flops = super_model.module.get_cost_list(
109 | config.data.input_size, cost_type='flops')
110 | super_model.module.sub_obj_list = flops_list
111 | logging.info("Super Network flops (M) list: \n")
112 | logging.info(str(flops_list))
113 | logging.info("Total flops: " + str(total_flops))
114 | elif config.optim.sub_obj.type=='latency':
115 | with open(os.path.join('latency_list', config.optim.sub_obj.latency_list_path), 'r') as f:
116 | latency_list = eval(f.readline())
117 | super_model.module.sub_obj_list = latency_list
118 | logging.info("Super Network latency (ms) list: \n")
119 | logging.info(str(latency_list))
120 | else:
121 | raise NotImplementedError
122 | logging.info("Num Params = %.2fMB", utils.count_parameters_in_MB(super_model))
123 |
124 | if config.data.dataset == 'imagenet':
125 | imagenet = imagenet_data.ImageNet12(trainFolder=os.path.join(args.data_path, 'train'),
126 | testFolder=os.path.join(args.data_path, 'val'),
127 | num_workers=config.data.num_workers,
128 | type_of_data_augmentation=config.data.type_of_data_aug,
129 | data_config=config.data)
130 | train_queue, valid_queue = imagenet.getTrainTestLoader(config.data.batch_size,
131 | train_shuffle=True,
132 | val_shuffle=True)
133 | else:
134 | raise NotImplementedError
135 |
136 | search_optim = Optimizer(super_model, criterion, config)
137 |
138 | scheduler = get_lr_scheduler(config, search_optim.weight_optimizer, imagenet.train_num_examples)
139 | scheduler.last_step = start_epoch * (imagenet.train_num_examples // config.data.batch_size + 1)
140 |
141 | search_trainer = SearchTrainer(train_queue, valid_queue, search_optim, criterion, scheduler, config, args)
142 |
143 | betas, head_alphas, stack_alphas = super_model.module.display_arch_params()
144 | derived_archs = arch_gener.derive_archs(betas, head_alphas, stack_alphas)
145 | derived_model = der_Net('|'.join(map(str, derived_archs)))
146 | logging.info("Derived Model Mult-Adds = %.2fMB" % comp_multadds(derived_model,
147 | input_size=config.data.input_size))
148 | logging.info("Derived Model Num Params = %.2fMB", utils.count_parameters_in_MB(derived_model))
149 |
150 | best_epoch = [0, 0, 0] # [epoch, acc_top1, acc_top5]
151 | rec_list = []
152 | for epoch in range(start_epoch, config.train_params.epochs):
153 | # training part1: update the architecture parameters
154 | if epoch >= config.search_params.arch_update_epoch:
155 | search_stage = 1
156 | search_optim.set_param_grad_state('Arch')
157 | train_acc_top1, train_acc_top5, train_obj, sub_obj, batch_time = search_trainer.train(
158 | super_model, epoch, 'Arch', search_stage)
159 | logging.info('EPOCH%d Arch Train_acc top1 %.2f top5 %.2f loss %.4f %s %.2f batch_time %.3f',
160 | epoch, train_acc_top1, train_acc_top5, train_obj, config.optim.sub_obj.type, sub_obj, batch_time)
161 | writer.add_scalar('arch_train_acc_top1', train_acc_top1, epoch)
162 | writer.add_scalar('arch_train_loss', train_obj, epoch)
163 | else:
164 | search_stage = 0
165 |
166 | # training part2: update the operator parameters
167 | search_optim.set_param_grad_state('Weights')
168 | train_acc_top1, train_acc_top5, train_obj, sub_obj, batch_time = search_trainer.train(
169 | super_model, epoch, 'Weights', search_stage)
170 | logging.info('EPOCH%d Weights Train_acc top1 %.2f top5 %.2f loss %.4f %s %.2f | batch_time %.3f',
171 | epoch, train_acc_top1, train_acc_top5, train_obj, config.optim.sub_obj.type, sub_obj, batch_time)
172 | writer.add_scalar('weight_train_acc_top1', train_acc_top1, epoch)
173 | writer.add_scalar('weight_train_loss', train_obj, epoch)
174 |
175 | # validation
176 | if epoch >= config.search_params.val_start_epoch:
177 | with torch.no_grad():
178 | val_acc_top1, val_acc_top5, valid_obj, sub_obj, batch_time = search_trainer.infer(super_model, epoch)
179 | logging.info('EPOCH%d Valid_acc top1 %.2f top5 %.2f %s %.2f batch_time %.3f',
180 | epoch, val_acc_top1, val_acc_top5, config.optim.sub_obj.type, sub_obj, batch_time)
181 | writer.add_scalar('arch_val_acc', val_acc_top1, epoch)
182 | writer.add_scalar('arch_whole_{}'.format(config.optim.sub_obj.type), sub_obj, epoch)
183 |
184 | if val_acc_top1 > best_epoch[1]:
185 | best_epoch = [epoch, val_acc_top1, val_acc_top5]
186 | utils.save(super_model, os.path.join(args.save, 'weights_best.pt'))
187 | logging.info('BEST EPOCH %d val_top1 %.2f val_top5 %.2f', best_epoch[0], best_epoch[1], best_epoch[2])
188 | else:
189 | utils.save(super_model, os.path.join(args.save, 'weights_best.pt'))
190 |
191 | betas, head_alphas, stack_alphas = super_model.module.display_arch_params()
192 | derived_arch = arch_gener.derive_archs(betas, head_alphas, stack_alphas)
193 | derived_arch_str = '|\n'.join(map(str, derived_arch))
194 | derived_model = der_Net(derived_arch_str)
195 | derived_flops = comp_multadds(derived_model, input_size=config.data.input_size)
196 | derived_params = utils.count_parameters_in_MB(derived_model)
197 | logging.info("Derived Model Mult-Adds = %.2fMB" % derived_flops)
198 | logging.info("Derived Model Num Params = %.2fMB" % derived_params)
199 | writer.add_scalar('derived_flops', derived_flops, epoch)
200 |
201 | if (epoch+1)==config.search_params.arch_update_epoch:
202 | utils.save(super_model, os.path.join(args.save, 'weights_{}.pt'.format(epoch)))
203 |
204 | if epoch >= config.search_params.val_start_epoch:
205 | epoch_rec = {'top1_acc': val_acc_top1,
206 | 'epoch': epoch,
207 | 'multadds': derived_flops,
208 | 'params': derived_params,
209 | 'arch': derived_arch_str}
210 | if_update = utils.record_topk(2, rec_list, epoch_rec, 'top1_acc', 'arch')
211 | if if_update:
212 | with open(os.path.join(args.save, 'top_results'), 'w') as f:
213 | f.write(str(rec_list) + '\n')
214 | f.write(job_name)
215 | with open(os.path.join(args.save, 'excel_record'), 'w') as f:
216 | for record in rec_list:
217 | f.write(',,,{:.2f}MB,{:.2f}MB,,,,{},{}\n'.format(
218 | record['multadds'], record['params'],
219 | job_name, record['epoch']))
220 | f.write(record['arch']+'\n')
221 |
222 | logging.info('\nTop2 arch records for Excel: ')
223 | for record in rec_list:
224 | logging.info('\n,,,{:.2f}MB,{:.2f}MB,,,,{},{}'.format(
225 | record['multadds'], record['params'], job_name, record['epoch']))
226 | logging.info('\n'+record['arch'])
227 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/latency_list/lat_list_densenas_mbv2_xp32:
--------------------------------------------------------------------------------
1 | [2.130230027015763, [[[[2.0424278577168784, 3.9147443482370092, 2.826663460394349, 5.422255848393296, 3.750695026282108, 7.1249667803446455]], [[1.6153342073613948, 3.0373611353864574, 2.3987865447998047, 4.654994853819259, 3.261269535681214, 6.318014583202324, 0.004162114075940065], [1.6315834690826108, 3.0268227933633205, 2.4445987951875936, 4.614120661610305, 3.1950888730058766, 6.186397436893348, 0.004321204291449653], [1.6260789139102203, 3.0519497274148346, 2.389646034048061, 4.737014577846335, 3.23591497209337, 6.18449805962919, 0.004463557041052616]]], [[[0.7256638883340238, 1.3191255415328826, 0.9336039755079482, 1.7637216201936357, 1.2678955540512546, 2.410705258147885]], [[0.4845246160873259, 0.871481124800865, 0.6886357490462486, 1.2964955002370506, 0.9334284849841186, 1.7795484475415162, 0.00426448956884519], [0.4868063782200669, 0.8848115892121287, 0.6880599079710064, 1.2923753863633283, 0.9581135017703278, 1.7851776787728975, 0.0040594736735026045], [0.4841990422720861, 0.8834096398016419, 0.676188324436997, 1.2965060725356594, 0.9642027122805816, 1.8043483628167047, 0.005086961418691307]]], [[[0.47651384816025244, 0.8765797904043486, 0.6815321276886294, 1.2988936780679106, 0.9366233902748184, 1.7686583538248082], [0.7290506122088192, 1.3409103046764026, 0.9578582493945805, 1.766649472593057, 1.2713261324949938, 2.4222339764989984]], [[0.714959664778276, 1.1088877494889076, 0.9645739709488069, 1.6069072665590227, 1.283542893149636, 2.241900617426092, 0.004182656606038411], [0.7255952285997795, 1.1247719899572508, 0.9729856192463576, 1.6400691234704219, 1.2840147933574637, 2.223358057966136, 0.004817090853296145], [0.7210639992145577, 1.118303236335215, 0.9568171067671342, 1.6137872320232969, 1.2853490463410964, 2.220366771775063, 0.0044962613269536175]]], [[[0.7221531867980957, 1.099681348511667, 0.9537572571725557, 1.6170255583946151, 1.258736475549563, 2.228106127844917], [0.5150983550331809, 0.8821482610220861, 0.7235027563692343, 1.285614919180822, 0.9779198241956306, 1.7995608214176062], [0.7342606120639378, 1.3387289432564167, 0.9407694171173404, 1.778323794856216, 1.2683106913711084, 2.3811431364579634]], [[0.7432978803461249, 1.3567979648859814, 1.0377877890461622, 1.9703037811048103, 1.43352544668949, 2.7093362085746997, 0.004813309871789181], [0.7404989184755267, 1.3694960179955067, 1.0372354526712437, 1.943115730478306, 1.4028442989696157, 2.726303037970957, 0.004229617841315992], [0.7309939885380292, 1.3608129819234211, 1.0305414536986688, 1.9552293449941307, 1.401155861941251, 2.7030082664104422, 0.00416184916640773]]], [[[0.3884150042678371, 0.7135931891624373, 0.5162300244726317, 0.954577417084665, 0.6934213397478817, 1.298581903631037], [0.4412888998937125, 0.59423191378815, 0.5425221751434635, 0.789706056768244, 0.6770676795882408, 1.076465303247625], [0.32727501609108667, 0.45259521465108854, 0.3506885634528266, 0.6308338377210828, 0.4525994291209211, 0.8446315081432612]], [[0.3733000129160255, 0.40592995556918055, 0.37138982252641156, 0.6014351411299271, 0.43380426638054126, 0.8142890593018195, 0.00443302019677981], [0.3568933467672329, 0.4109610210765492, 0.3373644087049696, 0.5994894769456651, 0.43327100349195075, 0.8130553515270503, 0.004622550925823173], [0.34756833856756036, 0.4080691000427863, 0.33895901959351815, 0.6000231251572117, 0.42908610719622986, 0.8126938945115214, 0.004636446634928385]]], [[[0.31285572533655653, 0.41354148074834035, 0.3234596685929732, 0.5967205702656447, 0.42944281992286143, 0.8081769702410456], [0.38233056212916516, 0.7180399124068444, 0.506961490168716, 0.958221440363412, 0.6713989286711722, 1.2972089497729986], [0.4516297879845205, 0.584752680075289, 0.5289154341726592, 0.7857127623124556, 0.6657846287043407, 1.0508054434651077], [0.3216169338033657, 0.4714334613145, 0.341750852989428, 0.6193477216393056, 0.4556712237271396, 0.8496491595952198]], [[0.3577318336024429, 0.46187034761062773, 0.36865022447374135, 0.6880141508699668, 0.490064789550473, 0.9283094213466452, 0.004251918407401653], [0.35744508107503253, 0.4682911526073109, 0.3670787811279297, 0.6878670297487818, 0.4916710323757596, 0.9247834995539501, 0.004101233048872515], [0.34304789822511, 0.46506821507155294, 0.3702619340684679, 0.6890249252319336, 0.48857850257796476, 0.9265595493894635, 0.004756330239652383]]], [[[0.32409022552798494, 0.4537436938045001, 0.3600639044636428, 0.6765010621812608, 0.4840817355146312, 0.919375443699384], [0.3252748287085331, 0.39878895788481744, 0.32009264435430973, 0.5979235967000326, 0.4323175218370226, 0.8091085125701596], [0.3775752433622726, 0.7120554856579713, 0.507318853127836, 0.9536376866427335, 0.6691536277231545, 1.2858475819982664], [0.4203124961467704, 0.5663527382744684, 0.5275549310626406, 0.7793375940033884, 0.6773896169180822, 1.0505064328511555]], [[0.3410032301238089, 0.5246755089422669, 0.4104647251090618, 0.7753114748482752, 0.5516608796938501, 1.0461353774022573, 0.004761459851505781], [0.3467666500746602, 0.5258405810654765, 0.41413808109784367, 0.7768700580404262, 0.5466712364042648, 1.0479942475906527, 0.004405325109308416], [0.3509985316883434, 0.52610416605015, 0.4151724324081883, 0.7738488611548838, 0.5471140928942747, 1.0415510697798296, 0.004083628606314611]]], [[[0.3314778780696368, 0.5213186716792559, 0.4044876194963552, 0.7711166564864342, 0.5433830107101287, 1.0415975734440968], [0.31709625263406777, 0.4603457450866699, 0.3640068661082875, 0.6819324059919878, 0.49100427916555694, 0.9225822940017239], [0.30831818628792806, 0.403487417432997, 0.3342929030909683, 0.5993133121066623, 0.43180643910109395, 0.8127121973519373]], [[0.39097884688714535, 0.714596136651858, 0.5550253030025598, 1.0415667476076067, 0.734896683933759, 1.4084796953682948, 0.00444214753430299], [0.39187493950429586, 0.7178884323197182, 0.5559180962918985, 1.0453465731457028, 0.7339226597487325, 1.4049145669648142, 0.004976397812968553], [0.3937760748044409, 0.7125301072091768, 0.5555001894632975, 1.0429809551046352, 0.7313448010068951, 1.4042837932856396, 0.004143618574046126]]], [[[0.3865939679771963, 0.7147598989082105, 0.5532094444891419, 1.0452751920680807, 0.7332464661261049, 1.4074049092302419], [0.3078347263914166, 0.5324024865121553, 0.413363076219655, 0.7835227070432721, 0.5587832373802109, 1.0567584904757412], [0.3132787136116413, 0.46954424694331004, 0.37312808662954006, 0.6916724310980903, 0.5062123019285876, 0.9339909120039507], [0.3129829541601316, 0.41637312282215466, 0.33040607818449386, 0.6098940637376573, 0.447241272589173, 0.8206562562422318]], [[0.4596819781293773, 0.8472787250172008, 0.6554982638118243, 1.2365547334304965, 0.8659759916440405, 1.6601162968259868, 0.004585535839350537], [0.4650660717126095, 0.8608341217041016, 0.6551676567154701, 1.2411773566043738, 0.8663152685069074, 1.657198920394435, 0.0049291475854738795], [0.46498640619143095, 0.8596966001722547, 0.6567297078142262, 1.2392559918490322, 0.8636613566466051, 1.6618470952968405, 0.004425771308667732]]], [[[0.4479286887428977, 0.8394645440458047, 0.6429801083574391, 1.2274719729568018, 0.8534268658570568, 1.6510065155799942], [0.38627870155103283, 0.7220117251078287, 0.5556040821653424, 1.0479230832571935, 0.7360771689752135, 1.4107511741946441], [0.31525987567323627, 0.5351946811483365, 0.41544396467883177, 0.7865141618131387, 0.5511918934908779, 1.0561841425269543], [0.32904622530696365, 0.4712159465057681, 0.37493383041535966, 0.6946381655606356, 0.497800147894657, 0.9343978612109869]], [[0.5242839244881061, 0.9681367151664966, 0.7472212386853767, 1.4109373333478215, 0.9878720659198184, 1.9006856041725235, 0.004231592621466126], [0.5298834858518657, 0.9775502513153386, 0.7481799703655821, 1.4128783014085557, 0.9842759431010545, 1.889745057231248, 0.004668765597873264], [0.5280137543726449, 0.9755541339065089, 0.7452876399261783, 1.408308395231613, 0.9832095618199821, 1.8874812848640212, 0.005188975671325067]]], [[[0.33862624505553585, 0.6294144765295164, 0.4290954512779159, 0.810597280059198, 0.5483022121467976, 1.0510839356316461], [0.31762077350809115, 0.559200927464649, 0.37940246890289614, 0.7181714038656216, 0.481732493699199, 0.9220814223241324], [0.31610031320591164, 0.47471125920613605, 0.3283519937534525, 0.6014666171989055, 0.4142167833116319, 0.7825871188231188]], [[0.35007797106347904, 0.4412787610834295, 0.3423054049713443, 0.5852765025514545, 0.40640749112524166, 0.7427115873856978, 0.004194794279156309], [0.33576348815301454, 0.4444064275182859, 0.33754336713540434, 0.5864705461444277, 0.4062691601839932, 0.741335888101597, 0.0042497991311429735], [0.347997010356248, 0.4433897047331839, 0.34060092887493093, 0.5851883599252411, 0.40573748675259674, 0.7406246300899622, 0.00453204819650361]]], [[[0.3260928693443838, 0.440192800579649, 0.3165682879361239, 0.5835583715727835, 0.38191544889199613, 0.7402725894041736], [0.3473482709942442, 0.636373818522752, 0.4305076839947941, 0.8120978721464523, 0.5494835400822187, 1.0491823668431755], [0.32828465856686984, 0.5603574261520848, 0.400460705612645, 0.7195479942090584, 0.5029239558210277, 0.9257183893762454], [0.3226006633103496, 0.4737337671145044, 0.32909169341578626, 0.606823545513731, 0.41159104819249626, 0.7856156368448277]], [[0.34432806149877687, 0.48627826902601456, 0.35802836369986485, 0.6468107483603737, 0.43973238781245066, 0.8166859125850177, 0.004473575437911833], [0.34155412153764203, 0.4882547590467665, 0.3579827510949337, 0.6446643790813408, 0.4383250197978935, 0.813951227400038, 0.0045983478276416506], [0.3423182892076897, 0.49038797917992183, 0.3526753849453397, 0.6453903275306778, 0.4380230951790858, 0.814479264346036, 0.00436650382147895]]], [[[0.31611798989652384, 0.5243616152291346, 0.34948548885306924, 0.6844820879926585, 0.4360723254656551, 0.8539520851289384], [0.3155406797775115, 0.441239843464861, 0.3129776318868001, 0.5864422971552069, 0.3826266346555767, 0.7415137387285329], [0.341466942218819, 0.6324205013236613, 0.43117410004741014, 0.813953322593612, 0.549957872641207, 1.0534283849928114], [0.32031817869706586, 0.5669151412116156, 0.3825205985945885, 0.7180473298737498, 0.4875175399009628, 0.9277093530905367]], [[0.33806482950846356, 0.5306571180170233, 0.38618906579836454, 0.7049235189803923, 0.4787800769613247, 0.8875630841110692, 0.004487085824060922], [0.3448331958115703, 0.5330200869627674, 0.3988961739973588, 0.7003687848948469, 0.47722561190826723, 0.8862311912305427, 0.004113322556620896], [0.3474274789444124, 0.5329236117276278, 0.3846857282850477, 0.7030708139592952, 0.4763580813552394, 0.8857826993922995, 0.0040633269030638415]]], [[[0.32171962237117263, 0.5723524575281626, 0.40756129255198464, 0.7432415991118461, 0.4988236619968607, 0.927507034455887], [0.3153900185016671, 0.5259105412646977, 0.37463693907766626, 0.6845260388923413, 0.4606634679466787, 0.8534405448219993], [0.3186584241462476, 0.48466046651204425, 0.3262169433362556, 0.6259828625303326, 0.4041823473843661, 0.7803911873788545]], [[0.5497075572158351, 1.0372076612530332, 0.6927506851427483, 1.3213739250645493, 0.8462854587670529, 1.630257669121328, 0.004380182786421342], [0.5530304860587072, 1.0435495232090806, 0.693162354555997, 1.321483693941675, 0.8439054392805003, 1.6276801475370772, 0.004326261655248777], [0.5519224176503191, 1.042814471504905, 0.6919608694134336, 1.3225178766732264, 0.8431200065998116, 1.6249998410542805, 0.004217167093296244]]], [[[0.5501972545276989, 1.058707164995598, 0.694795594070897, 1.3447087220471314, 0.8489726047323208, 1.6523896082483156], [0.3268535691078263, 0.5884998013274838, 0.4162517460909757, 0.7617164380622633, 0.508760129562532, 0.9446223576863607], [0.31738830335212476, 0.5434732726126006, 0.383788262954866, 0.7018130716651377, 0.46994878788187044, 0.8708676665720314], [0.31652166385843294, 0.49711326155999697, 0.3372591192072088, 0.640807007298325, 0.41489528887199634, 0.7945932282341851]], [[0.6180085317052976, 1.1803230372342197, 0.7760107878482703, 1.4968219670382412, 0.9470603923604946, 1.8356977809559216, 0.00415854983859592], [0.6226346468684649, 1.1885866733512493, 0.7775900339839434, 1.4975251573504824, 0.944289607231063, 1.833727817342739, 0.005115378986705434], [0.6210861784039121, 1.187923966032086, 0.776205351858428, 1.4959932096076733, 0.9441822225397283, 1.8313705559932825, 0.0047388221278335104]]], [[[0.6076182259453667, 1.1761999130249023, 0.7682572229944095, 1.4921375236125907, 0.9383929377854472, 1.827945179409451], [0.5541018524555245, 1.0647979408803612, 0.6995696973319006, 1.348778743936558, 0.8533838782647644, 1.6592941621337274], [0.3311891025967068, 0.5921483280682804, 0.4187226054644344, 0.7656114269988705, 0.5120331109172166, 0.9497349671643189], [0.31480916822799526, 0.5449382223264136, 0.38729677296648124, 0.7054818278611309, 0.4725636857928652, 0.8741613590356074]], [[0.674338870578342, 1.2978931870123354, 0.8485467024523803, 1.6414815970141476, 1.0351639564591224, 2.0140134445344557, 0.004651305651423907], [0.6819091662011966, 1.305962861186326, 0.8498003747728136, 1.645497866351195, 1.032505324392608, 2.0091032259392017, 0.00428734403668028], [0.6797817259123831, 1.3053699936529604, 0.8498629415878142, 1.642856477486967, 1.032006981396916, 2.0077070322903716, 0.004319518503516611]]]], [0.3001111926454486, 0.28059167091292564, 0.26766254444314974], 0.4041198769001046]
--------------------------------------------------------------------------------
/models/search_space_base.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import math
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | from tools.multadds_count import comp_multadds_fw
10 | from tools.utils import latency_measure_fw
11 | from . import operations
12 | from .operations import OPS
13 |
14 |
15 | class MixedOp(nn.Module):
16 | def __init__(self, C_in, C_out, stride, primitives):
17 | super(MixedOp, self).__init__()
18 | self._ops = nn.ModuleList()
19 | for primitive in primitives:
20 | op = OPS[primitive](C_in, C_out, stride, affine=False, track_running_stats=True)
21 | self._ops.append(op)
22 |
23 |
24 | class HeadLayer(nn.Module):
25 | def __init__(self, in_chs, ch, strides, config):
26 | super(HeadLayer, self).__init__()
27 | self.head_branches = nn.ModuleList()
28 | for in_ch, stride in zip(in_chs, strides):
29 | self.head_branches.append(
30 | MixedOp(in_ch, ch, stride,
31 | config.search_params.PRIMITIVES_head)
32 | )
33 |
34 |
35 | class StackLayers(nn.Module):
36 | def __init__(self, ch, num_block_layers, config, primitives):
37 | super(StackLayers, self).__init__()
38 |
39 | if num_block_layers != 0:
40 | self.stack_layers = nn.ModuleList()
41 | for i in range(num_block_layers):
42 | self.stack_layers.append(MixedOp(ch, ch, 1, primitives))
43 | else:
44 | self.stack_layers = None
45 |
46 |
47 | class Block(nn.Module):
48 | def __init__(self, in_chs, block_ch, strides, num_block_layers, config):
49 | super(Block, self).__init__()
50 | assert len(in_chs) == len(strides)
51 | self.head_layer = HeadLayer(in_chs, block_ch, strides, config)
52 | self.stack_layers = StackLayers(block_ch, num_block_layers, config, config.search_params.PRIMITIVES_stack)
53 |
54 |
55 | class Conv1_1_Branch(nn.Module):
56 | def __init__(self, in_ch, block_ch):
57 | super(Conv1_1_Branch, self).__init__()
58 | self.conv1_1 = nn.Sequential(
59 | nn.Conv2d(in_channels=in_ch, out_channels=block_ch,
60 | kernel_size=1, stride=1, padding=0, bias=False),
61 | nn.BatchNorm2d(block_ch, affine=False, track_running_stats=True),
62 | nn.ReLU6(inplace=True)
63 | )
64 |
65 | def forward(self, x):
66 | return self.conv1_1(x)
67 |
68 |
69 | class Conv1_1_Block(nn.Module):
70 | def __init__(self, in_chs, block_ch):
71 | super(Conv1_1_Block, self).__init__()
72 | self.conv1_1_branches = nn.ModuleList()
73 | for in_ch in in_chs:
74 | self.conv1_1_branches.append(Conv1_1_Branch(in_ch, block_ch))
75 |
76 | def forward(self, inputs, betas, block_sub_obj):
77 | branch_weights = F.softmax(torch.stack(betas), dim=-1)
78 | return sum(branch_weight * branch(input_data) for input_data, branch, branch_weight in zip(
79 | inputs, self.conv1_1_branches, branch_weights)), \
80 | [block_sub_obj, 0]
81 |
82 |
83 | class Network(nn.Module):
84 | def __init__(self, init_ch, dataset, config):
85 | super(Network, self).__init__()
86 | self.config = config
87 | self._C_input = init_ch
88 | self._head_dim = self.config.optim.head_dim
89 | self._dataset = dataset
90 | # use 100-class sub dataset for search
91 | self._num_classes = 100
92 |
93 | self.initialize()
94 |
95 |
96 | def initialize(self):
97 | self._init_block_config()
98 | self._create_output_list()
99 | self._create_input_list()
100 | self._init_betas()
101 | self._init_alphas()
102 | self._init_sample_branch()
103 |
104 |
105 | def init_model(self, model_init='he_fout', init_div_groups=True):
106 | for m in self.modules():
107 | if isinstance(m, nn.Conv2d):
108 | if model_init == 'he_fout':
109 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
110 | if init_div_groups:
111 | n /= m.groups
112 | m.weight.data.normal_(0, math.sqrt(2. / n))
113 | elif model_init == 'he_fin':
114 | n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
115 | if init_div_groups:
116 | n /= m.groups
117 | m.weight.data.normal_(0, math.sqrt(2. / n))
118 | else:
119 | raise NotImplementedError
120 | elif isinstance(m, nn.BatchNorm2d):
121 | if m.affine==True:
122 | m.weight.data.fill_(1)
123 | m.bias.data.zero_()
124 | elif isinstance(m, nn.Linear):
125 | if m.bias is not None:
126 | m.bias.data.zero_()
127 | elif isinstance(m, nn.BatchNorm1d):
128 | if m.affine==True:
129 | m.weight.data.fill_(1)
130 | m.bias.data.zero_()
131 |
132 |
133 | def set_bn_param(self, bn_momentum, bn_eps):
134 | for m in self.modules():
135 | if isinstance(m, nn.BatchNorm2d):
136 | m.momentum = bn_momentum
137 | m.eps = bn_eps
138 | return
139 |
140 |
141 | def _init_betas(self):
142 | r"""
143 | beta weights for the output ch choices in the head layer of the block
144 | """
145 | self.beta_weights = nn.ParameterList()
146 | for block in self.output_configs:
147 | num_betas = len(block['out_chs'])
148 | self.beta_weights.append(
149 | nn.Parameter(1e-3 * torch.randn(num_betas))
150 | )
151 |
152 |
153 | def _init_alphas(self):
154 | r"""
155 | alpha weights for the op type in the block
156 | """
157 | self.alpha_head_weights = nn.ParameterList()
158 | self.alpha_stack_weights = nn.ParameterList()
159 |
160 | for block in self.input_configs[:-1]:
161 | num_head_alpha = len(block['in_block_idx'])
162 | self.alpha_head_weights.append(nn.Parameter(
163 | 1e-3*torch.randn(num_head_alpha, len(self.config.search_params.PRIMITIVES_head)))
164 | )
165 |
166 | num_layers = block['num_stack_layers']
167 | self.alpha_stack_weights.append(nn.Parameter(
168 | 1e-3*torch.randn(num_layers, len(self.config.search_params.PRIMITIVES_stack)))
169 | )
170 |
171 | @property
172 | def arch_parameters(self):
173 | arch_params = nn.ParameterList()
174 | arch_params.extend(self.beta_weights)
175 | arch_params.extend(self.alpha_head_weights)
176 | arch_params.extend(self.alpha_stack_weights)
177 | return arch_params
178 |
179 | @property
180 | def arch_beta_params(self):
181 | return self.beta_weights
182 |
183 | @property
184 | def arch_alpha_params(self):
185 | alpha_params = nn.ParameterList()
186 | alpha_params.extend(self.alpha_head_weights)
187 | alpha_params.extend(self.alpha_stack_weights)
188 | return alpha_params
189 |
190 |
191 | def display_arch_params(self, display=True):
192 | branch_weights = []
193 | head_op_weights = []
194 | stack_op_weights = []
195 | for betas in self.beta_weights:
196 | branch_weights.append(F.softmax(betas, dim=-1))
197 | for head_alpha in self.alpha_head_weights:
198 | head_op_weights.append(F.softmax(head_alpha, dim=-1))
199 | for stack_alpha in self.alpha_stack_weights:
200 | stack_op_weights.append(F.softmax(stack_alpha, dim=-1))
201 |
202 | if display:
203 | logging.info('branch_weights \n' + '\n'.join(map(str, branch_weights)))
204 | if len(self.config.search_params.PRIMITIVES_head) > 1:
205 | logging.info('head_op_weights \n' + '\n'.join(map(str, head_op_weights)))
206 | logging.info('stack_op_weights \n' + '\n'.join(map(str, stack_op_weights)))
207 |
208 | return [x.tolist() for x in branch_weights], \
209 | [x.tolist() for x in head_op_weights], \
210 | [x.tolist() for x in stack_op_weights]
211 |
212 |
213 | def _init_sample_branch(self):
214 | _, _ = self.sample_branch('head', 1, training=False)
215 | _, _ = self.sample_branch('stack', 1, training=False)
216 |
217 |
218 | def sample_branch(self, params_type, sample_num, training=True, search_stage=1, if_sort=True):
219 | r"""
220 | the sampling computing is based on torch
221 | input: params_type
222 | output: sampled params
223 | """
224 |
225 | def sample(param, weight, sample_num, sample_policy='prob', if_sort=True):
226 | if sample_num >= weight.shape[-1]:
227 | sample_policy = 'all'
228 | assert param.shape == weight.shape
229 | assert sample_policy in ['prob', 'uniform', 'all']
230 | if param.shape[0] == 0:
231 | return [], []
232 | if sample_policy == 'prob':
233 | sampled_index = torch.multinomial(weight, num_samples=sample_num, replacement=False)
234 | elif sample_policy == 'uniform':
235 | weight = torch.ones_like(weight)
236 | sampled_index = torch.multinomial(weight, num_samples=sample_num, replacement=False)
237 | else:
238 | sampled_index = torch.arange(start=0, end=weight.shape[-1], step=1, device=weight.device
239 | ).repeat(param.shape[0], 1)
240 | if if_sort:
241 | sampled_index, _ = torch.sort(sampled_index, descending=False)
242 | sampled_param_old = torch.gather(param, dim=-1, index=sampled_index)
243 | return sampled_param_old, sampled_index
244 |
245 | if params_type=='head':
246 | params = self.alpha_head_weights
247 | elif params_type=='stack':
248 | params = self.alpha_stack_weights
249 | else:
250 | raise TypeError
251 |
252 | weights = []
253 | sampled_params_old = []
254 | sampled_indices = []
255 | if training:
256 | sample_policy = self.config.search_params.sample_policy if search_stage==1 else 'uniform'
257 | else:
258 | sample_policy = 'all'
259 |
260 | for param in params:
261 | weights.append(F.softmax(param, dim=-1))
262 |
263 | for param, weight in zip(params, weights): #list dim
264 | sampled_param_old, sampled_index = sample(
265 | param, weight, sample_num, sample_policy, if_sort)
266 | sampled_params_old.append(sampled_param_old)
267 | sampled_indices.append(sampled_index)
268 |
269 | if params_type=='head':
270 | self.alpha_head_index = sampled_indices
271 | elif params_type=='stack':
272 | self.alpha_stack_index = sampled_indices
273 | return sampled_params_old, sampled_indices
274 |
275 |
276 | def _init_block_config(self):
277 | self.block_chs = self.config.search_params.net_scale.chs
278 | self.block_fm_sizes = self.config.search_params.net_scale.fm_sizes
279 | self.num_blocks = len(self.block_chs) - 1 # not include the head and tail
280 | self.num_block_layers = self.config.search_params.net_scale.num_layers
281 | if hasattr(self.config.search_params.net_scale, 'stage'):
282 | self.block_stage = self.config.search_params.net_scale.stage
283 |
284 | self.block_chs.append(self.config.optim.last_dim)
285 | self.block_fm_sizes.append(self.block_fm_sizes[-1])
286 | self.num_block_layers.append(0)
287 |
288 |
289 | def _create_output_list(self):
290 | r"""
291 | Generate the output config of each block, which contains:
292 | 'ch': the channel number of the block
293 | 'out_chs': the possible output channel numbers
294 | 'strides': the corresponding stride
295 | """
296 |
297 | self.output_configs = []
298 | for i in range(len(self.block_chs)-1):
299 | if hasattr(self, 'block_stage'):
300 | stage = self.block_stage[i]
301 | output_config = {'ch': self.block_chs[i],
302 | 'fm_size': self.block_fm_sizes[i],
303 | 'out_chs': [],
304 | 'out_fms': [],
305 | 'strides': [],
306 | 'out_id': [],
307 | 'num_stack_layers': self.num_block_layers[i]}
308 | for j in range(self.config.search_params.adjoin_connect_nums[stage]):
309 | out_index = i + j + 1
310 | if out_index >= len(self.block_chs):
311 | break
312 | if hasattr(self, 'block_stage'):
313 | block_stage = getattr(self, 'block_stage')
314 | if block_stage[out_index]-block_stage[i] > 1:
315 | break
316 | fm_size_ratio = self.block_fm_sizes[i] / self.block_fm_sizes[out_index]
317 | if fm_size_ratio == 2:
318 | output_config['strides'].append(2)
319 | elif fm_size_ratio == 1:
320 | output_config['strides'] .append(1)
321 | else:
322 | break # only connet to the block whose fm size expansion ratio is 1 or 2
323 | output_config['out_chs'].append(self.block_chs[out_index])
324 | output_config['out_fms'].append(self.block_fm_sizes[out_index])
325 | output_config['out_id'].append(out_index)
326 |
327 | self.output_configs.append(output_config)
328 |
329 | logging.info('Network output configs: \n' + '\n'.join(map(str, self.output_configs)))
330 |
331 |
332 | def _create_input_list(self):
333 | r"""
334 | Generate the input config of each block for constructing the whole network.
335 | Each config dict contains:
336 | 'ch': the channel number of the block
337 | 'in_chs': all the possible input channel numbers
338 | 'strides': the corresponding stride
339 | 'in_block_idx': the index of the input block
340 | 'beta_idx': the corresponding beta weight index.
341 | """
342 |
343 | self.input_configs = []
344 | for i in range(1, len(self.block_chs)):
345 | input_config = {'ch': self.block_chs[i],
346 | 'fm_size': self.block_fm_sizes[i],
347 | 'in_chs': [],
348 | 'in_fms': [],
349 | 'strides': [],
350 | 'in_block_idx': [],
351 | 'beta_idx': [],
352 | 'num_stack_layers': self.num_block_layers[i]}
353 | for j in range(i):
354 | in_index = i - j - 1
355 | if in_index < 0:
356 | break
357 | output_config = self.output_configs[in_index]
358 | if i in output_config['out_id']:
359 | beta_idx = output_config['out_id'].index(i)
360 | input_config['in_block_idx'].append(in_index)
361 | input_config['in_chs'].append(output_config['ch'])
362 | input_config['in_fms'].append(output_config['fm_size'])
363 | input_config['beta_idx'].append(beta_idx)
364 | input_config['strides'].append(output_config['strides'][beta_idx])
365 | else:
366 | continue
367 |
368 | self.input_configs.append(input_config)
369 |
370 | logging.info('Network input configs: \n' + '\n'.join(map(str, self.input_configs)))
371 |
372 |
373 | def get_cost_list(self, data_shape, cost_type='flops',
374 | use_gpu=True, meas_times=1000):
375 | cost_list = []
376 | block_datas = []
377 | total_cost = 0
378 | if cost_type == 'flops':
379 | cost_func = lambda module, data: comp_multadds_fw(
380 | module, data, use_gpu)
381 | elif cost_type == 'latency':
382 | cost_func = lambda module, data: latency_measure_fw(
383 | module, data, meas_times)
384 | else:
385 | raise NotImplementedError
386 |
387 | if len(data_shape) == 3:
388 | input_data = torch.randn((1,) + tuple(data_shape))
389 | else:
390 | input_data = torch.randn(tuple(data_shape))
391 | if use_gpu:
392 | input_data = input_data.cuda()
393 |
394 | cost, block_data = cost_func(self.input_block, input_data)
395 | cost_list.append(cost)
396 | block_datas.append(block_data)
397 | total_cost += cost
398 | if hasattr(self, 'head_block'):
399 | cost, block_data = cost_func(self.head_block, block_data)
400 | cost_list[0] += cost
401 | block_datas[0] = block_data
402 |
403 | block_flops = []
404 | for block_id, block in enumerate(self.blocks):
405 | input_config = self.input_configs[block_id]
406 | inputs = [block_datas[i] for i in input_config['in_block_idx']]
407 |
408 | head_branch_flops = []
409 | for branch_id, head_branch in enumerate(block.head_layer.head_branches):
410 | op_flops = []
411 | for op in head_branch._ops:
412 | cost, block_data = cost_func(op, inputs[branch_id])
413 | op_flops.append(cost)
414 | total_cost += cost
415 |
416 | head_branch_flops.append(op_flops)
417 |
418 | stack_layer_flops = []
419 | if block.stack_layers.stack_layers is not None:
420 | for stack_layer in block.stack_layers.stack_layers:
421 | op_flops = []
422 | for op in stack_layer._ops:
423 | cost, block_data = cost_func(op, block_data)
424 | if isinstance(op, operations.Skip) and \
425 | self.config.optim.sub_obj.skip_reg:
426 | # skip_reg is used for regularization as the cost of skip is too small
427 | cost = op_flops[0] / 10.
428 | op_flops.append(cost)
429 | total_cost += cost
430 | stack_layer_flops.append(op_flops)
431 | block_flops.append([head_branch_flops, stack_layer_flops])
432 | block_datas.append(block_data)
433 |
434 | cost_list.append(block_flops)
435 |
436 | conv1_1_flops = []
437 | input_config = self.input_configs[-1]
438 | inputs = [block_datas[i] for i in input_config['in_block_idx']]
439 | for branch_id, branch in enumerate(self.conv1_1_block.conv1_1_branches):
440 | cost, block_data = cost_func(branch, inputs[branch_id])
441 | conv1_1_flops.append(cost)
442 | total_cost += cost
443 | block_datas.append(block_data)
444 |
445 | cost_list.append(conv1_1_flops)
446 | out = block_datas[-1]
447 | out = self.global_pooling(out)
448 |
449 | cost, out = cost_func(self.classifier, out.view(out.size(0), -1))
450 | cost_list.append(cost)
451 | total_cost += cost
452 |
453 | return cost_list, total_cost
454 |
--------------------------------------------------------------------------------