├── README.md
├── efficientnet_model.py
└── train_imagenet
├── train.py
└── train_dali.py
/README.md:
--------------------------------------------------------------------------------
1 | # EfficientNet-Gluon
2 | [EfficientNet](https://arxiv.org/abs/1905.11946) Gluon implementation
3 |
4 | ## ImageNet experiments
5 |
6 | ### Requirements
7 | Python 3.7 or later with packages:
8 | - `mxnet >= 1.5.0`
9 | - `gluoncv >= 0.6.0`
10 | - `nvidia-dali >= 0.19.0`
11 |
12 | ### Usage
13 | #### Prepare ImageNet dataset
14 | 1. Download and extract dataset following this tutorial:
15 | https://gluon-cv.mxnet.io/build/examples_datasets/imagenet.html
16 | 2. Create mxnet-record files following this turorial:
17 | https://gluon-cv.mxnet.io/build/examples_datasets/recordio.html#imagerecord-file-for-imagenet
18 |
19 | #### Clone this repo
20 | ```
21 | git clone https://github.com/mnikitin/EfficientNet.git
22 | cd EfficientNet/train_imagenet
23 | ```
24 |
25 | #### Train your model
26 | Example of training *efficientnet-b0* with *nvidia-dali data loader* using 4 gpus:
27 | ```
28 | IMAGENET_RECORD_ROOT='path/to/imagenet/record/files'
29 | MODEL='efficientnet-b0'
30 | python3 train_dali.py --rec-train $IMAGENET_RECORD_ROOT/train --rec-val $IMAGENET_RECORD_ROOT/val --input-size 224 --batch-size 64 --num-gpus 4 --num-epochs 80 --lr 0.1 --lr-decay-epoch 40,60 --save-dir params-$MODEL --logging-file params-$MODEL/log.txt --save-frequency 5 --mode hybrid --model $MODEL
31 | ```
32 |
33 | ### Results
34 | Code in this repo was used to train *efficientnet-b0* and *efficientnet-lite0* models.
35 | Pretrained params are avaliable (18.8 mb in total = 13.7 mb for *extractor* + 5.1 mb for *classifier*).
36 |
37 |
38 |
39 | |
40 | err-top1 |
41 | err-top5 |
42 | pretrained params |
43 |
44 |
45 | efficientnet-b0 |
46 | 0.335842 |
47 | 0.128043 |
48 | dropbox link |
49 |
50 |
51 | efficientnet-lite0 |
52 | 0.305316 |
53 | 0.106322 |
54 | dropbox link |
55 |
56 |
57 |
58 | **Note** that due to limited computational resources obtained results are worse than in the original paper.
59 | Moreover, *efficientnet-lite0* was trained using more gpus and bigger batch size, so in spite of simpler architecture (relu6 instead of swish) its results are better than for *efficientnet-b0* model.
60 | Anyway, I believe provided pretrained params can serve as a good initialization for your task.
61 |
62 | That's how *efficientnet-b0* and *efficientnet-lite0* were trained exactly:
63 | ```
64 | MODEL='efficientnet-b0'
65 | python3 train_dali.py --rec-train $IMAGENET_RECORD_ROOT/train --rec-val $IMAGENET_RECORD_ROOT/val --input-size 224 --batch-size 56 --num-gpus 4 --num-epochs 50 --lr 0.1 --lr-decay-epoch 20,30,40 --save-dir params-$MODEL --logging-file params-$MODEL/log.txt --save-frequency 5 --mode hybrid --model $MODEL
66 | ```
67 | ```
68 | MODEL='efficientnet-lite0'
69 | python3 train_dali.py --rec-train $IMAGENET_RECORD_ROOT/train --rec-val $IMAGENET_RECORD_ROOT/val --input-size 224 --batch-size 72 --num-gpus 6 --num-epochs 60 --lr 0.1 --lr-decay-epoch 20,35,50 --save-dir params-$MODEL --logging-file params-$MODEL/log.txt --save-frequency 5 --mode hybrid --model $MODEL
70 | ```
71 |
--------------------------------------------------------------------------------
/efficientnet_model.py:
--------------------------------------------------------------------------------
1 | from mxnet.gluon.block import HybridBlock
2 | from mxnet.gluon import nn
3 | from math import ceil
4 |
5 |
6 | class ReLU6(nn.HybridBlock):
7 | def __init__(self, **kwargs):
8 | super(ReLU6, self).__init__(**kwargs)
9 |
10 | def hybrid_forward(self, F, x):
11 | return F.clip(x, 0, 6, name="relu6")
12 |
13 |
14 | def _add_conv(out, channels=1, kernel=1, stride=1, pad=0,
15 | num_group=1, active=True, lite=False):
16 | out.add(nn.Conv2D(channels, kernel, stride, pad, groups=num_group, use_bias=False))
17 | out.add(nn.BatchNorm(scale=True, momentum=0.99, epsilon=1e-3))
18 | if active:
19 | if lite:
20 | out.add(ReLU6())
21 | else:
22 | out.add(nn.Swish())
23 |
24 |
25 | class MBConv(nn.HybridBlock):
26 | def __init__(self, in_channels, channels, t, kernel, stride, lite, **kwargs):
27 | super(MBConv, self).__init__(**kwargs)
28 | self.use_shortcut = stride == 1 and in_channels == channels
29 | with self.name_scope():
30 | self.out = nn.HybridSequential()
31 | _add_conv(self.out, in_channels * t, active=True, lite=lite)
32 | _add_conv(self.out, in_channels * t, kernel=kernel, stride=stride,
33 | pad=int((kernel-1)/2), num_group=in_channels * t,
34 | active=True, lite=lite)
35 | _add_conv(self.out, channels, active=False, lite=lite)
36 |
37 | def hybrid_forward(self, F, x):
38 | out = self.out(x)
39 | if self.use_shortcut:
40 | out = F.elemwise_add(out, x)
41 | return out
42 |
43 |
44 | class EfficientNet(nn.HybridBlock):
45 | r"""
46 | Parameters
47 | ----------
48 | alpha : float, default 1.0
49 | The depth multiplier for controling the model size. The actual number of layers on each channel_size level
50 | is equal to the original number of layers multiplied by alpha.
51 | beta : float, default 1.0
52 | The width multiplier for controling the model size. The actual number of channels
53 | is equal to the original channel size multiplied by beta.
54 | dropout_rate : float, default 0.0
55 | Dropout probability for the final features layer.
56 | classes : int, default 1000
57 | Number of classes for the output layer.
58 | """
59 |
60 | def __init__(self, alpha=1.0, beta=1.0, lite=False,
61 | dropout_rate=0.0, classes=1000, **kwargs):
62 | super(EfficientNet, self).__init__(**kwargs)
63 | with self.name_scope():
64 | self.features = nn.HybridSequential(prefix='features_')
65 | with self.features.name_scope():
66 | # stem conv
67 | channels = 32 if lite else int(32 * beta)
68 | _add_conv(self.features, channels, kernel=3, stride=2, pad=1,
69 | active=True, lite=lite)
70 |
71 | # base model settings
72 | repeats = [1, 2, 2, 3, 3, 4, 1]
73 | channels_num = [16, 24, 40, 80, 112, 192, 320]
74 | kernels_num = [3, 3, 5, 3, 5, 5, 3]
75 | t_num = [1, 6, 6, 6, 6, 6, 6]
76 | strides_first = [1, 2, 2, 1, 2, 2, 1]
77 |
78 | # determine params of MBConv layers
79 | in_channels_group = []
80 | for rep, ch_num in zip([1] + repeats[:-1], [32] + channels_num[:-1]):
81 | in_channels_group += [int(ch_num * beta)] * int(ceil(alpha * rep))
82 | channels_group, kernels, ts, strides = [], [], [], []
83 | for rep, ch, kernel, t, s in zip(repeats, channels_num, kernels_num, t_num, strides_first):
84 | rep = int(ceil(alpha * rep))
85 | channels_group += [int(ch * beta)] * rep
86 | kernels += [kernel] * rep
87 | ts += [t] * rep
88 | strides += [s] + [1] * (rep - 1)
89 |
90 | # add MBConv layers
91 | for in_c, c, t, k, s in zip(in_channels_group, channels_group, ts, kernels, strides):
92 | self.features.add(MBConv(in_channels=in_c, channels=c, t=t, kernel=k,
93 | stride=s, lite=lite))
94 |
95 | # head layers
96 | last_channels = int(1280 * beta) if not lite and beta > 1.0 else 1280
97 | _add_conv(self.features, last_channels, active=True, lite=lite)
98 | self.features.add(nn.GlobalAvgPool2D())
99 |
100 | # features dropout
101 | self.dropout = nn.Dropout(dropout_rate) if dropout_rate > 0.0 else None
102 |
103 | # output layer
104 | self.output = nn.HybridSequential(prefix='output_')
105 | with self.output.name_scope():
106 | self.output.add(
107 | nn.Conv2D(classes, 1, use_bias=False, prefix='pred_'),
108 | nn.Flatten()
109 | )
110 |
111 | def hybrid_forward(self, F, x):
112 | x = self.features(x)
113 | if self.dropout:
114 | x = self.dropout(x)
115 | x = self.output(x)
116 | return x
117 |
118 |
119 | def get_efficientnet(model_name, num_classes=1000):
120 | params_dict = { # (width_coefficient, depth_coefficient, input_resolution, dropout_rate)
121 | 'efficientnet-b0': (1.0, 1.0, 224, 0.2),
122 | 'efficientnet-b1': (1.0, 1.1, 240, 0.2),
123 | 'efficientnet-b2': (1.1, 1.2, 260, 0.3),
124 | 'efficientnet-b3': (1.2, 1.4, 300, 0.3),
125 | 'efficientnet-b4': (1.4, 1.8, 380, 0.4),
126 | 'efficientnet-b5': (1.6, 2.2, 456, 0.4),
127 | 'efficientnet-b6': (1.8, 2.6, 528, 0.5),
128 | 'efficientnet-b7': (2.0, 3.1, 600, 0.5)
129 | }
130 | width_coeff, depth_coeff, input_resolution, dropout_rate = params_dict[model_name]
131 | model = EfficientNet(alpha=depth_coeff, beta=width_coeff, lite=False,
132 | dropout_rate=dropout_rate, classes=num_classes)
133 | return model, input_resolution
134 |
135 |
136 | def get_efficientnet_lite(model_name, num_classes=1000):
137 | params_dict = { # (width_coefficient, depth_coefficient, input_resolution, dropout_rate)
138 | 'efficientnet-lite0': (1.0, 1.0, 224, 0.2),
139 | 'efficientnet-lite1': (1.0, 1.1, 240, 0.2),
140 | 'efficientnet-lite2': (1.1, 1.2, 260, 0.3),
141 | 'efficientnet-lite3': (1.2, 1.4, 280, 0.3),
142 | 'efficientnet-lite4': (1.4, 1.8, 300, 0.3)
143 | }
144 | width_coeff, depth_coeff, input_resolution, dropout_rate = params_dict[model_name]
145 | model = EfficientNet(alpha=depth_coeff, beta=width_coeff, lite=True,
146 | dropout_rate=dropout_rate, classes=num_classes)
147 | return model, input_resolution
148 |
--------------------------------------------------------------------------------
/train_imagenet/train.py:
--------------------------------------------------------------------------------
1 | import argparse, time, logging, os, sys, math
2 |
3 | import mxnet as mx
4 | from mxnet import gluon, nd
5 | from mxnet import autograd as ag
6 | from mxnet.gluon import nn
7 | from mxnet.gluon.data.vision import transforms
8 |
9 | import gluoncv as gcv
10 | gcv.utils.check_version('0.6.0')
11 | from gluoncv.data import imagenet
12 | from gluoncv.utils import makedirs, LRSequential, LRScheduler
13 |
14 | sys.path.insert(1, os.path.join(sys.path[0], '..'))
15 | from efficientnet_model import get_efficientnet, get_efficientnet_lite
16 |
17 | # CLI
18 | def parse_args():
19 | parser = argparse.ArgumentParser(description='Train a model for image classification.')
20 | parser.add_argument('--data-dir', type=str, default='~/.mxnet/datasets/imagenet',
21 | help='training and validation pictures to use.')
22 | parser.add_argument('--rec-train', type=str, default='~/.mxnet/datasets/imagenet/rec/train.rec',
23 | help='the training data')
24 | parser.add_argument('--rec-train-idx', type=str, default='~/.mxnet/datasets/imagenet/rec/train.idx',
25 | help='the index of training data')
26 | parser.add_argument('--rec-val', type=str, default='~/.mxnet/datasets/imagenet/rec/val.rec',
27 | help='the validation data')
28 | parser.add_argument('--rec-val-idx', type=str, default='~/.mxnet/datasets/imagenet/rec/val.idx',
29 | help='the index of validation data')
30 | parser.add_argument('--use-rec', action='store_true',
31 | help='use image record iter for data input. default is false.')
32 | parser.add_argument('--batch-size', type=int, default=32,
33 | help='training batch size per device (CPU/GPU).')
34 | parser.add_argument('--dtype', type=str, default='float32',
35 | help='data type for training. default is float32')
36 | parser.add_argument('--num-gpus', type=int, default=0,
37 | help='number of gpus to use.')
38 | parser.add_argument('-j', '--num-data-workers', dest='num_workers', default=4, type=int,
39 | help='number of preprocessing workers')
40 | parser.add_argument('--num-epochs', type=int, default=3,
41 | help='number of training epochs.')
42 | parser.add_argument('--lr', type=float, default=0.1,
43 | help='learning rate. default is 0.1.')
44 | parser.add_argument('--momentum', type=float, default=0.9,
45 | help='momentum value for optimizer, default is 0.9.')
46 | parser.add_argument('--wd', type=float, default=0.0001,
47 | help='weight decay rate. default is 0.0001.')
48 | parser.add_argument('--lr-mode', type=str, default='step',
49 | help='learning rate scheduler mode. options are step, poly and cosine.')
50 | parser.add_argument('--lr-decay', type=float, default=0.1,
51 | help='decay rate of learning rate. default is 0.1.')
52 | parser.add_argument('--lr-decay-period', type=int, default=0,
53 | help='interval for periodic learning rate decays. default is 0 to disable.')
54 | parser.add_argument('--lr-decay-epoch', type=str, default='40,60',
55 | help='epochs at which learning rate decays. default is 40,60.')
56 | parser.add_argument('--warmup-lr', type=float, default=0.0,
57 | help='starting warmup learning rate. default is 0.0.')
58 | parser.add_argument('--warmup-epochs', type=int, default=0,
59 | help='number of warmup epochs.')
60 | parser.add_argument('--last-gamma', action='store_true',
61 | help='whether to init gamma of the last BN layer in each bottleneck to 0.')
62 | parser.add_argument('--mode', type=str,
63 | help='mode in which to train the model. options are symbolic, imperative, hybrid')
64 | parser.add_argument('--model', type=str, required=True,
65 | help='type of model to use. see vision_model for options.')
66 | parser.add_argument('--input-size', type=int, default=224,
67 | help='size of the input image size. default is 224')
68 | parser.add_argument('--crop-ratio', type=float, default=0.875,
69 | help='Crop ratio during validation. default is 0.875')
70 | parser.add_argument('--no-wd', action='store_true',
71 | help='whether to remove weight decay on bias, and beta/gamma for batchnorm layers.')
72 | parser.add_argument('--batch-norm', action='store_true',
73 | help='enable batch normalization or not in vgg. default is false.')
74 | parser.add_argument('--save-frequency', type=int, default=10,
75 | help='frequency of model saving.')
76 | parser.add_argument('--save-dir', type=str, default='params',
77 | help='directory of saved models')
78 | parser.add_argument('--resume-epoch', type=int, default=0,
79 | help='epoch to resume training from.')
80 | parser.add_argument('--resume-params', type=str, default='',
81 | help='path of parameters to load from.')
82 | parser.add_argument('--resume-states', type=str, default='',
83 | help='path of trainer state to load from.')
84 | parser.add_argument('--log-interval', type=int, default=50,
85 | help='Number of batches to wait before logging.')
86 | parser.add_argument('--logging-file', type=str, default='train_imagenet.log',
87 | help='name of training log file')
88 | opt = parser.parse_args()
89 | return opt
90 |
91 |
92 | def main():
93 | opt = parse_args()
94 |
95 | filehandler = logging.FileHandler(opt.logging_file)
96 | streamhandler = logging.StreamHandler()
97 |
98 | logger = logging.getLogger('')
99 | logger.setLevel(logging.INFO)
100 | logger.addHandler(filehandler)
101 | logger.addHandler(streamhandler)
102 |
103 | logger.info(opt)
104 |
105 | batch_size = opt.batch_size
106 | classes = 1000
107 | num_training_samples = 1281167
108 |
109 | num_gpus = opt.num_gpus
110 | batch_size *= max(1, num_gpus)
111 | context = [mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
112 | num_workers = opt.num_workers
113 |
114 | lr_decay = opt.lr_decay
115 | lr_decay_period = opt.lr_decay_period
116 | if opt.lr_decay_period > 0:
117 | lr_decay_epoch = list(range(lr_decay_period, opt.num_epochs, lr_decay_period))
118 | else:
119 | lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')]
120 | lr_decay_epoch = [e - opt.warmup_epochs for e in lr_decay_epoch]
121 | num_batches = num_training_samples // batch_size
122 |
123 | lr_scheduler = LRSequential([
124 | LRScheduler('linear', base_lr=0, target_lr=opt.lr,
125 | nepochs=opt.warmup_epochs, iters_per_epoch=num_batches),
126 | LRScheduler(opt.lr_mode, base_lr=opt.lr, target_lr=0,
127 | nepochs=opt.num_epochs - opt.warmup_epochs,
128 | iters_per_epoch=num_batches,
129 | step_epoch=lr_decay_epoch,
130 | step_factor=lr_decay, power=2)
131 | ])
132 |
133 | optimizer = 'nag'
134 | optimizer_params = {'wd': opt.wd, 'momentum': opt.momentum, 'lr_scheduler': lr_scheduler}
135 | if opt.dtype != 'float32':
136 | optimizer_params['multi_precision'] = True
137 |
138 | model_name = opt.model
139 | if 'lite' in model_name:
140 | net, input_size = get_efficientnet_lite(model_name, num_classes=classes)
141 | else:
142 | net, input_size = get_efficientnet(model_name, num_classes=classes)
143 | assert input_size == opt.input_size
144 | net.cast(opt.dtype)
145 | if opt.resume_params is not '':
146 | net.load_parameters(opt.resume_params, ctx = context)
147 |
148 | # Two functions for reading data from record file or raw images
149 | def get_data_rec(rec_train, rec_train_idx, rec_val, rec_val_idx, batch_size, num_workers):
150 | rec_train = os.path.expanduser(rec_train)
151 | rec_train_idx = os.path.expanduser(rec_train_idx)
152 | rec_val = os.path.expanduser(rec_val)
153 | rec_val_idx = os.path.expanduser(rec_val_idx)
154 | jitter_param = 0.4
155 | lighting_param = 0.1
156 | input_size = opt.input_size
157 | crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875
158 | resize = int(math.ceil(input_size / crop_ratio))
159 | mean_rgb = [123.68, 116.779, 103.939]
160 | std_rgb = [58.393, 57.12, 57.375]
161 |
162 | def batch_fn(batch, ctx):
163 | data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
164 | label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
165 | return data, label
166 |
167 | train_data = mx.io.ImageRecordIter(
168 | path_imgrec = rec_train,
169 | path_imgidx = rec_train_idx,
170 | preprocess_threads = num_workers,
171 | shuffle = True,
172 | batch_size = batch_size,
173 |
174 | data_shape = (3, input_size, input_size),
175 | mean_r = mean_rgb[0],
176 | mean_g = mean_rgb[1],
177 | mean_b = mean_rgb[2],
178 | std_r = std_rgb[0],
179 | std_g = std_rgb[1],
180 | std_b = std_rgb[2],
181 | rand_mirror = True,
182 | random_resized_crop = True,
183 | max_aspect_ratio = 4. / 3.,
184 | min_aspect_ratio = 3. / 4.,
185 | max_random_area = 1,
186 | min_random_area = 0.08,
187 | brightness = jitter_param,
188 | saturation = jitter_param,
189 | contrast = jitter_param,
190 | pca_noise = lighting_param,
191 | )
192 | val_data = mx.io.ImageRecordIter(
193 | path_imgrec = rec_val,
194 | path_imgidx = rec_val_idx,
195 | preprocess_threads = num_workers,
196 | shuffle = False,
197 | batch_size = batch_size,
198 |
199 | resize = resize,
200 | data_shape = (3, input_size, input_size),
201 | mean_r = mean_rgb[0],
202 | mean_g = mean_rgb[1],
203 | mean_b = mean_rgb[2],
204 | std_r = std_rgb[0],
205 | std_g = std_rgb[1],
206 | std_b = std_rgb[2],
207 | )
208 | return train_data, val_data, batch_fn
209 |
210 | def get_data_loader(data_dir, batch_size, num_workers):
211 | normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
212 | jitter_param = 0.4
213 | lighting_param = 0.1
214 | input_size = opt.input_size
215 | crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875
216 | resize = int(math.ceil(input_size / crop_ratio))
217 |
218 | def batch_fn(batch, ctx):
219 | data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
220 | label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
221 | return data, label
222 |
223 | transform_train = transforms.Compose([
224 | transforms.RandomResizedCrop(input_size),
225 | transforms.RandomFlipLeftRight(),
226 | transforms.RandomColorJitter(brightness=jitter_param, contrast=jitter_param,
227 | saturation=jitter_param),
228 | transforms.RandomLighting(lighting_param),
229 | transforms.ToTensor(),
230 | normalize
231 | ])
232 | transform_test = transforms.Compose([
233 | transforms.Resize(resize, keep_ratio=True),
234 | transforms.CenterCrop(input_size),
235 | transforms.ToTensor(),
236 | normalize
237 | ])
238 |
239 | train_data = gluon.data.DataLoader(
240 | imagenet.classification.ImageNet(data_dir, train=True).transform_first(transform_train),
241 | batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=num_workers)
242 | val_data = gluon.data.DataLoader(
243 | imagenet.classification.ImageNet(data_dir, train=False).transform_first(transform_test),
244 | batch_size=batch_size, shuffle=False, num_workers=num_workers)
245 |
246 | return train_data, val_data, batch_fn
247 |
248 | if opt.use_rec:
249 | train_data, val_data, batch_fn = get_data_rec(opt.rec_train, opt.rec_train_idx,
250 | opt.rec_val, opt.rec_val_idx,
251 | batch_size, num_workers)
252 | else:
253 | train_data, val_data, batch_fn = get_data_loader(opt.data_dir, batch_size, num_workers)
254 |
255 | train_metric = mx.metric.Accuracy()
256 | acc_top1 = mx.metric.Accuracy()
257 | acc_top5 = mx.metric.TopKAccuracy(5)
258 |
259 | save_frequency = opt.save_frequency
260 | if opt.save_dir and save_frequency:
261 | save_dir = opt.save_dir
262 | makedirs(save_dir)
263 | else:
264 | save_dir = ''
265 | save_frequency = 0
266 |
267 | def test(ctx, val_data):
268 | if opt.use_rec:
269 | val_data.reset()
270 | acc_top1.reset()
271 | acc_top5.reset()
272 | for i, batch in enumerate(val_data):
273 | data, label = batch_fn(batch, ctx)
274 | outputs = [net(X.astype(opt.dtype, copy=False)) for X in data]
275 | acc_top1.update(label, outputs)
276 | acc_top5.update(label, outputs)
277 |
278 | _, top1 = acc_top1.get()
279 | _, top5 = acc_top5.get()
280 | return (1-top1, 1-top5)
281 |
282 | def train(ctx):
283 | if isinstance(ctx, mx.Context):
284 | ctx = [ctx]
285 | if opt.resume_params is '':
286 | net.initialize(mx.init.MSRAPrelu(), ctx=ctx)
287 |
288 | if opt.no_wd:
289 | for k, v in net.collect_params('.*beta|.*gamma|.*bias').items():
290 | v.wd_mult = 0.0
291 |
292 | trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params)
293 | if opt.resume_states is not '':
294 | trainer.load_states(opt.resume_states)
295 |
296 | L = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=True)
297 |
298 | best_val_score = 1
299 |
300 | for epoch in range(opt.resume_epoch, opt.num_epochs):
301 | tic = time.time()
302 | if opt.use_rec:
303 | train_data.reset()
304 | train_metric.reset()
305 | btic = time.time()
306 |
307 | for i, batch in enumerate(train_data):
308 | data, label = batch_fn(batch, ctx)
309 | with ag.record():
310 | outputs = [net(X.astype(opt.dtype, copy=False)) for X in data]
311 | loss = [L(yhat, y.astype(opt.dtype, copy=False)) for yhat, y in zip(outputs, label)]
312 | for l in loss:
313 | l.backward()
314 | trainer.step(batch_size)
315 |
316 | train_metric.update(label, outputs)
317 |
318 | if opt.log_interval and not (i+1)%opt.log_interval:
319 | train_metric_name, train_metric_score = train_metric.get()
320 | logger.info('Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f\tlr=%f'%(
321 | epoch, i, batch_size*opt.log_interval/(time.time()-btic),
322 | train_metric_name, train_metric_score, trainer.learning_rate))
323 | btic = time.time()
324 |
325 | train_metric_name, train_metric_score = train_metric.get()
326 | throughput = int(batch_size * i /(time.time() - tic))
327 |
328 | err_top1_val, err_top5_val = test(ctx, val_data)
329 |
330 | logger.info('[Epoch %d] training: %s=%f'%(epoch, train_metric_name, train_metric_score))
331 | logger.info('[Epoch %d] speed: %d samples/sec\ttime cost: %f'%(epoch, throughput, time.time()-tic))
332 | logger.info('[Epoch %d] validation: err-top1=%f err-top5=%f'%(epoch, err_top1_val, err_top5_val))
333 |
334 | if err_top1_val < best_val_score:
335 | best_val_score = err_top1_val
336 | net.save_parameters('%s/%.4f-imagenet-%s-%d-best.params'%(save_dir, best_val_score, model_name, epoch))
337 | trainer.save_states('%s/%.4f-imagenet-%s-%d-best.states'%(save_dir, best_val_score, model_name, epoch))
338 |
339 | if save_frequency and save_dir and (epoch + 1) % save_frequency == 0:
340 | net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, epoch))
341 | trainer.save_states('%s/imagenet-%s-%d.states'%(save_dir, model_name, epoch))
342 |
343 | if save_frequency and save_dir:
344 | net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, opt.num_epochs-1))
345 | trainer.save_states('%s/imagenet-%s-%d.states'%(save_dir, model_name, opt.num_epochs-1))
346 |
347 | if opt.mode == 'hybrid':
348 | net.hybridize(static_alloc=True, static_shape=True)
349 |
350 | train(context)
351 |
352 | if __name__ == '__main__':
353 | main()
354 |
--------------------------------------------------------------------------------
/train_imagenet/train_dali.py:
--------------------------------------------------------------------------------
1 | import argparse, time, logging, os, sys, math
2 |
3 | import mxnet as mx
4 | from mxnet import gluon, autograd
5 |
6 | from nvidia.dali.pipeline import Pipeline
7 | import nvidia.dali.ops as ops
8 | import nvidia.dali.types as types
9 | from nvidia.dali.plugin.mxnet import DALIClassificationIterator
10 |
11 | sys.path.insert(1, os.path.join(sys.path[0], '..'))
12 | from efficientnet_model import get_efficientnet, get_efficientnet_lite
13 |
14 |
15 | # DALI RECORD PIPELINE
16 | class HybridRecPipe(Pipeline):
17 | def __init__(self, db_prefix, for_train, input_size, batch_size, num_threads, device_id, num_gpus):
18 | super(HybridRecPipe, self).__init__(batch_size, num_threads, device_id, seed=12+device_id, prefetch_queue_depth=2)
19 | self.for_train = for_train
20 | self.input = ops.MXNetReader(path=[db_prefix + ".rec"], index_path=[db_prefix + ".idx"],
21 | random_shuffle=for_train, shard_id=device_id, num_shards=num_gpus)
22 | self.resize = ops.Resize(device="gpu", resize_x=input_size, resize_y=input_size)
23 | self.cmnp = ops.CropMirrorNormalize(device = "gpu",
24 | output_dtype = types.FLOAT,
25 | output_layout = types.NCHW,
26 | crop = (input_size, input_size),
27 | image_type = types.RGB,
28 | mean = [0.485 * 255,0.456 * 255,0.406 * 255],
29 | std = [0.229 * 255,0.224 * 255,0.225 * 255])
30 | if self.for_train:
31 | self.decode = ops.ImageDecoderRandomCrop(device="mixed",
32 | output_type=types.RGB,
33 | random_aspect_ratio=[3/4, 4/3],
34 | random_area=[0.08, 1.0],
35 | num_attempts=100)
36 | self.color = ops.ColorTwist(device='gpu')
37 | self.rng_brightness = ops.Uniform(range=(0.6, 1.4))
38 | self.rng_contrast = ops.Uniform(range=(0.6, 1.4))
39 | self.rng_saturation = ops.Uniform(range=(0.6, 1.4))
40 | self.mirror_coin = ops.CoinFlip(probability=0.5)
41 | else:
42 | self.decode = ops.ImageDecoder(device="mixed", output_type=types.RGB)
43 |
44 | def define_graph(self):
45 | inputs, labels = self.input(name = "Reader")
46 | images = self.decode(inputs)
47 | images = self.resize(images)
48 | if self.for_train:
49 | images = self.color(images, brightness=self.rng_brightness(),
50 | contrast=self.rng_contrast(), saturation=self.rng_saturation())
51 | output = self.cmnp(images, mirror=self.mirror_coin())
52 | else:
53 | output = self.cmnp(images)
54 | return [output, labels.gpu()]
55 |
56 |
57 | def get_rec_data_iterators(train_db_prefix, val_db_prefix, input_size, batch_size, devices):
58 | num_threads = 2
59 | num_shards = len(devices)
60 | train_pipes = [HybridRecPipe(train_db_prefix, True, input_size, batch_size,
61 | num_threads, device_id, num_shards) for device_id in range(num_shards)]
62 | # Build train pipeline to get the epoch size out of the reader
63 | train_pipes[0].build()
64 | print("Training pipeline epoch size: {}".format(train_pipes[0].epoch_size("Reader")))
65 | # Make train MXNet iterators out of rec pipelines
66 | dali_train_iter = DALIClassificationIterator(train_pipes, train_pipes[0].epoch_size("Reader"))
67 | if val_db_prefix:
68 | val_pipes = [HybridRecPipe(val_db_prefix, False, input_size, batch_size,
69 | num_threads, device_id, num_shards) for device_id in range(num_shards)]
70 | # Build val pipeline get the epoch size out of the reader
71 | val_pipes[0].build()
72 | print("Validation pipeline epoch size: {}".format(val_pipes[0].epoch_size("Reader")))
73 | # Make val MXNet iterators out of rec pipelines
74 | dali_val_iter = DALIClassificationIterator(val_pipes, val_pipes[0].epoch_size("Reader"))
75 | else:
76 | dali_val_iter = None
77 | return dali_train_iter, dali_val_iter
78 |
79 |
80 | # CLI
81 | def parse_args():
82 | parser = argparse.ArgumentParser(description='Train a model for image classification.')
83 | parser.add_argument('--rec-train', type=str, default='~/.mxnet/datasets/imagenet/rec/train',
84 | help='the training data')
85 | parser.add_argument('--rec-val', type=str, default='~/.mxnet/datasets/imagenet/rec/val',
86 | help='the validation data')
87 | parser.add_argument('--batch-size', type=int, default=32,
88 | help='training batch size per device (CPU/GPU).')
89 | parser.add_argument('--num-gpus', type=int, default=0,
90 | help='number of gpus to use.')
91 | parser.add_argument('--num-epochs', type=int, default=3,
92 | help='number of training epochs.')
93 | parser.add_argument('--lr', type=float, default=0.1,
94 | help='learning rate. default is 0.1.')
95 | parser.add_argument('--momentum', type=float, default=0.9,
96 | help='momentum value for optimizer, default is 0.9.')
97 | parser.add_argument('--wd', type=float, default=0.0001,
98 | help='weight decay rate. default is 0.0001.')
99 | parser.add_argument('--lr-decay', type=float, default=0.1,
100 | help='decay rate of learning rate. default is 0.1.')
101 | parser.add_argument('--lr-decay-epoch', type=str, default='40,60',
102 | help='epochs at which learning rate decays. default is 40,60.')
103 | parser.add_argument('--mode', type=str,
104 | help='mode in which to train the model. options are symbolic, imperative, hybrid')
105 | parser.add_argument('--model', type=str, required=True,
106 | help='type of model to use. see vision_model for options.')
107 | parser.add_argument('--input-size', type=int, default=224,
108 | help='size of the input image size. default is 224')
109 | parser.add_argument('--no-wd', action='store_true',
110 | help='whether to remove weight decay on bias, and beta/gamma for batchnorm layers.')
111 | parser.add_argument('--save-frequency', type=int, default=10,
112 | help='frequency of model saving.')
113 | parser.add_argument('--save-dir', type=str, default='params',
114 | help='directory of saved models')
115 | parser.add_argument('--resume-epoch', type=int, default=0,
116 | help='epoch to resume training from.')
117 | parser.add_argument('--resume-params', type=str, default='',
118 | help='path of parameters to load from.')
119 | parser.add_argument('--resume-states', type=str, default='',
120 | help='path of trainer state to load from.')
121 | parser.add_argument('--log-interval', type=int, default=50,
122 | help='Number of batches to wait before logging.')
123 | parser.add_argument('--logging-file', type=str, default='train_imagenet.log',
124 | help='name of training log file')
125 | opt = parser.parse_args()
126 | return opt
127 |
128 |
129 | def main():
130 | opt = parse_args()
131 |
132 | save_dir = opt.save_dir
133 | os.makedirs(save_dir, exist_ok=True)
134 |
135 | filehandler = logging.FileHandler(opt.logging_file)
136 | streamhandler = logging.StreamHandler()
137 |
138 | logger = logging.getLogger('')
139 | logger.setLevel(logging.INFO)
140 | logger.addHandler(filehandler)
141 | logger.addHandler(streamhandler)
142 |
143 | logger.info(opt)
144 |
145 | batch_size = opt.batch_size
146 | classes = 1000
147 | num_training_samples = 1281167
148 |
149 | num_gpus = opt.num_gpus
150 | batch_size *= max(1, num_gpus)
151 | context = [mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
152 |
153 | lr_decay = opt.lr_decay
154 | lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')]
155 | num_batches = num_training_samples // batch_size
156 |
157 | optimizer = 'nag'
158 | optimizer_params = {'wd': opt.wd, 'momentum': opt.momentum}
159 |
160 | model_name = opt.model
161 | if 'lite' in model_name:
162 | net, input_size = get_efficientnet_lite(model_name, num_classes=classes)
163 | else:
164 | net, input_size = get_efficientnet(model_name, num_classes=classes)
165 | assert input_size == opt.input_size
166 | if opt.resume_params is not '':
167 | net.load_parameters(opt.resume_params, ctx=context)
168 |
169 | if opt.mode == 'hybrid':
170 | net.hybridize(static_alloc=True, static_shape=True)
171 |
172 | # Two functions for reading data from record file or raw images
173 | def get_data_rec(rec_train_prefix, rec_val_prefix, batch_size, devices):
174 | rec_train_prefix = os.path.expanduser(rec_train_prefix)
175 | rec_val_prefix = os.path.expanduser(rec_val_prefix)
176 | input_size = opt.input_size
177 |
178 | def batch_fn(batch):
179 | data = [b.data[0] for b in batch]
180 | label = [b.label[0] for b in batch]
181 | return data, label
182 |
183 | train_data, val_data = get_rec_data_iterators(rec_train_prefix, rec_val_prefix, input_size, batch_size // len(devices), devices)
184 | return train_data, val_data, batch_fn
185 |
186 | train_data, val_data, batch_fn = get_data_rec(opt.rec_train, opt.rec_val, batch_size, context)
187 |
188 | train_metric = mx.metric.Accuracy()
189 | acc_top1 = mx.metric.Accuracy()
190 | acc_top5 = mx.metric.TopKAccuracy(5)
191 |
192 | save_frequency = opt.save_frequency
193 |
194 | def test(ctx, val_data):
195 | val_data.reset()
196 | acc_top1.reset()
197 | acc_top5.reset()
198 | for i, batch in enumerate(val_data):
199 | data, label = batch_fn(batch)
200 | outputs = [net(X) for X in data]
201 | acc_top1.update(label, outputs)
202 | acc_top5.update(label, outputs)
203 |
204 | _, top1 = acc_top1.get()
205 | _, top5 = acc_top5.get()
206 | return (1-top1, 1-top5)
207 |
208 | def train(ctx):
209 | if isinstance(ctx, mx.Context):
210 | ctx = [ctx]
211 | if opt.resume_params is '':
212 | net.collect_params('.*gamma|.*alpha|.*running_mean|.*running_var').initialize(mx.init.Constant(1), ctx=ctx)
213 | net.collect_params('.*beta|.*bias').initialize(mx.init.Constant(0.0), ctx=ctx)
214 | net.collect_params('.*weight').initialize(mx.init.Xavier(), ctx=ctx)
215 |
216 | if opt.no_wd:
217 | for k, v in net.collect_params('.*beta|.*gamma|.*bias').items():
218 | v.wd_mult = 0.0
219 |
220 | trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params, kvstore='local')
221 | if opt.resume_states is not '':
222 | trainer.load_states(opt.resume_states)
223 |
224 | L = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=True)
225 |
226 | best_val_score = 1
227 |
228 | trainer.set_learning_rate(opt.lr)
229 | for epoch in range(opt.resume_epoch, opt.num_epochs):
230 | tic = time.time()
231 | train_data.reset()
232 | train_metric.reset()
233 | btic = time.time()
234 |
235 | if epoch in lr_decay_epoch:
236 | trainer.set_learning_rate(trainer.learning_rate * lr_decay)
237 | logger.info('Learning rate has been changed to %f' % trainer.learning_rate)
238 |
239 | for i, batch in enumerate(train_data):
240 | data, label = batch_fn(batch)
241 | with autograd.record():
242 | outputs = [net(X) for X in data]
243 | loss = [L(yhat, y) for yhat, y in zip(outputs, label)]
244 | for l in loss:
245 | l.backward()
246 | trainer.step(batch_size)
247 |
248 | train_metric.update(label, outputs)
249 |
250 | if opt.log_interval and not (i+1)%opt.log_interval:
251 | train_metric_name, train_metric_score = train_metric.get()
252 | logger.info('Epoch[%d/%d] Batch [%d/%d]\tSpeed: %f samples/sec\t%s=%f\tlr=%f'%(
253 | epoch, opt.num_epochs, i, num_batches,
254 | batch_size*opt.log_interval/(time.time()-btic),
255 | train_metric_name, train_metric_score, trainer.learning_rate))
256 | btic = time.time()
257 |
258 | train_metric_name, train_metric_score = train_metric.get()
259 | throughput = int(batch_size * i /(time.time() - tic))
260 |
261 | err_top1_val, err_top5_val = test(ctx, val_data)
262 |
263 | logger.info('[Epoch %d] training: %s=%f'%(epoch, train_metric_name, train_metric_score))
264 | logger.info('[Epoch %d] speed: %d samples/sec\ttime cost: %f'%(epoch, throughput, time.time()-tic))
265 | logger.info('[Epoch %d] validation: err-top1=%f err-top5=%f'%(epoch, err_top1_val, err_top5_val))
266 |
267 | if err_top1_val < best_val_score:
268 | best_val_score = err_top1_val
269 | net.save_parameters('%s/%.4f-imagenet-%s-%d-best.params'%(save_dir, best_val_score, model_name, epoch))
270 | trainer.save_states('%s/%.4f-imagenet-%s-%d-best.states'%(save_dir, best_val_score, model_name, epoch))
271 |
272 | if save_frequency and save_dir and (epoch + 1) % save_frequency == 0:
273 | net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, epoch))
274 | trainer.save_states('%s/imagenet-%s-%d.states'%(save_dir, model_name, epoch))
275 |
276 | if save_frequency and save_dir:
277 | net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, opt.num_epochs-1))
278 | trainer.save_states('%s/imagenet-%s-%d.states'%(save_dir, model_name, opt.num_epochs-1))
279 |
280 | train(context)
281 |
282 | if __name__ == '__main__':
283 | main()
284 |
--------------------------------------------------------------------------------