├── README.md ├── efficientnet_model.py └── train_imagenet ├── train.py └── train_dali.py /README.md: -------------------------------------------------------------------------------- 1 | # EfficientNet-Gluon 2 | [EfficientNet](https://arxiv.org/abs/1905.11946) Gluon implementation 3 | 4 | ## ImageNet experiments 5 | 6 | ### Requirements 7 | Python 3.7 or later with packages: 8 | - `mxnet >= 1.5.0` 9 | - `gluoncv >= 0.6.0` 10 | - `nvidia-dali >= 0.19.0` 11 | 12 | ### Usage 13 | #### Prepare ImageNet dataset 14 | 1. Download and extract dataset following this tutorial:
15 | https://gluon-cv.mxnet.io/build/examples_datasets/imagenet.html 16 | 2. Create mxnet-record files following this turorial:
17 | https://gluon-cv.mxnet.io/build/examples_datasets/recordio.html#imagerecord-file-for-imagenet 18 | 19 | #### Clone this repo 20 | ``` 21 | git clone https://github.com/mnikitin/EfficientNet.git 22 | cd EfficientNet/train_imagenet 23 | ``` 24 | 25 | #### Train your model 26 | Example of training *efficientnet-b0* with *nvidia-dali data loader* using 4 gpus: 27 | ``` 28 | IMAGENET_RECORD_ROOT='path/to/imagenet/record/files' 29 | MODEL='efficientnet-b0' 30 | python3 train_dali.py --rec-train $IMAGENET_RECORD_ROOT/train --rec-val $IMAGENET_RECORD_ROOT/val --input-size 224 --batch-size 64 --num-gpus 4 --num-epochs 80 --lr 0.1 --lr-decay-epoch 40,60 --save-dir params-$MODEL --logging-file params-$MODEL/log.txt --save-frequency 5 --mode hybrid --model $MODEL 31 | ``` 32 | 33 | ### Results 34 | Code in this repo was used to train *efficientnet-b0* and *efficientnet-lite0* models.
35 | Pretrained params are avaliable (18.8 mb in total = 13.7 mb for *extractor* + 5.1 mb for *classifier*). 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 |
err-top1err-top5pretrained params
efficientnet-b00.3358420.128043dropbox link
efficientnet-lite00.3053160.106322dropbox link
57 | 58 | **Note** that due to limited computational resources obtained results are worse than in the original paper.
59 | Moreover, *efficientnet-lite0* was trained using more gpus and bigger batch size, so in spite of simpler architecture (relu6 instead of swish) its results are better than for *efficientnet-b0* model.
60 | Anyway, I believe provided pretrained params can serve as a good initialization for your task. 61 | 62 | That's how *efficientnet-b0* and *efficientnet-lite0* were trained exactly:
63 | ``` 64 | MODEL='efficientnet-b0' 65 | python3 train_dali.py --rec-train $IMAGENET_RECORD_ROOT/train --rec-val $IMAGENET_RECORD_ROOT/val --input-size 224 --batch-size 56 --num-gpus 4 --num-epochs 50 --lr 0.1 --lr-decay-epoch 20,30,40 --save-dir params-$MODEL --logging-file params-$MODEL/log.txt --save-frequency 5 --mode hybrid --model $MODEL 66 | ``` 67 | ``` 68 | MODEL='efficientnet-lite0' 69 | python3 train_dali.py --rec-train $IMAGENET_RECORD_ROOT/train --rec-val $IMAGENET_RECORD_ROOT/val --input-size 224 --batch-size 72 --num-gpus 6 --num-epochs 60 --lr 0.1 --lr-decay-epoch 20,35,50 --save-dir params-$MODEL --logging-file params-$MODEL/log.txt --save-frequency 5 --mode hybrid --model $MODEL 70 | ``` 71 | -------------------------------------------------------------------------------- /efficientnet_model.py: -------------------------------------------------------------------------------- 1 | from mxnet.gluon.block import HybridBlock 2 | from mxnet.gluon import nn 3 | from math import ceil 4 | 5 | 6 | class ReLU6(nn.HybridBlock): 7 | def __init__(self, **kwargs): 8 | super(ReLU6, self).__init__(**kwargs) 9 | 10 | def hybrid_forward(self, F, x): 11 | return F.clip(x, 0, 6, name="relu6") 12 | 13 | 14 | def _add_conv(out, channels=1, kernel=1, stride=1, pad=0, 15 | num_group=1, active=True, lite=False): 16 | out.add(nn.Conv2D(channels, kernel, stride, pad, groups=num_group, use_bias=False)) 17 | out.add(nn.BatchNorm(scale=True, momentum=0.99, epsilon=1e-3)) 18 | if active: 19 | if lite: 20 | out.add(ReLU6()) 21 | else: 22 | out.add(nn.Swish()) 23 | 24 | 25 | class MBConv(nn.HybridBlock): 26 | def __init__(self, in_channels, channels, t, kernel, stride, lite, **kwargs): 27 | super(MBConv, self).__init__(**kwargs) 28 | self.use_shortcut = stride == 1 and in_channels == channels 29 | with self.name_scope(): 30 | self.out = nn.HybridSequential() 31 | _add_conv(self.out, in_channels * t, active=True, lite=lite) 32 | _add_conv(self.out, in_channels * t, kernel=kernel, stride=stride, 33 | pad=int((kernel-1)/2), num_group=in_channels * t, 34 | active=True, lite=lite) 35 | _add_conv(self.out, channels, active=False, lite=lite) 36 | 37 | def hybrid_forward(self, F, x): 38 | out = self.out(x) 39 | if self.use_shortcut: 40 | out = F.elemwise_add(out, x) 41 | return out 42 | 43 | 44 | class EfficientNet(nn.HybridBlock): 45 | r""" 46 | Parameters 47 | ---------- 48 | alpha : float, default 1.0 49 | The depth multiplier for controling the model size. The actual number of layers on each channel_size level 50 | is equal to the original number of layers multiplied by alpha. 51 | beta : float, default 1.0 52 | The width multiplier for controling the model size. The actual number of channels 53 | is equal to the original channel size multiplied by beta. 54 | dropout_rate : float, default 0.0 55 | Dropout probability for the final features layer. 56 | classes : int, default 1000 57 | Number of classes for the output layer. 58 | """ 59 | 60 | def __init__(self, alpha=1.0, beta=1.0, lite=False, 61 | dropout_rate=0.0, classes=1000, **kwargs): 62 | super(EfficientNet, self).__init__(**kwargs) 63 | with self.name_scope(): 64 | self.features = nn.HybridSequential(prefix='features_') 65 | with self.features.name_scope(): 66 | # stem conv 67 | channels = 32 if lite else int(32 * beta) 68 | _add_conv(self.features, channels, kernel=3, stride=2, pad=1, 69 | active=True, lite=lite) 70 | 71 | # base model settings 72 | repeats = [1, 2, 2, 3, 3, 4, 1] 73 | channels_num = [16, 24, 40, 80, 112, 192, 320] 74 | kernels_num = [3, 3, 5, 3, 5, 5, 3] 75 | t_num = [1, 6, 6, 6, 6, 6, 6] 76 | strides_first = [1, 2, 2, 1, 2, 2, 1] 77 | 78 | # determine params of MBConv layers 79 | in_channels_group = [] 80 | for rep, ch_num in zip([1] + repeats[:-1], [32] + channels_num[:-1]): 81 | in_channels_group += [int(ch_num * beta)] * int(ceil(alpha * rep)) 82 | channels_group, kernels, ts, strides = [], [], [], [] 83 | for rep, ch, kernel, t, s in zip(repeats, channels_num, kernels_num, t_num, strides_first): 84 | rep = int(ceil(alpha * rep)) 85 | channels_group += [int(ch * beta)] * rep 86 | kernels += [kernel] * rep 87 | ts += [t] * rep 88 | strides += [s] + [1] * (rep - 1) 89 | 90 | # add MBConv layers 91 | for in_c, c, t, k, s in zip(in_channels_group, channels_group, ts, kernels, strides): 92 | self.features.add(MBConv(in_channels=in_c, channels=c, t=t, kernel=k, 93 | stride=s, lite=lite)) 94 | 95 | # head layers 96 | last_channels = int(1280 * beta) if not lite and beta > 1.0 else 1280 97 | _add_conv(self.features, last_channels, active=True, lite=lite) 98 | self.features.add(nn.GlobalAvgPool2D()) 99 | 100 | # features dropout 101 | self.dropout = nn.Dropout(dropout_rate) if dropout_rate > 0.0 else None 102 | 103 | # output layer 104 | self.output = nn.HybridSequential(prefix='output_') 105 | with self.output.name_scope(): 106 | self.output.add( 107 | nn.Conv2D(classes, 1, use_bias=False, prefix='pred_'), 108 | nn.Flatten() 109 | ) 110 | 111 | def hybrid_forward(self, F, x): 112 | x = self.features(x) 113 | if self.dropout: 114 | x = self.dropout(x) 115 | x = self.output(x) 116 | return x 117 | 118 | 119 | def get_efficientnet(model_name, num_classes=1000): 120 | params_dict = { # (width_coefficient, depth_coefficient, input_resolution, dropout_rate) 121 | 'efficientnet-b0': (1.0, 1.0, 224, 0.2), 122 | 'efficientnet-b1': (1.0, 1.1, 240, 0.2), 123 | 'efficientnet-b2': (1.1, 1.2, 260, 0.3), 124 | 'efficientnet-b3': (1.2, 1.4, 300, 0.3), 125 | 'efficientnet-b4': (1.4, 1.8, 380, 0.4), 126 | 'efficientnet-b5': (1.6, 2.2, 456, 0.4), 127 | 'efficientnet-b6': (1.8, 2.6, 528, 0.5), 128 | 'efficientnet-b7': (2.0, 3.1, 600, 0.5) 129 | } 130 | width_coeff, depth_coeff, input_resolution, dropout_rate = params_dict[model_name] 131 | model = EfficientNet(alpha=depth_coeff, beta=width_coeff, lite=False, 132 | dropout_rate=dropout_rate, classes=num_classes) 133 | return model, input_resolution 134 | 135 | 136 | def get_efficientnet_lite(model_name, num_classes=1000): 137 | params_dict = { # (width_coefficient, depth_coefficient, input_resolution, dropout_rate) 138 | 'efficientnet-lite0': (1.0, 1.0, 224, 0.2), 139 | 'efficientnet-lite1': (1.0, 1.1, 240, 0.2), 140 | 'efficientnet-lite2': (1.1, 1.2, 260, 0.3), 141 | 'efficientnet-lite3': (1.2, 1.4, 280, 0.3), 142 | 'efficientnet-lite4': (1.4, 1.8, 300, 0.3) 143 | } 144 | width_coeff, depth_coeff, input_resolution, dropout_rate = params_dict[model_name] 145 | model = EfficientNet(alpha=depth_coeff, beta=width_coeff, lite=True, 146 | dropout_rate=dropout_rate, classes=num_classes) 147 | return model, input_resolution 148 | -------------------------------------------------------------------------------- /train_imagenet/train.py: -------------------------------------------------------------------------------- 1 | import argparse, time, logging, os, sys, math 2 | 3 | import mxnet as mx 4 | from mxnet import gluon, nd 5 | from mxnet import autograd as ag 6 | from mxnet.gluon import nn 7 | from mxnet.gluon.data.vision import transforms 8 | 9 | import gluoncv as gcv 10 | gcv.utils.check_version('0.6.0') 11 | from gluoncv.data import imagenet 12 | from gluoncv.utils import makedirs, LRSequential, LRScheduler 13 | 14 | sys.path.insert(1, os.path.join(sys.path[0], '..')) 15 | from efficientnet_model import get_efficientnet, get_efficientnet_lite 16 | 17 | # CLI 18 | def parse_args(): 19 | parser = argparse.ArgumentParser(description='Train a model for image classification.') 20 | parser.add_argument('--data-dir', type=str, default='~/.mxnet/datasets/imagenet', 21 | help='training and validation pictures to use.') 22 | parser.add_argument('--rec-train', type=str, default='~/.mxnet/datasets/imagenet/rec/train.rec', 23 | help='the training data') 24 | parser.add_argument('--rec-train-idx', type=str, default='~/.mxnet/datasets/imagenet/rec/train.idx', 25 | help='the index of training data') 26 | parser.add_argument('--rec-val', type=str, default='~/.mxnet/datasets/imagenet/rec/val.rec', 27 | help='the validation data') 28 | parser.add_argument('--rec-val-idx', type=str, default='~/.mxnet/datasets/imagenet/rec/val.idx', 29 | help='the index of validation data') 30 | parser.add_argument('--use-rec', action='store_true', 31 | help='use image record iter for data input. default is false.') 32 | parser.add_argument('--batch-size', type=int, default=32, 33 | help='training batch size per device (CPU/GPU).') 34 | parser.add_argument('--dtype', type=str, default='float32', 35 | help='data type for training. default is float32') 36 | parser.add_argument('--num-gpus', type=int, default=0, 37 | help='number of gpus to use.') 38 | parser.add_argument('-j', '--num-data-workers', dest='num_workers', default=4, type=int, 39 | help='number of preprocessing workers') 40 | parser.add_argument('--num-epochs', type=int, default=3, 41 | help='number of training epochs.') 42 | parser.add_argument('--lr', type=float, default=0.1, 43 | help='learning rate. default is 0.1.') 44 | parser.add_argument('--momentum', type=float, default=0.9, 45 | help='momentum value for optimizer, default is 0.9.') 46 | parser.add_argument('--wd', type=float, default=0.0001, 47 | help='weight decay rate. default is 0.0001.') 48 | parser.add_argument('--lr-mode', type=str, default='step', 49 | help='learning rate scheduler mode. options are step, poly and cosine.') 50 | parser.add_argument('--lr-decay', type=float, default=0.1, 51 | help='decay rate of learning rate. default is 0.1.') 52 | parser.add_argument('--lr-decay-period', type=int, default=0, 53 | help='interval for periodic learning rate decays. default is 0 to disable.') 54 | parser.add_argument('--lr-decay-epoch', type=str, default='40,60', 55 | help='epochs at which learning rate decays. default is 40,60.') 56 | parser.add_argument('--warmup-lr', type=float, default=0.0, 57 | help='starting warmup learning rate. default is 0.0.') 58 | parser.add_argument('--warmup-epochs', type=int, default=0, 59 | help='number of warmup epochs.') 60 | parser.add_argument('--last-gamma', action='store_true', 61 | help='whether to init gamma of the last BN layer in each bottleneck to 0.') 62 | parser.add_argument('--mode', type=str, 63 | help='mode in which to train the model. options are symbolic, imperative, hybrid') 64 | parser.add_argument('--model', type=str, required=True, 65 | help='type of model to use. see vision_model for options.') 66 | parser.add_argument('--input-size', type=int, default=224, 67 | help='size of the input image size. default is 224') 68 | parser.add_argument('--crop-ratio', type=float, default=0.875, 69 | help='Crop ratio during validation. default is 0.875') 70 | parser.add_argument('--no-wd', action='store_true', 71 | help='whether to remove weight decay on bias, and beta/gamma for batchnorm layers.') 72 | parser.add_argument('--batch-norm', action='store_true', 73 | help='enable batch normalization or not in vgg. default is false.') 74 | parser.add_argument('--save-frequency', type=int, default=10, 75 | help='frequency of model saving.') 76 | parser.add_argument('--save-dir', type=str, default='params', 77 | help='directory of saved models') 78 | parser.add_argument('--resume-epoch', type=int, default=0, 79 | help='epoch to resume training from.') 80 | parser.add_argument('--resume-params', type=str, default='', 81 | help='path of parameters to load from.') 82 | parser.add_argument('--resume-states', type=str, default='', 83 | help='path of trainer state to load from.') 84 | parser.add_argument('--log-interval', type=int, default=50, 85 | help='Number of batches to wait before logging.') 86 | parser.add_argument('--logging-file', type=str, default='train_imagenet.log', 87 | help='name of training log file') 88 | opt = parser.parse_args() 89 | return opt 90 | 91 | 92 | def main(): 93 | opt = parse_args() 94 | 95 | filehandler = logging.FileHandler(opt.logging_file) 96 | streamhandler = logging.StreamHandler() 97 | 98 | logger = logging.getLogger('') 99 | logger.setLevel(logging.INFO) 100 | logger.addHandler(filehandler) 101 | logger.addHandler(streamhandler) 102 | 103 | logger.info(opt) 104 | 105 | batch_size = opt.batch_size 106 | classes = 1000 107 | num_training_samples = 1281167 108 | 109 | num_gpus = opt.num_gpus 110 | batch_size *= max(1, num_gpus) 111 | context = [mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()] 112 | num_workers = opt.num_workers 113 | 114 | lr_decay = opt.lr_decay 115 | lr_decay_period = opt.lr_decay_period 116 | if opt.lr_decay_period > 0: 117 | lr_decay_epoch = list(range(lr_decay_period, opt.num_epochs, lr_decay_period)) 118 | else: 119 | lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')] 120 | lr_decay_epoch = [e - opt.warmup_epochs for e in lr_decay_epoch] 121 | num_batches = num_training_samples // batch_size 122 | 123 | lr_scheduler = LRSequential([ 124 | LRScheduler('linear', base_lr=0, target_lr=opt.lr, 125 | nepochs=opt.warmup_epochs, iters_per_epoch=num_batches), 126 | LRScheduler(opt.lr_mode, base_lr=opt.lr, target_lr=0, 127 | nepochs=opt.num_epochs - opt.warmup_epochs, 128 | iters_per_epoch=num_batches, 129 | step_epoch=lr_decay_epoch, 130 | step_factor=lr_decay, power=2) 131 | ]) 132 | 133 | optimizer = 'nag' 134 | optimizer_params = {'wd': opt.wd, 'momentum': opt.momentum, 'lr_scheduler': lr_scheduler} 135 | if opt.dtype != 'float32': 136 | optimizer_params['multi_precision'] = True 137 | 138 | model_name = opt.model 139 | if 'lite' in model_name: 140 | net, input_size = get_efficientnet_lite(model_name, num_classes=classes) 141 | else: 142 | net, input_size = get_efficientnet(model_name, num_classes=classes) 143 | assert input_size == opt.input_size 144 | net.cast(opt.dtype) 145 | if opt.resume_params is not '': 146 | net.load_parameters(opt.resume_params, ctx = context) 147 | 148 | # Two functions for reading data from record file or raw images 149 | def get_data_rec(rec_train, rec_train_idx, rec_val, rec_val_idx, batch_size, num_workers): 150 | rec_train = os.path.expanduser(rec_train) 151 | rec_train_idx = os.path.expanduser(rec_train_idx) 152 | rec_val = os.path.expanduser(rec_val) 153 | rec_val_idx = os.path.expanduser(rec_val_idx) 154 | jitter_param = 0.4 155 | lighting_param = 0.1 156 | input_size = opt.input_size 157 | crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875 158 | resize = int(math.ceil(input_size / crop_ratio)) 159 | mean_rgb = [123.68, 116.779, 103.939] 160 | std_rgb = [58.393, 57.12, 57.375] 161 | 162 | def batch_fn(batch, ctx): 163 | data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0) 164 | label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0) 165 | return data, label 166 | 167 | train_data = mx.io.ImageRecordIter( 168 | path_imgrec = rec_train, 169 | path_imgidx = rec_train_idx, 170 | preprocess_threads = num_workers, 171 | shuffle = True, 172 | batch_size = batch_size, 173 | 174 | data_shape = (3, input_size, input_size), 175 | mean_r = mean_rgb[0], 176 | mean_g = mean_rgb[1], 177 | mean_b = mean_rgb[2], 178 | std_r = std_rgb[0], 179 | std_g = std_rgb[1], 180 | std_b = std_rgb[2], 181 | rand_mirror = True, 182 | random_resized_crop = True, 183 | max_aspect_ratio = 4. / 3., 184 | min_aspect_ratio = 3. / 4., 185 | max_random_area = 1, 186 | min_random_area = 0.08, 187 | brightness = jitter_param, 188 | saturation = jitter_param, 189 | contrast = jitter_param, 190 | pca_noise = lighting_param, 191 | ) 192 | val_data = mx.io.ImageRecordIter( 193 | path_imgrec = rec_val, 194 | path_imgidx = rec_val_idx, 195 | preprocess_threads = num_workers, 196 | shuffle = False, 197 | batch_size = batch_size, 198 | 199 | resize = resize, 200 | data_shape = (3, input_size, input_size), 201 | mean_r = mean_rgb[0], 202 | mean_g = mean_rgb[1], 203 | mean_b = mean_rgb[2], 204 | std_r = std_rgb[0], 205 | std_g = std_rgb[1], 206 | std_b = std_rgb[2], 207 | ) 208 | return train_data, val_data, batch_fn 209 | 210 | def get_data_loader(data_dir, batch_size, num_workers): 211 | normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 212 | jitter_param = 0.4 213 | lighting_param = 0.1 214 | input_size = opt.input_size 215 | crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875 216 | resize = int(math.ceil(input_size / crop_ratio)) 217 | 218 | def batch_fn(batch, ctx): 219 | data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0) 220 | label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0) 221 | return data, label 222 | 223 | transform_train = transforms.Compose([ 224 | transforms.RandomResizedCrop(input_size), 225 | transforms.RandomFlipLeftRight(), 226 | transforms.RandomColorJitter(brightness=jitter_param, contrast=jitter_param, 227 | saturation=jitter_param), 228 | transforms.RandomLighting(lighting_param), 229 | transforms.ToTensor(), 230 | normalize 231 | ]) 232 | transform_test = transforms.Compose([ 233 | transforms.Resize(resize, keep_ratio=True), 234 | transforms.CenterCrop(input_size), 235 | transforms.ToTensor(), 236 | normalize 237 | ]) 238 | 239 | train_data = gluon.data.DataLoader( 240 | imagenet.classification.ImageNet(data_dir, train=True).transform_first(transform_train), 241 | batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=num_workers) 242 | val_data = gluon.data.DataLoader( 243 | imagenet.classification.ImageNet(data_dir, train=False).transform_first(transform_test), 244 | batch_size=batch_size, shuffle=False, num_workers=num_workers) 245 | 246 | return train_data, val_data, batch_fn 247 | 248 | if opt.use_rec: 249 | train_data, val_data, batch_fn = get_data_rec(opt.rec_train, opt.rec_train_idx, 250 | opt.rec_val, opt.rec_val_idx, 251 | batch_size, num_workers) 252 | else: 253 | train_data, val_data, batch_fn = get_data_loader(opt.data_dir, batch_size, num_workers) 254 | 255 | train_metric = mx.metric.Accuracy() 256 | acc_top1 = mx.metric.Accuracy() 257 | acc_top5 = mx.metric.TopKAccuracy(5) 258 | 259 | save_frequency = opt.save_frequency 260 | if opt.save_dir and save_frequency: 261 | save_dir = opt.save_dir 262 | makedirs(save_dir) 263 | else: 264 | save_dir = '' 265 | save_frequency = 0 266 | 267 | def test(ctx, val_data): 268 | if opt.use_rec: 269 | val_data.reset() 270 | acc_top1.reset() 271 | acc_top5.reset() 272 | for i, batch in enumerate(val_data): 273 | data, label = batch_fn(batch, ctx) 274 | outputs = [net(X.astype(opt.dtype, copy=False)) for X in data] 275 | acc_top1.update(label, outputs) 276 | acc_top5.update(label, outputs) 277 | 278 | _, top1 = acc_top1.get() 279 | _, top5 = acc_top5.get() 280 | return (1-top1, 1-top5) 281 | 282 | def train(ctx): 283 | if isinstance(ctx, mx.Context): 284 | ctx = [ctx] 285 | if opt.resume_params is '': 286 | net.initialize(mx.init.MSRAPrelu(), ctx=ctx) 287 | 288 | if opt.no_wd: 289 | for k, v in net.collect_params('.*beta|.*gamma|.*bias').items(): 290 | v.wd_mult = 0.0 291 | 292 | trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params) 293 | if opt.resume_states is not '': 294 | trainer.load_states(opt.resume_states) 295 | 296 | L = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=True) 297 | 298 | best_val_score = 1 299 | 300 | for epoch in range(opt.resume_epoch, opt.num_epochs): 301 | tic = time.time() 302 | if opt.use_rec: 303 | train_data.reset() 304 | train_metric.reset() 305 | btic = time.time() 306 | 307 | for i, batch in enumerate(train_data): 308 | data, label = batch_fn(batch, ctx) 309 | with ag.record(): 310 | outputs = [net(X.astype(opt.dtype, copy=False)) for X in data] 311 | loss = [L(yhat, y.astype(opt.dtype, copy=False)) for yhat, y in zip(outputs, label)] 312 | for l in loss: 313 | l.backward() 314 | trainer.step(batch_size) 315 | 316 | train_metric.update(label, outputs) 317 | 318 | if opt.log_interval and not (i+1)%opt.log_interval: 319 | train_metric_name, train_metric_score = train_metric.get() 320 | logger.info('Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f\tlr=%f'%( 321 | epoch, i, batch_size*opt.log_interval/(time.time()-btic), 322 | train_metric_name, train_metric_score, trainer.learning_rate)) 323 | btic = time.time() 324 | 325 | train_metric_name, train_metric_score = train_metric.get() 326 | throughput = int(batch_size * i /(time.time() - tic)) 327 | 328 | err_top1_val, err_top5_val = test(ctx, val_data) 329 | 330 | logger.info('[Epoch %d] training: %s=%f'%(epoch, train_metric_name, train_metric_score)) 331 | logger.info('[Epoch %d] speed: %d samples/sec\ttime cost: %f'%(epoch, throughput, time.time()-tic)) 332 | logger.info('[Epoch %d] validation: err-top1=%f err-top5=%f'%(epoch, err_top1_val, err_top5_val)) 333 | 334 | if err_top1_val < best_val_score: 335 | best_val_score = err_top1_val 336 | net.save_parameters('%s/%.4f-imagenet-%s-%d-best.params'%(save_dir, best_val_score, model_name, epoch)) 337 | trainer.save_states('%s/%.4f-imagenet-%s-%d-best.states'%(save_dir, best_val_score, model_name, epoch)) 338 | 339 | if save_frequency and save_dir and (epoch + 1) % save_frequency == 0: 340 | net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, epoch)) 341 | trainer.save_states('%s/imagenet-%s-%d.states'%(save_dir, model_name, epoch)) 342 | 343 | if save_frequency and save_dir: 344 | net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, opt.num_epochs-1)) 345 | trainer.save_states('%s/imagenet-%s-%d.states'%(save_dir, model_name, opt.num_epochs-1)) 346 | 347 | if opt.mode == 'hybrid': 348 | net.hybridize(static_alloc=True, static_shape=True) 349 | 350 | train(context) 351 | 352 | if __name__ == '__main__': 353 | main() 354 | -------------------------------------------------------------------------------- /train_imagenet/train_dali.py: -------------------------------------------------------------------------------- 1 | import argparse, time, logging, os, sys, math 2 | 3 | import mxnet as mx 4 | from mxnet import gluon, autograd 5 | 6 | from nvidia.dali.pipeline import Pipeline 7 | import nvidia.dali.ops as ops 8 | import nvidia.dali.types as types 9 | from nvidia.dali.plugin.mxnet import DALIClassificationIterator 10 | 11 | sys.path.insert(1, os.path.join(sys.path[0], '..')) 12 | from efficientnet_model import get_efficientnet, get_efficientnet_lite 13 | 14 | 15 | # DALI RECORD PIPELINE 16 | class HybridRecPipe(Pipeline): 17 | def __init__(self, db_prefix, for_train, input_size, batch_size, num_threads, device_id, num_gpus): 18 | super(HybridRecPipe, self).__init__(batch_size, num_threads, device_id, seed=12+device_id, prefetch_queue_depth=2) 19 | self.for_train = for_train 20 | self.input = ops.MXNetReader(path=[db_prefix + ".rec"], index_path=[db_prefix + ".idx"], 21 | random_shuffle=for_train, shard_id=device_id, num_shards=num_gpus) 22 | self.resize = ops.Resize(device="gpu", resize_x=input_size, resize_y=input_size) 23 | self.cmnp = ops.CropMirrorNormalize(device = "gpu", 24 | output_dtype = types.FLOAT, 25 | output_layout = types.NCHW, 26 | crop = (input_size, input_size), 27 | image_type = types.RGB, 28 | mean = [0.485 * 255,0.456 * 255,0.406 * 255], 29 | std = [0.229 * 255,0.224 * 255,0.225 * 255]) 30 | if self.for_train: 31 | self.decode = ops.ImageDecoderRandomCrop(device="mixed", 32 | output_type=types.RGB, 33 | random_aspect_ratio=[3/4, 4/3], 34 | random_area=[0.08, 1.0], 35 | num_attempts=100) 36 | self.color = ops.ColorTwist(device='gpu') 37 | self.rng_brightness = ops.Uniform(range=(0.6, 1.4)) 38 | self.rng_contrast = ops.Uniform(range=(0.6, 1.4)) 39 | self.rng_saturation = ops.Uniform(range=(0.6, 1.4)) 40 | self.mirror_coin = ops.CoinFlip(probability=0.5) 41 | else: 42 | self.decode = ops.ImageDecoder(device="mixed", output_type=types.RGB) 43 | 44 | def define_graph(self): 45 | inputs, labels = self.input(name = "Reader") 46 | images = self.decode(inputs) 47 | images = self.resize(images) 48 | if self.for_train: 49 | images = self.color(images, brightness=self.rng_brightness(), 50 | contrast=self.rng_contrast(), saturation=self.rng_saturation()) 51 | output = self.cmnp(images, mirror=self.mirror_coin()) 52 | else: 53 | output = self.cmnp(images) 54 | return [output, labels.gpu()] 55 | 56 | 57 | def get_rec_data_iterators(train_db_prefix, val_db_prefix, input_size, batch_size, devices): 58 | num_threads = 2 59 | num_shards = len(devices) 60 | train_pipes = [HybridRecPipe(train_db_prefix, True, input_size, batch_size, 61 | num_threads, device_id, num_shards) for device_id in range(num_shards)] 62 | # Build train pipeline to get the epoch size out of the reader 63 | train_pipes[0].build() 64 | print("Training pipeline epoch size: {}".format(train_pipes[0].epoch_size("Reader"))) 65 | # Make train MXNet iterators out of rec pipelines 66 | dali_train_iter = DALIClassificationIterator(train_pipes, train_pipes[0].epoch_size("Reader")) 67 | if val_db_prefix: 68 | val_pipes = [HybridRecPipe(val_db_prefix, False, input_size, batch_size, 69 | num_threads, device_id, num_shards) for device_id in range(num_shards)] 70 | # Build val pipeline get the epoch size out of the reader 71 | val_pipes[0].build() 72 | print("Validation pipeline epoch size: {}".format(val_pipes[0].epoch_size("Reader"))) 73 | # Make val MXNet iterators out of rec pipelines 74 | dali_val_iter = DALIClassificationIterator(val_pipes, val_pipes[0].epoch_size("Reader")) 75 | else: 76 | dali_val_iter = None 77 | return dali_train_iter, dali_val_iter 78 | 79 | 80 | # CLI 81 | def parse_args(): 82 | parser = argparse.ArgumentParser(description='Train a model for image classification.') 83 | parser.add_argument('--rec-train', type=str, default='~/.mxnet/datasets/imagenet/rec/train', 84 | help='the training data') 85 | parser.add_argument('--rec-val', type=str, default='~/.mxnet/datasets/imagenet/rec/val', 86 | help='the validation data') 87 | parser.add_argument('--batch-size', type=int, default=32, 88 | help='training batch size per device (CPU/GPU).') 89 | parser.add_argument('--num-gpus', type=int, default=0, 90 | help='number of gpus to use.') 91 | parser.add_argument('--num-epochs', type=int, default=3, 92 | help='number of training epochs.') 93 | parser.add_argument('--lr', type=float, default=0.1, 94 | help='learning rate. default is 0.1.') 95 | parser.add_argument('--momentum', type=float, default=0.9, 96 | help='momentum value for optimizer, default is 0.9.') 97 | parser.add_argument('--wd', type=float, default=0.0001, 98 | help='weight decay rate. default is 0.0001.') 99 | parser.add_argument('--lr-decay', type=float, default=0.1, 100 | help='decay rate of learning rate. default is 0.1.') 101 | parser.add_argument('--lr-decay-epoch', type=str, default='40,60', 102 | help='epochs at which learning rate decays. default is 40,60.') 103 | parser.add_argument('--mode', type=str, 104 | help='mode in which to train the model. options are symbolic, imperative, hybrid') 105 | parser.add_argument('--model', type=str, required=True, 106 | help='type of model to use. see vision_model for options.') 107 | parser.add_argument('--input-size', type=int, default=224, 108 | help='size of the input image size. default is 224') 109 | parser.add_argument('--no-wd', action='store_true', 110 | help='whether to remove weight decay on bias, and beta/gamma for batchnorm layers.') 111 | parser.add_argument('--save-frequency', type=int, default=10, 112 | help='frequency of model saving.') 113 | parser.add_argument('--save-dir', type=str, default='params', 114 | help='directory of saved models') 115 | parser.add_argument('--resume-epoch', type=int, default=0, 116 | help='epoch to resume training from.') 117 | parser.add_argument('--resume-params', type=str, default='', 118 | help='path of parameters to load from.') 119 | parser.add_argument('--resume-states', type=str, default='', 120 | help='path of trainer state to load from.') 121 | parser.add_argument('--log-interval', type=int, default=50, 122 | help='Number of batches to wait before logging.') 123 | parser.add_argument('--logging-file', type=str, default='train_imagenet.log', 124 | help='name of training log file') 125 | opt = parser.parse_args() 126 | return opt 127 | 128 | 129 | def main(): 130 | opt = parse_args() 131 | 132 | save_dir = opt.save_dir 133 | os.makedirs(save_dir, exist_ok=True) 134 | 135 | filehandler = logging.FileHandler(opt.logging_file) 136 | streamhandler = logging.StreamHandler() 137 | 138 | logger = logging.getLogger('') 139 | logger.setLevel(logging.INFO) 140 | logger.addHandler(filehandler) 141 | logger.addHandler(streamhandler) 142 | 143 | logger.info(opt) 144 | 145 | batch_size = opt.batch_size 146 | classes = 1000 147 | num_training_samples = 1281167 148 | 149 | num_gpus = opt.num_gpus 150 | batch_size *= max(1, num_gpus) 151 | context = [mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()] 152 | 153 | lr_decay = opt.lr_decay 154 | lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')] 155 | num_batches = num_training_samples // batch_size 156 | 157 | optimizer = 'nag' 158 | optimizer_params = {'wd': opt.wd, 'momentum': opt.momentum} 159 | 160 | model_name = opt.model 161 | if 'lite' in model_name: 162 | net, input_size = get_efficientnet_lite(model_name, num_classes=classes) 163 | else: 164 | net, input_size = get_efficientnet(model_name, num_classes=classes) 165 | assert input_size == opt.input_size 166 | if opt.resume_params is not '': 167 | net.load_parameters(opt.resume_params, ctx=context) 168 | 169 | if opt.mode == 'hybrid': 170 | net.hybridize(static_alloc=True, static_shape=True) 171 | 172 | # Two functions for reading data from record file or raw images 173 | def get_data_rec(rec_train_prefix, rec_val_prefix, batch_size, devices): 174 | rec_train_prefix = os.path.expanduser(rec_train_prefix) 175 | rec_val_prefix = os.path.expanduser(rec_val_prefix) 176 | input_size = opt.input_size 177 | 178 | def batch_fn(batch): 179 | data = [b.data[0] for b in batch] 180 | label = [b.label[0] for b in batch] 181 | return data, label 182 | 183 | train_data, val_data = get_rec_data_iterators(rec_train_prefix, rec_val_prefix, input_size, batch_size // len(devices), devices) 184 | return train_data, val_data, batch_fn 185 | 186 | train_data, val_data, batch_fn = get_data_rec(opt.rec_train, opt.rec_val, batch_size, context) 187 | 188 | train_metric = mx.metric.Accuracy() 189 | acc_top1 = mx.metric.Accuracy() 190 | acc_top5 = mx.metric.TopKAccuracy(5) 191 | 192 | save_frequency = opt.save_frequency 193 | 194 | def test(ctx, val_data): 195 | val_data.reset() 196 | acc_top1.reset() 197 | acc_top5.reset() 198 | for i, batch in enumerate(val_data): 199 | data, label = batch_fn(batch) 200 | outputs = [net(X) for X in data] 201 | acc_top1.update(label, outputs) 202 | acc_top5.update(label, outputs) 203 | 204 | _, top1 = acc_top1.get() 205 | _, top5 = acc_top5.get() 206 | return (1-top1, 1-top5) 207 | 208 | def train(ctx): 209 | if isinstance(ctx, mx.Context): 210 | ctx = [ctx] 211 | if opt.resume_params is '': 212 | net.collect_params('.*gamma|.*alpha|.*running_mean|.*running_var').initialize(mx.init.Constant(1), ctx=ctx) 213 | net.collect_params('.*beta|.*bias').initialize(mx.init.Constant(0.0), ctx=ctx) 214 | net.collect_params('.*weight').initialize(mx.init.Xavier(), ctx=ctx) 215 | 216 | if opt.no_wd: 217 | for k, v in net.collect_params('.*beta|.*gamma|.*bias').items(): 218 | v.wd_mult = 0.0 219 | 220 | trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params, kvstore='local') 221 | if opt.resume_states is not '': 222 | trainer.load_states(opt.resume_states) 223 | 224 | L = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=True) 225 | 226 | best_val_score = 1 227 | 228 | trainer.set_learning_rate(opt.lr) 229 | for epoch in range(opt.resume_epoch, opt.num_epochs): 230 | tic = time.time() 231 | train_data.reset() 232 | train_metric.reset() 233 | btic = time.time() 234 | 235 | if epoch in lr_decay_epoch: 236 | trainer.set_learning_rate(trainer.learning_rate * lr_decay) 237 | logger.info('Learning rate has been changed to %f' % trainer.learning_rate) 238 | 239 | for i, batch in enumerate(train_data): 240 | data, label = batch_fn(batch) 241 | with autograd.record(): 242 | outputs = [net(X) for X in data] 243 | loss = [L(yhat, y) for yhat, y in zip(outputs, label)] 244 | for l in loss: 245 | l.backward() 246 | trainer.step(batch_size) 247 | 248 | train_metric.update(label, outputs) 249 | 250 | if opt.log_interval and not (i+1)%opt.log_interval: 251 | train_metric_name, train_metric_score = train_metric.get() 252 | logger.info('Epoch[%d/%d] Batch [%d/%d]\tSpeed: %f samples/sec\t%s=%f\tlr=%f'%( 253 | epoch, opt.num_epochs, i, num_batches, 254 | batch_size*opt.log_interval/(time.time()-btic), 255 | train_metric_name, train_metric_score, trainer.learning_rate)) 256 | btic = time.time() 257 | 258 | train_metric_name, train_metric_score = train_metric.get() 259 | throughput = int(batch_size * i /(time.time() - tic)) 260 | 261 | err_top1_val, err_top5_val = test(ctx, val_data) 262 | 263 | logger.info('[Epoch %d] training: %s=%f'%(epoch, train_metric_name, train_metric_score)) 264 | logger.info('[Epoch %d] speed: %d samples/sec\ttime cost: %f'%(epoch, throughput, time.time()-tic)) 265 | logger.info('[Epoch %d] validation: err-top1=%f err-top5=%f'%(epoch, err_top1_val, err_top5_val)) 266 | 267 | if err_top1_val < best_val_score: 268 | best_val_score = err_top1_val 269 | net.save_parameters('%s/%.4f-imagenet-%s-%d-best.params'%(save_dir, best_val_score, model_name, epoch)) 270 | trainer.save_states('%s/%.4f-imagenet-%s-%d-best.states'%(save_dir, best_val_score, model_name, epoch)) 271 | 272 | if save_frequency and save_dir and (epoch + 1) % save_frequency == 0: 273 | net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, epoch)) 274 | trainer.save_states('%s/imagenet-%s-%d.states'%(save_dir, model_name, epoch)) 275 | 276 | if save_frequency and save_dir: 277 | net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, opt.num_epochs-1)) 278 | trainer.save_states('%s/imagenet-%s-%d.states'%(save_dir, model_name, opt.num_epochs-1)) 279 | 280 | train(context) 281 | 282 | if __name__ == '__main__': 283 | main() 284 | --------------------------------------------------------------------------------