├── FKD_train.py ├── README.md ├── extensions ├── __init__.py ├── data_parallel.py ├── kd_loss.py └── teacher_wrapper.py ├── hubconf.py ├── imagenet.py ├── images ├── MEAL-V2_more_tricks_top1.png ├── MEAL-V2_more_tricks_top5.png └── comparison.png ├── inference.py ├── loss.py ├── models ├── __init__.py ├── blocks.py ├── discriminator.py └── model_factory.py ├── opts.py ├── script ├── resume_train.sh └── train.sh ├── test.py ├── train.py ├── utils.py └── utils_FKD.py /FKD_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Script to train a model through soft labels on ImageNet's train set.""" 3 | 4 | import argparse 5 | import logging 6 | import pprint 7 | import os 8 | import sys 9 | import time 10 | import math 11 | import numpy as np 12 | 13 | import torch 14 | from torch import nn 15 | 16 | from loss import discriminatorLoss 17 | 18 | import imagenet 19 | from models import model_factory 20 | from models import discriminator 21 | import opts 22 | import test 23 | import utils 24 | from utils_FKD import Recover_soft_label 25 | 26 | def parse_args(argv): 27 | """Parse arguments @argv and return the flags needed for training.""" 28 | parser = argparse.ArgumentParser(description=__doc__, allow_abbrev=False) 29 | 30 | group = parser.add_argument_group('General Options') 31 | opts.add_general_flags(group) 32 | 33 | group = parser.add_argument_group('Dataset Options') 34 | opts.add_dataset_flags(group) 35 | 36 | group = parser.add_argument_group('Model Options') 37 | opts.add_model_flags(group) 38 | 39 | group = parser.add_argument_group('Soft Label Options') 40 | opts.add_teacher_flags(group) 41 | 42 | group = parser.add_argument_group('Training Options') 43 | opts.add_training_flags(group) 44 | 45 | group = parser.add_argument_group('CutMix Training Options') 46 | opts.add_cutmix_training_flags(group) 47 | 48 | args = parser.parse_args(argv) 49 | 50 | return args 51 | 52 | 53 | class LearningRateRegime: 54 | """Encapsulates the learning rate regime for training a model. 55 | 56 | Args: 57 | @intervals (list): A list of triples (start, end, lr). The intervals 58 | are inclusive (for start <= epoch <= end, lr will be used). The 59 | start of each interval must be right after the end of its previous 60 | interval. 61 | """ 62 | 63 | def __init__(self, regime): 64 | if len(regime) % 3 != 0: 65 | raise ValueError("Regime length should be devisible by 3.") 66 | intervals = list(zip(regime[0::3], regime[1::3], regime[2::3])) 67 | self._validate_intervals(intervals) 68 | self.intervals = intervals 69 | self.num_epochs = intervals[-1][1] 70 | 71 | @classmethod 72 | def _validate_intervals(cls, intervals): 73 | if type(intervals) is not list: 74 | raise TypeError("Intervals must be a list of triples.") 75 | elif len(intervals) == 0: 76 | raise ValueError("Intervals must be a non empty list.") 77 | elif intervals[0][0] != 1: 78 | raise ValueError("Intervals must start from 1: {}".format(intervals)) 79 | elif any(end < start for (start, end, lr) in intervals): 80 | raise ValueError("End of intervals must be greater or equal than their" 81 | " start: {}".format(intervals)) 82 | elif any(intervals[i][1] + 1 != intervals[i + 1][0] 83 | for i in range(len(intervals) - 1)): 84 | raise ValueError("Start of each each interval must be the end of its " 85 | "previous interval plus one: {}".format(intervals)) 86 | 87 | def get_lr(self, epoch): 88 | for (start, end, lr) in self.intervals: 89 | if start <= epoch <= end: 90 | return lr 91 | raise ValueError("Invalid epoch {} for regime {!r}".format( 92 | epoch, self.intervals)) 93 | 94 | 95 | def _set_learning_rate(optimizer, lr): 96 | for param_group in optimizer.param_groups: 97 | param_group['lr'] = lr 98 | 99 | def adjust_learning_rate(optimizer, epoch, args): 100 | """Decay the learning rate based on schedule""" 101 | lr = args.lr 102 | if args.cos: # cosine lr schedule 103 | lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) 104 | else: # stepwise lr schedule 105 | for milestone in args.schedule: 106 | lr *= 0.1 if epoch >= milestone else 1. 107 | for param_group in optimizer.param_groups: 108 | param_group['lr'] = lr 109 | 110 | def _get_learning_rate(optimizer): 111 | return max(param_group['lr'] for param_group in optimizer.param_groups) 112 | 113 | 114 | def train_for_one_epoch(model, g_loss, discriminator_loss, train_loader, optimizer, epoch_number, args): 115 | model.train() 116 | g_loss.train() 117 | 118 | data_time_meter = utils.AverageMeter() 119 | batch_time_meter = utils.AverageMeter() 120 | g_loss_meter = utils.AverageMeter(recent=100) 121 | d_loss_meter = utils.AverageMeter(recent=100) 122 | top1_meter = utils.AverageMeter(recent=100) 123 | top5_meter = utils.AverageMeter(recent=100) 124 | 125 | timestamp = time.time() 126 | for i, (images, labels, soft_labels) in enumerate(train_loader): 127 | batch_size = args.batch_size 128 | 129 | # Record data time 130 | data_time_meter.update(time.time() - timestamp) 131 | 132 | images = torch.cat(images, dim=0) 133 | soft_labels = torch.cat(soft_labels, dim=0) 134 | labels = torch.cat(labels, dim=0) 135 | 136 | if args.soft_label_type == 'ori': 137 | soft_labels = soft_labels.cuda() 138 | else: 139 | soft_labels = Recover_soft_label(soft_labels, args.soft_label_type, args.num_classes) 140 | soft_labels = soft_labels.cuda() 141 | 142 | if utils.is_model_cuda(model): 143 | images = images.cuda() 144 | labels = labels.cuda() 145 | 146 | if args.w_cutmix == True: 147 | r = np.random.rand(1) 148 | if args.beta > 0 and r < args.cutmix_prob: 149 | # generate mixed sample 150 | lam = np.random.beta(args.beta, args.beta) 151 | rand_index = torch.randperm(images.size()[0]).cuda() 152 | target_a = soft_labels 153 | target_b = soft_labels[rand_index] 154 | bbx1, bby1, bbx2, bby2 = utils.rand_bbox(images.size(), lam) 155 | images[:, :, bbx1:bbx2, bby1:bby2] = images[rand_index, :, bbx1:bbx2, bby1:bby2] 156 | lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (images.size()[-1] * images.size()[-2])) 157 | 158 | # Forward pass, backward pass, and update parameters. 159 | output = model(images) 160 | # output, soft_label, soft_no_softmax = outputs 161 | if args.w_cutmix == True: 162 | g_loss_output1 = g_loss((output, target_a), labels) 163 | g_loss_output2 = g_loss((output, target_b), labels) 164 | else: 165 | g_loss_output = g_loss((output, soft_labels), labels) 166 | if args.use_discriminator_loss: 167 | # Our stored label is "after softmax", this is slightly different from original MEAL V2 168 | # that used probibilaties "before softmax" for the discriminator. 169 | output_softmax = nn.functional.softmax(output) 170 | if args.w_cutmix == True: 171 | d_loss_value = discriminator_loss([output_softmax], [target_a]) * lam + discriminator_loss([output_softmax], [target_b]) * (1 - lam) 172 | else: 173 | d_loss_value = discriminator_loss([output_softmax], [soft_labels]) 174 | 175 | # Sometimes loss function returns a modified version of the output, 176 | # which must be used to compute the model accuracy. 177 | if args.w_cutmix == True: 178 | if isinstance(g_loss_output1, tuple): 179 | g_loss_value1, output1 = g_loss_output1 180 | g_loss_value2, output2 = g_loss_output2 181 | g_loss_value = g_loss_value1 * lam + g_loss_value2 * (1 - lam) 182 | else: 183 | g_loss_value = g_loss_output1 * lam + g_loss_output2 * (1 - lam) 184 | else: 185 | if isinstance(g_loss_output, tuple): 186 | g_loss_value, output = g_loss_output 187 | else: 188 | g_loss_value = g_loss_output 189 | 190 | if args.use_discriminator_loss: 191 | loss_value = g_loss_value + d_loss_value 192 | else: 193 | loss_value = g_loss_value 194 | 195 | loss_value.backward() 196 | 197 | # Update parameters and reset gradients. 198 | optimizer.step() 199 | optimizer.zero_grad() 200 | 201 | # Record loss and model accuracy. 202 | g_loss_meter.update(g_loss_value.item(), batch_size) 203 | d_loss_meter.update(d_loss_value.item(), batch_size) 204 | 205 | top1, top5 = utils.topk_accuracy(output, labels, recalls=(1, 5)) 206 | top1_meter.update(top1, batch_size) 207 | top5_meter.update(top5, batch_size) 208 | 209 | # Record batch time 210 | batch_time_meter.update(time.time() - timestamp) 211 | timestamp = time.time() 212 | 213 | if i%20 == 0: 214 | logging.info( 215 | 'Epoch: [{epoch}][{batch}/{epoch_size}]\t' 216 | 'Time {batch_time.value:.2f} ({batch_time.average:.2f}) ' 217 | 'Data {data_time.value:.2f} ({data_time.average:.2f}) ' 218 | 'G_Loss {g_loss.value:.3f} {{{g_loss.average:.3f}, {g_loss.average_recent:.3f}}} ' 219 | 'D_Loss {d_loss.value:.3f} {{{d_loss.average:.3f}, {d_loss.average_recent:.3f}}} ' 220 | 'Top-1 {top1.value:.2f} {{{top1.average:.2f}, {top1.average_recent:.2f}}} ' 221 | 'Top-5 {top5.value:.2f} {{{top5.average:.2f}, {top5.average_recent:.2f}}} ' 222 | 'LR {lr:.5f}'.format( 223 | epoch=epoch_number, batch=i + 1, epoch_size=len(train_loader), 224 | batch_time=batch_time_meter, data_time=data_time_meter, 225 | g_loss=g_loss_meter, d_loss=d_loss_meter, top1=top1_meter, top5=top5_meter, 226 | lr=_get_learning_rate(optimizer))) 227 | # Log the overall train stats 228 | logging.info( 229 | 'Epoch: [{epoch}] -- TRAINING SUMMARY\t' 230 | 'Time {batch_time.sum:.2f} ' 231 | 'Data {data_time.sum:.2f} ' 232 | 'G_Loss {g_loss.average:.3f} ' 233 | 'D_Loss {d_loss.average:.3f} ' 234 | 'Top-1 {top1.average:.2f} ' 235 | 'Top-5 {top5.average:.2f} '.format( 236 | epoch=epoch_number, batch_time=batch_time_meter, data_time=data_time_meter, 237 | g_loss=g_loss_meter, d_loss=d_loss_meter, top1=top1_meter, top5=top5_meter)) 238 | 239 | 240 | def save_checkpoint(checkpoints_dir, model, optimizer, epoch): 241 | model_state_file = os.path.join(checkpoints_dir, 'model_state_{:02}.pytar'.format(epoch)) 242 | optim_state_file = os.path.join(checkpoints_dir, 'optim_state_{:02}.pytar'.format(epoch)) 243 | torch.save(model.state_dict(), model_state_file) 244 | torch.save(optimizer.state_dict(), optim_state_file) 245 | 246 | 247 | def create_optimizer(model, discriminator_parameters, momentum=0.9, weight_decay=0): 248 | # Get model parameters that require a gradient. 249 | parameters = [{'params': model.parameters()}, discriminator_parameters] 250 | optimizer = torch.optim.SGD(parameters, lr=0, 251 | momentum=momentum, weight_decay=weight_decay) 252 | return optimizer 253 | 254 | def create_discriminator_criterion(args): 255 | d = discriminator.Discriminator(outputs_size=1000, K=8).cuda() 256 | d = torch.nn.DataParallel(d) 257 | update_parameters = {'params': d.parameters(), "lr": args.d_lr} 258 | discriminators_criterion = discriminatorLoss(d).cuda() 259 | if len(args.gpus) > 1: 260 | discriminators_criterion = torch.nn.DataParallel(discriminators_criterion, device_ids=args.gpus) 261 | return discriminators_criterion, update_parameters 262 | 263 | def main(argv): 264 | """Run the training script with command line arguments @argv.""" 265 | args = parse_args(argv) 266 | utils.general_setup(args.save, args.gpus) 267 | 268 | logging.info("Arguments parsed.\n{}".format(pprint.pformat(vars(args)))) 269 | 270 | # convert to TRUE number of loading-images since we use multiple crops from the same image within a minbatch 271 | args.batch_size = math.ceil(args.batch_size / args.num_crops) 272 | 273 | # Create the train and the validation data loaders. 274 | train_loader = imagenet.get_train_loader_FKD(args.imagenet, args.batch_size, 275 | args.num_workers, args.image_size, args.num_crops, args.softlabel_path) 276 | val_loader = imagenet.get_val_loader(args.imagenet, args.batch_size, 277 | args.num_workers, args.image_size) 278 | # Create model with optional teachers. 279 | model, loss = model_factory.create_model( 280 | args.model, args.student_state_file, args.gpus, args.teacher_model, 281 | args.teacher_state_file, True) 282 | logging.info("Model:\n{}".format(model)) 283 | 284 | discriminator_loss, update_parameters = create_discriminator_criterion(args) 285 | 286 | optimizer = create_optimizer(model, update_parameters, args.momentum, args.weight_decay) 287 | 288 | for epoch in range(args.start_epoch, args.epochs, args.num_crops): 289 | adjust_learning_rate(optimizer, epoch, args) 290 | train_for_one_epoch(model, loss, discriminator_loss, train_loader, optimizer, epoch, args) 291 | test.test_for_one_epoch(model, loss, val_loader, epoch) 292 | save_checkpoint(args.save, model, optimizer, epoch) 293 | 294 | 295 | if __name__ == '__main__': 296 | main(sys.argv[1:]) 297 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MEAL-V2 2 | 3 | This is the official pytorch implementation of our paper: 4 | ["MEAL V2: Boosting Vanilla ResNet-50 to 80%+ Top-1 Accuracy on ImageNet without Tricks"](https://arxiv.org/abs/2009.08453) by 5 | [Zhiqiang Shen](http://zhiqiangshen.com/) and [Marios Savvides](https://www.ece.cmu.edu/directory/bios/savvides-marios.html) from Carnegie Mellon University. 6 | 7 |
8 | 9 |
10 | 11 | In this paper, we introduce a simple yet effective approach that can boost the vanilla ResNet-50 to 80%+ Top-1 accuracy on ImageNet without any tricks. Generally, our method is based on the recently proposed [MEAL](https://arxiv.org/abs/1812.02425), i.e., ensemble knowledge distillation via discriminators. We further simplify it through 1) adopting the similarity loss and discriminator only on the final outputs and 2) using the average of softmax probabilities from all teacher ensembles as the stronger supervision for distillation. One crucial perspective of our method is that the one-hot/hard label should not be used in the distillation process. We show that such a simple framework can achieve state-of-the-art results without involving any commonly-used tricks, such as 1) architecture modification; 2) outside training data beyond ImageNet; 3) autoaug/randaug; 4) cosine learning rate; 5) mixup/cutmix training; 6) label smoothing; etc. 12 | 13 | ## Citation 14 | 15 | If you find our code is helpful for your research, please cite: 16 | 17 | @article{shen2020mealv2, 18 | title={MEAL V2: Boosting Vanilla ResNet-50 to 80%+ Top-1 Accuracy on ImageNet without Tricks}, 19 | author={Shen, Zhiqiang and Savvides, Marios}, 20 | journal={arXiv preprint arXiv:2009.08453}, 21 | year={2020} 22 | } 23 | 24 | ## News 25 | 26 | **[Dec. 5, 2021]** **New:** Add [FKD](https://arxiv.org/abs/2112.01528) training support. We highly recommend to use FKD for training MEAL V2 models, which will be 2~4x faster with similar accuracy. 27 | 28 | - Download our [soft label](http://zhiqiangshen.com/projects/FKD/index.html) for MEAL V2. 29 | - run `FKD_train.py` with the desired model architecture, the path to the ImageNet dataset and the path to the soft label, for example: 30 | 31 | ```shell 32 | # 224 x 224 ResNet-50 33 | python FKD_train.py --save MEAL_V2_resnet50_224 \ 34 | --batch-size 512 -j 48 \ 35 | --model resnet50 --epochs 200 \ 36 | --teacher-model gluon_senet154,gluon_resnet152_v1s \ 37 | --imagenet [imagenet-folder with train and val folders] \ 38 | --num_crops 8 --soft_label_type marginal_smoothing_k5 \ 39 | --softlabel_path [path of soft label] \ 40 | --schedule 100 180 --use-discriminator-loss 41 | ``` 42 | Add `--cos` if you would like to train with cosine learning rate. 43 | 44 | **New:** Basically, adding back tricks (cosine *lr*, etc.) into MEAL V2 can consistently improve the accuracy: 45 | 46 |
47 | 48 | 49 |
50 | 51 | **New:** Add CutMix training support, use *--w-cutmix* to enable it. 52 | 53 | **[Mar. 19, 2021]** Long version of MEAL V2 is available on: [arXiv](https://arxiv.org/abs/2009.08453) or [paper](http://zhiqiangshen.com/projects/MEAL_V2/arxiv.pdf). 54 | 55 | **[Dec. 16, 2020]** MEAL V2 is now available in [PyTorch Hub](https://pytorch.org/hub/pytorch_vision_meal_v2/). 56 | 57 | **[Nov. 3, 2020]** Short version of MEAL V2 has been accepted in NeurIPS 2020 [Beyond BackPropagation: Novel Ideas for Training Neural Architectures](https://beyondbackprop.github.io/) workshop. Long version is coming soon. 58 | 59 | ## Preparation 60 | 61 | ### 1. Requirements: 62 | This repo is tested with: 63 | 64 | * Python 3.6 65 | 66 | * CUDA 10.2 67 | 68 | * PyTorch 1.6.0 69 | 70 | * torchvision 0.7.0 71 | 72 | * timm 0.2.1 73 | (pip install timm) 74 | 75 | But it should be runnable with other PyTorch versions. 76 | 77 | ### 2. Data: 78 | * Download ImageNet dataset following https://github.com/pytorch/examples/tree/master/imagenet#requirements. 79 | 80 | ## Results & Models 81 | 82 | We provide pre-trained models with different trainings, we report in the table training/validation resolution, #parameters, Top-1 and Top-5 accuracy on ImageNet validation set: 83 | 84 | | Models | Resolution| #Parameters | Top-1/Top-5 | Trained models | 85 | | :---: | :-: | :-: | :------:| :------: | 86 | | [MEAL-V1 w/ ResNet50](https://arxiv.org/abs/1812.02425) | 224 | 25.6M |**78.21/94.01** | [GitHub](https://github.com/AaronHeee/MEAL#imagenet-model) | 87 | | MEAL-V2 w/ ResNet18 | 224 | 11.7M | **73.19/90.82** | [Download (46.8M)](https://1drv.ms/u/s!AtMVZxJ8MfxCi03CTdVPH24ce6rD?e=l7BoZL) | 88 | | MEAL-V2 w/ ResNet50 | 224 | 25.6M | **80.67/95.09** | [Download (102.6M)](https://1drv.ms/u/s!AtMVZxJ8MfxCi0NGENlMK0pYVDQM?e=GkwZ93) | 89 | | MEAL-V2 w/ ResNet50| 380 | 25.6M | **81.72/95.81** | [Download (102.6M)](https://1drv.ms/u/s!AtMVZxJ8MfxCi0T9nodVNdnklHNt?e=7oJGIy) | 90 | | MEAL-V2 + CutMix w/ ResNet50| 224 | 25.6M | **80.98/95.35** | [Download (102.6M)](https://1drv.ms/u/s!AtMVZxJ8MfxCi0cIf5IqpBX6nl1U?e=Fig91M) | 91 | | MEAL-V2 w/ MobileNet V3-Small 0.75| 224 | 2.04M | **67.60/87.23** | [Download (8.3M)](https://1drv.ms/u/s!AtMVZxJ8MfxCi0nIq1jZo36dpN7Q?e=ODcoAN) | 92 | | MEAL-V2 w/ MobileNet V3-Small 1.0| 224 | 2.54M | **69.65/88.71** | [Download (10.3M)](https://1drv.ms/u/s!AtMVZxJ8MfxCiz9v7QqUmvQOLmTS?e=9nCWMa) | 93 | | MEAL-V2 w/ MobileNet V3-Large 1.0 | 224 | 5.48M | **76.92/93.32** | [Download (22.1M)](https://1drv.ms/u/s!AtMVZxJ8MfxCi0Ciwz-q-P2jwtXR?e=OebKAr) | 94 | | MEAL-V2 w/ EfficientNet-B0| 224 | 5.29M | **78.29/93.95** | [Download (21.5M)](https://1drv.ms/u/s!AtMVZxJ8MfxCi0XZLUEB3uYq3eBe?e=FJV9K1) | 95 | 96 | 97 | ## Training & Testing 98 | ### 1. Training: 99 | * To train a model, run script/train.sh with the desired model architecture and the path to the ImageNet dataset, for example: 100 | 101 | ```shell 102 | # 224 x 224 ResNet-50 103 | python train.py --save MEAL_V2_resnet50_224 --batch-size 512 -j 48 --model resnet50 --epochs 180 --teacher-model gluon_senet154,gluon_resnet152_v1s --imagenet [imagenet-folder with train and val folders] 104 | ``` 105 | 106 | ```shell 107 | # 224 x 224 ResNet-50 w/ CutMix 108 | python train.py --save MEAL_V2_resnet50_224 --batch-size 512 -j 48 --model resnet50 --epochs 180 --teacher-model gluon_senet154,gluon_resnet152_v1s --imagenet [imagenet-folder with train and val folders] --w-cutmix 109 | ``` 110 | 111 | ```shell 112 | # 380 x 380 ResNet-50 113 | python train.py --save MEAL_V2_resnet50_380 --batch-size 512 -j 48 --model resnet50 --image-size 380 --teacher-model tf_efficientnet_b4_ns,tf_efficientnet_b4 --imagenet [imagenet-folder with train and val folders] 114 | ``` 115 | 116 | ```shell 117 | # 224 x 224 MobileNet V3-Small 0.75 118 | python train.py --save MEAL_V2_mobilenetv3_small_075 --batch-size 512 -j 48 --model tf_mobilenetv3_small_075 --teacher-model gluon_senet154,gluon_resnet152_v1s --imagenet [imagenet-folder with train and val folders] 119 | ``` 120 | 121 | ```shell 122 | # 224 x 224 MobileNet V3-Small 1.0 123 | python train.py --save MEAL_V2_mobilenetv3_small_100 --batch-size 512 -j 48 --model tf_mobilenetv3_small_100 --teacher-model gluon_senet154,gluon_resnet152_v1s --imagenet [imagenet-folder with train and val folders] 124 | ``` 125 | 126 | ```shell 127 | # 224 x 224 MobileNet V3-Large 1.0 128 | python train.py --save MEAL_V2_mobilenetv3_large_100 --batch-size 512 -j 48 --model tf_mobilenetv3_large_100 --teacher-model gluon_senet154,gluon_resnet152_v1s --imagenet [imagenet-folder with train and val folders] 129 | ``` 130 | 131 | ```shell 132 | # 224 x 224 EfficientNet-B0 133 | python train.py --save MEAL_V2_efficientnet_b0 --batch-size 512 -j 48 --model tf_efficientnet_b0 --teacher-model gluon_senet154,gluon_resnet152_v1s --imagenet [imagenet-folder with train and val folders] 134 | ``` 135 | *Please reduce the ``--batch-size`` if you get ''out of memory'' error. We also notice that more training epochs can slightly improve the performance.* 136 | 137 | * To resume training a model, run script/resume_train.sh with the desired model architecture, starting number of training epoch and the path to the ImageNet dataset: 138 | 139 | ```shell 140 | sh script/resume_train.sh 141 | ``` 142 | 143 | ### 2. Testing: 144 | 145 | * To test a model, run inference.py with the desired model architecture, model path, resolution and the path to the ImageNet dataset: 146 | 147 | ```shell 148 | CUDA_VISIBLE_DEVICES=0,1,2,3 python inference.py -a resnet50 --res 224 --resume MODEL_PATH -e [imagenet-folder with train and val folders] 149 | ``` 150 | change ``--res`` with other image resolution [224/380] and ``-a`` with other model architecture [tf\_mobilenetv3\_small\_100; tf\_mobilenetv3\_large\_100; tf\_efficientnet\_b0] to test other trained models. 151 | 152 | 153 | ## Contact 154 | 155 | Zhiqiang Shen, CMU (zhiqians at andrew.cmu.edu) 156 | 157 | Any comments or suggestions are welcome! -------------------------------------------------------------------------------- /extensions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/szq0214/MEAL-V2/59e6bca29c77a7df9229e77d4463f2f795723810/extensions/__init__.py -------------------------------------------------------------------------------- /extensions/data_parallel.py: -------------------------------------------------------------------------------- 1 | __author__ = "Hessam Bagherinezhad " 2 | 3 | from torch import nn 4 | from torch.nn.modules import loss 5 | 6 | 7 | class DataParallel(nn.DataParallel): 8 | """An extension of nn.DataParallel. 9 | 10 | The only extensions are: 11 | 1) If an attribute is missing in an object of this class, it will look 12 | for it in the wrapped module. This is useful for getting `LR_REGIME` 13 | of the wrapped module for example. 14 | 2) state_dict() of this class calls the wrapped module's state_dict(), 15 | hence the weights can be transferred from a data parallel wrapped 16 | module to a single gpu module. 17 | """ 18 | 19 | 20 | def __getattr__(self, name): 21 | # If attribute doesn't exist in the DataParallel object this method will 22 | # be called. Here we first ask the super class to get the attribute, if 23 | # couldn't find it, we ask the underlying module that is wrapped by this 24 | # DataParallel to get the attribute. 25 | try: 26 | return super().__getattr__(name) 27 | except AttributeError: 28 | underlying_module = super().__getattr__('module') 29 | return getattr(underlying_module, name) 30 | 31 | def state_dict(self, *args, **kwargs): 32 | return self.module.state_dict(*args, **kwargs) 33 | -------------------------------------------------------------------------------- /extensions/kd_loss.py: -------------------------------------------------------------------------------- 1 | __author__ = "Hessam Bagherinezhad " 2 | 3 | # modified by "Zhiqiang Shen " 4 | 5 | import torch 6 | from torch.nn import functional as F 7 | from torch.nn.modules import loss 8 | 9 | 10 | class KLLoss(loss._Loss): 11 | """The KL-Divergence loss for the model and soft labels output. 12 | 13 | output must be a pair of (model_output, soft_labels), both NxC tensors. 14 | The rows of soft_labels must all add up to one (probability scores); 15 | however, model_output must be the pre-softmax output of the network.""" 16 | 17 | def forward(self, output, target): 18 | if not self.training: 19 | # Loss is normal cross entropy loss between the model output and the 20 | # target. 21 | return F.cross_entropy(output, target) 22 | 23 | assert type(output) == tuple and len(output) == 2 and output[0].size() == \ 24 | output[1].size(), "output must a pair of tensors of same size." 25 | 26 | # Target is ignored at training time. Loss is defined as KL divergence 27 | # between the model output and the soft labels. 28 | model_output, soft_labels = output 29 | if soft_labels.requires_grad: 30 | raise ValueError("soft labels should not require gradients.") 31 | 32 | model_output_log_prob = F.log_softmax(model_output, dim=1) 33 | del model_output 34 | 35 | # Loss is -dot(model_output_log_prob, soft_labels). Prepare tensors 36 | # for batch matrix multiplicatio 37 | soft_labels = soft_labels.unsqueeze(1) 38 | model_output_log_prob = model_output_log_prob.unsqueeze(2) 39 | 40 | # Compute the loss, and average for the batch. 41 | cross_entropy_loss = -torch.bmm(soft_labels, model_output_log_prob) 42 | cross_entropy_loss = cross_entropy_loss.mean() 43 | # Return a pair of (loss_output, model_output). Model output will be 44 | # used for top-1 and top-5 evaluation. 45 | model_output_log_prob = model_output_log_prob.squeeze(2) 46 | return (cross_entropy_loss, model_output_log_prob) 47 | -------------------------------------------------------------------------------- /extensions/teacher_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | import random 6 | import numpy as np 7 | 8 | 9 | class ModelDistillationWrapper(nn.Module): 10 | """Convenient wrapper class to train a model with soft label .""" 11 | 12 | def __init__(self, model, teacher): 13 | super().__init__() 14 | self.model = model 15 | self.teachers_0 = teacher 16 | self.combine = True 17 | 18 | # Since we don't want to back-prop through the teacher network, 19 | # make the parameters of the teacher network not require gradients. This 20 | # saves some GPU memory. 21 | 22 | for model in self.teachers_0: 23 | for param in model.parameters(): 24 | param.requires_grad = False 25 | 26 | self.false = False 27 | 28 | @property 29 | def LR_REGIME(self): 30 | # Training with soft label does not change learing rate regime. 31 | # Return's wrapped model lr regime. 32 | return self.model.LR_REGIME 33 | 34 | def state_dict(self): 35 | return self.model.state_dict() 36 | 37 | def forward(self, input, before=False): 38 | if self.training: 39 | if len(self.teachers_0) == 3 and self.combine == False: 40 | index = [0,1,1,2,2] 41 | idx = random.randint(0, 4) 42 | soft_labels_ = self.teachers_0[index[idx]](input) 43 | soft_labels = F.softmax(soft_labels_, dim=1) 44 | 45 | elif self.combine: 46 | soft_labels_ = [ torch.unsqueeze(self.teachers_0[idx](input), dim=2) for idx in range(len(self.teachers_0))] 47 | soft_labels_softmax = [F.softmax(i, dim=1) for i in soft_labels_] 48 | soft_labels_ = torch.cat(soft_labels_, dim=2).mean(dim=2) 49 | soft_labels = torch.cat(soft_labels_softmax, dim=2).mean(dim=2) 50 | 51 | else: 52 | idx = random.randint(0, len(self.teachers_0)-1) 53 | soft_labels_ = self.teachers_0[idx](input) 54 | soft_labels = F.softmax(soft_labels_, dim=1) 55 | 56 | # soft_labels = F.softmax(soft_labels_, dim=1) 57 | model_output = self.model(input) 58 | 59 | if before: 60 | return (model_output, soft_labels, soft_labels_) 61 | 62 | return (model_output, soft_labels) 63 | 64 | else: 65 | return self.model(input) 66 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | dependencies = [ 2 | 'timm', 3 | 'torch', 4 | ] 5 | 6 | import torch, timm 7 | 8 | __all__ = ['mealv1_resnest50', 'mealv2_resnest50', 'mealv2_resnest50_cutmix', 'mealv2_resnest50_380x380', 'mealv2_mobilenetv3_small_075', 'mealv2_mobilenetv3_small_100', 'mealv2_mobilenet_v3_large_100', 'mealv2_efficientnet_b0'] 9 | 10 | model_urls = { 11 | 'mealv1_resnest50': 'https://github.com/szq0214/MEAL-V2/releases/download/v1.0.0/MEALV1_ResNet50_224.pth', 12 | 'mealv2_resnest50': 'https://github.com/szq0214/MEAL-V2/releases/download/v1.0.0/MEALV2_ResNet50_224.pth', 13 | 'mealv2_resnest50_cutmix': 'https://github.com/szq0214/MEAL-V2/releases/download/v1.0.0/MEALV2_ResNet50_224_cutmix.pth', 14 | 'mealv2_resnest50_380x380': 'https://github.com/szq0214/MEAL-V2/releases/download/v1.0.0/MEALV2_ResNet50_380.pth', 15 | 'mealv2_mobilenetv3_small_075': 'https://github.com/szq0214/MEAL-V2/releases/download/v1.0.0/MEALV2_MobileNet_V3_Small_0.75_224.pth', 16 | 'mealv2_mobilenetv3_small_100': 'https://github.com/szq0214/MEAL-V2/releases/download/v1.0.0/MEALV2_MobileNet_V3_Small_1.0_224.pth', 17 | 'mealv2_mobilenet_v3_large_100': 'https://github.com/szq0214/MEAL-V2/releases/download/v1.0.0/MEALV2_MobileNet_V3_Large_1.0_224.pth', 18 | 'mealv2_efficientnet_b0': 'https://github.com/szq0214/MEAL-V2/releases/download/v1.0.0/MEALV2_EfficientNet_B0_224.pth', 19 | } 20 | 21 | 22 | mapping = {'mealv1_resnest50':'resnet50', 23 | 'mealv2_resnest50':'resnet50', 24 | 'mealv2_resnest50_cutmix':'resnet50', 25 | 'mealv2_resnest50_380x380':'resnet50', 26 | 'mealv2_mobilenetv3_small_075':'tf_mobilenetv3_small_075', 27 | 'mealv2_mobilenetv3_small_100':'tf_mobilenetv3_small_100', 28 | 'mealv2_mobilenet_v3_large_100':'tf_mobilenetv3_large_100', 29 | 'mealv2_efficientnet_b0':'tf_efficientnet_b0' 30 | } 31 | 32 | def meal_v2(model_name, pretrained=True, progress=True, exportable=False): 33 | """ MEAL V2 models from 34 | `"MEAL V2: Boosting Vanilla ResNet-50 to 80%+ Top-1 Accuracy on ImageNet without Tricks" `_ 35 | 36 | Args: 37 | model_name: Name of the model to load 38 | pretrained (bool): If True, returns a model trained with MEAL V2 on ImageNet 39 | progress (bool): If True, displays a progress bar of the download to stderr 40 | """ 41 | 42 | model = timm.create_model(mapping[model_name.lower()], pretrained=False, exportable=exportable) 43 | if pretrained: 44 | state_dict = torch.hub.load_state_dict_from_url(model_urls[model_name.lower()], progress=progress) 45 | model = torch.nn.DataParallel(model).cuda() 46 | model.load_state_dict(state_dict) 47 | return model -------------------------------------------------------------------------------- /imagenet.py: -------------------------------------------------------------------------------- 1 | """Dataset class for loading imagenet data.""" 2 | 3 | import os 4 | 5 | from torch.utils import data as data_utils 6 | from torchvision import datasets as torch_datasets 7 | from torchvision import transforms 8 | 9 | from utils_FKD import RandomResizedCrop_FKD,RandomHorizontalFlip_FKD,ImageFolder_FKD,Compose_FKD 10 | from torchvision.transforms import InterpolationMode 11 | 12 | def get_train_loader(imagenet_path, batch_size, num_workers, image_size): 13 | train_dataset = ImageNet(imagenet_path, image_size, is_train=True) 14 | return data_utils.DataLoader( 15 | train_dataset, shuffle=True, batch_size=batch_size, pin_memory=True, 16 | num_workers=num_workers) 17 | 18 | def get_train_loader_FKD(imagenet_path, batch_size, num_workers, image_size, num_crops, softlabel_path): 19 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 20 | std=[0.229, 0.224, 0.225]) 21 | train_dataset = ImageFolder_FKD( 22 | num_crops=num_crops, 23 | softlabel_path=softlabel_path, 24 | root=os.path.join(imagenet_path, 'train'), 25 | transform=Compose_FKD(transforms=[ 26 | RandomResizedCrop_FKD(size=224, 27 | interpolation='bilinear'), 28 | RandomHorizontalFlip_FKD(), 29 | transforms.ToTensor(), 30 | normalize, 31 | ])) 32 | return data_utils.DataLoader( 33 | train_dataset, shuffle=True, batch_size=batch_size, pin_memory=True, 34 | num_workers=num_workers) 35 | 36 | def get_val_loader(imagenet_path, batch_size, num_workers, image_size): 37 | val_dataset = ImageNet(imagenet_path, image_size, is_train=False) 38 | return data_utils.DataLoader( 39 | val_dataset, shuffle=False, batch_size=batch_size, pin_memory=True, 40 | num_workers=num_workers) 41 | 42 | 43 | class ImageNet(torch_datasets.ImageFolder): 44 | """Dataset class for ImageNet dataset. 45 | 46 | Arguments: 47 | root_dir (str): Path to the dataset root directory, which must contain 48 | train/ and val/ directories. 49 | is_train (bool): Whether to read training or validation images. 50 | """ 51 | MEAN = [0.485, 0.456, 0.406] 52 | STD = [0.229, 0.224, 0.225] 53 | 54 | def __init__(self, root_dir, im_size, is_train): 55 | if is_train: 56 | root_dir = os.path.join(root_dir, 'train') 57 | transform = transforms.Compose([ 58 | transforms.RandomResizedCrop(im_size), 59 | transforms.RandomHorizontalFlip(), 60 | transforms.ToTensor(), 61 | transforms.Normalize(ImageNet.MEAN, ImageNet.STD), 62 | ]) 63 | else: 64 | root_dir = os.path.join(root_dir, 'val') 65 | transform = transforms.Compose([ 66 | transforms.Resize(int(256/224*im_size)), 67 | transforms.CenterCrop(im_size), 68 | transforms.ToTensor(), 69 | transforms.Normalize(ImageNet.MEAN, ImageNet.STD), 70 | ]) 71 | super().__init__(root_dir, transform=transform) 72 | 73 | 74 | -------------------------------------------------------------------------------- /images/MEAL-V2_more_tricks_top1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/szq0214/MEAL-V2/59e6bca29c77a7df9229e77d4463f2f795723810/images/MEAL-V2_more_tricks_top1.png -------------------------------------------------------------------------------- /images/MEAL-V2_more_tricks_top5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/szq0214/MEAL-V2/59e6bca29c77a7df9229e77d4463f2f795723810/images/MEAL-V2_more_tricks_top5.png -------------------------------------------------------------------------------- /images/comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/szq0214/MEAL-V2/59e6bca29c77a7df9229e77d4463f2f795723810/images/comparison.png -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import shutil 5 | import time 6 | import warnings 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.parallel 11 | import torch.backends.cudnn as cudnn 12 | import torch.distributed as dist 13 | import torch.optim 14 | import torch.multiprocessing as mp 15 | import torch.utils.data 16 | import torch.utils.data.distributed 17 | import torchvision.transforms as transforms 18 | import torchvision.datasets as datasets 19 | import torchvision.models as models 20 | import timm 21 | 22 | model_names = sorted(name for name in models.__dict__ 23 | if name.islower() and not name.startswith("__") 24 | and callable(models.__dict__[name])) 25 | 26 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 27 | parser.add_argument('data', metavar='DIR', 28 | help='path to dataset') 29 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', 30 | # choices=model_names, 31 | help='model architecture: ' + 32 | ' | '.join(model_names) + 33 | ' (default: resnet18)') 34 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 35 | help='number of data loading workers (default: 4)') 36 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 37 | help='number of total epochs to run') 38 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 39 | help='manual epoch number (useful on restarts)') 40 | parser.add_argument('--res', default=224, type=int, 41 | help='image resolution for testing') 42 | parser.add_argument('-b', '--batch-size', default=256, type=int, 43 | metavar='N', 44 | help='mini-batch size (default: 256), this is the total ' 45 | 'batch size of all GPUs on the current node when ' 46 | 'using Data Parallel or Distributed Data Parallel') 47 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 48 | metavar='LR', help='initial learning rate', dest='lr') 49 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 50 | help='momentum') 51 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 52 | metavar='W', help='weight decay (default: 1e-4)', 53 | dest='weight_decay') 54 | parser.add_argument('-p', '--print-freq', default=10, type=int, 55 | metavar='N', help='print frequency (default: 10)') 56 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 57 | help='path to latest checkpoint (default: none)') 58 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 59 | help='evaluate model on validation set') 60 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 61 | help='use pre-trained model') 62 | parser.add_argument('--world-size', default=-1, type=int, 63 | help='number of nodes for distributed training') 64 | parser.add_argument('--rank', default=-1, type=int, 65 | help='node rank for distributed training') 66 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 67 | help='url used to set up distributed training') 68 | parser.add_argument('--dist-backend', default='nccl', type=str, 69 | help='distributed backend') 70 | parser.add_argument('--seed', default=None, type=int, 71 | help='seed for initializing training. ') 72 | parser.add_argument('--gpu', default=None, type=int, 73 | help='GPU id to use.') 74 | parser.add_argument('--multiprocessing-distributed', action='store_true', 75 | help='Use multi-processing distributed training to launch ' 76 | 'N processes per node, which has N GPUs. This is the ' 77 | 'fastest way to use PyTorch for either single node or ' 78 | 'multi node data parallel training') 79 | 80 | best_acc1 = 0 81 | 82 | 83 | def main(): 84 | args = parser.parse_args() 85 | 86 | if args.seed is not None: 87 | random.seed(args.seed) 88 | torch.manual_seed(args.seed) 89 | cudnn.deterministic = True 90 | warnings.warn('You have chosen to seed training. ' 91 | 'This will turn on the CUDNN deterministic setting, ' 92 | 'which can slow down your training considerably! ' 93 | 'You may see unexpected behavior when restarting ' 94 | 'from checkpoints.') 95 | 96 | if args.gpu is not None: 97 | warnings.warn('You have chosen a specific GPU. This will completely ' 98 | 'disable data parallelism.') 99 | 100 | if args.dist_url == "env://" and args.world_size == -1: 101 | args.world_size = int(os.environ["WORLD_SIZE"]) 102 | 103 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 104 | 105 | ngpus_per_node = torch.cuda.device_count() 106 | if args.multiprocessing_distributed: 107 | # Since we have ngpus_per_node processes per node, the total world_size 108 | # needs to be adjusted accordingly 109 | args.world_size = ngpus_per_node * args.world_size 110 | # Use torch.multiprocessing.spawn to launch distributed processes: the 111 | # main_worker process function 112 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 113 | else: 114 | # Simply call main_worker function 115 | main_worker(args.gpu, ngpus_per_node, args) 116 | 117 | 118 | def main_worker(gpu, ngpus_per_node, args): 119 | global best_acc1 120 | args.gpu = gpu 121 | 122 | if args.gpu is not None: 123 | print("Use GPU: {} for training".format(args.gpu)) 124 | 125 | if args.distributed: 126 | if args.dist_url == "env://" and args.rank == -1: 127 | args.rank = int(os.environ["RANK"]) 128 | if args.multiprocessing_distributed: 129 | # For multiprocessing distributed training, rank needs to be the 130 | # global rank among all the processes 131 | args.rank = args.rank * ngpus_per_node + gpu 132 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 133 | world_size=args.world_size, rank=args.rank) 134 | # create model 135 | if args.pretrained: 136 | print("=> using pre-trained model '{}'".format(args.arch)) 137 | # model = models.__dict__[args.arch](pretrained=True) 138 | model = timm.create_model(args.arch, pretrained=True) 139 | else: 140 | print("=> creating model '{}'".format(args.arch)) 141 | # model = models.__dict__[args.arch]() 142 | model = timm.create_model(args.arch, pretrained=False) 143 | 144 | if not torch.cuda.is_available(): 145 | print('using CPU, this will be slow') 146 | elif args.distributed: 147 | # For multiprocessing distributed, DistributedDataParallel constructor 148 | # should always set the single device scope, otherwise, 149 | # DistributedDataParallel will use all available devices. 150 | if args.gpu is not None: 151 | torch.cuda.set_device(args.gpu) 152 | model.cuda(args.gpu) 153 | # When using a single GPU per process and per 154 | # DistributedDataParallel, we need to divide the batch size 155 | # ourselves based on the total number of GPUs we have 156 | args.batch_size = int(args.batch_size / ngpus_per_node) 157 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 158 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 159 | else: 160 | model.cuda() 161 | # DistributedDataParallel will divide and allocate batch_size to all 162 | # available GPUs if device_ids are not set 163 | model = torch.nn.parallel.DistributedDataParallel(model) 164 | elif args.gpu is not None: 165 | torch.cuda.set_device(args.gpu) 166 | model = model.cuda(args.gpu) 167 | else: 168 | # DataParallel will divide and allocate batch_size to all available GPUs 169 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 170 | model.features = torch.nn.DataParallel(model.features) 171 | model.cuda() 172 | else: 173 | model = torch.nn.DataParallel(model).cuda() 174 | 175 | # define loss function (criterion) and optimizer 176 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 177 | 178 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 179 | momentum=args.momentum, 180 | weight_decay=args.weight_decay) 181 | 182 | # optionally resume from a checkpoint 183 | if args.resume: 184 | if os.path.isfile(args.resume): 185 | print("=> loading checkpoint '{}'".format(args.resume)) 186 | if args.gpu is None: 187 | checkpoint = torch.load(args.resume) 188 | else: 189 | # Map model to be loaded to specified single gpu. 190 | loc = 'cuda:{}'.format(args.gpu) 191 | checkpoint = torch.load(args.resume, map_location=loc) 192 | model.load_state_dict(checkpoint) 193 | else: 194 | print("=> no checkpoint found at '{}'".format(args.resume)) 195 | 196 | cudnn.benchmark = True 197 | 198 | # Data loading code 199 | traindir = os.path.join(args.data, 'train') 200 | valdir = os.path.join(args.data, 'val') 201 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 202 | std=[0.229, 0.224, 0.225]) 203 | 204 | train_dataset = datasets.ImageFolder( 205 | traindir, 206 | transforms.Compose([ 207 | transforms.RandomResizedCrop(224), 208 | transforms.RandomHorizontalFlip(), 209 | transforms.ToTensor(), 210 | normalize, 211 | ])) 212 | 213 | if args.distributed: 214 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 215 | else: 216 | train_sampler = None 217 | 218 | train_loader = torch.utils.data.DataLoader( 219 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 220 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 221 | 222 | val_loader = torch.utils.data.DataLoader( 223 | datasets.ImageFolder(valdir, transforms.Compose([ 224 | transforms.Resize(int(256/224*args.res)), 225 | transforms.CenterCrop(args.res), 226 | transforms.ToTensor(), 227 | normalize, 228 | ])), 229 | batch_size=args.batch_size, shuffle=False, 230 | num_workers=args.workers, pin_memory=True) 231 | 232 | if args.evaluate: 233 | validate(val_loader, model, criterion, args) 234 | return 235 | 236 | for epoch in range(args.start_epoch, args.epochs): 237 | if args.distributed: 238 | train_sampler.set_epoch(epoch) 239 | adjust_learning_rate(optimizer, epoch, args) 240 | 241 | # train for one epoch 242 | train(train_loader, model, criterion, optimizer, epoch, args) 243 | 244 | # evaluate on validation set 245 | acc1 = validate(val_loader, model, criterion, args) 246 | 247 | # remember best acc@1 and save checkpoint 248 | is_best = acc1 > best_acc1 249 | best_acc1 = max(acc1, best_acc1) 250 | 251 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 252 | and args.rank % ngpus_per_node == 0): 253 | save_checkpoint({ 254 | 'epoch': epoch + 1, 255 | 'arch': args.arch, 256 | 'state_dict': model.state_dict(), 257 | 'best_acc1': best_acc1, 258 | 'optimizer' : optimizer.state_dict(), 259 | }, is_best) 260 | 261 | 262 | def train(train_loader, model, criterion, optimizer, epoch, args): 263 | batch_time = AverageMeter('Time', ':6.3f') 264 | data_time = AverageMeter('Data', ':6.3f') 265 | losses = AverageMeter('Loss', ':.4e') 266 | top1 = AverageMeter('Acc@1', ':6.2f') 267 | top5 = AverageMeter('Acc@5', ':6.2f') 268 | progress = ProgressMeter( 269 | len(train_loader), 270 | [batch_time, data_time, losses, top1, top5], 271 | prefix="Epoch: [{}]".format(epoch)) 272 | 273 | # switch to train mode 274 | model.train() 275 | 276 | end = time.time() 277 | for i, (images, target) in enumerate(train_loader): 278 | # measure data loading time 279 | data_time.update(time.time() - end) 280 | 281 | if args.gpu is not None: 282 | images = images.cuda(args.gpu, non_blocking=True) 283 | if torch.cuda.is_available(): 284 | target = target.cuda(args.gpu, non_blocking=True) 285 | 286 | # compute output 287 | output = model(images) 288 | loss = criterion(output, target) 289 | 290 | # measure accuracy and record loss 291 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 292 | losses.update(loss.item(), images.size(0)) 293 | top1.update(acc1[0], images.size(0)) 294 | top5.update(acc5[0], images.size(0)) 295 | 296 | # compute gradient and do SGD step 297 | optimizer.zero_grad() 298 | loss.backward() 299 | optimizer.step() 300 | 301 | # measure elapsed time 302 | batch_time.update(time.time() - end) 303 | end = time.time() 304 | 305 | if i % args.print_freq == 0: 306 | progress.display(i) 307 | 308 | 309 | def validate(val_loader, model, criterion, args): 310 | batch_time = AverageMeter('Time', ':6.3f') 311 | losses = AverageMeter('Loss', ':.4e') 312 | top1 = AverageMeter('Acc@1', ':6.2f') 313 | top5 = AverageMeter('Acc@5', ':6.2f') 314 | progress = ProgressMeter( 315 | len(val_loader), 316 | [batch_time, losses, top1, top5], 317 | prefix='Test: ') 318 | 319 | # switch to evaluate mode 320 | model.eval() 321 | 322 | with torch.no_grad(): 323 | end = time.time() 324 | for i, (images, target) in enumerate(val_loader): 325 | if args.gpu is not None: 326 | images = images.cuda(args.gpu, non_blocking=True) 327 | if torch.cuda.is_available(): 328 | target = target.cuda(args.gpu, non_blocking=True) 329 | 330 | # compute output 331 | output = model(images) 332 | loss = criterion(output, target) 333 | 334 | # measure accuracy and record loss 335 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 336 | losses.update(loss.item(), images.size(0)) 337 | top1.update(acc1[0], images.size(0)) 338 | top5.update(acc5[0], images.size(0)) 339 | 340 | # measure elapsed time 341 | batch_time.update(time.time() - end) 342 | end = time.time() 343 | 344 | if i % args.print_freq == 0: 345 | progress.display(i) 346 | 347 | # TODO: this should also be done with the ProgressMeter 348 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 349 | .format(top1=top1, top5=top5)) 350 | 351 | return top1.avg 352 | 353 | 354 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 355 | torch.save(state, filename) 356 | if is_best: 357 | shutil.copyfile(filename, 'model_best.pth.tar') 358 | 359 | 360 | class AverageMeter(object): 361 | """Computes and stores the average and current value""" 362 | def __init__(self, name, fmt=':f'): 363 | self.name = name 364 | self.fmt = fmt 365 | self.reset() 366 | 367 | def reset(self): 368 | self.val = 0 369 | self.avg = 0 370 | self.sum = 0 371 | self.count = 0 372 | 373 | def update(self, val, n=1): 374 | self.val = val 375 | self.sum += val * n 376 | self.count += n 377 | self.avg = self.sum / self.count 378 | 379 | def __str__(self): 380 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 381 | return fmtstr.format(**self.__dict__) 382 | 383 | 384 | class ProgressMeter(object): 385 | def __init__(self, num_batches, meters, prefix=""): 386 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 387 | self.meters = meters 388 | self.prefix = prefix 389 | 390 | def display(self, batch): 391 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 392 | entries += [str(meter) for meter in self.meters] 393 | print('\t'.join(entries)) 394 | 395 | def _get_batch_fmtstr(self, num_batches): 396 | num_digits = len(str(num_batches // 1)) 397 | fmt = '{:' + str(num_digits) + 'd}' 398 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 399 | 400 | 401 | def adjust_learning_rate(optimizer, epoch, args): 402 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 403 | lr = args.lr * (0.1 ** (epoch // 30)) 404 | for param_group in optimizer.param_groups: 405 | param_group['lr'] = lr 406 | 407 | 408 | def accuracy(output, target, topk=(1,)): 409 | """Computes the accuracy over the k top predictions for the specified values of k""" 410 | with torch.no_grad(): 411 | maxk = max(topk) 412 | batch_size = target.size(0) 413 | 414 | _, pred = output.topk(maxk, 1, True, True) 415 | pred = pred.t() 416 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 417 | 418 | res = [] 419 | for k in topk: 420 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 421 | res.append(correct_k.mul_(100.0 / batch_size)) 422 | return res 423 | 424 | 425 | if __name__ == '__main__': 426 | main() -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class betweenLoss(nn.Module): 6 | def __init__(self, gamma=[1,1,1,1,1,1], loss=nn.L1Loss()): 7 | super(betweenLoss, self).__init__() 8 | self.gamma = gamma 9 | self.loss = loss 10 | 11 | def forward(self, outputs, targets): 12 | assert len(outputs) 13 | assert len(outputs) == len(targets) 14 | 15 | length = len(outputs) 16 | 17 | res = sum([self.gamma[i]*self.loss(outputs[i], targets[i]) for i in range(length)]) 18 | 19 | return res 20 | 21 | def CrossEntropy(outputs, targets): 22 | log_softmax_outputs = F.log_softmax(outputs, dim=1) 23 | softmax_targets = F.softmax(targets, dim=1) 24 | 25 | return -(log_softmax_outputs*softmax_targets).sum(dim=1).mean() 26 | 27 | 28 | class discriminatorLoss(nn.Module): 29 | def __init__(self, models, loss=nn.BCEWithLogitsLoss()): 30 | super(discriminatorLoss, self).__init__() 31 | self.models = models 32 | self.loss = loss 33 | 34 | def forward(self, outputs, targets): 35 | inputs = [torch.cat((i,j),0) for i, j in zip(outputs, targets)] 36 | inputs = torch.cat(inputs, 1) 37 | batch_size = inputs.size(0) 38 | target = torch.FloatTensor([[1, 0] for _ in range(batch_size//2)] + [[0, 1] for _ in range(batch_size//2)]) 39 | target = target.to(inputs[0].device) 40 | output = self.models(inputs) 41 | res = self.loss(output, target) 42 | return res 43 | 44 | 45 | class discriminatorFakeLoss(nn.Module): 46 | def forward(self, outputs, targets): 47 | res = (0*outputs[0]).sum() 48 | return res 49 | 50 | 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/szq0214/MEAL-V2/59e6bca29c77a7df9229e77d4463f2f795723810/models/__init__.py -------------------------------------------------------------------------------- /models/blocks.py: -------------------------------------------------------------------------------- 1 | """A list of commonly used building blocks.""" 2 | 3 | from torch import nn 4 | 5 | 6 | class Conv2dBnRelu(nn.Module): 7 | """A commonly used building block: Conv -> BN -> ReLU""" 8 | 9 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 10 | padding=0, bias=True, pooling=None, 11 | activation=nn.ReLU(inplace=True)): 12 | super().__init__() 13 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, 14 | padding, bias=bias) 15 | self.bn = nn.BatchNorm2d(out_channels) 16 | self.pooling = pooling 17 | self.activation = activation 18 | 19 | def forward(self, x): 20 | x = self.bn(self.conv(x)) 21 | if self.pooling is not None: 22 | x = self.pooling(x) 23 | return self.activation(x) 24 | 25 | 26 | class LinearBnRelu(nn.Module): 27 | """A commonly used building block: FC -> BN -> ReLU""" 28 | 29 | def __init__(self, in_features, out_features, bias=True, 30 | activation=nn.ReLU(inplace=True)): 31 | super().__init__() 32 | self.linear = nn.Linear(in_features, out_features, bias=bias) 33 | self.bn = nn.BatchNorm1d(out_features) 34 | self.activation = activation 35 | 36 | def forward(self, x): 37 | return self.activation(self.bn(self.linear(x))) 38 | -------------------------------------------------------------------------------- /models/discriminator.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class Discriminator(nn.Module): 7 | def __init__(self, outputs_size, K = 2): 8 | super(Discriminator, self).__init__() 9 | self.conv1 = nn.Conv2d(in_channels=outputs_size, out_channels=outputs_size//K, kernel_size=1, stride=1, bias=True) 10 | outputs_size = outputs_size // K 11 | self.conv2 = nn.Conv2d(in_channels=outputs_size, out_channels=outputs_size//K, kernel_size=1, stride=1, bias=True) 12 | outputs_size = outputs_size // K 13 | self.conv3 = nn.Conv2d(in_channels=outputs_size, out_channels=2, kernel_size=1, stride=1, bias=True) 14 | 15 | def forward(self, x): 16 | x = x[:,:,None,None] 17 | out = F.relu(self.conv1(x)) 18 | out = F.relu(self.conv2(out)) 19 | out = F.relu(self.conv3(out)) 20 | out = out.view(out.size(0), -1) 21 | return out 22 | 23 | -------------------------------------------------------------------------------- /models/model_factory.py: -------------------------------------------------------------------------------- 1 | """Utility functions to construct a model.""" 2 | 3 | import torch 4 | from torch import nn 5 | 6 | import random 7 | 8 | from extensions import data_parallel 9 | from extensions import teacher_wrapper 10 | from extensions import kd_loss 11 | import torchvision.models as models 12 | import timm 13 | 14 | 15 | def _create_single_cpu_model(model_name, state_file=None): 16 | model = _create_model(model_name, teacher=False, pretrain=True) 17 | if state_file is not None: 18 | model.load_state_dict(torch.load(state_file)) 19 | return model 20 | 21 | def _create_checkpoint_model(model_name, state_file=None): 22 | model = _create_model(model_name, teacher=False, pretrain=True) 23 | # model = timm.create_model(model_name.lower(), pretrained=False) 24 | if state_file is not None: 25 | model.load_state_dict(torch.load(state_file)) 26 | return model 27 | 28 | def _create_model(model_name, teacher=False, pretrain=True): 29 | if pretrain: 30 | print("=> teacher" if teacher else "=> student", end=":") 31 | print(" using pre-trained model '{}'".format(model_name)) 32 | 33 | # model = models.__dict__[model_name.lower()](pretrained=True) 34 | model = timm.create_model(model_name.lower(), pretrained=True) 35 | else: 36 | print("=> creating model '{}'".format(model_name)) 37 | # model = models.__dict__[model_name.lower()]() 38 | model = timm.create_model(model_name.lower(), pretrained=False) 39 | 40 | if model_name.startswith('alexnet') or model_name.startswith('vgg'): 41 | model.features = torch.nn.DataParallel(model.features) 42 | model.cuda() 43 | else: 44 | model = torch.nn.DataParallel(model).cuda() 45 | 46 | if teacher: 47 | for p in model.parameters(): 48 | p.requires_grad = False 49 | model.eval() 50 | 51 | return model 52 | 53 | 54 | def teachers(teachers=['resnet50'], state_file=None): 55 | if state_file is not None: 56 | return [_create_single_cpu_model(t, state_file).cuda() for t in teachers] 57 | else: 58 | return [_create_model(t, teacher=True).cuda() for t in teachers] 59 | 60 | 61 | def create_model(model_name, student_state_file=None, gpus=[], teacher=None, 62 | teacher_state_file=None, FKD=True): 63 | if FKD: 64 | model = _create_checkpoint_model(model_name, student_state_file) 65 | loss = kd_loss.KLLoss() 66 | return model, loss 67 | else: 68 | model = _create_checkpoint_model(model_name, student_state_file) 69 | model.LR_REGIME = [0, 100, 0.01, 101, 300, 0.001] # LR_REGIME 70 | if teacher is not None: 71 | # assert teacher_state_file is not None, "Teacher state is None." 72 | 73 | teacher = teachers(teacher.split(","), teacher_state_file) 74 | model = teacher_wrapper.ModelDistillationWrapper(model, teacher) 75 | loss = kd_loss.KLLoss() 76 | else: 77 | loss = nn.CrossEntropyLoss() 78 | 79 | return model, loss -------------------------------------------------------------------------------- /opts.py: -------------------------------------------------------------------------------- 1 | from torch.utils import data as data_utils 2 | 3 | from models import model_factory 4 | 5 | 6 | def add_general_flags(parser): 7 | parser.add_argument('--save', default='checkpoints', 8 | help="Path to the directory to save logs and " 9 | "checkpoints.") 10 | parser.add_argument('--gpus', '--gpu', nargs='+', default=[0], type=int, 11 | help="The GPU(s) on which the model should run. The " 12 | "first GPU will be the main one.") 13 | parser.add_argument('--cpu', action='store_const', const=[], 14 | dest='gpus', help="If set, no gpus will be used.") 15 | 16 | 17 | def add_dataset_flags(parser): 18 | parser.add_argument('--imagenet', required=True, help="Path to ImageNet's " 19 | "root directory holding 'train/' and 'val/' " 20 | "directories.") 21 | parser.add_argument('--batch-size', default=256, help="Batch size to use " 22 | "distributed over all GPUs.", type=int) 23 | parser.add_argument('--num-workers', '-j', default=40, help="Number of " 24 | "data loading processes to use for loading data and " 25 | "transforming.", type=int) 26 | parser.add_argument('--image-size', default=224, help="image size to train " 27 | "input image size.", type=int) 28 | parser.add_argument('--softlabel_path', default='./soft_label', type=str, metavar='PATH', 29 | help='path to soft label files (default: none)') 30 | 31 | 32 | def add_model_flags(parser): 33 | parser.add_argument('--model', required=True, help="The model architecture " 34 | "name.") 35 | parser.add_argument('--student-state-file', default=None, help="Path to student model" 36 | "state file to initialize the student model.") 37 | 38 | 39 | def add_teacher_flags(parser): 40 | parser.add_argument('--teacher-model', default="gluon_senet154,gluon_resnet152_v1s", help="The " 41 | "model that will generate soft labels per crop.", 42 | ) 43 | parser.add_argument('--teacher-state-file', default=None, 44 | help="Path to teacher model state file.") 45 | 46 | 47 | def add_training_flags(parser): 48 | parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, 49 | metavar='LR', help='initial learning rate', dest='lr') 50 | parser.add_argument('--lr-regime', default=None, nargs='+', type=float, 51 | help="If set, it will override the default learning " 52 | "rate regime of the model. Learning rate passed must " 53 | "be as list of [start, end, lr, ...].") 54 | parser.add_argument('--d_lr', default=1e-4, type=float, 55 | help="The learning rate for discriminator training") 56 | parser.add_argument('--start-epoch', default=0, help="manual epoch number " 57 | "useful on restarts.", type=int) 58 | parser.add_argument('--epochs', default=200, type=int, help='number of total epochs to run') 59 | parser.add_argument('--schedule', default=[100, 200], nargs='*', type=int, 60 | help='learning rate schedule (when to drop lr by 10x). This works for FKD training') 61 | parser.add_argument('--cos', action='store_true', 62 | help='use cosine lr schedule. This works for FKD training') 63 | parser.add_argument('--momentum', default=0.9, type=float, 64 | help="The momentum of the optimization.") 65 | parser.add_argument('--weight-decay', default=0, type=float, 66 | help="The weight decay of the optimization.") 67 | parser.add_argument('--use-discriminator-loss', action='store_true', 68 | help='use discriminating training') 69 | parser.add_argument('--num_crops', default=8, type=int, 70 | help='number of crops in each image, 1 is the standard training') 71 | parser.add_argument('--soft_label_type', default='marginal_smoothing_k5', type=str, metavar='TYPE', 72 | help='(1) ori; (2) hard; (3) smoothing; (4) marginal_smoothing_k5; (5) marginal_smoothing_k10; (6) marginal_renorm_k5') 73 | parser.add_argument('--num_classes', default=1000, type=int, 74 | help='number of classes. ') 75 | 76 | def add_cutmix_training_flags(parser): 77 | parser.add_argument('--w-cutmix', action='store_true', 78 | help='use cutmix training') 79 | parser.add_argument('--beta', default=1.0, type=float, 80 | help='hyperparameter beta') 81 | parser.add_argument('--cutmix-prob', default=1.0, type=float, 82 | help='cutmix probability') -------------------------------------------------------------------------------- /script/resume_train.sh: -------------------------------------------------------------------------------- 1 | # an example 2 | python train.py --save MEAL_V2_resnet50_224 --batch-size 512 --model resnet50 --start-epoch 96 --teacher-model gluon_senet154,gluon_resnet152_v1s --student-state-file ./MEAL_V2_resnet50/model_state_95.pytar --imagenet [imagenet-folder with train and val folders] -j 40 3 | -------------------------------------------------------------------------------- /script/train.sh: -------------------------------------------------------------------------------- 1 | # 224 x 224 ResNet-50 Tested on 8 TITAN Xp GPUs 2 | python train.py --save MEAL_V2_resnet50_224 --batch-size 512 -j 48 --model resnet50 --teacher-model gluon_senet154,gluon_resnet152_v1s --imagenet [imagenet-folder with train and val folders] 3 | 4 | # # 224 x 224 MobileNet V3-Small 0.75 5 | # python train.py --save MEAL_V2_mobilenetv3_small_075 --batch-size 512 -j 48 --model tf_mobilenetv3_small_075 --teacher-model gluon_senet154,gluon_resnet152_v1s --imagenet [imagenet-folder with train and val folders] 6 | 7 | # # 224 x 224 MobileNet V3-Small 1.0 8 | # python train.py --save MEAL_V2_mobilenetv3_small_100 --batch-size 512 -j 48 --model tf_mobilenetv3_small_100 --teacher-model gluon_senet154,gluon_resnet152_v1s --imagenet [imagenet-folder with train and val folders] 9 | 10 | # # 224 x 224 MobileNet V3-Large 1.0 11 | # python train.py --save MEAL_V2_mobilenetv3_large_100 --batch-size 512 -j 48 --model tf_mobilenetv3_large_100 --teacher-model gluon_senet154,gluon_resnet152_v1s --imagenet [imagenet-folder with train and val folders] 12 | 13 | # # 224 x 224 EfficientNet-B0 14 | # python train.py --save MEAL_V2_efficientnet_b0 --batch-size 512 -j 48 --model tf_efficientnet_b0 --teacher-model gluon_senet154,gluon_resnet152_v1s --imagenet [imagenet-folder with train and val folders] 15 | 16 | # 380 x 380 ResNet-50 17 | # python train.py --save MEAL_V2_resnet50_380 --batch-size 512 -j 48 --model resnet50 --image-size 380 --teacher-model tf_efficientnet_b4_ns,tf_efficientnet_b4 --imagenet [imagenet-folder with train and val folders] 18 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Script to test a pytorch model on ImageNet's validation set.""" 3 | 4 | import argparse 5 | import logging 6 | import pprint 7 | import sys 8 | import time 9 | 10 | import torch 11 | from torch import nn 12 | 13 | import imagenet 14 | from models import model_factory 15 | import opts 16 | import utils 17 | 18 | 19 | def parse_args(argv): 20 | """Parse arguments @argv and return the flags needed for training.""" 21 | parser = argparse.ArgumentParser(description=__doc__, allow_abbrev=False) 22 | 23 | group = parser.add_argument_group('General Options') 24 | opts.add_general_flags(group) 25 | 26 | group = parser.add_argument_group('Dataset Options') 27 | opts.add_dataset_flags(group) 28 | 29 | group = parser.add_argument_group('Model Options') 30 | opts.add_model_flags(group) 31 | 32 | args = parser.parse_args(argv) 33 | 34 | if args.student_state_file is None: 35 | parser.error("You should set --model-state-file (student) to reload a model " 36 | "state.") 37 | 38 | return args 39 | 40 | 41 | def test_for_one_epoch(model, loss, test_loader, epoch_number): 42 | model.eval() 43 | loss.eval() 44 | 45 | data_time_meter = utils.AverageMeter() 46 | batch_time_meter = utils.AverageMeter() 47 | loss_meter = utils.AverageMeter(recent=100) 48 | top1_meter = utils.AverageMeter(recent=100) 49 | top5_meter = utils.AverageMeter(recent=100) 50 | 51 | timestamp = time.time() 52 | for i, (images, labels) in enumerate(test_loader): 53 | batch_size = images.size(0) 54 | 55 | if utils.is_model_cuda(model): 56 | images = images.cuda() 57 | labels = labels.cuda() 58 | 59 | # Record data time 60 | data_time_meter.update(time.time() - timestamp) 61 | 62 | # Forward pass without computing gradients. 63 | with torch.no_grad(): 64 | outputs = model(images) 65 | loss_output = loss(outputs, labels) 66 | 67 | # Sometimes loss function returns a modified version of the output, 68 | # which must be used to compute the model accuracy. 69 | if isinstance(loss_output, tuple): 70 | loss_value, outputs = loss_output 71 | else: 72 | loss_value = loss_output 73 | 74 | # Record loss and model accuracy. 75 | loss_meter.update(loss_value.item(), batch_size) 76 | top1, top5 = utils.topk_accuracy(outputs, labels, recalls=(1, 5)) 77 | top1_meter.update(top1, batch_size) 78 | top5_meter.update(top5, batch_size) 79 | 80 | # Record batch time 81 | batch_time_meter.update(time.time() - timestamp) 82 | timestamp = time.time() 83 | 84 | logging.info( 85 | 'Epoch: [{epoch}][{batch}/{epoch_size}]\t' 86 | 'Time {batch_time.value:.2f} ({batch_time.average:.2f}) ' 87 | 'Data {data_time.value:.2f} ({data_time.average:.2f}) ' 88 | 'Loss {loss.value:.3f} {{{loss.average:.3f}, {loss.average_recent:.3f}}} ' 89 | 'Top-1 {top1.value:.2f} {{{top1.average:.2f}, {top1.average_recent:.2f}}} ' 90 | 'Top-5 {top5.value:.2f} {{{top5.average:.2f}, {top5.average_recent:.2f}}} '.format( 91 | epoch=epoch_number, batch=i + 1, epoch_size=len(test_loader), 92 | batch_time=batch_time_meter, data_time=data_time_meter, 93 | loss=loss_meter, top1=top1_meter, top5=top5_meter)) 94 | # Log the overall test stats 95 | logging.info( 96 | 'Epoch: [{epoch}] -- TESTING SUMMARY\t' 97 | 'Time {batch_time.sum:.2f} ' 98 | 'Data {data_time.sum:.2f} ' 99 | 'Loss {loss.average:.3f} ' 100 | 'Top-1 {top1.average:.2f} ' 101 | 'Top-5 {top5.average:.2f} '.format( 102 | epoch=epoch_number, batch_time=batch_time_meter, data_time=data_time_meter, 103 | loss=loss_meter, top1=top1_meter, top5=top5_meter)) 104 | 105 | 106 | def main(argv): 107 | """Run the test script with command line arguments @argv.""" 108 | args = parse_args(argv) 109 | utils.general_setup(args.save, args.gpus) 110 | 111 | logging.info("Arguments parsed.\n{}".format(pprint.pformat(vars(args)))) 112 | 113 | # Create the validation data loaders. 114 | val_loader = imagenet.get_val_loader(args.imagenet, args.batch_size, 115 | args.num_workers) 116 | # Create model and the loss. 117 | model, loss = model_factory.create_model( 118 | args.model, args.student_state_file, args.gpus) 119 | logging.info("Model:\n{}".format(model)) 120 | 121 | # Test for one epoch. 122 | test_for_one_epoch(model, loss, val_loader, epoch_number=1) 123 | 124 | 125 | if __name__ == '__main__': 126 | main(sys.argv[1:]) 127 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Script to train a model through soft labels on ImageNet's train set.""" 3 | 4 | import argparse 5 | import logging 6 | import pprint 7 | import os 8 | import sys 9 | import time 10 | import numpy as np 11 | 12 | import torch 13 | from torch import nn 14 | 15 | from loss import discriminatorLoss 16 | 17 | import imagenet 18 | from models import model_factory 19 | from models import discriminator 20 | import opts 21 | import test 22 | import utils 23 | 24 | 25 | def parse_args(argv): 26 | """Parse arguments @argv and return the flags needed for training.""" 27 | parser = argparse.ArgumentParser(description=__doc__, allow_abbrev=False) 28 | 29 | group = parser.add_argument_group('General Options') 30 | opts.add_general_flags(group) 31 | 32 | group = parser.add_argument_group('Dataset Options') 33 | opts.add_dataset_flags(group) 34 | 35 | group = parser.add_argument_group('Model Options') 36 | opts.add_model_flags(group) 37 | 38 | group = parser.add_argument_group('Soft Label Options') 39 | opts.add_teacher_flags(group) 40 | 41 | group = parser.add_argument_group('Training Options') 42 | opts.add_training_flags(group) 43 | 44 | group = parser.add_argument_group('CutMix Training Options') 45 | opts.add_cutmix_training_flags(group) 46 | 47 | args = parser.parse_args(argv) 48 | 49 | # if args.teacher_model is not None and args.teacher_state_file is None: 50 | # parser.error("You should set --teacher-state-file if " 51 | # "--teacher-model is set.") 52 | 53 | return args 54 | 55 | 56 | class LearningRateRegime: 57 | """Encapsulates the learning rate regime for training a model. 58 | 59 | Args: 60 | @intervals (list): A list of triples (start, end, lr). The intervals 61 | are inclusive (for start <= epoch <= end, lr will be used). The 62 | start of each interval must be right after the end of its previous 63 | interval. 64 | """ 65 | 66 | def __init__(self, regime): 67 | if len(regime) % 3 != 0: 68 | raise ValueError("Regime length should be devisible by 3.") 69 | intervals = list(zip(regime[0::3], regime[1::3], regime[2::3])) 70 | self._validate_intervals(intervals) 71 | self.intervals = intervals 72 | self.num_epochs = intervals[-1][1] 73 | 74 | @classmethod 75 | def _validate_intervals(cls, intervals): 76 | if type(intervals) is not list: 77 | raise TypeError("Intervals must be a list of triples.") 78 | elif len(intervals) == 0: 79 | raise ValueError("Intervals must be a non empty list.") 80 | # elif intervals[0][0] != 1: 81 | # raise ValueError("Intervals must start from 1: {}".format(intervals)) 82 | elif any(end < start for (start, end, lr) in intervals): 83 | raise ValueError("End of intervals must be greater or equal than their" 84 | " start: {}".format(intervals)) 85 | elif any(intervals[i][1] + 1 != intervals[i + 1][0] 86 | for i in range(len(intervals) - 1)): 87 | raise ValueError("Start of each each interval must be the end of its " 88 | "previous interval plus one: {}".format(intervals)) 89 | 90 | def get_lr(self, epoch): 91 | for (start, end, lr) in self.intervals: 92 | if start <= epoch <= end: 93 | return lr 94 | raise ValueError("Invalid epoch {} for regime {!r}".format( 95 | epoch, self.intervals)) 96 | 97 | 98 | def _set_learning_rate(optimizer, lr): 99 | for param_group in optimizer.param_groups: 100 | param_group['lr'] = lr 101 | 102 | 103 | def _get_learning_rate(optimizer): 104 | return max(param_group['lr'] for param_group in optimizer.param_groups) 105 | 106 | 107 | def train_for_one_epoch(model, g_loss, discriminator_loss, train_loader, optimizer, epoch_number, args): 108 | model.train() 109 | g_loss.train() 110 | 111 | data_time_meter = utils.AverageMeter() 112 | batch_time_meter = utils.AverageMeter() 113 | g_loss_meter = utils.AverageMeter(recent=100) 114 | d_loss_meter = utils.AverageMeter(recent=100) 115 | top1_meter = utils.AverageMeter(recent=100) 116 | top5_meter = utils.AverageMeter(recent=100) 117 | 118 | timestamp = time.time() 119 | for i, (images, labels) in enumerate(train_loader): 120 | batch_size = images.size(0) 121 | 122 | if utils.is_model_cuda(model): 123 | images = images.cuda() 124 | labels = labels.cuda() 125 | 126 | # Record data time 127 | data_time_meter.update(time.time() - timestamp) 128 | 129 | if args.w_cutmix == True: 130 | r = np.random.rand(1) 131 | if args.beta > 0 and r < args.cutmix_prob: 132 | # generate mixed sample 133 | lam = np.random.beta(args.beta, args.beta) 134 | rand_index = torch.randperm(images.size()[0]).cuda() 135 | target_a = labels 136 | target_b = labels[rand_index] 137 | bbx1, bby1, bbx2, bby2 = utils.rand_bbox(images.size(), lam) 138 | images[:, :, bbx1:bbx2, bby1:bby2] = images[rand_index, :, bbx1:bbx2, bby1:bby2] 139 | 140 | # Forward pass, backward pass, and update parameters. 141 | outputs = model(images, before=True) 142 | output, soft_label, soft_no_softmax = outputs 143 | g_loss_output = g_loss((output, soft_label), labels) 144 | d_loss_value = discriminator_loss([output], [soft_no_softmax]) 145 | 146 | # Sometimes loss function returns a modified version of the output, 147 | # which must be used to compute the model accuracy. 148 | if isinstance(g_loss_output, tuple): 149 | g_loss_value, outputs = g_loss_output 150 | else: 151 | g_loss_value = g_loss_output 152 | 153 | loss_value = g_loss_value + d_loss_value 154 | 155 | loss_value.backward() 156 | 157 | # Update parameters and reset gradients. 158 | optimizer.step() 159 | optimizer.zero_grad() 160 | 161 | # Record loss and model accuracy. 162 | g_loss_meter.update(g_loss_value.item(), batch_size) 163 | d_loss_meter.update(d_loss_value.item(), batch_size) 164 | 165 | top1, top5 = utils.topk_accuracy(outputs, labels, recalls=(1, 5)) 166 | top1_meter.update(top1, batch_size) 167 | top5_meter.update(top5, batch_size) 168 | 169 | # Record batch time 170 | batch_time_meter.update(time.time() - timestamp) 171 | timestamp = time.time() 172 | 173 | if i%20 == 0: 174 | logging.info( 175 | 'Epoch: [{epoch}][{batch}/{epoch_size}]\t' 176 | 'Time {batch_time.value:.2f} ({batch_time.average:.2f}) ' 177 | 'Data {data_time.value:.2f} ({data_time.average:.2f}) ' 178 | 'G_Loss {g_loss.value:.3f} {{{g_loss.average:.3f}, {g_loss.average_recent:.3f}}} ' 179 | 'D_Loss {d_loss.value:.3f} {{{d_loss.average:.3f}, {d_loss.average_recent:.3f}}} ' 180 | 'Top-1 {top1.value:.2f} {{{top1.average:.2f}, {top1.average_recent:.2f}}} ' 181 | 'Top-5 {top5.value:.2f} {{{top5.average:.2f}, {top5.average_recent:.2f}}} ' 182 | 'LR {lr:.5f}'.format( 183 | epoch=epoch_number, batch=i + 1, epoch_size=len(train_loader), 184 | batch_time=batch_time_meter, data_time=data_time_meter, 185 | g_loss=g_loss_meter, d_loss=d_loss_meter, top1=top1_meter, top5=top5_meter, 186 | lr=_get_learning_rate(optimizer))) 187 | # Log the overall train stats 188 | logging.info( 189 | 'Epoch: [{epoch}] -- TRAINING SUMMARY\t' 190 | 'Time {batch_time.sum:.2f} ' 191 | 'Data {data_time.sum:.2f} ' 192 | 'G_Loss {g_loss.average:.3f} ' 193 | 'D_Loss {d_loss.average:.3f} ' 194 | 'Top-1 {top1.average:.2f} ' 195 | 'Top-5 {top5.average:.2f} '.format( 196 | epoch=epoch_number, batch_time=batch_time_meter, data_time=data_time_meter, 197 | g_loss=g_loss_meter, d_loss=d_loss_meter, top1=top1_meter, top5=top5_meter)) 198 | 199 | 200 | def save_checkpoint(checkpoints_dir, model, optimizer, epoch): 201 | model_state_file = os.path.join(checkpoints_dir, 'model_state_{:02}.pytar'.format(epoch)) 202 | optim_state_file = os.path.join(checkpoints_dir, 'optim_state_{:02}.pytar'.format(epoch)) 203 | torch.save(model.state_dict(), model_state_file) 204 | torch.save(optimizer.state_dict(), optim_state_file) 205 | 206 | 207 | def create_optimizer(model, discriminator_parameters, momentum=0.9, weight_decay=0): 208 | # Get model parameters that require a gradient. 209 | # model_trainable_parameters = filter(lambda x: x.requires_grad, model.parameters()) 210 | parameters = [{'params': model.parameters()}, discriminator_parameters] 211 | optimizer = torch.optim.SGD(parameters, lr=0, 212 | momentum=momentum, weight_decay=weight_decay) 213 | return optimizer 214 | 215 | def create_discriminator_criterion(args): 216 | d = discriminator.Discriminator(outputs_size=1000, K=8).cuda() 217 | d = torch.nn.DataParallel(d) 218 | update_parameters = {'params': d.parameters(), "lr": args.d_lr} 219 | discriminators_criterion = discriminatorLoss(d).cuda() 220 | if len(args.gpus) > 1: 221 | discriminators_criterion = torch.nn.DataParallel(discriminators_criterion, device_ids=args.gpus) 222 | return discriminators_criterion, update_parameters 223 | 224 | def main(argv): 225 | """Run the training script with command line arguments @argv.""" 226 | args = parse_args(argv) 227 | utils.general_setup(args.save, args.gpus) 228 | 229 | logging.info("Arguments parsed.\n{}".format(pprint.pformat(vars(args)))) 230 | 231 | # Create the train and the validation data loaders. 232 | train_loader = imagenet.get_train_loader(args.imagenet, args.batch_size, 233 | args.num_workers, args.image_size) 234 | val_loader = imagenet.get_val_loader(args.imagenet, args.batch_size, 235 | args.num_workers, args.image_size) 236 | # Create model with optional teachers. 237 | model, loss = model_factory.create_model( 238 | args.model, args.student_state_file, args.gpus, args.teacher_model, 239 | args.teacher_state_file, False) 240 | logging.info("Model:\n{}".format(model)) 241 | 242 | discriminator_loss, update_parameters = create_discriminator_criterion(args) 243 | 244 | if args.lr_regime is None: 245 | lr_regime = model.LR_REGIME 246 | else: 247 | lr_regime = args.lr_regime 248 | regime = LearningRateRegime(lr_regime) 249 | # Train and test for needed number of epochs. 250 | optimizer = create_optimizer(model, update_parameters, args.momentum, args.weight_decay) 251 | 252 | for epoch in range(args.start_epoch, args.epochs): 253 | lr = regime.get_lr(epoch) 254 | _set_learning_rate(optimizer, lr) 255 | train_for_one_epoch(model, loss, discriminator_loss, train_loader, optimizer, epoch, args) 256 | test.test_for_one_epoch(model, loss, val_loader, epoch) 257 | save_checkpoint(args.save, model, optimizer, epoch) 258 | 259 | 260 | if __name__ == '__main__': 261 | main(sys.argv[1:]) 262 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import logging 3 | import os 4 | import sys 5 | import numpy as np 6 | 7 | import torch 8 | 9 | 10 | def general_setup(checkpoints_dir=None, gpus=[]): 11 | if checkpoints_dir is not None: 12 | os.makedirs(checkpoints_dir, exist_ok=True) 13 | if len(gpus) > 0: 14 | torch.cuda.set_device(gpus[0]) 15 | # Setup python's logging module. 16 | log_formatter = logging.Formatter( 17 | '%(levelname)s %(asctime)-20s:\t %(message)s') 18 | root_logger = logging.getLogger() 19 | root_logger.setLevel(logging.INFO) 20 | # Add a console handler to write to stdout. 21 | console_handler = logging.StreamHandler(sys.stdout) 22 | console_handler.setFormatter(log_formatter) 23 | root_logger.addHandler(console_handler) 24 | # Add a file handler to write to log.txt. 25 | log_filepath = os.path.join(checkpoints_dir, 'log.txt') 26 | file_handler = logging.FileHandler(log_filepath) 27 | file_handler.setFormatter(log_formatter) 28 | root_logger.addHandler(file_handler) 29 | 30 | 31 | def is_model_cuda(model): 32 | # Check if the first parameter is on cuda. 33 | return next(model.parameters()).is_cuda 34 | 35 | 36 | def topk_accuracy(outputs, labels, recalls=(1, 5)): 37 | """Return @recall accuracies for the given recalls.""" 38 | 39 | _, num_classes = outputs.size() 40 | maxk = min(max(recalls), num_classes) 41 | 42 | _, pred = outputs.topk(maxk, dim=1, largest=True, sorted=True) 43 | correct = (pred == labels[:,None].expand_as(pred)).float() 44 | 45 | topk_accuracy = [] 46 | for recall in recalls: 47 | topk_accuracy.append(100 * correct[:, :recall].sum(1).mean()) 48 | return topk_accuracy 49 | 50 | 51 | class AverageMeter: 52 | """Helper class to track the running average (and optionally the recent k 53 | items average of a sequence).""" 54 | 55 | def __init__(self, recent=None): 56 | self._recent = recent 57 | if recent is not None: 58 | self._q = collections.deque() 59 | self.reset() 60 | 61 | def reset(self): 62 | self.value = 0 63 | self.sum = 0 64 | self.count = 0 65 | if self._recent is not None: 66 | self.sum_recent = 0 67 | self.count_recent = 0 68 | self._q.clear() 69 | 70 | def update(self, value, n=1): 71 | self.value = value 72 | self.sum += value * n 73 | self.count += n 74 | 75 | if self._recent is not None: 76 | self.sum_recent += value * n 77 | self.count_recent += n 78 | self._q.append((n, value)) 79 | while len(self._q) > self._recent: 80 | (n, value) = self._q.popleft() 81 | self.sum_recent -= value * n 82 | self.count_recent -= n 83 | 84 | @property 85 | def average(self): 86 | if self.count > 0: 87 | return self.sum / self.count 88 | else: 89 | return 0 90 | 91 | @property 92 | def average_recent(self): 93 | if self.count_recent > 0: 94 | return self.sum_recent / self.count_recent 95 | else: 96 | return 0 97 | 98 | def rand_bbox(size, lam): 99 | W = size[2] 100 | H = size[3] 101 | cut_rat = np.sqrt(1. - lam) 102 | cut_w = np.int(W * cut_rat) 103 | cut_h = np.int(H * cut_rat) 104 | 105 | # uniform 106 | cx = np.random.randint(W) 107 | cy = np.random.randint(H) 108 | 109 | bbx1 = np.clip(cx - cut_w // 2, 0, W) 110 | bby1 = np.clip(cy - cut_h // 2, 0, H) 111 | bbx2 = np.clip(cx + cut_w // 2, 0, W) 112 | bby2 = np.clip(cy + cut_h // 2, 0, H) 113 | 114 | return bbx1, bby1, bbx2, bby2 115 | -------------------------------------------------------------------------------- /utils_FKD.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed 4 | import torch.nn as nn 5 | import torchvision 6 | from torchvision.ops import roi_align 7 | from torchvision.transforms import functional as t_F 8 | from torch.nn import functional as F 9 | from torchvision.datasets.folder import ImageFolder 10 | from torch.nn.modules import loss 11 | from torchvision.transforms import InterpolationMode 12 | import random 13 | import numpy as np 14 | 15 | 16 | class RandomResizedCrop_FKD(torchvision.transforms.RandomResizedCrop): 17 | def __init__(self, **kwargs): 18 | super(RandomResizedCrop_FKD, self).__init__(**kwargs) 19 | 20 | def __call__(self, img, coords, status): 21 | i = coords[0].item() * img.size[1] 22 | j = coords[1].item() * img.size[0] 23 | h = coords[2].item() * img.size[1] 24 | w = coords[3].item() * img.size[0] 25 | 26 | if self.interpolation == 'bilinear': 27 | inter = InterpolationMode.BILINEAR 28 | elif self.interpolation == 'bicubic': 29 | inter = InterpolationMode.BICUBIC 30 | return t_F.resized_crop(img, i, j, h, w, self.size, inter) 31 | 32 | 33 | class RandomHorizontalFlip_FKD(torch.nn.Module): 34 | def __init__(self, p=0.5): 35 | super().__init__() 36 | self.p = p 37 | 38 | def forward(self, img, coords, status): 39 | 40 | if status == True: 41 | return t_F.hflip(img) 42 | else: 43 | return img 44 | 45 | def __repr__(self): 46 | return self.__class__.__name__ + '(p={})'.format(self.p) 47 | 48 | 49 | class Compose_FKD(torchvision.transforms.Compose): 50 | def __init__(self, **kwargs): 51 | super(Compose_FKD, self).__init__(**kwargs) 52 | 53 | def __call__(self, img, coords, status): 54 | for t in self.transforms: 55 | if type(t).__name__ == 'RandomResizedCrop_FKD': 56 | img = t(img, coords, status) 57 | elif type(t).__name__ == 'RandomCrop_FKD': 58 | img, coords = t(img) 59 | elif type(t).__name__ == 'RandomHorizontalFlip_FKD': 60 | img = t(img, coords, status) 61 | else: 62 | img = t(img) 63 | return img 64 | 65 | 66 | class ImageFolder_FKD(torchvision.datasets.ImageFolder): 67 | def __init__(self, **kwargs): 68 | self.num_crops = kwargs['num_crops'] 69 | self.softlabel_path = kwargs['softlabel_path'] 70 | kwargs.pop('num_crops') 71 | kwargs.pop('softlabel_path') 72 | super(ImageFolder_FKD, self).__init__(**kwargs) 73 | 74 | def __getitem__(self, index): 75 | 76 | path, target = self.samples[index] 77 | 78 | label_path = os.path.join(self.softlabel_path, '/'.join(path.split('/')[-4:]).split('.')[0] + '.tar') 79 | 80 | label = torch.load(label_path, map_location=torch.device('cpu')) 81 | 82 | coords, flip_status, output = label 83 | 84 | rand_index = torch.randperm(len(output))#.cuda() 85 | output_new = [] 86 | 87 | sample = self.loader(path) 88 | sample_all = [] 89 | target_all = [] 90 | 91 | for i in range(self.num_crops): 92 | if self.transform is not None: 93 | output_new.append(output[rand_index[i]]) 94 | sample_new = self.transform(sample, coords[rand_index[i]], flip_status[rand_index[i]]) 95 | sample_all.append(sample_new) 96 | target_all.append(target) 97 | else: 98 | coords = None 99 | flip_status = None 100 | if self.target_transform is not None: 101 | target = self.target_transform(target) 102 | 103 | return sample_all, target_all, output_new 104 | 105 | 106 | def Recover_soft_label(label, label_type, n_classes): 107 | if label_type == 'hard': 108 | return torch.zeros(label.size(0), n_classes).scatter_(1, label.view(-1, 1), 1) 109 | elif label_type == 'smoothing': 110 | index = label[:,0].to(dtype=int) 111 | value = label[:,1] 112 | minor_value = (torch.ones_like(value) - value)/(n_classes-1) 113 | minor_value = minor_value.reshape(-1,1).repeat_interleave(n_classes, dim=1) 114 | soft_label = (minor_value * torch.ones(index.size(0), n_classes)).scatter_(1, index.view(-1, 1), value.view(-1, 1)) 115 | return soft_label 116 | elif label_type == 'marginal_smoothing_k5': 117 | index = label[:,0,:].to(dtype=int) 118 | value = label[:,1,:] 119 | minor_value = (torch.ones(label.size(0),1) - torch.sum(value, dim=1, keepdim=True))/(n_classes-5) 120 | minor_value = minor_value.reshape(-1,1).repeat_interleave(n_classes, dim=1) 121 | soft_label = (minor_value * torch.ones(index.size(0), n_classes)).scatter_(1, index, value) 122 | return soft_label 123 | elif label_type == 'marginal_renorm': 124 | index = label[:,0,:].to(dtype=int) 125 | value = label[:,1,:] 126 | soft_label = torch.zeros(index.size(0), n_classes).scatter_(1, index, value) 127 | soft_label = F.normalize(soft_label, p=1.0, dim=1, eps=1e-12) 128 | return soft_label 129 | elif label_type == 'marginal_smoothing_k10': 130 | index = label[:,0,:].to(dtype=int) 131 | value = label[:,1,:] 132 | minor_value = (torch.ones(label.size(0),1) - torch.sum(value, dim=1, keepdim=True))/(n_classes-10) 133 | minor_value = minor_value.reshape(-1,1).repeat_interleave(n_classes, dim=1) 134 | soft_label = (minor_value * torch.ones(index.size(0), n_classes)).scatter_(1, index, value) 135 | return soft_label --------------------------------------------------------------------------------