├── FKD_train.py
├── README.md
├── extensions
├── __init__.py
├── data_parallel.py
├── kd_loss.py
└── teacher_wrapper.py
├── hubconf.py
├── imagenet.py
├── images
├── MEAL-V2_more_tricks_top1.png
├── MEAL-V2_more_tricks_top5.png
└── comparison.png
├── inference.py
├── loss.py
├── models
├── __init__.py
├── blocks.py
├── discriminator.py
└── model_factory.py
├── opts.py
├── script
├── resume_train.sh
└── train.sh
├── test.py
├── train.py
├── utils.py
└── utils_FKD.py
/FKD_train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """Script to train a model through soft labels on ImageNet's train set."""
3 |
4 | import argparse
5 | import logging
6 | import pprint
7 | import os
8 | import sys
9 | import time
10 | import math
11 | import numpy as np
12 |
13 | import torch
14 | from torch import nn
15 |
16 | from loss import discriminatorLoss
17 |
18 | import imagenet
19 | from models import model_factory
20 | from models import discriminator
21 | import opts
22 | import test
23 | import utils
24 | from utils_FKD import Recover_soft_label
25 |
26 | def parse_args(argv):
27 | """Parse arguments @argv and return the flags needed for training."""
28 | parser = argparse.ArgumentParser(description=__doc__, allow_abbrev=False)
29 |
30 | group = parser.add_argument_group('General Options')
31 | opts.add_general_flags(group)
32 |
33 | group = parser.add_argument_group('Dataset Options')
34 | opts.add_dataset_flags(group)
35 |
36 | group = parser.add_argument_group('Model Options')
37 | opts.add_model_flags(group)
38 |
39 | group = parser.add_argument_group('Soft Label Options')
40 | opts.add_teacher_flags(group)
41 |
42 | group = parser.add_argument_group('Training Options')
43 | opts.add_training_flags(group)
44 |
45 | group = parser.add_argument_group('CutMix Training Options')
46 | opts.add_cutmix_training_flags(group)
47 |
48 | args = parser.parse_args(argv)
49 |
50 | return args
51 |
52 |
53 | class LearningRateRegime:
54 | """Encapsulates the learning rate regime for training a model.
55 |
56 | Args:
57 | @intervals (list): A list of triples (start, end, lr). The intervals
58 | are inclusive (for start <= epoch <= end, lr will be used). The
59 | start of each interval must be right after the end of its previous
60 | interval.
61 | """
62 |
63 | def __init__(self, regime):
64 | if len(regime) % 3 != 0:
65 | raise ValueError("Regime length should be devisible by 3.")
66 | intervals = list(zip(regime[0::3], regime[1::3], regime[2::3]))
67 | self._validate_intervals(intervals)
68 | self.intervals = intervals
69 | self.num_epochs = intervals[-1][1]
70 |
71 | @classmethod
72 | def _validate_intervals(cls, intervals):
73 | if type(intervals) is not list:
74 | raise TypeError("Intervals must be a list of triples.")
75 | elif len(intervals) == 0:
76 | raise ValueError("Intervals must be a non empty list.")
77 | elif intervals[0][0] != 1:
78 | raise ValueError("Intervals must start from 1: {}".format(intervals))
79 | elif any(end < start for (start, end, lr) in intervals):
80 | raise ValueError("End of intervals must be greater or equal than their"
81 | " start: {}".format(intervals))
82 | elif any(intervals[i][1] + 1 != intervals[i + 1][0]
83 | for i in range(len(intervals) - 1)):
84 | raise ValueError("Start of each each interval must be the end of its "
85 | "previous interval plus one: {}".format(intervals))
86 |
87 | def get_lr(self, epoch):
88 | for (start, end, lr) in self.intervals:
89 | if start <= epoch <= end:
90 | return lr
91 | raise ValueError("Invalid epoch {} for regime {!r}".format(
92 | epoch, self.intervals))
93 |
94 |
95 | def _set_learning_rate(optimizer, lr):
96 | for param_group in optimizer.param_groups:
97 | param_group['lr'] = lr
98 |
99 | def adjust_learning_rate(optimizer, epoch, args):
100 | """Decay the learning rate based on schedule"""
101 | lr = args.lr
102 | if args.cos: # cosine lr schedule
103 | lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs))
104 | else: # stepwise lr schedule
105 | for milestone in args.schedule:
106 | lr *= 0.1 if epoch >= milestone else 1.
107 | for param_group in optimizer.param_groups:
108 | param_group['lr'] = lr
109 |
110 | def _get_learning_rate(optimizer):
111 | return max(param_group['lr'] for param_group in optimizer.param_groups)
112 |
113 |
114 | def train_for_one_epoch(model, g_loss, discriminator_loss, train_loader, optimizer, epoch_number, args):
115 | model.train()
116 | g_loss.train()
117 |
118 | data_time_meter = utils.AverageMeter()
119 | batch_time_meter = utils.AverageMeter()
120 | g_loss_meter = utils.AverageMeter(recent=100)
121 | d_loss_meter = utils.AverageMeter(recent=100)
122 | top1_meter = utils.AverageMeter(recent=100)
123 | top5_meter = utils.AverageMeter(recent=100)
124 |
125 | timestamp = time.time()
126 | for i, (images, labels, soft_labels) in enumerate(train_loader):
127 | batch_size = args.batch_size
128 |
129 | # Record data time
130 | data_time_meter.update(time.time() - timestamp)
131 |
132 | images = torch.cat(images, dim=0)
133 | soft_labels = torch.cat(soft_labels, dim=0)
134 | labels = torch.cat(labels, dim=0)
135 |
136 | if args.soft_label_type == 'ori':
137 | soft_labels = soft_labels.cuda()
138 | else:
139 | soft_labels = Recover_soft_label(soft_labels, args.soft_label_type, args.num_classes)
140 | soft_labels = soft_labels.cuda()
141 |
142 | if utils.is_model_cuda(model):
143 | images = images.cuda()
144 | labels = labels.cuda()
145 |
146 | if args.w_cutmix == True:
147 | r = np.random.rand(1)
148 | if args.beta > 0 and r < args.cutmix_prob:
149 | # generate mixed sample
150 | lam = np.random.beta(args.beta, args.beta)
151 | rand_index = torch.randperm(images.size()[0]).cuda()
152 | target_a = soft_labels
153 | target_b = soft_labels[rand_index]
154 | bbx1, bby1, bbx2, bby2 = utils.rand_bbox(images.size(), lam)
155 | images[:, :, bbx1:bbx2, bby1:bby2] = images[rand_index, :, bbx1:bbx2, bby1:bby2]
156 | lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (images.size()[-1] * images.size()[-2]))
157 |
158 | # Forward pass, backward pass, and update parameters.
159 | output = model(images)
160 | # output, soft_label, soft_no_softmax = outputs
161 | if args.w_cutmix == True:
162 | g_loss_output1 = g_loss((output, target_a), labels)
163 | g_loss_output2 = g_loss((output, target_b), labels)
164 | else:
165 | g_loss_output = g_loss((output, soft_labels), labels)
166 | if args.use_discriminator_loss:
167 | # Our stored label is "after softmax", this is slightly different from original MEAL V2
168 | # that used probibilaties "before softmax" for the discriminator.
169 | output_softmax = nn.functional.softmax(output)
170 | if args.w_cutmix == True:
171 | d_loss_value = discriminator_loss([output_softmax], [target_a]) * lam + discriminator_loss([output_softmax], [target_b]) * (1 - lam)
172 | else:
173 | d_loss_value = discriminator_loss([output_softmax], [soft_labels])
174 |
175 | # Sometimes loss function returns a modified version of the output,
176 | # which must be used to compute the model accuracy.
177 | if args.w_cutmix == True:
178 | if isinstance(g_loss_output1, tuple):
179 | g_loss_value1, output1 = g_loss_output1
180 | g_loss_value2, output2 = g_loss_output2
181 | g_loss_value = g_loss_value1 * lam + g_loss_value2 * (1 - lam)
182 | else:
183 | g_loss_value = g_loss_output1 * lam + g_loss_output2 * (1 - lam)
184 | else:
185 | if isinstance(g_loss_output, tuple):
186 | g_loss_value, output = g_loss_output
187 | else:
188 | g_loss_value = g_loss_output
189 |
190 | if args.use_discriminator_loss:
191 | loss_value = g_loss_value + d_loss_value
192 | else:
193 | loss_value = g_loss_value
194 |
195 | loss_value.backward()
196 |
197 | # Update parameters and reset gradients.
198 | optimizer.step()
199 | optimizer.zero_grad()
200 |
201 | # Record loss and model accuracy.
202 | g_loss_meter.update(g_loss_value.item(), batch_size)
203 | d_loss_meter.update(d_loss_value.item(), batch_size)
204 |
205 | top1, top5 = utils.topk_accuracy(output, labels, recalls=(1, 5))
206 | top1_meter.update(top1, batch_size)
207 | top5_meter.update(top5, batch_size)
208 |
209 | # Record batch time
210 | batch_time_meter.update(time.time() - timestamp)
211 | timestamp = time.time()
212 |
213 | if i%20 == 0:
214 | logging.info(
215 | 'Epoch: [{epoch}][{batch}/{epoch_size}]\t'
216 | 'Time {batch_time.value:.2f} ({batch_time.average:.2f}) '
217 | 'Data {data_time.value:.2f} ({data_time.average:.2f}) '
218 | 'G_Loss {g_loss.value:.3f} {{{g_loss.average:.3f}, {g_loss.average_recent:.3f}}} '
219 | 'D_Loss {d_loss.value:.3f} {{{d_loss.average:.3f}, {d_loss.average_recent:.3f}}} '
220 | 'Top-1 {top1.value:.2f} {{{top1.average:.2f}, {top1.average_recent:.2f}}} '
221 | 'Top-5 {top5.value:.2f} {{{top5.average:.2f}, {top5.average_recent:.2f}}} '
222 | 'LR {lr:.5f}'.format(
223 | epoch=epoch_number, batch=i + 1, epoch_size=len(train_loader),
224 | batch_time=batch_time_meter, data_time=data_time_meter,
225 | g_loss=g_loss_meter, d_loss=d_loss_meter, top1=top1_meter, top5=top5_meter,
226 | lr=_get_learning_rate(optimizer)))
227 | # Log the overall train stats
228 | logging.info(
229 | 'Epoch: [{epoch}] -- TRAINING SUMMARY\t'
230 | 'Time {batch_time.sum:.2f} '
231 | 'Data {data_time.sum:.2f} '
232 | 'G_Loss {g_loss.average:.3f} '
233 | 'D_Loss {d_loss.average:.3f} '
234 | 'Top-1 {top1.average:.2f} '
235 | 'Top-5 {top5.average:.2f} '.format(
236 | epoch=epoch_number, batch_time=batch_time_meter, data_time=data_time_meter,
237 | g_loss=g_loss_meter, d_loss=d_loss_meter, top1=top1_meter, top5=top5_meter))
238 |
239 |
240 | def save_checkpoint(checkpoints_dir, model, optimizer, epoch):
241 | model_state_file = os.path.join(checkpoints_dir, 'model_state_{:02}.pytar'.format(epoch))
242 | optim_state_file = os.path.join(checkpoints_dir, 'optim_state_{:02}.pytar'.format(epoch))
243 | torch.save(model.state_dict(), model_state_file)
244 | torch.save(optimizer.state_dict(), optim_state_file)
245 |
246 |
247 | def create_optimizer(model, discriminator_parameters, momentum=0.9, weight_decay=0):
248 | # Get model parameters that require a gradient.
249 | parameters = [{'params': model.parameters()}, discriminator_parameters]
250 | optimizer = torch.optim.SGD(parameters, lr=0,
251 | momentum=momentum, weight_decay=weight_decay)
252 | return optimizer
253 |
254 | def create_discriminator_criterion(args):
255 | d = discriminator.Discriminator(outputs_size=1000, K=8).cuda()
256 | d = torch.nn.DataParallel(d)
257 | update_parameters = {'params': d.parameters(), "lr": args.d_lr}
258 | discriminators_criterion = discriminatorLoss(d).cuda()
259 | if len(args.gpus) > 1:
260 | discriminators_criterion = torch.nn.DataParallel(discriminators_criterion, device_ids=args.gpus)
261 | return discriminators_criterion, update_parameters
262 |
263 | def main(argv):
264 | """Run the training script with command line arguments @argv."""
265 | args = parse_args(argv)
266 | utils.general_setup(args.save, args.gpus)
267 |
268 | logging.info("Arguments parsed.\n{}".format(pprint.pformat(vars(args))))
269 |
270 | # convert to TRUE number of loading-images since we use multiple crops from the same image within a minbatch
271 | args.batch_size = math.ceil(args.batch_size / args.num_crops)
272 |
273 | # Create the train and the validation data loaders.
274 | train_loader = imagenet.get_train_loader_FKD(args.imagenet, args.batch_size,
275 | args.num_workers, args.image_size, args.num_crops, args.softlabel_path)
276 | val_loader = imagenet.get_val_loader(args.imagenet, args.batch_size,
277 | args.num_workers, args.image_size)
278 | # Create model with optional teachers.
279 | model, loss = model_factory.create_model(
280 | args.model, args.student_state_file, args.gpus, args.teacher_model,
281 | args.teacher_state_file, True)
282 | logging.info("Model:\n{}".format(model))
283 |
284 | discriminator_loss, update_parameters = create_discriminator_criterion(args)
285 |
286 | optimizer = create_optimizer(model, update_parameters, args.momentum, args.weight_decay)
287 |
288 | for epoch in range(args.start_epoch, args.epochs, args.num_crops):
289 | adjust_learning_rate(optimizer, epoch, args)
290 | train_for_one_epoch(model, loss, discriminator_loss, train_loader, optimizer, epoch, args)
291 | test.test_for_one_epoch(model, loss, val_loader, epoch)
292 | save_checkpoint(args.save, model, optimizer, epoch)
293 |
294 |
295 | if __name__ == '__main__':
296 | main(sys.argv[1:])
297 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MEAL-V2
2 |
3 | This is the official pytorch implementation of our paper:
4 | ["MEAL V2: Boosting Vanilla ResNet-50 to 80%+ Top-1 Accuracy on ImageNet without Tricks"](https://arxiv.org/abs/2009.08453) by
5 | [Zhiqiang Shen](http://zhiqiangshen.com/) and [Marios Savvides](https://www.ece.cmu.edu/directory/bios/savvides-marios.html) from Carnegie Mellon University.
6 |
7 |
8 |

9 |
10 |
11 | In this paper, we introduce a simple yet effective approach that can boost the vanilla ResNet-50 to 80%+ Top-1 accuracy on ImageNet without any tricks. Generally, our method is based on the recently proposed [MEAL](https://arxiv.org/abs/1812.02425), i.e., ensemble knowledge distillation via discriminators. We further simplify it through 1) adopting the similarity loss and discriminator only on the final outputs and 2) using the average of softmax probabilities from all teacher ensembles as the stronger supervision for distillation. One crucial perspective of our method is that the one-hot/hard label should not be used in the distillation process. We show that such a simple framework can achieve state-of-the-art results without involving any commonly-used tricks, such as 1) architecture modification; 2) outside training data beyond ImageNet; 3) autoaug/randaug; 4) cosine learning rate; 5) mixup/cutmix training; 6) label smoothing; etc.
12 |
13 | ## Citation
14 |
15 | If you find our code is helpful for your research, please cite:
16 |
17 | @article{shen2020mealv2,
18 | title={MEAL V2: Boosting Vanilla ResNet-50 to 80%+ Top-1 Accuracy on ImageNet without Tricks},
19 | author={Shen, Zhiqiang and Savvides, Marios},
20 | journal={arXiv preprint arXiv:2009.08453},
21 | year={2020}
22 | }
23 |
24 | ## News
25 |
26 | **[Dec. 5, 2021]** **New:** Add [FKD](https://arxiv.org/abs/2112.01528) training support. We highly recommend to use FKD for training MEAL V2 models, which will be 2~4x faster with similar accuracy.
27 |
28 | - Download our [soft label](http://zhiqiangshen.com/projects/FKD/index.html) for MEAL V2.
29 | - run `FKD_train.py` with the desired model architecture, the path to the ImageNet dataset and the path to the soft label, for example:
30 |
31 | ```shell
32 | # 224 x 224 ResNet-50
33 | python FKD_train.py --save MEAL_V2_resnet50_224 \
34 | --batch-size 512 -j 48 \
35 | --model resnet50 --epochs 200 \
36 | --teacher-model gluon_senet154,gluon_resnet152_v1s \
37 | --imagenet [imagenet-folder with train and val folders] \
38 | --num_crops 8 --soft_label_type marginal_smoothing_k5 \
39 | --softlabel_path [path of soft label] \
40 | --schedule 100 180 --use-discriminator-loss
41 | ```
42 | Add `--cos` if you would like to train with cosine learning rate.
43 |
44 | **New:** Basically, adding back tricks (cosine *lr*, etc.) into MEAL V2 can consistently improve the accuracy:
45 |
46 |
47 |

48 |

49 |
50 |
51 | **New:** Add CutMix training support, use *--w-cutmix* to enable it.
52 |
53 | **[Mar. 19, 2021]** Long version of MEAL V2 is available on: [arXiv](https://arxiv.org/abs/2009.08453) or [paper](http://zhiqiangshen.com/projects/MEAL_V2/arxiv.pdf).
54 |
55 | **[Dec. 16, 2020]** MEAL V2 is now available in [PyTorch Hub](https://pytorch.org/hub/pytorch_vision_meal_v2/).
56 |
57 | **[Nov. 3, 2020]** Short version of MEAL V2 has been accepted in NeurIPS 2020 [Beyond BackPropagation: Novel Ideas for Training Neural Architectures](https://beyondbackprop.github.io/) workshop. Long version is coming soon.
58 |
59 | ## Preparation
60 |
61 | ### 1. Requirements:
62 | This repo is tested with:
63 |
64 | * Python 3.6
65 |
66 | * CUDA 10.2
67 |
68 | * PyTorch 1.6.0
69 |
70 | * torchvision 0.7.0
71 |
72 | * timm 0.2.1
73 | (pip install timm)
74 |
75 | But it should be runnable with other PyTorch versions.
76 |
77 | ### 2. Data:
78 | * Download ImageNet dataset following https://github.com/pytorch/examples/tree/master/imagenet#requirements.
79 |
80 | ## Results & Models
81 |
82 | We provide pre-trained models with different trainings, we report in the table training/validation resolution, #parameters, Top-1 and Top-5 accuracy on ImageNet validation set:
83 |
84 | | Models | Resolution| #Parameters | Top-1/Top-5 | Trained models |
85 | | :---: | :-: | :-: | :------:| :------: |
86 | | [MEAL-V1 w/ ResNet50](https://arxiv.org/abs/1812.02425) | 224 | 25.6M |**78.21/94.01** | [GitHub](https://github.com/AaronHeee/MEAL#imagenet-model) |
87 | | MEAL-V2 w/ ResNet18 | 224 | 11.7M | **73.19/90.82** | [Download (46.8M)](https://1drv.ms/u/s!AtMVZxJ8MfxCi03CTdVPH24ce6rD?e=l7BoZL) |
88 | | MEAL-V2 w/ ResNet50 | 224 | 25.6M | **80.67/95.09** | [Download (102.6M)](https://1drv.ms/u/s!AtMVZxJ8MfxCi0NGENlMK0pYVDQM?e=GkwZ93) |
89 | | MEAL-V2 w/ ResNet50| 380 | 25.6M | **81.72/95.81** | [Download (102.6M)](https://1drv.ms/u/s!AtMVZxJ8MfxCi0T9nodVNdnklHNt?e=7oJGIy) |
90 | | MEAL-V2 + CutMix w/ ResNet50| 224 | 25.6M | **80.98/95.35** | [Download (102.6M)](https://1drv.ms/u/s!AtMVZxJ8MfxCi0cIf5IqpBX6nl1U?e=Fig91M) |
91 | | MEAL-V2 w/ MobileNet V3-Small 0.75| 224 | 2.04M | **67.60/87.23** | [Download (8.3M)](https://1drv.ms/u/s!AtMVZxJ8MfxCi0nIq1jZo36dpN7Q?e=ODcoAN) |
92 | | MEAL-V2 w/ MobileNet V3-Small 1.0| 224 | 2.54M | **69.65/88.71** | [Download (10.3M)](https://1drv.ms/u/s!AtMVZxJ8MfxCiz9v7QqUmvQOLmTS?e=9nCWMa) |
93 | | MEAL-V2 w/ MobileNet V3-Large 1.0 | 224 | 5.48M | **76.92/93.32** | [Download (22.1M)](https://1drv.ms/u/s!AtMVZxJ8MfxCi0Ciwz-q-P2jwtXR?e=OebKAr) |
94 | | MEAL-V2 w/ EfficientNet-B0| 224 | 5.29M | **78.29/93.95** | [Download (21.5M)](https://1drv.ms/u/s!AtMVZxJ8MfxCi0XZLUEB3uYq3eBe?e=FJV9K1) |
95 |
96 |
97 | ## Training & Testing
98 | ### 1. Training:
99 | * To train a model, run script/train.sh with the desired model architecture and the path to the ImageNet dataset, for example:
100 |
101 | ```shell
102 | # 224 x 224 ResNet-50
103 | python train.py --save MEAL_V2_resnet50_224 --batch-size 512 -j 48 --model resnet50 --epochs 180 --teacher-model gluon_senet154,gluon_resnet152_v1s --imagenet [imagenet-folder with train and val folders]
104 | ```
105 |
106 | ```shell
107 | # 224 x 224 ResNet-50 w/ CutMix
108 | python train.py --save MEAL_V2_resnet50_224 --batch-size 512 -j 48 --model resnet50 --epochs 180 --teacher-model gluon_senet154,gluon_resnet152_v1s --imagenet [imagenet-folder with train and val folders] --w-cutmix
109 | ```
110 |
111 | ```shell
112 | # 380 x 380 ResNet-50
113 | python train.py --save MEAL_V2_resnet50_380 --batch-size 512 -j 48 --model resnet50 --image-size 380 --teacher-model tf_efficientnet_b4_ns,tf_efficientnet_b4 --imagenet [imagenet-folder with train and val folders]
114 | ```
115 |
116 | ```shell
117 | # 224 x 224 MobileNet V3-Small 0.75
118 | python train.py --save MEAL_V2_mobilenetv3_small_075 --batch-size 512 -j 48 --model tf_mobilenetv3_small_075 --teacher-model gluon_senet154,gluon_resnet152_v1s --imagenet [imagenet-folder with train and val folders]
119 | ```
120 |
121 | ```shell
122 | # 224 x 224 MobileNet V3-Small 1.0
123 | python train.py --save MEAL_V2_mobilenetv3_small_100 --batch-size 512 -j 48 --model tf_mobilenetv3_small_100 --teacher-model gluon_senet154,gluon_resnet152_v1s --imagenet [imagenet-folder with train and val folders]
124 | ```
125 |
126 | ```shell
127 | # 224 x 224 MobileNet V3-Large 1.0
128 | python train.py --save MEAL_V2_mobilenetv3_large_100 --batch-size 512 -j 48 --model tf_mobilenetv3_large_100 --teacher-model gluon_senet154,gluon_resnet152_v1s --imagenet [imagenet-folder with train and val folders]
129 | ```
130 |
131 | ```shell
132 | # 224 x 224 EfficientNet-B0
133 | python train.py --save MEAL_V2_efficientnet_b0 --batch-size 512 -j 48 --model tf_efficientnet_b0 --teacher-model gluon_senet154,gluon_resnet152_v1s --imagenet [imagenet-folder with train and val folders]
134 | ```
135 | *Please reduce the ``--batch-size`` if you get ''out of memory'' error. We also notice that more training epochs can slightly improve the performance.*
136 |
137 | * To resume training a model, run script/resume_train.sh with the desired model architecture, starting number of training epoch and the path to the ImageNet dataset:
138 |
139 | ```shell
140 | sh script/resume_train.sh
141 | ```
142 |
143 | ### 2. Testing:
144 |
145 | * To test a model, run inference.py with the desired model architecture, model path, resolution and the path to the ImageNet dataset:
146 |
147 | ```shell
148 | CUDA_VISIBLE_DEVICES=0,1,2,3 python inference.py -a resnet50 --res 224 --resume MODEL_PATH -e [imagenet-folder with train and val folders]
149 | ```
150 | change ``--res`` with other image resolution [224/380] and ``-a`` with other model architecture [tf\_mobilenetv3\_small\_100; tf\_mobilenetv3\_large\_100; tf\_efficientnet\_b0] to test other trained models.
151 |
152 |
153 | ## Contact
154 |
155 | Zhiqiang Shen, CMU (zhiqians at andrew.cmu.edu)
156 |
157 | Any comments or suggestions are welcome!
--------------------------------------------------------------------------------
/extensions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/szq0214/MEAL-V2/59e6bca29c77a7df9229e77d4463f2f795723810/extensions/__init__.py
--------------------------------------------------------------------------------
/extensions/data_parallel.py:
--------------------------------------------------------------------------------
1 | __author__ = "Hessam Bagherinezhad "
2 |
3 | from torch import nn
4 | from torch.nn.modules import loss
5 |
6 |
7 | class DataParallel(nn.DataParallel):
8 | """An extension of nn.DataParallel.
9 |
10 | The only extensions are:
11 | 1) If an attribute is missing in an object of this class, it will look
12 | for it in the wrapped module. This is useful for getting `LR_REGIME`
13 | of the wrapped module for example.
14 | 2) state_dict() of this class calls the wrapped module's state_dict(),
15 | hence the weights can be transferred from a data parallel wrapped
16 | module to a single gpu module.
17 | """
18 |
19 |
20 | def __getattr__(self, name):
21 | # If attribute doesn't exist in the DataParallel object this method will
22 | # be called. Here we first ask the super class to get the attribute, if
23 | # couldn't find it, we ask the underlying module that is wrapped by this
24 | # DataParallel to get the attribute.
25 | try:
26 | return super().__getattr__(name)
27 | except AttributeError:
28 | underlying_module = super().__getattr__('module')
29 | return getattr(underlying_module, name)
30 |
31 | def state_dict(self, *args, **kwargs):
32 | return self.module.state_dict(*args, **kwargs)
33 |
--------------------------------------------------------------------------------
/extensions/kd_loss.py:
--------------------------------------------------------------------------------
1 | __author__ = "Hessam Bagherinezhad "
2 |
3 | # modified by "Zhiqiang Shen "
4 |
5 | import torch
6 | from torch.nn import functional as F
7 | from torch.nn.modules import loss
8 |
9 |
10 | class KLLoss(loss._Loss):
11 | """The KL-Divergence loss for the model and soft labels output.
12 |
13 | output must be a pair of (model_output, soft_labels), both NxC tensors.
14 | The rows of soft_labels must all add up to one (probability scores);
15 | however, model_output must be the pre-softmax output of the network."""
16 |
17 | def forward(self, output, target):
18 | if not self.training:
19 | # Loss is normal cross entropy loss between the model output and the
20 | # target.
21 | return F.cross_entropy(output, target)
22 |
23 | assert type(output) == tuple and len(output) == 2 and output[0].size() == \
24 | output[1].size(), "output must a pair of tensors of same size."
25 |
26 | # Target is ignored at training time. Loss is defined as KL divergence
27 | # between the model output and the soft labels.
28 | model_output, soft_labels = output
29 | if soft_labels.requires_grad:
30 | raise ValueError("soft labels should not require gradients.")
31 |
32 | model_output_log_prob = F.log_softmax(model_output, dim=1)
33 | del model_output
34 |
35 | # Loss is -dot(model_output_log_prob, soft_labels). Prepare tensors
36 | # for batch matrix multiplicatio
37 | soft_labels = soft_labels.unsqueeze(1)
38 | model_output_log_prob = model_output_log_prob.unsqueeze(2)
39 |
40 | # Compute the loss, and average for the batch.
41 | cross_entropy_loss = -torch.bmm(soft_labels, model_output_log_prob)
42 | cross_entropy_loss = cross_entropy_loss.mean()
43 | # Return a pair of (loss_output, model_output). Model output will be
44 | # used for top-1 and top-5 evaluation.
45 | model_output_log_prob = model_output_log_prob.squeeze(2)
46 | return (cross_entropy_loss, model_output_log_prob)
47 |
--------------------------------------------------------------------------------
/extensions/teacher_wrapper.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 | import random
6 | import numpy as np
7 |
8 |
9 | class ModelDistillationWrapper(nn.Module):
10 | """Convenient wrapper class to train a model with soft label ."""
11 |
12 | def __init__(self, model, teacher):
13 | super().__init__()
14 | self.model = model
15 | self.teachers_0 = teacher
16 | self.combine = True
17 |
18 | # Since we don't want to back-prop through the teacher network,
19 | # make the parameters of the teacher network not require gradients. This
20 | # saves some GPU memory.
21 |
22 | for model in self.teachers_0:
23 | for param in model.parameters():
24 | param.requires_grad = False
25 |
26 | self.false = False
27 |
28 | @property
29 | def LR_REGIME(self):
30 | # Training with soft label does not change learing rate regime.
31 | # Return's wrapped model lr regime.
32 | return self.model.LR_REGIME
33 |
34 | def state_dict(self):
35 | return self.model.state_dict()
36 |
37 | def forward(self, input, before=False):
38 | if self.training:
39 | if len(self.teachers_0) == 3 and self.combine == False:
40 | index = [0,1,1,2,2]
41 | idx = random.randint(0, 4)
42 | soft_labels_ = self.teachers_0[index[idx]](input)
43 | soft_labels = F.softmax(soft_labels_, dim=1)
44 |
45 | elif self.combine:
46 | soft_labels_ = [ torch.unsqueeze(self.teachers_0[idx](input), dim=2) for idx in range(len(self.teachers_0))]
47 | soft_labels_softmax = [F.softmax(i, dim=1) for i in soft_labels_]
48 | soft_labels_ = torch.cat(soft_labels_, dim=2).mean(dim=2)
49 | soft_labels = torch.cat(soft_labels_softmax, dim=2).mean(dim=2)
50 |
51 | else:
52 | idx = random.randint(0, len(self.teachers_0)-1)
53 | soft_labels_ = self.teachers_0[idx](input)
54 | soft_labels = F.softmax(soft_labels_, dim=1)
55 |
56 | # soft_labels = F.softmax(soft_labels_, dim=1)
57 | model_output = self.model(input)
58 |
59 | if before:
60 | return (model_output, soft_labels, soft_labels_)
61 |
62 | return (model_output, soft_labels)
63 |
64 | else:
65 | return self.model(input)
66 |
--------------------------------------------------------------------------------
/hubconf.py:
--------------------------------------------------------------------------------
1 | dependencies = [
2 | 'timm',
3 | 'torch',
4 | ]
5 |
6 | import torch, timm
7 |
8 | __all__ = ['mealv1_resnest50', 'mealv2_resnest50', 'mealv2_resnest50_cutmix', 'mealv2_resnest50_380x380', 'mealv2_mobilenetv3_small_075', 'mealv2_mobilenetv3_small_100', 'mealv2_mobilenet_v3_large_100', 'mealv2_efficientnet_b0']
9 |
10 | model_urls = {
11 | 'mealv1_resnest50': 'https://github.com/szq0214/MEAL-V2/releases/download/v1.0.0/MEALV1_ResNet50_224.pth',
12 | 'mealv2_resnest50': 'https://github.com/szq0214/MEAL-V2/releases/download/v1.0.0/MEALV2_ResNet50_224.pth',
13 | 'mealv2_resnest50_cutmix': 'https://github.com/szq0214/MEAL-V2/releases/download/v1.0.0/MEALV2_ResNet50_224_cutmix.pth',
14 | 'mealv2_resnest50_380x380': 'https://github.com/szq0214/MEAL-V2/releases/download/v1.0.0/MEALV2_ResNet50_380.pth',
15 | 'mealv2_mobilenetv3_small_075': 'https://github.com/szq0214/MEAL-V2/releases/download/v1.0.0/MEALV2_MobileNet_V3_Small_0.75_224.pth',
16 | 'mealv2_mobilenetv3_small_100': 'https://github.com/szq0214/MEAL-V2/releases/download/v1.0.0/MEALV2_MobileNet_V3_Small_1.0_224.pth',
17 | 'mealv2_mobilenet_v3_large_100': 'https://github.com/szq0214/MEAL-V2/releases/download/v1.0.0/MEALV2_MobileNet_V3_Large_1.0_224.pth',
18 | 'mealv2_efficientnet_b0': 'https://github.com/szq0214/MEAL-V2/releases/download/v1.0.0/MEALV2_EfficientNet_B0_224.pth',
19 | }
20 |
21 |
22 | mapping = {'mealv1_resnest50':'resnet50',
23 | 'mealv2_resnest50':'resnet50',
24 | 'mealv2_resnest50_cutmix':'resnet50',
25 | 'mealv2_resnest50_380x380':'resnet50',
26 | 'mealv2_mobilenetv3_small_075':'tf_mobilenetv3_small_075',
27 | 'mealv2_mobilenetv3_small_100':'tf_mobilenetv3_small_100',
28 | 'mealv2_mobilenet_v3_large_100':'tf_mobilenetv3_large_100',
29 | 'mealv2_efficientnet_b0':'tf_efficientnet_b0'
30 | }
31 |
32 | def meal_v2(model_name, pretrained=True, progress=True, exportable=False):
33 | """ MEAL V2 models from
34 | `"MEAL V2: Boosting Vanilla ResNet-50 to 80%+ Top-1 Accuracy on ImageNet without Tricks" `_
35 |
36 | Args:
37 | model_name: Name of the model to load
38 | pretrained (bool): If True, returns a model trained with MEAL V2 on ImageNet
39 | progress (bool): If True, displays a progress bar of the download to stderr
40 | """
41 |
42 | model = timm.create_model(mapping[model_name.lower()], pretrained=False, exportable=exportable)
43 | if pretrained:
44 | state_dict = torch.hub.load_state_dict_from_url(model_urls[model_name.lower()], progress=progress)
45 | model = torch.nn.DataParallel(model).cuda()
46 | model.load_state_dict(state_dict)
47 | return model
--------------------------------------------------------------------------------
/imagenet.py:
--------------------------------------------------------------------------------
1 | """Dataset class for loading imagenet data."""
2 |
3 | import os
4 |
5 | from torch.utils import data as data_utils
6 | from torchvision import datasets as torch_datasets
7 | from torchvision import transforms
8 |
9 | from utils_FKD import RandomResizedCrop_FKD,RandomHorizontalFlip_FKD,ImageFolder_FKD,Compose_FKD
10 | from torchvision.transforms import InterpolationMode
11 |
12 | def get_train_loader(imagenet_path, batch_size, num_workers, image_size):
13 | train_dataset = ImageNet(imagenet_path, image_size, is_train=True)
14 | return data_utils.DataLoader(
15 | train_dataset, shuffle=True, batch_size=batch_size, pin_memory=True,
16 | num_workers=num_workers)
17 |
18 | def get_train_loader_FKD(imagenet_path, batch_size, num_workers, image_size, num_crops, softlabel_path):
19 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
20 | std=[0.229, 0.224, 0.225])
21 | train_dataset = ImageFolder_FKD(
22 | num_crops=num_crops,
23 | softlabel_path=softlabel_path,
24 | root=os.path.join(imagenet_path, 'train'),
25 | transform=Compose_FKD(transforms=[
26 | RandomResizedCrop_FKD(size=224,
27 | interpolation='bilinear'),
28 | RandomHorizontalFlip_FKD(),
29 | transforms.ToTensor(),
30 | normalize,
31 | ]))
32 | return data_utils.DataLoader(
33 | train_dataset, shuffle=True, batch_size=batch_size, pin_memory=True,
34 | num_workers=num_workers)
35 |
36 | def get_val_loader(imagenet_path, batch_size, num_workers, image_size):
37 | val_dataset = ImageNet(imagenet_path, image_size, is_train=False)
38 | return data_utils.DataLoader(
39 | val_dataset, shuffle=False, batch_size=batch_size, pin_memory=True,
40 | num_workers=num_workers)
41 |
42 |
43 | class ImageNet(torch_datasets.ImageFolder):
44 | """Dataset class for ImageNet dataset.
45 |
46 | Arguments:
47 | root_dir (str): Path to the dataset root directory, which must contain
48 | train/ and val/ directories.
49 | is_train (bool): Whether to read training or validation images.
50 | """
51 | MEAN = [0.485, 0.456, 0.406]
52 | STD = [0.229, 0.224, 0.225]
53 |
54 | def __init__(self, root_dir, im_size, is_train):
55 | if is_train:
56 | root_dir = os.path.join(root_dir, 'train')
57 | transform = transforms.Compose([
58 | transforms.RandomResizedCrop(im_size),
59 | transforms.RandomHorizontalFlip(),
60 | transforms.ToTensor(),
61 | transforms.Normalize(ImageNet.MEAN, ImageNet.STD),
62 | ])
63 | else:
64 | root_dir = os.path.join(root_dir, 'val')
65 | transform = transforms.Compose([
66 | transforms.Resize(int(256/224*im_size)),
67 | transforms.CenterCrop(im_size),
68 | transforms.ToTensor(),
69 | transforms.Normalize(ImageNet.MEAN, ImageNet.STD),
70 | ])
71 | super().__init__(root_dir, transform=transform)
72 |
73 |
74 |
--------------------------------------------------------------------------------
/images/MEAL-V2_more_tricks_top1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/szq0214/MEAL-V2/59e6bca29c77a7df9229e77d4463f2f795723810/images/MEAL-V2_more_tricks_top1.png
--------------------------------------------------------------------------------
/images/MEAL-V2_more_tricks_top5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/szq0214/MEAL-V2/59e6bca29c77a7df9229e77d4463f2f795723810/images/MEAL-V2_more_tricks_top5.png
--------------------------------------------------------------------------------
/images/comparison.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/szq0214/MEAL-V2/59e6bca29c77a7df9229e77d4463f2f795723810/images/comparison.png
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 | import shutil
5 | import time
6 | import warnings
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.parallel
11 | import torch.backends.cudnn as cudnn
12 | import torch.distributed as dist
13 | import torch.optim
14 | import torch.multiprocessing as mp
15 | import torch.utils.data
16 | import torch.utils.data.distributed
17 | import torchvision.transforms as transforms
18 | import torchvision.datasets as datasets
19 | import torchvision.models as models
20 | import timm
21 |
22 | model_names = sorted(name for name in models.__dict__
23 | if name.islower() and not name.startswith("__")
24 | and callable(models.__dict__[name]))
25 |
26 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
27 | parser.add_argument('data', metavar='DIR',
28 | help='path to dataset')
29 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
30 | # choices=model_names,
31 | help='model architecture: ' +
32 | ' | '.join(model_names) +
33 | ' (default: resnet18)')
34 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
35 | help='number of data loading workers (default: 4)')
36 | parser.add_argument('--epochs', default=90, type=int, metavar='N',
37 | help='number of total epochs to run')
38 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
39 | help='manual epoch number (useful on restarts)')
40 | parser.add_argument('--res', default=224, type=int,
41 | help='image resolution for testing')
42 | parser.add_argument('-b', '--batch-size', default=256, type=int,
43 | metavar='N',
44 | help='mini-batch size (default: 256), this is the total '
45 | 'batch size of all GPUs on the current node when '
46 | 'using Data Parallel or Distributed Data Parallel')
47 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
48 | metavar='LR', help='initial learning rate', dest='lr')
49 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
50 | help='momentum')
51 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
52 | metavar='W', help='weight decay (default: 1e-4)',
53 | dest='weight_decay')
54 | parser.add_argument('-p', '--print-freq', default=10, type=int,
55 | metavar='N', help='print frequency (default: 10)')
56 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
57 | help='path to latest checkpoint (default: none)')
58 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
59 | help='evaluate model on validation set')
60 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
61 | help='use pre-trained model')
62 | parser.add_argument('--world-size', default=-1, type=int,
63 | help='number of nodes for distributed training')
64 | parser.add_argument('--rank', default=-1, type=int,
65 | help='node rank for distributed training')
66 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
67 | help='url used to set up distributed training')
68 | parser.add_argument('--dist-backend', default='nccl', type=str,
69 | help='distributed backend')
70 | parser.add_argument('--seed', default=None, type=int,
71 | help='seed for initializing training. ')
72 | parser.add_argument('--gpu', default=None, type=int,
73 | help='GPU id to use.')
74 | parser.add_argument('--multiprocessing-distributed', action='store_true',
75 | help='Use multi-processing distributed training to launch '
76 | 'N processes per node, which has N GPUs. This is the '
77 | 'fastest way to use PyTorch for either single node or '
78 | 'multi node data parallel training')
79 |
80 | best_acc1 = 0
81 |
82 |
83 | def main():
84 | args = parser.parse_args()
85 |
86 | if args.seed is not None:
87 | random.seed(args.seed)
88 | torch.manual_seed(args.seed)
89 | cudnn.deterministic = True
90 | warnings.warn('You have chosen to seed training. '
91 | 'This will turn on the CUDNN deterministic setting, '
92 | 'which can slow down your training considerably! '
93 | 'You may see unexpected behavior when restarting '
94 | 'from checkpoints.')
95 |
96 | if args.gpu is not None:
97 | warnings.warn('You have chosen a specific GPU. This will completely '
98 | 'disable data parallelism.')
99 |
100 | if args.dist_url == "env://" and args.world_size == -1:
101 | args.world_size = int(os.environ["WORLD_SIZE"])
102 |
103 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed
104 |
105 | ngpus_per_node = torch.cuda.device_count()
106 | if args.multiprocessing_distributed:
107 | # Since we have ngpus_per_node processes per node, the total world_size
108 | # needs to be adjusted accordingly
109 | args.world_size = ngpus_per_node * args.world_size
110 | # Use torch.multiprocessing.spawn to launch distributed processes: the
111 | # main_worker process function
112 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
113 | else:
114 | # Simply call main_worker function
115 | main_worker(args.gpu, ngpus_per_node, args)
116 |
117 |
118 | def main_worker(gpu, ngpus_per_node, args):
119 | global best_acc1
120 | args.gpu = gpu
121 |
122 | if args.gpu is not None:
123 | print("Use GPU: {} for training".format(args.gpu))
124 |
125 | if args.distributed:
126 | if args.dist_url == "env://" and args.rank == -1:
127 | args.rank = int(os.environ["RANK"])
128 | if args.multiprocessing_distributed:
129 | # For multiprocessing distributed training, rank needs to be the
130 | # global rank among all the processes
131 | args.rank = args.rank * ngpus_per_node + gpu
132 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
133 | world_size=args.world_size, rank=args.rank)
134 | # create model
135 | if args.pretrained:
136 | print("=> using pre-trained model '{}'".format(args.arch))
137 | # model = models.__dict__[args.arch](pretrained=True)
138 | model = timm.create_model(args.arch, pretrained=True)
139 | else:
140 | print("=> creating model '{}'".format(args.arch))
141 | # model = models.__dict__[args.arch]()
142 | model = timm.create_model(args.arch, pretrained=False)
143 |
144 | if not torch.cuda.is_available():
145 | print('using CPU, this will be slow')
146 | elif args.distributed:
147 | # For multiprocessing distributed, DistributedDataParallel constructor
148 | # should always set the single device scope, otherwise,
149 | # DistributedDataParallel will use all available devices.
150 | if args.gpu is not None:
151 | torch.cuda.set_device(args.gpu)
152 | model.cuda(args.gpu)
153 | # When using a single GPU per process and per
154 | # DistributedDataParallel, we need to divide the batch size
155 | # ourselves based on the total number of GPUs we have
156 | args.batch_size = int(args.batch_size / ngpus_per_node)
157 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
158 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
159 | else:
160 | model.cuda()
161 | # DistributedDataParallel will divide and allocate batch_size to all
162 | # available GPUs if device_ids are not set
163 | model = torch.nn.parallel.DistributedDataParallel(model)
164 | elif args.gpu is not None:
165 | torch.cuda.set_device(args.gpu)
166 | model = model.cuda(args.gpu)
167 | else:
168 | # DataParallel will divide and allocate batch_size to all available GPUs
169 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
170 | model.features = torch.nn.DataParallel(model.features)
171 | model.cuda()
172 | else:
173 | model = torch.nn.DataParallel(model).cuda()
174 |
175 | # define loss function (criterion) and optimizer
176 | criterion = nn.CrossEntropyLoss().cuda(args.gpu)
177 |
178 | optimizer = torch.optim.SGD(model.parameters(), args.lr,
179 | momentum=args.momentum,
180 | weight_decay=args.weight_decay)
181 |
182 | # optionally resume from a checkpoint
183 | if args.resume:
184 | if os.path.isfile(args.resume):
185 | print("=> loading checkpoint '{}'".format(args.resume))
186 | if args.gpu is None:
187 | checkpoint = torch.load(args.resume)
188 | else:
189 | # Map model to be loaded to specified single gpu.
190 | loc = 'cuda:{}'.format(args.gpu)
191 | checkpoint = torch.load(args.resume, map_location=loc)
192 | model.load_state_dict(checkpoint)
193 | else:
194 | print("=> no checkpoint found at '{}'".format(args.resume))
195 |
196 | cudnn.benchmark = True
197 |
198 | # Data loading code
199 | traindir = os.path.join(args.data, 'train')
200 | valdir = os.path.join(args.data, 'val')
201 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
202 | std=[0.229, 0.224, 0.225])
203 |
204 | train_dataset = datasets.ImageFolder(
205 | traindir,
206 | transforms.Compose([
207 | transforms.RandomResizedCrop(224),
208 | transforms.RandomHorizontalFlip(),
209 | transforms.ToTensor(),
210 | normalize,
211 | ]))
212 |
213 | if args.distributed:
214 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
215 | else:
216 | train_sampler = None
217 |
218 | train_loader = torch.utils.data.DataLoader(
219 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
220 | num_workers=args.workers, pin_memory=True, sampler=train_sampler)
221 |
222 | val_loader = torch.utils.data.DataLoader(
223 | datasets.ImageFolder(valdir, transforms.Compose([
224 | transforms.Resize(int(256/224*args.res)),
225 | transforms.CenterCrop(args.res),
226 | transforms.ToTensor(),
227 | normalize,
228 | ])),
229 | batch_size=args.batch_size, shuffle=False,
230 | num_workers=args.workers, pin_memory=True)
231 |
232 | if args.evaluate:
233 | validate(val_loader, model, criterion, args)
234 | return
235 |
236 | for epoch in range(args.start_epoch, args.epochs):
237 | if args.distributed:
238 | train_sampler.set_epoch(epoch)
239 | adjust_learning_rate(optimizer, epoch, args)
240 |
241 | # train for one epoch
242 | train(train_loader, model, criterion, optimizer, epoch, args)
243 |
244 | # evaluate on validation set
245 | acc1 = validate(val_loader, model, criterion, args)
246 |
247 | # remember best acc@1 and save checkpoint
248 | is_best = acc1 > best_acc1
249 | best_acc1 = max(acc1, best_acc1)
250 |
251 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed
252 | and args.rank % ngpus_per_node == 0):
253 | save_checkpoint({
254 | 'epoch': epoch + 1,
255 | 'arch': args.arch,
256 | 'state_dict': model.state_dict(),
257 | 'best_acc1': best_acc1,
258 | 'optimizer' : optimizer.state_dict(),
259 | }, is_best)
260 |
261 |
262 | def train(train_loader, model, criterion, optimizer, epoch, args):
263 | batch_time = AverageMeter('Time', ':6.3f')
264 | data_time = AverageMeter('Data', ':6.3f')
265 | losses = AverageMeter('Loss', ':.4e')
266 | top1 = AverageMeter('Acc@1', ':6.2f')
267 | top5 = AverageMeter('Acc@5', ':6.2f')
268 | progress = ProgressMeter(
269 | len(train_loader),
270 | [batch_time, data_time, losses, top1, top5],
271 | prefix="Epoch: [{}]".format(epoch))
272 |
273 | # switch to train mode
274 | model.train()
275 |
276 | end = time.time()
277 | for i, (images, target) in enumerate(train_loader):
278 | # measure data loading time
279 | data_time.update(time.time() - end)
280 |
281 | if args.gpu is not None:
282 | images = images.cuda(args.gpu, non_blocking=True)
283 | if torch.cuda.is_available():
284 | target = target.cuda(args.gpu, non_blocking=True)
285 |
286 | # compute output
287 | output = model(images)
288 | loss = criterion(output, target)
289 |
290 | # measure accuracy and record loss
291 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
292 | losses.update(loss.item(), images.size(0))
293 | top1.update(acc1[0], images.size(0))
294 | top5.update(acc5[0], images.size(0))
295 |
296 | # compute gradient and do SGD step
297 | optimizer.zero_grad()
298 | loss.backward()
299 | optimizer.step()
300 |
301 | # measure elapsed time
302 | batch_time.update(time.time() - end)
303 | end = time.time()
304 |
305 | if i % args.print_freq == 0:
306 | progress.display(i)
307 |
308 |
309 | def validate(val_loader, model, criterion, args):
310 | batch_time = AverageMeter('Time', ':6.3f')
311 | losses = AverageMeter('Loss', ':.4e')
312 | top1 = AverageMeter('Acc@1', ':6.2f')
313 | top5 = AverageMeter('Acc@5', ':6.2f')
314 | progress = ProgressMeter(
315 | len(val_loader),
316 | [batch_time, losses, top1, top5],
317 | prefix='Test: ')
318 |
319 | # switch to evaluate mode
320 | model.eval()
321 |
322 | with torch.no_grad():
323 | end = time.time()
324 | for i, (images, target) in enumerate(val_loader):
325 | if args.gpu is not None:
326 | images = images.cuda(args.gpu, non_blocking=True)
327 | if torch.cuda.is_available():
328 | target = target.cuda(args.gpu, non_blocking=True)
329 |
330 | # compute output
331 | output = model(images)
332 | loss = criterion(output, target)
333 |
334 | # measure accuracy and record loss
335 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
336 | losses.update(loss.item(), images.size(0))
337 | top1.update(acc1[0], images.size(0))
338 | top5.update(acc5[0], images.size(0))
339 |
340 | # measure elapsed time
341 | batch_time.update(time.time() - end)
342 | end = time.time()
343 |
344 | if i % args.print_freq == 0:
345 | progress.display(i)
346 |
347 | # TODO: this should also be done with the ProgressMeter
348 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
349 | .format(top1=top1, top5=top5))
350 |
351 | return top1.avg
352 |
353 |
354 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
355 | torch.save(state, filename)
356 | if is_best:
357 | shutil.copyfile(filename, 'model_best.pth.tar')
358 |
359 |
360 | class AverageMeter(object):
361 | """Computes and stores the average and current value"""
362 | def __init__(self, name, fmt=':f'):
363 | self.name = name
364 | self.fmt = fmt
365 | self.reset()
366 |
367 | def reset(self):
368 | self.val = 0
369 | self.avg = 0
370 | self.sum = 0
371 | self.count = 0
372 |
373 | def update(self, val, n=1):
374 | self.val = val
375 | self.sum += val * n
376 | self.count += n
377 | self.avg = self.sum / self.count
378 |
379 | def __str__(self):
380 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
381 | return fmtstr.format(**self.__dict__)
382 |
383 |
384 | class ProgressMeter(object):
385 | def __init__(self, num_batches, meters, prefix=""):
386 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
387 | self.meters = meters
388 | self.prefix = prefix
389 |
390 | def display(self, batch):
391 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
392 | entries += [str(meter) for meter in self.meters]
393 | print('\t'.join(entries))
394 |
395 | def _get_batch_fmtstr(self, num_batches):
396 | num_digits = len(str(num_batches // 1))
397 | fmt = '{:' + str(num_digits) + 'd}'
398 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
399 |
400 |
401 | def adjust_learning_rate(optimizer, epoch, args):
402 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
403 | lr = args.lr * (0.1 ** (epoch // 30))
404 | for param_group in optimizer.param_groups:
405 | param_group['lr'] = lr
406 |
407 |
408 | def accuracy(output, target, topk=(1,)):
409 | """Computes the accuracy over the k top predictions for the specified values of k"""
410 | with torch.no_grad():
411 | maxk = max(topk)
412 | batch_size = target.size(0)
413 |
414 | _, pred = output.topk(maxk, 1, True, True)
415 | pred = pred.t()
416 | correct = pred.eq(target.view(1, -1).expand_as(pred))
417 |
418 | res = []
419 | for k in topk:
420 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
421 | res.append(correct_k.mul_(100.0 / batch_size))
422 | return res
423 |
424 |
425 | if __name__ == '__main__':
426 | main()
--------------------------------------------------------------------------------
/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class betweenLoss(nn.Module):
6 | def __init__(self, gamma=[1,1,1,1,1,1], loss=nn.L1Loss()):
7 | super(betweenLoss, self).__init__()
8 | self.gamma = gamma
9 | self.loss = loss
10 |
11 | def forward(self, outputs, targets):
12 | assert len(outputs)
13 | assert len(outputs) == len(targets)
14 |
15 | length = len(outputs)
16 |
17 | res = sum([self.gamma[i]*self.loss(outputs[i], targets[i]) for i in range(length)])
18 |
19 | return res
20 |
21 | def CrossEntropy(outputs, targets):
22 | log_softmax_outputs = F.log_softmax(outputs, dim=1)
23 | softmax_targets = F.softmax(targets, dim=1)
24 |
25 | return -(log_softmax_outputs*softmax_targets).sum(dim=1).mean()
26 |
27 |
28 | class discriminatorLoss(nn.Module):
29 | def __init__(self, models, loss=nn.BCEWithLogitsLoss()):
30 | super(discriminatorLoss, self).__init__()
31 | self.models = models
32 | self.loss = loss
33 |
34 | def forward(self, outputs, targets):
35 | inputs = [torch.cat((i,j),0) for i, j in zip(outputs, targets)]
36 | inputs = torch.cat(inputs, 1)
37 | batch_size = inputs.size(0)
38 | target = torch.FloatTensor([[1, 0] for _ in range(batch_size//2)] + [[0, 1] for _ in range(batch_size//2)])
39 | target = target.to(inputs[0].device)
40 | output = self.models(inputs)
41 | res = self.loss(output, target)
42 | return res
43 |
44 |
45 | class discriminatorFakeLoss(nn.Module):
46 | def forward(self, outputs, targets):
47 | res = (0*outputs[0]).sum()
48 | return res
49 |
50 |
51 |
52 |
53 |
54 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/szq0214/MEAL-V2/59e6bca29c77a7df9229e77d4463f2f795723810/models/__init__.py
--------------------------------------------------------------------------------
/models/blocks.py:
--------------------------------------------------------------------------------
1 | """A list of commonly used building blocks."""
2 |
3 | from torch import nn
4 |
5 |
6 | class Conv2dBnRelu(nn.Module):
7 | """A commonly used building block: Conv -> BN -> ReLU"""
8 |
9 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
10 | padding=0, bias=True, pooling=None,
11 | activation=nn.ReLU(inplace=True)):
12 | super().__init__()
13 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride,
14 | padding, bias=bias)
15 | self.bn = nn.BatchNorm2d(out_channels)
16 | self.pooling = pooling
17 | self.activation = activation
18 |
19 | def forward(self, x):
20 | x = self.bn(self.conv(x))
21 | if self.pooling is not None:
22 | x = self.pooling(x)
23 | return self.activation(x)
24 |
25 |
26 | class LinearBnRelu(nn.Module):
27 | """A commonly used building block: FC -> BN -> ReLU"""
28 |
29 | def __init__(self, in_features, out_features, bias=True,
30 | activation=nn.ReLU(inplace=True)):
31 | super().__init__()
32 | self.linear = nn.Linear(in_features, out_features, bias=bias)
33 | self.bn = nn.BatchNorm1d(out_features)
34 | self.activation = activation
35 |
36 | def forward(self, x):
37 | return self.activation(self.bn(self.linear(x)))
38 |
--------------------------------------------------------------------------------
/models/discriminator.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | class Discriminator(nn.Module):
7 | def __init__(self, outputs_size, K = 2):
8 | super(Discriminator, self).__init__()
9 | self.conv1 = nn.Conv2d(in_channels=outputs_size, out_channels=outputs_size//K, kernel_size=1, stride=1, bias=True)
10 | outputs_size = outputs_size // K
11 | self.conv2 = nn.Conv2d(in_channels=outputs_size, out_channels=outputs_size//K, kernel_size=1, stride=1, bias=True)
12 | outputs_size = outputs_size // K
13 | self.conv3 = nn.Conv2d(in_channels=outputs_size, out_channels=2, kernel_size=1, stride=1, bias=True)
14 |
15 | def forward(self, x):
16 | x = x[:,:,None,None]
17 | out = F.relu(self.conv1(x))
18 | out = F.relu(self.conv2(out))
19 | out = F.relu(self.conv3(out))
20 | out = out.view(out.size(0), -1)
21 | return out
22 |
23 |
--------------------------------------------------------------------------------
/models/model_factory.py:
--------------------------------------------------------------------------------
1 | """Utility functions to construct a model."""
2 |
3 | import torch
4 | from torch import nn
5 |
6 | import random
7 |
8 | from extensions import data_parallel
9 | from extensions import teacher_wrapper
10 | from extensions import kd_loss
11 | import torchvision.models as models
12 | import timm
13 |
14 |
15 | def _create_single_cpu_model(model_name, state_file=None):
16 | model = _create_model(model_name, teacher=False, pretrain=True)
17 | if state_file is not None:
18 | model.load_state_dict(torch.load(state_file))
19 | return model
20 |
21 | def _create_checkpoint_model(model_name, state_file=None):
22 | model = _create_model(model_name, teacher=False, pretrain=True)
23 | # model = timm.create_model(model_name.lower(), pretrained=False)
24 | if state_file is not None:
25 | model.load_state_dict(torch.load(state_file))
26 | return model
27 |
28 | def _create_model(model_name, teacher=False, pretrain=True):
29 | if pretrain:
30 | print("=> teacher" if teacher else "=> student", end=":")
31 | print(" using pre-trained model '{}'".format(model_name))
32 |
33 | # model = models.__dict__[model_name.lower()](pretrained=True)
34 | model = timm.create_model(model_name.lower(), pretrained=True)
35 | else:
36 | print("=> creating model '{}'".format(model_name))
37 | # model = models.__dict__[model_name.lower()]()
38 | model = timm.create_model(model_name.lower(), pretrained=False)
39 |
40 | if model_name.startswith('alexnet') or model_name.startswith('vgg'):
41 | model.features = torch.nn.DataParallel(model.features)
42 | model.cuda()
43 | else:
44 | model = torch.nn.DataParallel(model).cuda()
45 |
46 | if teacher:
47 | for p in model.parameters():
48 | p.requires_grad = False
49 | model.eval()
50 |
51 | return model
52 |
53 |
54 | def teachers(teachers=['resnet50'], state_file=None):
55 | if state_file is not None:
56 | return [_create_single_cpu_model(t, state_file).cuda() for t in teachers]
57 | else:
58 | return [_create_model(t, teacher=True).cuda() for t in teachers]
59 |
60 |
61 | def create_model(model_name, student_state_file=None, gpus=[], teacher=None,
62 | teacher_state_file=None, FKD=True):
63 | if FKD:
64 | model = _create_checkpoint_model(model_name, student_state_file)
65 | loss = kd_loss.KLLoss()
66 | return model, loss
67 | else:
68 | model = _create_checkpoint_model(model_name, student_state_file)
69 | model.LR_REGIME = [0, 100, 0.01, 101, 300, 0.001] # LR_REGIME
70 | if teacher is not None:
71 | # assert teacher_state_file is not None, "Teacher state is None."
72 |
73 | teacher = teachers(teacher.split(","), teacher_state_file)
74 | model = teacher_wrapper.ModelDistillationWrapper(model, teacher)
75 | loss = kd_loss.KLLoss()
76 | else:
77 | loss = nn.CrossEntropyLoss()
78 |
79 | return model, loss
--------------------------------------------------------------------------------
/opts.py:
--------------------------------------------------------------------------------
1 | from torch.utils import data as data_utils
2 |
3 | from models import model_factory
4 |
5 |
6 | def add_general_flags(parser):
7 | parser.add_argument('--save', default='checkpoints',
8 | help="Path to the directory to save logs and "
9 | "checkpoints.")
10 | parser.add_argument('--gpus', '--gpu', nargs='+', default=[0], type=int,
11 | help="The GPU(s) on which the model should run. The "
12 | "first GPU will be the main one.")
13 | parser.add_argument('--cpu', action='store_const', const=[],
14 | dest='gpus', help="If set, no gpus will be used.")
15 |
16 |
17 | def add_dataset_flags(parser):
18 | parser.add_argument('--imagenet', required=True, help="Path to ImageNet's "
19 | "root directory holding 'train/' and 'val/' "
20 | "directories.")
21 | parser.add_argument('--batch-size', default=256, help="Batch size to use "
22 | "distributed over all GPUs.", type=int)
23 | parser.add_argument('--num-workers', '-j', default=40, help="Number of "
24 | "data loading processes to use for loading data and "
25 | "transforming.", type=int)
26 | parser.add_argument('--image-size', default=224, help="image size to train "
27 | "input image size.", type=int)
28 | parser.add_argument('--softlabel_path', default='./soft_label', type=str, metavar='PATH',
29 | help='path to soft label files (default: none)')
30 |
31 |
32 | def add_model_flags(parser):
33 | parser.add_argument('--model', required=True, help="The model architecture "
34 | "name.")
35 | parser.add_argument('--student-state-file', default=None, help="Path to student model"
36 | "state file to initialize the student model.")
37 |
38 |
39 | def add_teacher_flags(parser):
40 | parser.add_argument('--teacher-model', default="gluon_senet154,gluon_resnet152_v1s", help="The "
41 | "model that will generate soft labels per crop.",
42 | )
43 | parser.add_argument('--teacher-state-file', default=None,
44 | help="Path to teacher model state file.")
45 |
46 |
47 | def add_training_flags(parser):
48 | parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,
49 | metavar='LR', help='initial learning rate', dest='lr')
50 | parser.add_argument('--lr-regime', default=None, nargs='+', type=float,
51 | help="If set, it will override the default learning "
52 | "rate regime of the model. Learning rate passed must "
53 | "be as list of [start, end, lr, ...].")
54 | parser.add_argument('--d_lr', default=1e-4, type=float,
55 | help="The learning rate for discriminator training")
56 | parser.add_argument('--start-epoch', default=0, help="manual epoch number "
57 | "useful on restarts.", type=int)
58 | parser.add_argument('--epochs', default=200, type=int, help='number of total epochs to run')
59 | parser.add_argument('--schedule', default=[100, 200], nargs='*', type=int,
60 | help='learning rate schedule (when to drop lr by 10x). This works for FKD training')
61 | parser.add_argument('--cos', action='store_true',
62 | help='use cosine lr schedule. This works for FKD training')
63 | parser.add_argument('--momentum', default=0.9, type=float,
64 | help="The momentum of the optimization.")
65 | parser.add_argument('--weight-decay', default=0, type=float,
66 | help="The weight decay of the optimization.")
67 | parser.add_argument('--use-discriminator-loss', action='store_true',
68 | help='use discriminating training')
69 | parser.add_argument('--num_crops', default=8, type=int,
70 | help='number of crops in each image, 1 is the standard training')
71 | parser.add_argument('--soft_label_type', default='marginal_smoothing_k5', type=str, metavar='TYPE',
72 | help='(1) ori; (2) hard; (3) smoothing; (4) marginal_smoothing_k5; (5) marginal_smoothing_k10; (6) marginal_renorm_k5')
73 | parser.add_argument('--num_classes', default=1000, type=int,
74 | help='number of classes. ')
75 |
76 | def add_cutmix_training_flags(parser):
77 | parser.add_argument('--w-cutmix', action='store_true',
78 | help='use cutmix training')
79 | parser.add_argument('--beta', default=1.0, type=float,
80 | help='hyperparameter beta')
81 | parser.add_argument('--cutmix-prob', default=1.0, type=float,
82 | help='cutmix probability')
--------------------------------------------------------------------------------
/script/resume_train.sh:
--------------------------------------------------------------------------------
1 | # an example
2 | python train.py --save MEAL_V2_resnet50_224 --batch-size 512 --model resnet50 --start-epoch 96 --teacher-model gluon_senet154,gluon_resnet152_v1s --student-state-file ./MEAL_V2_resnet50/model_state_95.pytar --imagenet [imagenet-folder with train and val folders] -j 40
3 |
--------------------------------------------------------------------------------
/script/train.sh:
--------------------------------------------------------------------------------
1 | # 224 x 224 ResNet-50 Tested on 8 TITAN Xp GPUs
2 | python train.py --save MEAL_V2_resnet50_224 --batch-size 512 -j 48 --model resnet50 --teacher-model gluon_senet154,gluon_resnet152_v1s --imagenet [imagenet-folder with train and val folders]
3 |
4 | # # 224 x 224 MobileNet V3-Small 0.75
5 | # python train.py --save MEAL_V2_mobilenetv3_small_075 --batch-size 512 -j 48 --model tf_mobilenetv3_small_075 --teacher-model gluon_senet154,gluon_resnet152_v1s --imagenet [imagenet-folder with train and val folders]
6 |
7 | # # 224 x 224 MobileNet V3-Small 1.0
8 | # python train.py --save MEAL_V2_mobilenetv3_small_100 --batch-size 512 -j 48 --model tf_mobilenetv3_small_100 --teacher-model gluon_senet154,gluon_resnet152_v1s --imagenet [imagenet-folder with train and val folders]
9 |
10 | # # 224 x 224 MobileNet V3-Large 1.0
11 | # python train.py --save MEAL_V2_mobilenetv3_large_100 --batch-size 512 -j 48 --model tf_mobilenetv3_large_100 --teacher-model gluon_senet154,gluon_resnet152_v1s --imagenet [imagenet-folder with train and val folders]
12 |
13 | # # 224 x 224 EfficientNet-B0
14 | # python train.py --save MEAL_V2_efficientnet_b0 --batch-size 512 -j 48 --model tf_efficientnet_b0 --teacher-model gluon_senet154,gluon_resnet152_v1s --imagenet [imagenet-folder with train and val folders]
15 |
16 | # 380 x 380 ResNet-50
17 | # python train.py --save MEAL_V2_resnet50_380 --batch-size 512 -j 48 --model resnet50 --image-size 380 --teacher-model tf_efficientnet_b4_ns,tf_efficientnet_b4 --imagenet [imagenet-folder with train and val folders]
18 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """Script to test a pytorch model on ImageNet's validation set."""
3 |
4 | import argparse
5 | import logging
6 | import pprint
7 | import sys
8 | import time
9 |
10 | import torch
11 | from torch import nn
12 |
13 | import imagenet
14 | from models import model_factory
15 | import opts
16 | import utils
17 |
18 |
19 | def parse_args(argv):
20 | """Parse arguments @argv and return the flags needed for training."""
21 | parser = argparse.ArgumentParser(description=__doc__, allow_abbrev=False)
22 |
23 | group = parser.add_argument_group('General Options')
24 | opts.add_general_flags(group)
25 |
26 | group = parser.add_argument_group('Dataset Options')
27 | opts.add_dataset_flags(group)
28 |
29 | group = parser.add_argument_group('Model Options')
30 | opts.add_model_flags(group)
31 |
32 | args = parser.parse_args(argv)
33 |
34 | if args.student_state_file is None:
35 | parser.error("You should set --model-state-file (student) to reload a model "
36 | "state.")
37 |
38 | return args
39 |
40 |
41 | def test_for_one_epoch(model, loss, test_loader, epoch_number):
42 | model.eval()
43 | loss.eval()
44 |
45 | data_time_meter = utils.AverageMeter()
46 | batch_time_meter = utils.AverageMeter()
47 | loss_meter = utils.AverageMeter(recent=100)
48 | top1_meter = utils.AverageMeter(recent=100)
49 | top5_meter = utils.AverageMeter(recent=100)
50 |
51 | timestamp = time.time()
52 | for i, (images, labels) in enumerate(test_loader):
53 | batch_size = images.size(0)
54 |
55 | if utils.is_model_cuda(model):
56 | images = images.cuda()
57 | labels = labels.cuda()
58 |
59 | # Record data time
60 | data_time_meter.update(time.time() - timestamp)
61 |
62 | # Forward pass without computing gradients.
63 | with torch.no_grad():
64 | outputs = model(images)
65 | loss_output = loss(outputs, labels)
66 |
67 | # Sometimes loss function returns a modified version of the output,
68 | # which must be used to compute the model accuracy.
69 | if isinstance(loss_output, tuple):
70 | loss_value, outputs = loss_output
71 | else:
72 | loss_value = loss_output
73 |
74 | # Record loss and model accuracy.
75 | loss_meter.update(loss_value.item(), batch_size)
76 | top1, top5 = utils.topk_accuracy(outputs, labels, recalls=(1, 5))
77 | top1_meter.update(top1, batch_size)
78 | top5_meter.update(top5, batch_size)
79 |
80 | # Record batch time
81 | batch_time_meter.update(time.time() - timestamp)
82 | timestamp = time.time()
83 |
84 | logging.info(
85 | 'Epoch: [{epoch}][{batch}/{epoch_size}]\t'
86 | 'Time {batch_time.value:.2f} ({batch_time.average:.2f}) '
87 | 'Data {data_time.value:.2f} ({data_time.average:.2f}) '
88 | 'Loss {loss.value:.3f} {{{loss.average:.3f}, {loss.average_recent:.3f}}} '
89 | 'Top-1 {top1.value:.2f} {{{top1.average:.2f}, {top1.average_recent:.2f}}} '
90 | 'Top-5 {top5.value:.2f} {{{top5.average:.2f}, {top5.average_recent:.2f}}} '.format(
91 | epoch=epoch_number, batch=i + 1, epoch_size=len(test_loader),
92 | batch_time=batch_time_meter, data_time=data_time_meter,
93 | loss=loss_meter, top1=top1_meter, top5=top5_meter))
94 | # Log the overall test stats
95 | logging.info(
96 | 'Epoch: [{epoch}] -- TESTING SUMMARY\t'
97 | 'Time {batch_time.sum:.2f} '
98 | 'Data {data_time.sum:.2f} '
99 | 'Loss {loss.average:.3f} '
100 | 'Top-1 {top1.average:.2f} '
101 | 'Top-5 {top5.average:.2f} '.format(
102 | epoch=epoch_number, batch_time=batch_time_meter, data_time=data_time_meter,
103 | loss=loss_meter, top1=top1_meter, top5=top5_meter))
104 |
105 |
106 | def main(argv):
107 | """Run the test script with command line arguments @argv."""
108 | args = parse_args(argv)
109 | utils.general_setup(args.save, args.gpus)
110 |
111 | logging.info("Arguments parsed.\n{}".format(pprint.pformat(vars(args))))
112 |
113 | # Create the validation data loaders.
114 | val_loader = imagenet.get_val_loader(args.imagenet, args.batch_size,
115 | args.num_workers)
116 | # Create model and the loss.
117 | model, loss = model_factory.create_model(
118 | args.model, args.student_state_file, args.gpus)
119 | logging.info("Model:\n{}".format(model))
120 |
121 | # Test for one epoch.
122 | test_for_one_epoch(model, loss, val_loader, epoch_number=1)
123 |
124 |
125 | if __name__ == '__main__':
126 | main(sys.argv[1:])
127 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """Script to train a model through soft labels on ImageNet's train set."""
3 |
4 | import argparse
5 | import logging
6 | import pprint
7 | import os
8 | import sys
9 | import time
10 | import numpy as np
11 |
12 | import torch
13 | from torch import nn
14 |
15 | from loss import discriminatorLoss
16 |
17 | import imagenet
18 | from models import model_factory
19 | from models import discriminator
20 | import opts
21 | import test
22 | import utils
23 |
24 |
25 | def parse_args(argv):
26 | """Parse arguments @argv and return the flags needed for training."""
27 | parser = argparse.ArgumentParser(description=__doc__, allow_abbrev=False)
28 |
29 | group = parser.add_argument_group('General Options')
30 | opts.add_general_flags(group)
31 |
32 | group = parser.add_argument_group('Dataset Options')
33 | opts.add_dataset_flags(group)
34 |
35 | group = parser.add_argument_group('Model Options')
36 | opts.add_model_flags(group)
37 |
38 | group = parser.add_argument_group('Soft Label Options')
39 | opts.add_teacher_flags(group)
40 |
41 | group = parser.add_argument_group('Training Options')
42 | opts.add_training_flags(group)
43 |
44 | group = parser.add_argument_group('CutMix Training Options')
45 | opts.add_cutmix_training_flags(group)
46 |
47 | args = parser.parse_args(argv)
48 |
49 | # if args.teacher_model is not None and args.teacher_state_file is None:
50 | # parser.error("You should set --teacher-state-file if "
51 | # "--teacher-model is set.")
52 |
53 | return args
54 |
55 |
56 | class LearningRateRegime:
57 | """Encapsulates the learning rate regime for training a model.
58 |
59 | Args:
60 | @intervals (list): A list of triples (start, end, lr). The intervals
61 | are inclusive (for start <= epoch <= end, lr will be used). The
62 | start of each interval must be right after the end of its previous
63 | interval.
64 | """
65 |
66 | def __init__(self, regime):
67 | if len(regime) % 3 != 0:
68 | raise ValueError("Regime length should be devisible by 3.")
69 | intervals = list(zip(regime[0::3], regime[1::3], regime[2::3]))
70 | self._validate_intervals(intervals)
71 | self.intervals = intervals
72 | self.num_epochs = intervals[-1][1]
73 |
74 | @classmethod
75 | def _validate_intervals(cls, intervals):
76 | if type(intervals) is not list:
77 | raise TypeError("Intervals must be a list of triples.")
78 | elif len(intervals) == 0:
79 | raise ValueError("Intervals must be a non empty list.")
80 | # elif intervals[0][0] != 1:
81 | # raise ValueError("Intervals must start from 1: {}".format(intervals))
82 | elif any(end < start for (start, end, lr) in intervals):
83 | raise ValueError("End of intervals must be greater or equal than their"
84 | " start: {}".format(intervals))
85 | elif any(intervals[i][1] + 1 != intervals[i + 1][0]
86 | for i in range(len(intervals) - 1)):
87 | raise ValueError("Start of each each interval must be the end of its "
88 | "previous interval plus one: {}".format(intervals))
89 |
90 | def get_lr(self, epoch):
91 | for (start, end, lr) in self.intervals:
92 | if start <= epoch <= end:
93 | return lr
94 | raise ValueError("Invalid epoch {} for regime {!r}".format(
95 | epoch, self.intervals))
96 |
97 |
98 | def _set_learning_rate(optimizer, lr):
99 | for param_group in optimizer.param_groups:
100 | param_group['lr'] = lr
101 |
102 |
103 | def _get_learning_rate(optimizer):
104 | return max(param_group['lr'] for param_group in optimizer.param_groups)
105 |
106 |
107 | def train_for_one_epoch(model, g_loss, discriminator_loss, train_loader, optimizer, epoch_number, args):
108 | model.train()
109 | g_loss.train()
110 |
111 | data_time_meter = utils.AverageMeter()
112 | batch_time_meter = utils.AverageMeter()
113 | g_loss_meter = utils.AverageMeter(recent=100)
114 | d_loss_meter = utils.AverageMeter(recent=100)
115 | top1_meter = utils.AverageMeter(recent=100)
116 | top5_meter = utils.AverageMeter(recent=100)
117 |
118 | timestamp = time.time()
119 | for i, (images, labels) in enumerate(train_loader):
120 | batch_size = images.size(0)
121 |
122 | if utils.is_model_cuda(model):
123 | images = images.cuda()
124 | labels = labels.cuda()
125 |
126 | # Record data time
127 | data_time_meter.update(time.time() - timestamp)
128 |
129 | if args.w_cutmix == True:
130 | r = np.random.rand(1)
131 | if args.beta > 0 and r < args.cutmix_prob:
132 | # generate mixed sample
133 | lam = np.random.beta(args.beta, args.beta)
134 | rand_index = torch.randperm(images.size()[0]).cuda()
135 | target_a = labels
136 | target_b = labels[rand_index]
137 | bbx1, bby1, bbx2, bby2 = utils.rand_bbox(images.size(), lam)
138 | images[:, :, bbx1:bbx2, bby1:bby2] = images[rand_index, :, bbx1:bbx2, bby1:bby2]
139 |
140 | # Forward pass, backward pass, and update parameters.
141 | outputs = model(images, before=True)
142 | output, soft_label, soft_no_softmax = outputs
143 | g_loss_output = g_loss((output, soft_label), labels)
144 | d_loss_value = discriminator_loss([output], [soft_no_softmax])
145 |
146 | # Sometimes loss function returns a modified version of the output,
147 | # which must be used to compute the model accuracy.
148 | if isinstance(g_loss_output, tuple):
149 | g_loss_value, outputs = g_loss_output
150 | else:
151 | g_loss_value = g_loss_output
152 |
153 | loss_value = g_loss_value + d_loss_value
154 |
155 | loss_value.backward()
156 |
157 | # Update parameters and reset gradients.
158 | optimizer.step()
159 | optimizer.zero_grad()
160 |
161 | # Record loss and model accuracy.
162 | g_loss_meter.update(g_loss_value.item(), batch_size)
163 | d_loss_meter.update(d_loss_value.item(), batch_size)
164 |
165 | top1, top5 = utils.topk_accuracy(outputs, labels, recalls=(1, 5))
166 | top1_meter.update(top1, batch_size)
167 | top5_meter.update(top5, batch_size)
168 |
169 | # Record batch time
170 | batch_time_meter.update(time.time() - timestamp)
171 | timestamp = time.time()
172 |
173 | if i%20 == 0:
174 | logging.info(
175 | 'Epoch: [{epoch}][{batch}/{epoch_size}]\t'
176 | 'Time {batch_time.value:.2f} ({batch_time.average:.2f}) '
177 | 'Data {data_time.value:.2f} ({data_time.average:.2f}) '
178 | 'G_Loss {g_loss.value:.3f} {{{g_loss.average:.3f}, {g_loss.average_recent:.3f}}} '
179 | 'D_Loss {d_loss.value:.3f} {{{d_loss.average:.3f}, {d_loss.average_recent:.3f}}} '
180 | 'Top-1 {top1.value:.2f} {{{top1.average:.2f}, {top1.average_recent:.2f}}} '
181 | 'Top-5 {top5.value:.2f} {{{top5.average:.2f}, {top5.average_recent:.2f}}} '
182 | 'LR {lr:.5f}'.format(
183 | epoch=epoch_number, batch=i + 1, epoch_size=len(train_loader),
184 | batch_time=batch_time_meter, data_time=data_time_meter,
185 | g_loss=g_loss_meter, d_loss=d_loss_meter, top1=top1_meter, top5=top5_meter,
186 | lr=_get_learning_rate(optimizer)))
187 | # Log the overall train stats
188 | logging.info(
189 | 'Epoch: [{epoch}] -- TRAINING SUMMARY\t'
190 | 'Time {batch_time.sum:.2f} '
191 | 'Data {data_time.sum:.2f} '
192 | 'G_Loss {g_loss.average:.3f} '
193 | 'D_Loss {d_loss.average:.3f} '
194 | 'Top-1 {top1.average:.2f} '
195 | 'Top-5 {top5.average:.2f} '.format(
196 | epoch=epoch_number, batch_time=batch_time_meter, data_time=data_time_meter,
197 | g_loss=g_loss_meter, d_loss=d_loss_meter, top1=top1_meter, top5=top5_meter))
198 |
199 |
200 | def save_checkpoint(checkpoints_dir, model, optimizer, epoch):
201 | model_state_file = os.path.join(checkpoints_dir, 'model_state_{:02}.pytar'.format(epoch))
202 | optim_state_file = os.path.join(checkpoints_dir, 'optim_state_{:02}.pytar'.format(epoch))
203 | torch.save(model.state_dict(), model_state_file)
204 | torch.save(optimizer.state_dict(), optim_state_file)
205 |
206 |
207 | def create_optimizer(model, discriminator_parameters, momentum=0.9, weight_decay=0):
208 | # Get model parameters that require a gradient.
209 | # model_trainable_parameters = filter(lambda x: x.requires_grad, model.parameters())
210 | parameters = [{'params': model.parameters()}, discriminator_parameters]
211 | optimizer = torch.optim.SGD(parameters, lr=0,
212 | momentum=momentum, weight_decay=weight_decay)
213 | return optimizer
214 |
215 | def create_discriminator_criterion(args):
216 | d = discriminator.Discriminator(outputs_size=1000, K=8).cuda()
217 | d = torch.nn.DataParallel(d)
218 | update_parameters = {'params': d.parameters(), "lr": args.d_lr}
219 | discriminators_criterion = discriminatorLoss(d).cuda()
220 | if len(args.gpus) > 1:
221 | discriminators_criterion = torch.nn.DataParallel(discriminators_criterion, device_ids=args.gpus)
222 | return discriminators_criterion, update_parameters
223 |
224 | def main(argv):
225 | """Run the training script with command line arguments @argv."""
226 | args = parse_args(argv)
227 | utils.general_setup(args.save, args.gpus)
228 |
229 | logging.info("Arguments parsed.\n{}".format(pprint.pformat(vars(args))))
230 |
231 | # Create the train and the validation data loaders.
232 | train_loader = imagenet.get_train_loader(args.imagenet, args.batch_size,
233 | args.num_workers, args.image_size)
234 | val_loader = imagenet.get_val_loader(args.imagenet, args.batch_size,
235 | args.num_workers, args.image_size)
236 | # Create model with optional teachers.
237 | model, loss = model_factory.create_model(
238 | args.model, args.student_state_file, args.gpus, args.teacher_model,
239 | args.teacher_state_file, False)
240 | logging.info("Model:\n{}".format(model))
241 |
242 | discriminator_loss, update_parameters = create_discriminator_criterion(args)
243 |
244 | if args.lr_regime is None:
245 | lr_regime = model.LR_REGIME
246 | else:
247 | lr_regime = args.lr_regime
248 | regime = LearningRateRegime(lr_regime)
249 | # Train and test for needed number of epochs.
250 | optimizer = create_optimizer(model, update_parameters, args.momentum, args.weight_decay)
251 |
252 | for epoch in range(args.start_epoch, args.epochs):
253 | lr = regime.get_lr(epoch)
254 | _set_learning_rate(optimizer, lr)
255 | train_for_one_epoch(model, loss, discriminator_loss, train_loader, optimizer, epoch, args)
256 | test.test_for_one_epoch(model, loss, val_loader, epoch)
257 | save_checkpoint(args.save, model, optimizer, epoch)
258 |
259 |
260 | if __name__ == '__main__':
261 | main(sys.argv[1:])
262 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import logging
3 | import os
4 | import sys
5 | import numpy as np
6 |
7 | import torch
8 |
9 |
10 | def general_setup(checkpoints_dir=None, gpus=[]):
11 | if checkpoints_dir is not None:
12 | os.makedirs(checkpoints_dir, exist_ok=True)
13 | if len(gpus) > 0:
14 | torch.cuda.set_device(gpus[0])
15 | # Setup python's logging module.
16 | log_formatter = logging.Formatter(
17 | '%(levelname)s %(asctime)-20s:\t %(message)s')
18 | root_logger = logging.getLogger()
19 | root_logger.setLevel(logging.INFO)
20 | # Add a console handler to write to stdout.
21 | console_handler = logging.StreamHandler(sys.stdout)
22 | console_handler.setFormatter(log_formatter)
23 | root_logger.addHandler(console_handler)
24 | # Add a file handler to write to log.txt.
25 | log_filepath = os.path.join(checkpoints_dir, 'log.txt')
26 | file_handler = logging.FileHandler(log_filepath)
27 | file_handler.setFormatter(log_formatter)
28 | root_logger.addHandler(file_handler)
29 |
30 |
31 | def is_model_cuda(model):
32 | # Check if the first parameter is on cuda.
33 | return next(model.parameters()).is_cuda
34 |
35 |
36 | def topk_accuracy(outputs, labels, recalls=(1, 5)):
37 | """Return @recall accuracies for the given recalls."""
38 |
39 | _, num_classes = outputs.size()
40 | maxk = min(max(recalls), num_classes)
41 |
42 | _, pred = outputs.topk(maxk, dim=1, largest=True, sorted=True)
43 | correct = (pred == labels[:,None].expand_as(pred)).float()
44 |
45 | topk_accuracy = []
46 | for recall in recalls:
47 | topk_accuracy.append(100 * correct[:, :recall].sum(1).mean())
48 | return topk_accuracy
49 |
50 |
51 | class AverageMeter:
52 | """Helper class to track the running average (and optionally the recent k
53 | items average of a sequence)."""
54 |
55 | def __init__(self, recent=None):
56 | self._recent = recent
57 | if recent is not None:
58 | self._q = collections.deque()
59 | self.reset()
60 |
61 | def reset(self):
62 | self.value = 0
63 | self.sum = 0
64 | self.count = 0
65 | if self._recent is not None:
66 | self.sum_recent = 0
67 | self.count_recent = 0
68 | self._q.clear()
69 |
70 | def update(self, value, n=1):
71 | self.value = value
72 | self.sum += value * n
73 | self.count += n
74 |
75 | if self._recent is not None:
76 | self.sum_recent += value * n
77 | self.count_recent += n
78 | self._q.append((n, value))
79 | while len(self._q) > self._recent:
80 | (n, value) = self._q.popleft()
81 | self.sum_recent -= value * n
82 | self.count_recent -= n
83 |
84 | @property
85 | def average(self):
86 | if self.count > 0:
87 | return self.sum / self.count
88 | else:
89 | return 0
90 |
91 | @property
92 | def average_recent(self):
93 | if self.count_recent > 0:
94 | return self.sum_recent / self.count_recent
95 | else:
96 | return 0
97 |
98 | def rand_bbox(size, lam):
99 | W = size[2]
100 | H = size[3]
101 | cut_rat = np.sqrt(1. - lam)
102 | cut_w = np.int(W * cut_rat)
103 | cut_h = np.int(H * cut_rat)
104 |
105 | # uniform
106 | cx = np.random.randint(W)
107 | cy = np.random.randint(H)
108 |
109 | bbx1 = np.clip(cx - cut_w // 2, 0, W)
110 | bby1 = np.clip(cy - cut_h // 2, 0, H)
111 | bbx2 = np.clip(cx + cut_w // 2, 0, W)
112 | bby2 = np.clip(cy + cut_h // 2, 0, H)
113 |
114 | return bbx1, bby1, bbx2, bby2
115 |
--------------------------------------------------------------------------------
/utils_FKD.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.distributed
4 | import torch.nn as nn
5 | import torchvision
6 | from torchvision.ops import roi_align
7 | from torchvision.transforms import functional as t_F
8 | from torch.nn import functional as F
9 | from torchvision.datasets.folder import ImageFolder
10 | from torch.nn.modules import loss
11 | from torchvision.transforms import InterpolationMode
12 | import random
13 | import numpy as np
14 |
15 |
16 | class RandomResizedCrop_FKD(torchvision.transforms.RandomResizedCrop):
17 | def __init__(self, **kwargs):
18 | super(RandomResizedCrop_FKD, self).__init__(**kwargs)
19 |
20 | def __call__(self, img, coords, status):
21 | i = coords[0].item() * img.size[1]
22 | j = coords[1].item() * img.size[0]
23 | h = coords[2].item() * img.size[1]
24 | w = coords[3].item() * img.size[0]
25 |
26 | if self.interpolation == 'bilinear':
27 | inter = InterpolationMode.BILINEAR
28 | elif self.interpolation == 'bicubic':
29 | inter = InterpolationMode.BICUBIC
30 | return t_F.resized_crop(img, i, j, h, w, self.size, inter)
31 |
32 |
33 | class RandomHorizontalFlip_FKD(torch.nn.Module):
34 | def __init__(self, p=0.5):
35 | super().__init__()
36 | self.p = p
37 |
38 | def forward(self, img, coords, status):
39 |
40 | if status == True:
41 | return t_F.hflip(img)
42 | else:
43 | return img
44 |
45 | def __repr__(self):
46 | return self.__class__.__name__ + '(p={})'.format(self.p)
47 |
48 |
49 | class Compose_FKD(torchvision.transforms.Compose):
50 | def __init__(self, **kwargs):
51 | super(Compose_FKD, self).__init__(**kwargs)
52 |
53 | def __call__(self, img, coords, status):
54 | for t in self.transforms:
55 | if type(t).__name__ == 'RandomResizedCrop_FKD':
56 | img = t(img, coords, status)
57 | elif type(t).__name__ == 'RandomCrop_FKD':
58 | img, coords = t(img)
59 | elif type(t).__name__ == 'RandomHorizontalFlip_FKD':
60 | img = t(img, coords, status)
61 | else:
62 | img = t(img)
63 | return img
64 |
65 |
66 | class ImageFolder_FKD(torchvision.datasets.ImageFolder):
67 | def __init__(self, **kwargs):
68 | self.num_crops = kwargs['num_crops']
69 | self.softlabel_path = kwargs['softlabel_path']
70 | kwargs.pop('num_crops')
71 | kwargs.pop('softlabel_path')
72 | super(ImageFolder_FKD, self).__init__(**kwargs)
73 |
74 | def __getitem__(self, index):
75 |
76 | path, target = self.samples[index]
77 |
78 | label_path = os.path.join(self.softlabel_path, '/'.join(path.split('/')[-4:]).split('.')[0] + '.tar')
79 |
80 | label = torch.load(label_path, map_location=torch.device('cpu'))
81 |
82 | coords, flip_status, output = label
83 |
84 | rand_index = torch.randperm(len(output))#.cuda()
85 | output_new = []
86 |
87 | sample = self.loader(path)
88 | sample_all = []
89 | target_all = []
90 |
91 | for i in range(self.num_crops):
92 | if self.transform is not None:
93 | output_new.append(output[rand_index[i]])
94 | sample_new = self.transform(sample, coords[rand_index[i]], flip_status[rand_index[i]])
95 | sample_all.append(sample_new)
96 | target_all.append(target)
97 | else:
98 | coords = None
99 | flip_status = None
100 | if self.target_transform is not None:
101 | target = self.target_transform(target)
102 |
103 | return sample_all, target_all, output_new
104 |
105 |
106 | def Recover_soft_label(label, label_type, n_classes):
107 | if label_type == 'hard':
108 | return torch.zeros(label.size(0), n_classes).scatter_(1, label.view(-1, 1), 1)
109 | elif label_type == 'smoothing':
110 | index = label[:,0].to(dtype=int)
111 | value = label[:,1]
112 | minor_value = (torch.ones_like(value) - value)/(n_classes-1)
113 | minor_value = minor_value.reshape(-1,1).repeat_interleave(n_classes, dim=1)
114 | soft_label = (minor_value * torch.ones(index.size(0), n_classes)).scatter_(1, index.view(-1, 1), value.view(-1, 1))
115 | return soft_label
116 | elif label_type == 'marginal_smoothing_k5':
117 | index = label[:,0,:].to(dtype=int)
118 | value = label[:,1,:]
119 | minor_value = (torch.ones(label.size(0),1) - torch.sum(value, dim=1, keepdim=True))/(n_classes-5)
120 | minor_value = minor_value.reshape(-1,1).repeat_interleave(n_classes, dim=1)
121 | soft_label = (minor_value * torch.ones(index.size(0), n_classes)).scatter_(1, index, value)
122 | return soft_label
123 | elif label_type == 'marginal_renorm':
124 | index = label[:,0,:].to(dtype=int)
125 | value = label[:,1,:]
126 | soft_label = torch.zeros(index.size(0), n_classes).scatter_(1, index, value)
127 | soft_label = F.normalize(soft_label, p=1.0, dim=1, eps=1e-12)
128 | return soft_label
129 | elif label_type == 'marginal_smoothing_k10':
130 | index = label[:,0,:].to(dtype=int)
131 | value = label[:,1,:]
132 | minor_value = (torch.ones(label.size(0),1) - torch.sum(value, dim=1, keepdim=True))/(n_classes-10)
133 | minor_value = minor_value.reshape(-1,1).repeat_interleave(n_classes, dim=1)
134 | soft_label = (minor_value * torch.ones(index.size(0), n_classes)).scatter_(1, index, value)
135 | return soft_label
--------------------------------------------------------------------------------