├── models ├── __init__.py └── volo.py ├── misc └── eccv.png ├── loss ├── __init__.py └── cross_entropy.py ├── distributed_train.sh ├── LICENSE ├── loss_adv.py ├── README.md ├── validate.py ├── train_kd.py └── train.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .volo import * -------------------------------------------------------------------------------- /misc/eccv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiawangbai/HAT/HEAD/misc/eccv.png -------------------------------------------------------------------------------- /loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .cross_entropy import TokenLabelGTCrossEntropy, TokenLabelSoftTargetCrossEntropy, TokenLabelCrossEntropy -------------------------------------------------------------------------------- /distributed_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | NUM_PROC=$1 3 | shift 4 | python3 -m torch.distributed.launch --nproc_per_node=$NUM_PROC train.py "$@" -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 jiawangbai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /loss_adv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class LabelSmoothingCrossEntropy(nn.Module): 7 | """ 8 | NLL loss with label smoothing. 9 | """ 10 | def __init__(self, smoothing=0.1): 11 | """ 12 | Constructor for the LabelSmoothing module. 13 | :param smoothing: label smoothing factor 14 | """ 15 | super(LabelSmoothingCrossEntropy, self).__init__() 16 | assert smoothing < 1.0 17 | self.smoothing = smoothing 18 | self.confidence = 1. - smoothing 19 | 20 | def forward(self, x, target): 21 | logprobs = F.log_softmax(x, dim=-1) 22 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 23 | nll_loss = nll_loss.squeeze(1) 24 | smooth_loss = -logprobs.mean(dim=-1) 25 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 26 | return loss.mean() 27 | 28 | 29 | class SoftTargetCrossEntropy(nn.Module): 30 | 31 | def __init__(self, reduce=False): 32 | super(SoftTargetCrossEntropy, self).__init__() 33 | self.reduce = reduce 34 | 35 | def forward(self, x, target): 36 | loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1) 37 | if self.reduce: 38 | return loss.mean() 39 | else: 40 | return loss 41 | 42 | 43 | 44 | class DistillationLoss(torch.nn.Module): 45 | """ 46 | This module wraps a standard criterion and adds an extra knowledge distillation loss by 47 | taking a teacher model prediction and using it as additional supervision. 48 | """ 49 | def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module, 50 | distillation_type: str, alpha: float, tau: float): 51 | super().__init__() 52 | self.base_criterion = base_criterion 53 | self.teacher_model = teacher_model 54 | assert distillation_type in ['none', 'soft', 'hard'] 55 | self.distillation_type = distillation_type 56 | self.alpha = alpha 57 | self.tau = tau 58 | 59 | def forward(self, inputs, outputs, labels): 60 | """ 61 | Args: 62 | inputs: The original inputs that are feed to the teacher model 63 | outputs: the outputs of the model to be trained. It is expected to be 64 | either a Tensor, or a Tuple[Tensor, Tensor], with the original output 65 | in the first position and the distillation predictions as the second output 66 | labels: the labels for the base criterion 67 | """ 68 | outputs_kd = None 69 | if not isinstance(outputs, torch.Tensor): 70 | # assume that the model outputs a tuple of [outputs, outputs_kd] 71 | outputs, outputs_kd = outputs 72 | base_loss = self.base_criterion(outputs, labels) 73 | if self.distillation_type == 'none': 74 | return base_loss 75 | 76 | if outputs_kd is None: 77 | raise ValueError("When knowledge distillation is enabled, the model is " 78 | "expected to return a Tuple[Tensor, Tensor] with the output of the " 79 | "class_token and the dist_token") 80 | # don't backprop throught the teacher 81 | with torch.no_grad(): 82 | teacher_outputs = self.teacher_model(inputs) 83 | 84 | if self.distillation_type == 'soft': 85 | T = self.tau 86 | # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 87 | # with slight modifications 88 | distillation_loss = F.kl_div( 89 | F.log_softmax(outputs_kd / T, dim=1), 90 | #We provide the teacher's targets in log probability because we use log_target=True 91 | #(as recommended in pytorch https://github.com/pytorch/pytorch/blob/9324181d0ac7b4f7949a574dbc3e8be30abe7041/torch/nn/functional.py#L2719) 92 | #but it is possible to give just the probabilities and set log_target=False. In our experiments we tried both. 93 | F.log_softmax(teacher_outputs / T, dim=1), 94 | reduction='sum', 95 | log_target=True 96 | ) * (T * T) / outputs_kd.numel() 97 | #We divide by outputs_kd.numel() to have the legacy PyTorch behavior. 98 | #But we also experiments output_kd.size(0) 99 | #see issue 61(https://github.com/facebookresearch/deit/issues/61) for more details 100 | elif self.distillation_type == 'hard': 101 | distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1)) 102 | 103 | loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha 104 | return loss -------------------------------------------------------------------------------- /loss/cross_entropy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class SoftTargetCrossEntropy(nn.Module): 7 | """ 8 | The native CE loss with soft target 9 | input: x is output of model, target is ground truth 10 | return: loss 11 | """ 12 | def __init__(self): 13 | super(SoftTargetCrossEntropy, self).__init__() 14 | 15 | def forward(self, x, target): 16 | N_rep = x.shape[0] 17 | N = target.shape[0] 18 | if not N == N_rep: 19 | target = target.repeat(N_rep // N, 1) 20 | loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1) 21 | return loss.mean() 22 | 23 | 24 | class TokenLabelGTCrossEntropy(nn.Module): 25 | """ 26 | Token labeling dense loss with ground gruth, see more from token labeling 27 | input: x is output of model, target is ground truth 28 | return: loss 29 | """ 30 | def __init__(self, 31 | dense_weight=1.0, 32 | cls_weight=1.0, 33 | mixup_active=True, 34 | smoothing=0.1, 35 | classes=1000): 36 | super(TokenLabelGTCrossEntropy, self).__init__() 37 | 38 | self.CE = SoftTargetCrossEntropy() 39 | 40 | self.dense_weight = dense_weight 41 | self.smoothing = smoothing 42 | self.mixup_active = mixup_active 43 | self.classes = classes 44 | self.cls_weight = cls_weight 45 | assert dense_weight + cls_weight > 0 46 | 47 | def forward(self, x, target): 48 | 49 | output, aux_output, bb = x 50 | bbx1, bby1, bbx2, bby2 = bb 51 | 52 | B, N, C = aux_output.shape 53 | if len(target.shape) == 2: 54 | target_cls = target 55 | target_aux = target.repeat(1, N).reshape(B * N, C) 56 | else: 57 | ground_truth = target[:, :, 0] 58 | target_cls = target[:, :, 1] 59 | ratio = (0.9 - 0.4 * 60 | (ground_truth.max(-1)[1] == target_cls.max(-1)[1]) 61 | ).unsqueeze(-1) 62 | target_cls = target_cls * ratio + ground_truth * (1 - ratio) 63 | target_aux = target[:, :, 2:] 64 | target_aux = target_aux.transpose(1, 2).reshape(-1, C) 65 | lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / N) 66 | if lam < 1: 67 | target_cls = lam * target_cls + (1 - lam) * target_cls.flip(0) 68 | 69 | aux_output = aux_output.reshape(-1, C) 70 | 71 | loss_cls = self.CE(output, target_cls) 72 | loss_aux = self.CE(aux_output, target_aux) 73 | 74 | return self.cls_weight * loss_cls + self.dense_weight * loss_aux 75 | 76 | 77 | class TokenLabelSoftTargetCrossEntropy(nn.Module): 78 | """ 79 | Token labeling dense loss with soft target, see more from token labeling 80 | input: x is output of model, target is ground truth 81 | return: loss 82 | """ 83 | def __init__(self): 84 | super(TokenLabelSoftTargetCrossEntropy, self).__init__() 85 | 86 | def forward(self, x, target): 87 | N_rep = x.shape[0] 88 | N = target.shape[0] 89 | if not N == N_rep: 90 | target = target.repeat(N_rep // N, 1) 91 | if len(target.shape) == 3 and target.shape[-1] == 2: 92 | target = target[:, :, 1] 93 | loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1) 94 | return loss.mean() 95 | 96 | 97 | class TokenLabelCrossEntropy(nn.Module): 98 | """ 99 | Token labeling loss without ground truth 100 | input: x is output of model, target is ground truth 101 | return: loss 102 | """ 103 | def __init__(self, 104 | dense_weight=1.0, 105 | cls_weight=1.0, 106 | mixup_active=True, 107 | classes=1000): 108 | """ 109 | Constructor Token labeling loss. 110 | """ 111 | super(TokenLabelCrossEntropy, self).__init__() 112 | 113 | self.CE = SoftTargetCrossEntropy() 114 | 115 | self.dense_weight = dense_weight 116 | self.mixup_active = mixup_active 117 | self.classes = classes 118 | self.cls_weight = cls_weight 119 | assert dense_weight + cls_weight > 0 120 | 121 | def forward(self, x, target): 122 | 123 | output, aux_output, bb = x 124 | bbx1, bby1, bbx2, bby2 = bb 125 | 126 | B, N, C = aux_output.shape 127 | if len(target.shape) == 2: 128 | target_cls = target 129 | target_aux = target.repeat(1, N).reshape(B * N, C) 130 | else: 131 | target_cls = target[:, :, 1] 132 | target_aux = target[:, :, 2:] 133 | target_aux = target_aux.transpose(1, 2).reshape(-1, C) 134 | lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / N) 135 | if lam < 1: 136 | target_cls = lam * target_cls + (1 - lam) * target_cls.flip(0) 137 | 138 | aux_output = aux_output.reshape(-1, C) 139 | loss_cls = self.CE(output, target_cls) 140 | loss_aux = self.CE(aux_output, target_aux) 141 | return self.cls_weight * loss_cls + self.dense_weight * loss_aux 142 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |
3 | # HAT 4 | 5 | Implementation of HAT https://arxiv.org/pdf/2204.00993 6 | ``` 7 | @inproceedings{bai2022improving, 8 | title={Improving Vision Transformers by Revisiting High-frequency Components}, 9 | author={Bai, Jiawang and Yuan, Li and Xia, Shu-Tao and Yan, Shuicheng and Li, Zhifeng and Liu, Wei}, 10 | booktitle={European Conference on Computer Vision}, 11 | year={2022} 12 | } 13 | ``` 14 | 15 | 16 | 17 | 18 | ## Requirements 19 | torch>=1.7.0 20 | torchvision>=0.8.0 21 | timm==0.4.5 22 | tlt==0.1.0 23 | pyyaml 24 | apex-amp 25 | 26 | ## ImageNet Classification 27 | 28 | ### Data Preparation 29 | We use the ImageNet-1K training and validation datasets by default. 30 | Please save them in [your_imagenet_path]. 31 | 32 | 33 | ### Training 34 | Training ViT models with HAT using the default settings in our paper on 8 GPUs: 35 | 36 | ```shell 37 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 \ 38 | --data_dir [your_imagenet_path] \ 39 | --model [your_vit_model_name] \ 40 | --adv-epochs 200 \ 41 | --adv-iters 3 \ 42 | --adv-eps 0.00784314 \ 43 | --adv-kl-weight 0.01 \ 44 | --adv-ce-weight 3.0 \ 45 | --output [your_output_path] \ 46 | and_other_parameters_specified_for_your_vit_models... 47 | ``` 48 | 49 | For instance, we train Swin-T with the following command: 50 | ```shell 51 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 \ 52 | --data_dir [your_imagenet_path] \ 53 | --model swin_tiny_patch4_window7_224 \ 54 | --adv-epochs 200 \ 55 | --adv-iters 3 \ 56 | --adv-eps 0.00784314 \ 57 | --adv-kl-weight 0.01 \ 58 | --adv-ce-weight 3.0 \ 59 | --output [your_output_path] \ 60 | --batch-size 256 \ 61 | --drop-path 0.2 \ 62 | --lr 1e-3 \ 63 | --weight-decay 0.05 \ 64 | --clip-grad 1.0 65 | ``` 66 | For training variants of ViT, Swin Transformer, VOLO, we use the hyper-parameters in [3], [4], and [2], respectively. 67 | 68 | 69 | We also combine HAT with knowledge distillation in [5], using [train_kd.py](). 70 | 71 | ### Validation 72 | 73 | After training, we can use validate.py to evaluate the ViT model trained with HAT. 74 | 75 | For instance, we evaluate Swin-T with the following command: 76 | ```shell 77 | python3 -u validate.py \ 78 | --data_dir [your_imagenet_path] \ 79 | --model swin_tiny_patch4_window7_224 \ 80 | --checkpoint [your_checkpoint_path] \ 81 | --batch-size 128 \ 82 | --num-gpu 8 \ 83 | --apex-amp \ 84 | --results-file [your_results_file_path] 85 | ``` 86 | 87 | 88 | ### Results 89 | | Model | Params | FLOPs | Test Size | Top-1 | +HAT Top-1 | Download | 90 | |:-:|:-:|:-:|:-:|:-:|:-:|:-:| 91 | | ViT-T | 5.7M | 1.6G | 224 | 72.2 | **73.3** | [link](https://github.com/jiawangbai/HAT/releases/download/v0.0.1/hat_vit_tiny_patch16_224.pth.tar)| 92 | | ViT-S | 22.1M | 4.7G | 224 | 80.1 | **80.9** |[link](https://github.com/jiawangbai/HAT/releases/download/v0.0.1/hat_vit_small_patch16_224.pth.tar)| 93 | | ViT-B | 86.6M | 17.6G | 224 | 82.0 | **83.2** |[link](https://github.com/jiawangbai/HAT/releases/download/v0.0.1/hat_vit_base_patch16_224.pth.tar)| 94 | | Swin-T | 28.3M | 4.5G | 224 | 81.2 | **82.0** |[link](https://github.com/jiawangbai/HAT/releases/download/v0.0.1/hat_swin_tiny_patch4_window7_224.pth.tar)| 95 | | Swin-S | 49.6M | 8.7G | 224 | 83.0 | **83.3** |[link](https://github.com/jiawangbai/HAT/releases/download/v0.0.1/hat_swin_small_patch4_window7_224.pth.tar)| 96 | | Swin-B | 87.8M | 15.4G | 224 | 83.5 | **84.0** |[link](https://github.com/jiawangbai/HAT/releases/download/v0.0.1/hat_swin_base_patch4_window7_224.pth.tar)| 97 | | VOLO-D1 | 26.6M | 6.8G | 224 | 84.2 | **84.5** |[link](https://github.com/jiawangbai/HAT/releases/download/v0.0.1/hat_volo_d1_224.pth.tar)| 98 | | VOLO-D1 | 26.6M | 22.8G | 384 | 85.2 | **85.5** |[link](https://github.com/jiawangbai/HAT/releases/download/v0.0.1/hat_volo_d1_384.pth.tar)| 99 | | VOLO-D5 | 295.5M | 69.0G | 224 | 86.1 | **86.3** |[link](https://github.com/jiawangbai/HAT/releases/download/v0.0.1/hat_volo_d5_224.pth.tar)| 100 | | VOLO-D5 | 295.5M | 304G | 448 | 87.0 | **87.2** |[link](https://github.com/jiawangbai/HAT/releases/download/v0.0.1/hat_volo_d5_448.pth.tar)| 101 | | VOLO-D5 | 295.5M | 412G | 512 | 87.1 | **87.3** |[link](https://github.com/jiawangbai/HAT/releases/download/v0.0.1/hat_volo_d5_512.pth.tar)| 102 | 103 | The result of combining HAT with knowledge distillation in [5] is 84.3% for ViT-B, and it can be downloaded [here](https://github.com/jiawangbai/HAT/releases/download/v0.0.1/hat_deit_base_distilled_patch16_224.pth.tar). 104 | 105 | ## Downstream Tasks 106 | 107 | We first pretrain Swin-T/S/B on the ImageNet-1k dataset with our proposed HAT, and then transfer the models to the downstream tasks, including object detection, instance segmentation, and semantic segmentation. 108 | 109 | We use the codes in [Swin Transformer for Object Detection](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection) and [Swin Transformer for Semantic Segmentaion](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation), and follow their configurations. 110 | 111 | ### Cascade Mask R-CNN on COCO val 2017 112 | | Backbone | Params | FLOPs | Config| AP_box | +HAT AP_box | AP_mask | +HAT AP_mask | 113 | |:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:| 114 | | Swin-T | 86M | 745G | [config](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/master/configs/swin/cascade_mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py) | 50.5 | **50.9** |43.7| **43.9** | 115 | | Swin-S | 107M | 838G | [config](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/master/configs/swin/cascade_mask_rcnn_swin_small_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py) | 51.8 | **52.5** |44.7| **45.4** | 116 | | Swin-B | 145M | 982G | [config](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/master/configs/swin/cascade_mask_rcnn_swin_base_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py) | 51.9 | **52.8** |45.0| **45.6** | 117 | 118 | ### UperNet on ADE20K 119 | | Backbone | Params | FLOPs | Config| mIoU(MS) | +HAT mIoU(MS) | 120 | |:-:|:-:|:-:|:-:|:-:|:-:| 121 | | Swin-T | 60M | 945G | [config](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation/blob/main/configs/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k.py) | 46.1 | **46.7** | 122 | | Swin-S | 81M | 1038G | [config](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation/blob/main/configs/swin/upernet_swin_small_patch4_window7_512x512_160k_ade20k.py) | 49.5 | **49.7** | 123 | | Swin-B | 121M | 1088G | [config](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation/blob/main/configs/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k.py) | 49.7 | **50.3** | 124 | 125 | 126 | [1] Wightman, R. Pytorch image models. https://github.com/rwightman/pytorch-image-models , 2019. 127 | [2] Yuan, L. et al. Volo: Vision outlooker for visual recognition. arXiv, 2021. 128 | [3] Dosovitskiy, A. et al. An image is worth 16x16 words: Transformers for image recognition at scale. ICLR, 2020. 129 | [4] Liu, Z. et al. Swin transformer: Hierarchical vision transformer using shifted windows. ICCV, 2021. 130 | [5] Touvron H. et al. Training data-efficient image transformers & distillation through attention. ICML, 2021. 131 | -------------------------------------------------------------------------------- /validate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import csv 4 | import glob 5 | import time 6 | import logging 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.parallel 10 | from collections import OrderedDict 11 | from contextlib import suppress 12 | 13 | from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models 14 | from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet 15 | from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_legacy 16 | import models 17 | 18 | has_apex = False 19 | try: 20 | from apex import amp 21 | has_apex = True 22 | except ImportError: 23 | pass 24 | 25 | has_native_amp = False 26 | try: 27 | if getattr(torch.cuda.amp, 'autocast') is not None: 28 | has_native_amp = True 29 | except AttributeError: 30 | pass 31 | 32 | torch.backends.cudnn.benchmark = True 33 | _logger = logging.getLogger('validate') 34 | 35 | 36 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation') 37 | parser.add_argument('--data_dir', metavar='DIR', 38 | help='path to dataset') 39 | parser.add_argument('--dataset', '-d', metavar='NAME', default='', 40 | help='dataset type (default: ImageFolder/ImageTar if empty)') 41 | parser.add_argument('--split', metavar='NAME', default='validation', 42 | help='dataset split (default: validation)') 43 | parser.add_argument('--model', '-m', metavar='NAME', default='dpn92', 44 | help='model architecture (default: dpn92)') 45 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 46 | help='number of data loading workers (default: 2)') 47 | parser.add_argument('-b', '--batch-size', default=256, type=int, 48 | metavar='N', help='mini-batch size (default: 256)') 49 | parser.add_argument('--img-size', default=None, type=int, 50 | metavar='N', help='Input image dimension, uses model default if empty') 51 | parser.add_argument('--input-size', default=None, nargs=3, type=int, 52 | metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224),' 53 | ' uses model default if empty') 54 | parser.add_argument('--crop-pct', default=None, type=float, 55 | metavar='N', help='Input image center crop pct') 56 | parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', 57 | help='Override mean pixel value of dataset') 58 | parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', 59 | help='Override std deviation of of dataset') 60 | parser.add_argument('--interpolation', default='', type=str, metavar='NAME', 61 | help='Image resize interpolation type (overrides model)') 62 | parser.add_argument('--num-classes', type=int, default=None, 63 | help='Number classes in dataset') 64 | parser.add_argument('--class-map', default='', type=str, metavar='FILENAME', 65 | help='path to class to idx mapping file (default: "")') 66 | parser.add_argument('--gp', default=None, type=str, metavar='POOL', 67 | help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') 68 | parser.add_argument('--log-freq', default=50, type=int, 69 | metavar='N', help='batch logging frequency (default: 10)') 70 | parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', 71 | help='path to latest checkpoint (default: none)') 72 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 73 | help='use pre-trained model') 74 | parser.add_argument('--num-gpu', type=int, default=1, 75 | help='Number of GPUS to use') 76 | parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true', 77 | help='disable test time pool') 78 | parser.add_argument('--no-prefetcher', action='store_true', default=False, 79 | help='disable fast prefetcher') 80 | parser.add_argument('--pin-mem', action='store_true', default=False, 81 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 82 | parser.add_argument('--channels-last', action='store_true', default=False, 83 | help='Use channels_last memory layout') 84 | parser.add_argument('--amp', action='store_true', default=False, 85 | help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.') 86 | parser.add_argument('--apex-amp', action='store_true', default=False, 87 | help='Use NVIDIA Apex AMP mixed precision') 88 | parser.add_argument('--native-amp', action='store_true', default=False, 89 | help='Use Native Torch AMP mixed precision') 90 | parser.add_argument('--tf-preprocessing', action='store_true', default=False, 91 | help='Use Tensorflow preprocessing pipeline (require CPU TF installed') 92 | parser.add_argument('--use-ema', dest='use_ema', action='store_true', 93 | help='use ema version of weights if present') 94 | parser.add_argument('--torchscript', dest='torchscript', action='store_true', 95 | help='convert model torchscript for inference') 96 | parser.add_argument('--legacy-jit', dest='legacy_jit', action='store_true', 97 | help='use legacy jit mode for pytorch 1.5/1.5.1/1.6 to get back fusion performance') 98 | parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', 99 | help='Output csv file for validation results (summary)') 100 | parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME', 101 | help='Real labels JSON file for imagenet evaluation') 102 | parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME', 103 | help='Valid label indices txt file for validation of partial label space') 104 | 105 | 106 | def validate(args): 107 | # might as well try to validate something 108 | args.pretrained = args.pretrained or not args.checkpoint 109 | args.prefetcher = not args.no_prefetcher 110 | amp_autocast = suppress # do nothing 111 | if args.amp: 112 | if has_native_amp: 113 | args.native_amp = True 114 | elif has_apex: 115 | args.apex_amp = True 116 | else: 117 | _logger.warning("Neither APEX or Native Torch AMP is available.") 118 | assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set." 119 | if args.native_amp: 120 | amp_autocast = torch.cuda.amp.autocast 121 | _logger.info('Validating in mixed precision with native PyTorch AMP.') 122 | elif args.apex_amp: 123 | _logger.info('Validating in mixed precision with NVIDIA APEX AMP.') 124 | else: 125 | _logger.info('Validating in float32. AMP not enabled.') 126 | 127 | if args.legacy_jit: 128 | set_jit_legacy() 129 | 130 | # create model 131 | model = create_model( 132 | args.model, 133 | pretrained=args.pretrained, 134 | num_classes=args.num_classes, 135 | in_chans=3, 136 | global_pool=args.gp, 137 | scriptable=args.torchscript, 138 | img_size=args.img_size) 139 | if args.num_classes is None: 140 | assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' 141 | args.num_classes = model.num_classes 142 | 143 | if args.checkpoint: 144 | load_checkpoint(model, args.checkpoint, args.use_ema, strict=False) 145 | 146 | param_count = sum([m.numel() for m in model.parameters()]) 147 | _logger.info('Model %s created, param count: %d' % (args.model, param_count)) 148 | 149 | data_config = resolve_data_config(vars(args), model=model, use_test_size=True) 150 | test_time_pool = False 151 | if not args.no_test_pool: 152 | model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True) 153 | 154 | if args.torchscript: 155 | torch.jit.optimized_execution(True) 156 | model = torch.jit.script(model) 157 | 158 | model = model.cuda() 159 | if args.apex_amp: 160 | model = amp.initialize(model, opt_level='O1') 161 | 162 | if args.channels_last: 163 | model = model.to(memory_format=torch.channels_last) 164 | 165 | if args.num_gpu > 1: 166 | model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) 167 | 168 | criterion = nn.CrossEntropyLoss().cuda() 169 | 170 | dataset = create_dataset( 171 | root=args.data_dir, name=args.dataset, split=args.split, 172 | load_bytes=args.tf_preprocessing, class_map=args.class_map) 173 | 174 | if args.valid_labels: 175 | with open(args.valid_labels, 'r') as f: 176 | valid_labels = {int(line.rstrip()) for line in f} 177 | valid_labels = [i in valid_labels for i in range(args.num_classes)] 178 | else: 179 | valid_labels = None 180 | 181 | if args.real_labels: 182 | real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels) 183 | else: 184 | real_labels = None 185 | 186 | crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] 187 | loader = create_loader( 188 | dataset, 189 | input_size=data_config['input_size'], 190 | batch_size=args.batch_size, 191 | use_prefetcher=args.prefetcher, 192 | interpolation=data_config['interpolation'], 193 | mean=data_config['mean'], 194 | std=data_config['std'], 195 | num_workers=args.workers, 196 | crop_pct=crop_pct, 197 | pin_memory=args.pin_mem, 198 | tf_preprocessing=args.tf_preprocessing) 199 | 200 | batch_time = AverageMeter() 201 | losses = AverageMeter() 202 | top1 = AverageMeter() 203 | top5 = AverageMeter() 204 | 205 | model.eval() 206 | with torch.no_grad(): 207 | # warmup, reduce variability of first batch time, especially for comparing torchscript vs non 208 | input = torch.randn((args.batch_size,) + data_config['input_size']).cuda() 209 | if args.channels_last: 210 | input = input.contiguous(memory_format=torch.channels_last) 211 | model(input) 212 | end = time.time() 213 | for batch_idx, (input, target) in enumerate(loader): 214 | if args.no_prefetcher: 215 | target = target.cuda() 216 | input = input.cuda() 217 | if args.channels_last: 218 | input = input.contiguous(memory_format=torch.channels_last) 219 | 220 | # compute output 221 | with amp_autocast(): 222 | output = model(input) 223 | if isinstance(output, (tuple, list)): 224 | output = output[0] 225 | if valid_labels is not None: 226 | output = output[:, valid_labels] 227 | loss = criterion(output, target) 228 | 229 | if real_labels is not None: 230 | real_labels.add_result(output) 231 | 232 | # measure accuracy and record loss 233 | acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5)) 234 | losses.update(loss.item(), input.size(0)) 235 | top1.update(acc1.item(), input.size(0)) 236 | top5.update(acc5.item(), input.size(0)) 237 | 238 | # measure elapsed time 239 | batch_time.update(time.time() - end) 240 | end = time.time() 241 | 242 | if batch_idx % args.log_freq == 0: 243 | _logger.info( 244 | 'Test: [{0:>4d}/{1}] ' 245 | 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 246 | 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 247 | 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 248 | 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( 249 | batch_idx, len(loader), batch_time=batch_time, 250 | rate_avg=input.size(0) / batch_time.avg, 251 | loss=losses, top1=top1, top5=top5)) 252 | 253 | if real_labels is not None: 254 | # real labels mode replaces topk values at the end 255 | top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5) 256 | else: 257 | top1a, top5a = top1.avg, top5.avg 258 | results = OrderedDict( 259 | top1=round(top1a, 4), top1_err=round(100 - top1a, 4), 260 | top5=round(top5a, 4), top5_err=round(100 - top5a, 4), 261 | param_count=round(param_count / 1e6, 2), 262 | img_size=data_config['input_size'][-1], 263 | cropt_pct=crop_pct, 264 | interpolation=data_config['interpolation']) 265 | 266 | _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format( 267 | results['top1'], results['top1_err'], results['top5'], results['top5_err'])) 268 | 269 | return results 270 | 271 | 272 | def main(): 273 | setup_default_logging() 274 | args = parser.parse_args() 275 | model_cfgs = [] 276 | model_names = [] 277 | if os.path.isdir(args.checkpoint): 278 | # validate all checkpoints in a path with same model 279 | checkpoints = glob.glob(args.checkpoint + '/*.pth.tar') 280 | checkpoints += glob.glob(args.checkpoint + '/*.pth') 281 | model_names = list_models(args.model) 282 | model_cfgs = [(args.model, c) for c in sorted(checkpoints, key=natural_key)] 283 | else: 284 | if args.model == 'all': 285 | # validate all models in a list of names with pretrained checkpoints 286 | args.pretrained = True 287 | model_names = list_models(pretrained=True, exclude_filters=['*in21k']) 288 | model_cfgs = [(n, '') for n in model_names] 289 | elif not is_model(args.model): 290 | # model name doesn't exist, try as wildcard filter 291 | model_names = list_models(args.model) 292 | model_cfgs = [(n, '') for n in model_names] 293 | 294 | if len(model_cfgs): 295 | results_file = args.results_file or './results-all.csv' 296 | _logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names))) 297 | results = [] 298 | try: 299 | start_batch_size = args.batch_size 300 | for m, c in model_cfgs: 301 | batch_size = start_batch_size 302 | args.model = m 303 | args.checkpoint = c 304 | result = OrderedDict(model=args.model) 305 | r = {} 306 | while not r and batch_size >= args.num_gpu: 307 | torch.cuda.empty_cache() 308 | try: 309 | args.batch_size = batch_size 310 | print('Validating with batch size: %d' % args.batch_size) 311 | r = validate(args) 312 | except RuntimeError as e: 313 | if batch_size <= args.num_gpu: 314 | print("Validation failed with no ability to reduce batch size. Exiting.") 315 | raise e 316 | batch_size = max(batch_size // 2, args.num_gpu) 317 | print("Validation failed, reducing batch size by 50%") 318 | result.update(r) 319 | if args.checkpoint: 320 | result['checkpoint'] = args.checkpoint 321 | results.append(result) 322 | except KeyboardInterrupt as e: 323 | pass 324 | results = sorted(results, key=lambda x: x['top1'], reverse=True) 325 | if len(results): 326 | write_results(results_file, results) 327 | else: 328 | validate(args) 329 | 330 | def write_results(results_file, results): 331 | with open(results_file, mode='w') as cf: 332 | dw = csv.DictWriter(cf, fieldnames=results[0].keys()) 333 | dw.writeheader() 334 | for r in results: 335 | dw.writerow(r) 336 | cf.flush() 337 | 338 | if __name__ == '__main__': 339 | main() -------------------------------------------------------------------------------- /models/volo.py: -------------------------------------------------------------------------------- 1 | """ 2 | Vision OutLOoker (VOLO) implementation 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 9 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 10 | from timm.models.registry import register_model 11 | import math 12 | import numpy as np 13 | 14 | 15 | def _cfg(url='', **kwargs): 16 | return { 17 | 'url': url, 18 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 19 | 'crop_pct': .96, 'interpolation': 'bicubic', 20 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 21 | 'first_conv': 'patch_embed.proj', 'classifier': 'head', 22 | **kwargs 23 | } 24 | 25 | 26 | default_cfgs = { 27 | 'volo': _cfg(crop_pct=0.96), 28 | 'volo_large': _cfg(crop_pct=1.15), 29 | } 30 | 31 | 32 | class OutlookAttention(nn.Module): 33 | """ 34 | Implementation of outlook attention 35 | --dim: hidden dim 36 | --num_heads: number of heads 37 | --kernel_size: kernel size in each window for outlook attention 38 | return: token features after outlook attention 39 | """ 40 | 41 | def __init__(self, dim, num_heads, kernel_size=3, padding=1, stride=1, 42 | qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 43 | super().__init__() 44 | head_dim = dim // num_heads 45 | self.num_heads = num_heads 46 | self.kernel_size = kernel_size 47 | self.padding = padding 48 | self.stride = stride 49 | self.scale = qk_scale or head_dim**-0.5 50 | 51 | self.v = nn.Linear(dim, dim, bias=qkv_bias) 52 | self.attn = nn.Linear(dim, kernel_size**4 * num_heads) 53 | 54 | self.attn_drop = nn.Dropout(attn_drop) 55 | self.proj = nn.Linear(dim, dim) 56 | self.proj_drop = nn.Dropout(proj_drop) 57 | 58 | self.unfold = nn.Unfold(kernel_size=kernel_size, padding=padding, stride=stride) 59 | self.pool = nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True) 60 | 61 | def forward(self, x): 62 | B, H, W, C = x.shape 63 | 64 | v = self.v(x).permute(0, 3, 1, 2) # B, C, H, W 65 | 66 | h, w = math.ceil(H / self.stride), math.ceil(W / self.stride) 67 | v = self.unfold(v).reshape(B, self.num_heads, C // self.num_heads, 68 | self.kernel_size * self.kernel_size, 69 | h * w).permute(0, 1, 4, 3, 2) # B,H,N,kxk,C/H 70 | 71 | attn = self.pool(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) 72 | attn = self.attn(attn).reshape( 73 | B, h * w, self.num_heads, self.kernel_size * self.kernel_size, 74 | self.kernel_size * self.kernel_size).permute(0, 2, 1, 3, 4) # B,H,N,kxk,kxk 75 | attn = attn * self.scale 76 | attn = attn.softmax(dim=-1) 77 | attn = self.attn_drop(attn) 78 | 79 | x = (attn @ v).permute(0, 1, 4, 3, 2).reshape( 80 | B, C * self.kernel_size * self.kernel_size, h * w) 81 | x = F.fold(x, output_size=(H, W), kernel_size=self.kernel_size, 82 | padding=self.padding, stride=self.stride) 83 | 84 | x = self.proj(x.permute(0, 2, 3, 1)) 85 | x = self.proj_drop(x) 86 | 87 | return x 88 | 89 | 90 | class Outlooker(nn.Module): 91 | """ 92 | Implementation of outlooker layer: which includes outlook attention + MLP 93 | Outlooker is the first stage in our VOLO 94 | --dim: hidden dim 95 | --num_heads: number of heads 96 | --mlp_ratio: mlp ratio 97 | --kernel_size: kernel size in each window for outlook attention 98 | return: outlooker layer 99 | """ 100 | def __init__(self, dim, kernel_size, padding, stride=1, 101 | num_heads=1,mlp_ratio=3., attn_drop=0., 102 | drop_path=0., act_layer=nn.GELU, 103 | norm_layer=nn.LayerNorm, qkv_bias=False, 104 | qk_scale=None): 105 | super().__init__() 106 | self.norm1 = norm_layer(dim) 107 | self.attn = OutlookAttention(dim, num_heads, kernel_size=kernel_size, 108 | padding=padding, stride=stride, 109 | qkv_bias=qkv_bias, qk_scale=qk_scale, 110 | attn_drop=attn_drop) 111 | 112 | self.drop_path = DropPath( 113 | drop_path) if drop_path > 0. else nn.Identity() 114 | 115 | self.norm2 = norm_layer(dim) 116 | mlp_hidden_dim = int(dim * mlp_ratio) 117 | self.mlp = Mlp(in_features=dim, 118 | hidden_features=mlp_hidden_dim, 119 | act_layer=act_layer) 120 | 121 | def forward(self, x): 122 | x = x + self.drop_path(self.attn(self.norm1(x))) 123 | x = x + self.drop_path(self.mlp(self.norm2(x))) 124 | return x 125 | 126 | 127 | class Mlp(nn.Module): 128 | "Implementation of MLP" 129 | 130 | def __init__(self, in_features, hidden_features=None, 131 | out_features=None, act_layer=nn.GELU, 132 | drop=0.): 133 | super().__init__() 134 | out_features = out_features or in_features 135 | hidden_features = hidden_features or in_features 136 | self.fc1 = nn.Linear(in_features, hidden_features) 137 | self.act = act_layer() 138 | self.fc2 = nn.Linear(hidden_features, out_features) 139 | self.drop = nn.Dropout(drop) 140 | 141 | def forward(self, x): 142 | x = self.fc1(x) 143 | x = self.act(x) 144 | x = self.drop(x) 145 | x = self.fc2(x) 146 | x = self.drop(x) 147 | return x 148 | 149 | 150 | class Attention(nn.Module): 151 | "Implementation of self-attention" 152 | 153 | def __init__(self, dim, num_heads=8, qkv_bias=False, 154 | qk_scale=None, attn_drop=0., proj_drop=0.): 155 | super().__init__() 156 | self.num_heads = num_heads 157 | head_dim = dim // num_heads 158 | self.scale = qk_scale or head_dim**-0.5 159 | 160 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 161 | self.attn_drop = nn.Dropout(attn_drop) 162 | self.proj = nn.Linear(dim, dim) 163 | self.proj_drop = nn.Dropout(proj_drop) 164 | 165 | def forward(self, x): 166 | B, H, W, C = x.shape 167 | 168 | qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, 169 | C // self.num_heads).permute(2, 0, 3, 1, 4) 170 | q, k, v = qkv[0], qkv[1], qkv[ 171 | 2] # make torchscript happy (cannot use tensor as tuple) 172 | 173 | attn = (q @ k.transpose(-2, -1)) * self.scale 174 | attn = attn.softmax(dim=-1) 175 | attn = self.attn_drop(attn) 176 | 177 | x = (attn @ v).transpose(1, 2).reshape(B, H, W, C) 178 | x = self.proj(x) 179 | x = self.proj_drop(x) 180 | 181 | return x 182 | 183 | 184 | class Transformer(nn.Module): 185 | """ 186 | Implementation of Transformer, 187 | Transformer is the second stage in our VOLO 188 | """ 189 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, 190 | qk_scale=None, attn_drop=0., drop_path=0., 191 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 192 | super().__init__() 193 | self.norm1 = norm_layer(dim) 194 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, 195 | qk_scale=qk_scale, attn_drop=attn_drop) 196 | 197 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 198 | self.drop_path = DropPath( 199 | drop_path) if drop_path > 0. else nn.Identity() 200 | 201 | self.norm2 = norm_layer(dim) 202 | mlp_hidden_dim = int(dim * mlp_ratio) 203 | self.mlp = Mlp(in_features=dim, 204 | hidden_features=mlp_hidden_dim, 205 | act_layer=act_layer) 206 | 207 | def forward(self, x): 208 | x = x + self.drop_path(self.attn(self.norm1(x))) 209 | x = x + self.drop_path(self.mlp(self.norm2(x))) 210 | return x 211 | 212 | 213 | class ClassAttention(nn.Module): 214 | """ 215 | Class attention layer from CaiT, see details in CaiT 216 | Class attention is the post stage in our VOLO, which is optional. 217 | """ 218 | def __init__(self, dim, num_heads=8, head_dim=None, qkv_bias=False, 219 | qk_scale=None, attn_drop=0., proj_drop=0.): 220 | super().__init__() 221 | self.num_heads = num_heads 222 | if head_dim is not None: 223 | self.head_dim = head_dim 224 | else: 225 | head_dim = dim // num_heads 226 | self.head_dim = head_dim 227 | self.scale = qk_scale or head_dim**-0.5 228 | 229 | self.kv = nn.Linear(dim, 230 | self.head_dim * self.num_heads * 2, 231 | bias=qkv_bias) 232 | self.q = nn.Linear(dim, self.head_dim * self.num_heads, bias=qkv_bias) 233 | self.attn_drop = nn.Dropout(attn_drop) 234 | self.proj = nn.Linear(self.head_dim * self.num_heads, dim) 235 | self.proj_drop = nn.Dropout(proj_drop) 236 | 237 | def forward(self, x): 238 | B, N, C = x.shape 239 | 240 | kv = self.kv(x).reshape(B, N, 2, self.num_heads, 241 | self.head_dim).permute(2, 0, 3, 1, 4) 242 | k, v = kv[0], kv[ 243 | 1] # make torchscript happy (cannot use tensor as tuple) 244 | q = self.q(x[:, :1, :]).reshape(B, self.num_heads, 1, self.head_dim) 245 | attn = ((q * self.scale) @ k.transpose(-2, -1)) 246 | attn = attn.softmax(dim=-1) 247 | attn = self.attn_drop(attn) 248 | 249 | cls_embed = (attn @ v).transpose(1, 2).reshape( 250 | B, 1, self.head_dim * self.num_heads) 251 | cls_embed = self.proj(cls_embed) 252 | cls_embed = self.proj_drop(cls_embed) 253 | return cls_embed 254 | 255 | 256 | class ClassBlock(nn.Module): 257 | """ 258 | Class attention block from CaiT, see details in CaiT 259 | We use two-layers class attention in our VOLO, which is optional. 260 | """ 261 | 262 | def __init__(self, dim, num_heads, head_dim=None, mlp_ratio=4., 263 | qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 264 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 265 | super().__init__() 266 | self.norm1 = norm_layer(dim) 267 | self.attn = ClassAttention( 268 | dim, num_heads=num_heads, head_dim=head_dim, qkv_bias=qkv_bias, 269 | qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 270 | # NOTE: drop path for stochastic depth 271 | self.drop_path = DropPath( 272 | drop_path) if drop_path > 0. else nn.Identity() 273 | self.norm2 = norm_layer(dim) 274 | mlp_hidden_dim = int(dim * mlp_ratio) 275 | self.mlp = Mlp(in_features=dim, 276 | hidden_features=mlp_hidden_dim, 277 | act_layer=act_layer, 278 | drop=drop) 279 | 280 | def forward(self, x): 281 | cls_embed = x[:, :1] 282 | cls_embed = cls_embed + self.drop_path(self.attn(self.norm1(x))) 283 | cls_embed = cls_embed + self.drop_path(self.mlp(self.norm2(cls_embed))) 284 | return torch.cat([cls_embed, x[:, 1:]], dim=1) 285 | 286 | 287 | def get_block(block_type, **kargs): 288 | """ 289 | get block by name, specifically for class attention block in here 290 | """ 291 | if block_type == 'ca': 292 | return ClassBlock(**kargs) 293 | 294 | 295 | def rand_bbox(size, lam, scale=1): 296 | """ 297 | get bounding box as token labeling (https://github.com/zihangJiang/TokenLabeling) 298 | return: bounding box 299 | """ 300 | W = size[1] // scale 301 | H = size[2] // scale 302 | cut_rat = np.sqrt(1. - lam) 303 | cut_w = np.int(W * cut_rat) 304 | cut_h = np.int(H * cut_rat) 305 | 306 | # uniform 307 | cx = np.random.randint(W) 308 | cy = np.random.randint(H) 309 | 310 | bbx1 = np.clip(cx - cut_w // 2, 0, W) 311 | bby1 = np.clip(cy - cut_h // 2, 0, H) 312 | bbx2 = np.clip(cx + cut_w // 2, 0, W) 313 | bby2 = np.clip(cy + cut_h // 2, 0, H) 314 | 315 | return bbx1, bby1, bbx2, bby2 316 | 317 | 318 | class PatchEmbed(nn.Module): 319 | """ 320 | Image to Patch Embedding. 321 | Different with ViT use 1 conv layer, we use 4 conv layers to do patch embedding 322 | """ 323 | 324 | def __init__(self, img_size=224, stem_conv=False, stem_stride=1, 325 | patch_size=8, in_chans=3, hidden_dim=64, embed_dim=384): 326 | super().__init__() 327 | assert patch_size in [4, 8, 16] 328 | 329 | self.stem_conv = stem_conv 330 | if stem_conv: 331 | self.conv = nn.Sequential( 332 | nn.Conv2d(in_chans, hidden_dim, kernel_size=7, stride=stem_stride, 333 | padding=3, bias=False), # 112x112 334 | nn.BatchNorm2d(hidden_dim), 335 | nn.ReLU(inplace=True), 336 | nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, 337 | padding=1, bias=False), # 112x112 338 | nn.BatchNorm2d(hidden_dim), 339 | nn.ReLU(inplace=True), 340 | nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, 341 | padding=1, bias=False), # 112x112 342 | nn.BatchNorm2d(hidden_dim), 343 | nn.ReLU(inplace=True), 344 | ) 345 | 346 | self.proj = nn.Conv2d(hidden_dim, 347 | embed_dim, 348 | kernel_size=patch_size // stem_stride, 349 | stride=patch_size // stem_stride) 350 | self.num_patches = (img_size // patch_size) * (img_size // patch_size) 351 | 352 | def forward(self, x): 353 | if self.stem_conv: 354 | x = self.conv(x) 355 | x = self.proj(x) # B, C, H, W 356 | return x 357 | 358 | 359 | class Downsample(nn.Module): 360 | """ 361 | Image to Patch Embedding, downsampling between stage1 and stage2 362 | """ 363 | def __init__(self, in_embed_dim, out_embed_dim, patch_size): 364 | super().__init__() 365 | self.proj = nn.Conv2d(in_embed_dim, out_embed_dim, 366 | kernel_size=patch_size, stride=patch_size) 367 | 368 | def forward(self, x): 369 | x = x.permute(0, 3, 1, 2) 370 | x = self.proj(x) # B, C, H, W 371 | x = x.permute(0, 2, 3, 1) 372 | return x 373 | 374 | 375 | def outlooker_blocks(block_fn, index, dim, layers, num_heads=1, kernel_size=3, 376 | padding=1,stride=1, mlp_ratio=3., qkv_bias=False, qk_scale=None, 377 | attn_drop=0, drop_path_rate=0., **kwargs): 378 | """ 379 | generate outlooker layer in stage1 380 | return: outlooker layers 381 | """ 382 | blocks = [] 383 | for block_idx in range(layers[index]): 384 | block_dpr = drop_path_rate * (block_idx + 385 | sum(layers[:index])) / (sum(layers) - 1) 386 | blocks.append(block_fn(dim, kernel_size=kernel_size, padding=padding, 387 | stride=stride, num_heads=num_heads, mlp_ratio=mlp_ratio, 388 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, 389 | drop_path=block_dpr)) 390 | 391 | blocks = nn.Sequential(*blocks) 392 | 393 | return blocks 394 | 395 | 396 | def transformer_blocks(block_fn, index, dim, layers, num_heads, mlp_ratio=3., 397 | qkv_bias=False, qk_scale=None, attn_drop=0, 398 | drop_path_rate=0., **kwargs): 399 | """ 400 | generate transformer layers in stage2 401 | return: transformer layers 402 | """ 403 | blocks = [] 404 | for block_idx in range(layers[index]): 405 | block_dpr = drop_path_rate * (block_idx + 406 | sum(layers[:index])) / (sum(layers) - 1) 407 | blocks.append( 408 | block_fn(dim, num_heads, 409 | mlp_ratio=mlp_ratio, 410 | qkv_bias=qkv_bias, 411 | qk_scale=qk_scale, 412 | attn_drop=attn_drop, 413 | drop_path=block_dpr)) 414 | 415 | blocks = nn.Sequential(*blocks) 416 | 417 | return blocks 418 | 419 | 420 | class VOLO(nn.Module): 421 | """ 422 | Vision Outlooker, the main class of our model 423 | --layers: [x,x,x,x], four blocks in two stages, the first block is outlooker, the 424 | other three are transformer, we set four blocks, which are easily 425 | applied to downstream tasks 426 | --img_size, --in_chans, --num_classes: these three are very easy to understand 427 | --patch_size: patch_size in outlook attention 428 | --stem_hidden_dim: hidden dim of patch embedding, d1-d4 is 64, d5 is 128 429 | --embed_dims, --num_heads: embedding dim, number of heads in each block 430 | --downsamples: flags to apply downsampling or not 431 | --outlook_attention: flags to apply outlook attention or not 432 | --mlp_ratios, --qkv_bias, --qk_scale, --drop_rate: easy to undertand 433 | --attn_drop_rate, --drop_path_rate, --norm_layer: easy to undertand 434 | --post_layers: post layers like two class attention layers using [ca, ca], 435 | if yes, return_mean=False 436 | --return_mean: use mean of all feature tokens for classification, if yes, no class token 437 | --return_dense: use token labeling, details are here: 438 | https://github.com/zihangJiang/TokenLabeling 439 | --mix_token: mixing tokens as token labeling, details are here: 440 | https://github.com/zihangJiang/TokenLabeling 441 | --pooling_scale: pooling_scale=2 means we downsample 2x 442 | --out_kernel, --out_stride, --out_padding: kerner size, 443 | stride, and padding for outlook attention 444 | """ 445 | def __init__(self, layers, img_size=224, in_chans=3, num_classes=1000, patch_size=8, 446 | stem_hidden_dim=64, embed_dims=None, num_heads=None, downsamples=None, 447 | outlook_attention=None, mlp_ratios=None, qkv_bias=False, qk_scale=None, 448 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, 449 | post_layers=None, return_mean=False, return_dense=True, mix_token=True, 450 | pooling_scale=2, out_kernel=3, out_stride=2, out_padding=1): 451 | 452 | super().__init__() 453 | self.num_classes = num_classes 454 | self.patch_embed = PatchEmbed(stem_conv=True, stem_stride=2, patch_size=patch_size, 455 | in_chans=in_chans, hidden_dim=stem_hidden_dim, 456 | embed_dim=embed_dims[0]) 457 | 458 | # inital positional encoding, we add positional encoding after outlooker blocks 459 | self.pos_embed = nn.Parameter( 460 | torch.zeros(1, img_size // patch_size // pooling_scale, 461 | img_size // patch_size // pooling_scale, 462 | embed_dims[-1])) 463 | 464 | self.pos_drop = nn.Dropout(p=drop_rate) 465 | 466 | # set the main block in network 467 | network = [] 468 | for i in range(len(layers)): 469 | if outlook_attention[i]: 470 | # stage 1 471 | stage = outlooker_blocks(Outlooker, i, embed_dims[i], layers, 472 | downsample=downsamples[i], num_heads=num_heads[i], 473 | kernel_size=out_kernel, stride=out_stride, 474 | padding=out_padding, mlp_ratio=mlp_ratios[i], 475 | qkv_bias=qkv_bias, qk_scale=qk_scale, 476 | attn_drop=attn_drop_rate, norm_layer=norm_layer) 477 | network.append(stage) 478 | else: 479 | # stage 2 480 | stage = transformer_blocks(Transformer, i, embed_dims[i], layers, 481 | num_heads[i], mlp_ratio=mlp_ratios[i], 482 | qkv_bias=qkv_bias, qk_scale=qk_scale, 483 | drop_path_rate=drop_path_rate, 484 | attn_drop=attn_drop_rate, 485 | norm_layer=norm_layer) 486 | network.append(stage) 487 | 488 | if downsamples[i]: 489 | # downsampling between two stages 490 | network.append(Downsample(embed_dims[i], embed_dims[i + 1], 2)) 491 | 492 | self.network = nn.ModuleList(network) 493 | 494 | # set post block, for example, class attention layers 495 | self.post_network = None 496 | if post_layers is not None: 497 | self.post_network = nn.ModuleList([ 498 | get_block(post_layers[i], 499 | dim=embed_dims[-1], 500 | num_heads=num_heads[-1], 501 | mlp_ratio=mlp_ratios[-1], 502 | qkv_bias=qkv_bias, 503 | qk_scale=qk_scale, 504 | attn_drop=attn_drop_rate, 505 | drop_path=0., 506 | norm_layer=norm_layer) 507 | for i in range(len(post_layers)) 508 | ]) 509 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims[-1])) 510 | trunc_normal_(self.cls_token, std=.02) 511 | 512 | # set output type 513 | self.return_mean = return_mean # if yes, return mean, not use class token 514 | self.return_dense = return_dense # if yes, return class token and all feature tokens 515 | if return_dense: 516 | assert not return_mean, "cannot return both mean and dense" 517 | self.mix_token = mix_token 518 | self.pooling_scale = pooling_scale 519 | if mix_token: # enable token mixing, see token labeling for details. 520 | self.beta = 1.0 521 | assert return_dense, "return all tokens if mix_token is enabled" 522 | if return_dense: 523 | self.aux_head = nn.Linear( 524 | embed_dims[-1], 525 | num_classes) if num_classes > 0 else nn.Identity() 526 | self.norm = norm_layer(embed_dims[-1]) 527 | 528 | # Classifier head 529 | self.head = nn.Linear( 530 | embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity() 531 | 532 | trunc_normal_(self.pos_embed, std=.02) 533 | self.apply(self._init_weights) 534 | 535 | def _init_weights(self, m): 536 | if isinstance(m, nn.Linear): 537 | trunc_normal_(m.weight, std=.02) 538 | if isinstance(m, nn.Linear) and m.bias is not None: 539 | nn.init.constant_(m.bias, 0) 540 | elif isinstance(m, nn.LayerNorm): 541 | nn.init.constant_(m.bias, 0) 542 | nn.init.constant_(m.weight, 1.0) 543 | 544 | @torch.jit.ignore 545 | def no_weight_decay(self): 546 | return {'pos_embed', 'cls_token'} 547 | 548 | def get_classifier(self): 549 | return self.head 550 | 551 | def reset_classifier(self, num_classes): 552 | self.num_classes = num_classes 553 | self.head = nn.Linear( 554 | self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 555 | 556 | def forward_embeddings(self, x): 557 | # patch embedding 558 | x = self.patch_embed(x) 559 | # B,C,H,W-> B,H,W,C 560 | x = x.permute(0, 2, 3, 1) 561 | return x 562 | 563 | def forward_tokens(self, x): 564 | for idx, block in enumerate(self.network): 565 | if idx == 2: # add positional encoding after outlooker blocks 566 | x = x + self.pos_embed 567 | x = self.pos_drop(x) 568 | x = block(x) 569 | 570 | B, H, W, C = x.shape 571 | x = x.reshape(B, -1, C) 572 | return x 573 | 574 | def forward_cls(self, x): 575 | B, N, C = x.shape 576 | cls_tokens = self.cls_token.expand(B, -1, -1) 577 | x = torch.cat((cls_tokens, x), dim=1) 578 | for block in self.post_network: 579 | x = block(x) 580 | return x 581 | 582 | def forward(self, x): 583 | # step1: patch embedding 584 | x = self.forward_embeddings(x) 585 | 586 | # mix token, see token labeling for details. 587 | if self.mix_token and self.training: 588 | lam = np.random.beta(self.beta, self.beta) 589 | patch_h, patch_w = x.shape[1] // self.pooling_scale, x.shape[ 590 | 2] // self.pooling_scale 591 | bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam, scale=self.pooling_scale) 592 | temp_x = x.clone() 593 | sbbx1,sbby1,sbbx2,sbby2=self.pooling_scale*bbx1,self.pooling_scale*bby1,\ 594 | self.pooling_scale*bbx2,self.pooling_scale*bby2 595 | temp_x[:, sbbx1:sbbx2, sbby1:sbby2, :] = x.flip(0)[:, sbbx1:sbbx2, sbby1:sbby2, :] 596 | x = temp_x 597 | else: 598 | bbx1, bby1, bbx2, bby2 = 0, 0, 0, 0 599 | 600 | # step2: tokens learning in the two stages 601 | x = self.forward_tokens(x) 602 | 603 | # step3: post network, apply class attention or not 604 | if self.post_network is not None: 605 | x = self.forward_cls(x) 606 | x = self.norm(x) 607 | 608 | if self.return_mean: # if no class token, return mean 609 | return self.head(x.mean(1)) 610 | 611 | x_cls = self.head(x[:, 0]) 612 | if not self.return_dense: 613 | return x_cls 614 | 615 | x_aux = self.aux_head( 616 | x[:, 1:] 617 | ) # generate classes in all feature tokens, see token labeling 618 | 619 | if not self.training: 620 | return x_cls + 0.5 * x_aux.max(1)[0] 621 | 622 | if self.mix_token and self.training: # reverse "mix token", see token labeling for details. 623 | x_aux = x_aux.reshape(x_aux.shape[0], patch_h, patch_w, x_aux.shape[-1]) 624 | 625 | temp_x = x_aux.clone() 626 | temp_x[:, bbx1:bbx2, bby1:bby2, :] = x_aux.flip(0)[:, bbx1:bbx2, bby1:bby2, :] 627 | x_aux = temp_x 628 | 629 | x_aux = x_aux.reshape(x_aux.shape[0], patch_h * patch_w, x_aux.shape[-1]) 630 | 631 | # return these: 1. class token, 2. classes from all feature tokens, 3. bounding box 632 | return x_cls, x_aux, (bbx1, bby1, bbx2, bby2) 633 | 634 | 635 | @register_model 636 | def volo_d1(pretrained=False, **kwargs): 637 | """ 638 | VOLO-D1 model, Params: 27M 639 | --layers: [x,x,x,x], four blocks in two stages, the first stage(block) is outlooker, 640 | the other three blocks are transformer, we set four blocks, which are easily 641 | applied to downstream tasks 642 | --embed_dims, --num_heads,: embedding dim, number of heads in each block 643 | --downsamples: flags to apply downsampling or not in four blocks 644 | --outlook_attention: flags to apply outlook attention or not 645 | --mlp_ratios: mlp ratio in four blocks 646 | --post_layers: post layers like two class attention layers using [ca, ca] 647 | See detail for all args in the class VOLO() 648 | """ 649 | layers = [4, 4, 8, 2] # num of layers in the four blocks 650 | embed_dims = [192, 384, 384, 384] 651 | num_heads = [6, 12, 12, 12] 652 | mlp_ratios = [3, 3, 3, 3] 653 | downsamples = [True, False, False, False] # do downsampling after first block 654 | outlook_attention = [True, False, False, False ] 655 | # first block is outlooker (stage1), the other three are transformer (stage2) 656 | model = VOLO(layers, 657 | embed_dims=embed_dims, 658 | num_heads=num_heads, 659 | mlp_ratios=mlp_ratios, 660 | downsamples=downsamples, 661 | outlook_attention=outlook_attention, 662 | post_layers=['ca', 'ca'], 663 | **kwargs) 664 | model.default_cfg = default_cfgs['volo'] 665 | return model 666 | 667 | 668 | @register_model 669 | def volo_d2(pretrained=False, **kwargs): 670 | """ 671 | VOLO-D2 model, Params: 59M 672 | """ 673 | layers = [6, 4, 10, 4] 674 | embed_dims = [256, 512, 512, 512] 675 | num_heads = [8, 16, 16, 16] 676 | mlp_ratios = [3, 3, 3, 3] 677 | downsamples = [True, False, False, False] 678 | outlook_attention = [True, False, False, False] 679 | model = VOLO(layers, 680 | embed_dims=embed_dims, 681 | num_heads=num_heads, 682 | mlp_ratios=mlp_ratios, 683 | downsamples=downsamples, 684 | outlook_attention=outlook_attention, 685 | post_layers=['ca', 'ca'], 686 | **kwargs) 687 | model.default_cfg = default_cfgs['volo'] 688 | return model 689 | 690 | 691 | @register_model 692 | def volo_d3(pretrained=False, **kwargs): 693 | """ 694 | VOLO-D3 model, Params: 86M 695 | """ 696 | layers = [8, 8, 16, 4] 697 | embed_dims = [256, 512, 512, 512] 698 | num_heads = [8, 16, 16, 16] 699 | mlp_ratios = [3, 3, 3, 3] 700 | downsamples = [True, False, False, False] 701 | outlook_attention = [True, False, False, False] 702 | model = VOLO(layers, 703 | embed_dims=embed_dims, 704 | num_heads=num_heads, 705 | mlp_ratios=mlp_ratios, 706 | downsamples=downsamples, 707 | outlook_attention=outlook_attention, 708 | post_layers=['ca', 'ca'], 709 | **kwargs) 710 | model.default_cfg = default_cfgs['volo'] 711 | return model 712 | 713 | 714 | @register_model 715 | def volo_d4(pretrained=False, **kwargs): 716 | """ 717 | VOLO-D4 model, Params: 193M 718 | """ 719 | layers = [8, 8, 16, 4] 720 | embed_dims = [384, 768, 768, 768] 721 | num_heads = [12, 16, 16, 16] 722 | mlp_ratios = [3, 3, 3, 3] 723 | downsamples = [True, False, False, False] 724 | outlook_attention = [True, False, False, False] 725 | model = VOLO(layers, 726 | embed_dims=embed_dims, 727 | num_heads=num_heads, 728 | mlp_ratios=mlp_ratios, 729 | downsamples=downsamples, 730 | outlook_attention=outlook_attention, 731 | post_layers=['ca', 'ca'], 732 | **kwargs) 733 | model.default_cfg = default_cfgs['volo_large'] 734 | return model 735 | 736 | 737 | @register_model 738 | def volo_d5(pretrained=False, **kwargs): 739 | """ 740 | VOLO-D5 model, Params: 296M 741 | stem_hidden_dim=128, the dim in patch embedding is 128 for VOLO-D5 742 | """ 743 | layers = [12, 12, 20, 4] 744 | embed_dims = [384, 768, 768, 768] 745 | num_heads = [12, 16, 16, 16] 746 | mlp_ratios = [4, 4, 4, 4] 747 | downsamples = [True, False, False, False] 748 | outlook_attention = [True, False, False, False] 749 | model = VOLO(layers, 750 | embed_dims=embed_dims, 751 | num_heads=num_heads, 752 | mlp_ratios=mlp_ratios, 753 | downsamples=downsamples, 754 | outlook_attention=outlook_attention, 755 | post_layers=['ca', 'ca'], 756 | stem_hidden_dim=128, 757 | **kwargs) 758 | model.default_cfg = default_cfgs['volo_large'] 759 | return model 760 | -------------------------------------------------------------------------------- /train_kd.py: -------------------------------------------------------------------------------- 1 | """ 2 | ImageNet Training Script 3 | This script is adapted from pytorch-image-models by Ross Wightman (https://github.com/rwightman/pytorch-image-models/) 4 | It was started from an early version of the PyTorch ImageNet example 5 | (https://github.com/pytorch/examples/tree/master/imagenet) 6 | """ 7 | import warnings 8 | warnings.filterwarnings("ignore") 9 | 10 | import argparse 11 | import time 12 | import yaml 13 | import os 14 | import logging 15 | from collections import OrderedDict 16 | from contextlib import suppress 17 | from datetime import datetime 18 | import numpy as np 19 | 20 | import torch 21 | import torch.nn as nn 22 | import torchvision.utils 23 | from torch.nn.parallel import DistributedDataParallel as NativeDDP 24 | 25 | from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset 26 | from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model, model_parameters 27 | from timm.utils import * 28 | from timm.loss import LabelSmoothingCrossEntropy 29 | from loss_adv import SoftTargetCrossEntropy, DistillationLoss 30 | from timm.optim import create_optimizer 31 | from timm.scheduler import create_scheduler 32 | from timm.utils import ApexScaler, NativeScaler 33 | 34 | from tlt.utils import load_pretrained_weights 35 | import models 36 | import torch.nn.functional as F 37 | 38 | try: 39 | from apex import amp 40 | from apex.parallel import DistributedDataParallel as ApexDDP 41 | from apex.parallel import convert_syncbn_model 42 | 43 | has_apex = True 44 | except ImportError: 45 | has_apex = False 46 | 47 | has_native_amp = False 48 | try: 49 | if getattr(torch.cuda.amp, 'autocast') is not None: 50 | has_native_amp = True 51 | except AttributeError: 52 | pass 53 | 54 | torch.backends.cudnn.benchmark = True 55 | _logger = logging.getLogger('train') 56 | 57 | # The first arg parser parses out only the --config argument, this argument is used to 58 | # load a yaml file containing key-values that override the defaults for the main parser below 59 | config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False) 60 | parser.add_argument('-c', '--config', default='', type=str, metavar='FILE', 61 | help='YAML config file specifying default arguments') 62 | 63 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 64 | 65 | # Dataset / Model parameters 66 | parser.add_argument('--data_dir', metavar='DIR', 67 | help='path to dataset') 68 | parser.add_argument('--dataset', '-d', metavar='NAME', default='', 69 | help='dataset type (default: ImageFolder/ImageTar if empty)') 70 | parser.add_argument('--train-split', metavar='NAME', default='train', 71 | help='dataset train split (default: train)') 72 | parser.add_argument('--val-split', metavar='NAME', default='validation', 73 | help='dataset validation split (default: validation)') 74 | parser.add_argument('--model', default='volo_d1', type=str, metavar='MODEL', 75 | help='Name of model to train (default: "volo_d1"') 76 | parser.add_argument('--pretrained', action='store_true', default=False, 77 | help='Start with pretrained version of specified network (if avail)') 78 | parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH', 79 | help='Initialize model from this checkpoint (default: none)') 80 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 81 | help='Resume full model and optimizer state from checkpoint (default: none)') 82 | parser.add_argument('--no-resume-opt', action='store_true', default=False, 83 | help='prevent resume of optimizer state when resuming model') 84 | parser.add_argument('--num-classes', type=int, default=None, metavar='N', 85 | help='number of label classes (Model default if None)') 86 | parser.add_argument('--gp', default=None, type=str, metavar='POOL', 87 | help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') 88 | parser.add_argument('--img-size', type=int, default=None, metavar='N', 89 | help='Image patch size (default: None => model default)') 90 | parser.add_argument('--input-size', default=None, nargs=3, type=int, 91 | metavar='N N N', 92 | help='Input all image dimensions (d h w, e.g. --input-size 3 224 224),' 93 | ' uses model default if empty') 94 | parser.add_argument('--crop-pct', default=None, type=float, 95 | metavar='N', help='Input image center crop percent (for validation only)') 96 | parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', 97 | help='Override mean pixel value of dataset') 98 | parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', 99 | help='Override std deviation of of dataset') 100 | parser.add_argument('--interpolation', default='', type=str, metavar='NAME', 101 | help='Image resize interpolation type (overrides model)') 102 | parser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N', 103 | help='input batch size for training (default: 128)') 104 | parser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N', 105 | help='ratio of validation batch size to training batch size (default: 1)') 106 | 107 | # Optimizer parameters 108 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', 109 | help='Optimizer (default: "adamw"') 110 | parser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON', 111 | help='Optimizer Epsilon (default: None, use opt default)') 112 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', 113 | help='Optimizer Betas (default: None, use opt default)') 114 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 115 | help='Optimizer momentum (default: 0.9)') 116 | parser.add_argument('--weight-decay', type=float, default=0.05, 117 | help='weight decay (default: 0.05)') 118 | parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', 119 | help='Clip gradient norm (default: None, no clipping)') 120 | parser.add_argument('--clip-mode', type=str, default='norm', 121 | help='Gradient clipping mode. One of ("norm", "value", "agc")') 122 | 123 | # Learning rate schedule parameters 124 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', 125 | help='LR scheduler (default: "cosine"') 126 | parser.add_argument('--lr', type=float, default=1.6e-3, metavar='LR', 127 | help='learning rate (default: 1.6e-3)') 128 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', 129 | help='learning rate noise on/off epoch percentages') 130 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', 131 | help='learning rate noise limit percent (default: 0.67)') 132 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', 133 | help='learning rate noise std-dev (default: 1.0)') 134 | parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT', 135 | help='learning rate cycle len multiplier (default: 1.0)') 136 | parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N', 137 | help='learning rate cycle limit') 138 | parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', 139 | help='warmup learning rate (default: 0.0001)') 140 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', 141 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 142 | parser.add_argument('--epochs', type=int, default=300, metavar='N', 143 | help='number of epochs to train (default: 300)') 144 | parser.add_argument('--start-epoch', default=None, type=int, metavar='N', 145 | help='manual epoch number (useful on restarts)') 146 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', 147 | help='epoch interval to decay LR') 148 | parser.add_argument('--warmup-epochs', type=int, default=20, metavar='N', 149 | help='epochs to warmup LR, if scheduler supports') 150 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', 151 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends') 152 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', 153 | help='patience epochs for Plateau LR scheduler (default: 10') 154 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', 155 | help='LR decay rate (default: 0.1)') 156 | 157 | # Augmentation & regularization parameters 158 | parser.add_argument('--no-aug', action='store_true', default=False, 159 | help='Disable all training augmentation, override other train aug args') 160 | parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT', 161 | help='Random resize scale (default: 0.08 1.0)') 162 | parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO', 163 | help='Random resize aspect ratio (default: 0.75 1.33)') 164 | parser.add_argument('--hflip', type=float, default=0.5, 165 | help='Horizontal flip training aug probability') 166 | parser.add_argument('--vflip', type=float, default=0., 167 | help='Vertical flip training aug probability') 168 | parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', 169 | help='Color jitter factor (default: 0.4)') 170 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 171 | help='Use AutoAugment policy. "v0" or "original". (default: rand-m9-mstd0.5-inc1)'), 172 | parser.add_argument('--aug-splits', type=int, default=0, 173 | help='Number of augmentation splits (default: 0, valid: 0 or >=2)') 174 | parser.add_argument('--jsd', action='store_true', default=False, 175 | help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.') 176 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 177 | help='Random erase prob (default: 0.25)') 178 | parser.add_argument('--remode', type=str, default='pixel', 179 | help='Random erase mode (default: "pixel")') 180 | parser.add_argument('--recount', type=int, default=1, 181 | help='Random erase count (default: 1)') 182 | parser.add_argument('--resplit', action='store_true', default=False, 183 | help='Do not random erase first (clean) augmentation split') 184 | parser.add_argument('--mixup', type=float, default=0.8, 185 | help='mixup alpha, mixup enabled if > 0. (default: 0.)') 186 | parser.add_argument('--cutmix', type=float, default=1.0, 187 | help='cutmix alpha, cutmix enabled if > 0. (default: 0.)') 188 | parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, 189 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 190 | parser.add_argument('--mixup-prob', type=float, default=1.0, 191 | help='Probability of performing mixup or cutmix when either/both is enabled') 192 | parser.add_argument('--mixup-switch-prob', type=float, default=0.5, 193 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 194 | parser.add_argument('--mixup-mode', type=str, default='batch', 195 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 196 | parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N', 197 | help='Turn off mixup after this epoch, disabled if 0 (default: 0)') 198 | parser.add_argument('--smoothing', type=float, default=0.1, 199 | help='Label smoothing (default: 0.1)') 200 | parser.add_argument('--train-interpolation', type=str, default='random', 201 | help='Training interpolation (random, bilinear, bicubic default: "random")') 202 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', 203 | help='Dropout rate (default: 0.)') 204 | parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT', 205 | help='Drop connect rate, DEPRECATED, use drop-path (default: None)') 206 | parser.add_argument('--drop-path', type=float, default=None, metavar='PCT', 207 | help='Drop path rate (default: None)') 208 | parser.add_argument('--drop-block', type=float, default=None, metavar='PCT', 209 | help='Drop block rate (default: None)') 210 | 211 | # Batch norm parameters (only works with gen_efficientnet based models currently) 212 | parser.add_argument('--bn-tf', action='store_true', default=False, 213 | help='Use Tensorflow BatchNorm defaults for models that support it (default: False)') 214 | parser.add_argument('--bn-momentum', type=float, default=None, 215 | help='BatchNorm momentum override (if not None)') 216 | parser.add_argument('--bn-eps', type=float, default=None, 217 | help='BatchNorm epsilon override (if not None)') 218 | parser.add_argument('--sync-bn', action='store_true', 219 | help='Enable NVIDIA Apex or Torch synchronized BatchNorm.') 220 | parser.add_argument('--dist-bn', type=str, default='', 221 | help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")') 222 | parser.add_argument('--split-bn', action='store_true', 223 | help='Enable separate BN layers per augmentation split.') 224 | 225 | # Model Exponential Moving Average 226 | parser.add_argument('--model-ema', action='store_true', default=False, 227 | help='Enable tracking moving average of model weights') 228 | parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, 229 | help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.') 230 | parser.add_argument('--model-ema-decay', type=float, default=0.99992, 231 | help='decay factor for model weights moving average (default: 0.99992)') 232 | 233 | # Misc 234 | parser.add_argument('--seed', type=int, default=42, metavar='S', 235 | help='random seed (default: 42)') 236 | parser.add_argument('--log-interval', type=int, default=50, metavar='N', 237 | help='how many batches to wait before logging training status') 238 | parser.add_argument('--recovery-interval', type=int, default=0, metavar='N', 239 | help='how many batches to wait before writing recovery checkpoint') 240 | parser.add_argument('--checkpoint-hist', type=int, default=3, metavar='N', 241 | help='number of checkpoints to keep (default: 10)') 242 | parser.add_argument('-j', '--workers', type=int, default=8, metavar='N', 243 | help='how many training processes to use (default: 1)') 244 | parser.add_argument('--save-images', action='store_true', default=False, 245 | help='save images of input bathes every log interval for debugging') 246 | parser.add_argument('--amp', action='store_true', default=False, 247 | help='use NVIDIA Apex AMP or Native AMP for mixed precision training') 248 | parser.add_argument('--apex-amp', action='store_true', default=True, 249 | help='Use NVIDIA Apex AMP mixed precision') 250 | parser.add_argument('--native-amp', action='store_true', default=False, 251 | help='Use Native Torch AMP mixed precision') 252 | parser.add_argument('--channels-last', action='store_true', default=False, 253 | help='Use channels_last memory layout') 254 | parser.add_argument('--pin-mem', action='store_true', default=False, 255 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 256 | parser.add_argument('--no-prefetcher', action='store_true', default=False, 257 | help='disable fast prefetcher') 258 | parser.add_argument('--output', default='', type=str, metavar='PATH', 259 | help='path to output folder (default: none, current dir)') 260 | parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC', 261 | help='Best metric (default: "top1"') 262 | parser.add_argument('--tta', type=int, default=0, metavar='N', 263 | help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)') 264 | parser.add_argument("--local_rank", default=0, type=int) 265 | parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False, 266 | help='use the multi-epochs-loader to save time at the beginning of every epoch') 267 | parser.add_argument('--torchscript', dest='torchscript', action='store_true', 268 | help='convert model torchscript for inference') 269 | 270 | 271 | # Finetune 272 | parser.add_argument('--finetune', default='', type=str, metavar='PATH', 273 | help='path to checkpoint file (default: none)') 274 | 275 | 276 | # Distillation parameters 277 | parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL', 278 | help='Name of teacher model to train (default: "regnety_160"') 279 | parser.add_argument('--teacher-path', type=str, default='') 280 | parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="") 281 | parser.add_argument('--distillation-alpha', default=0.5, type=float, help="") 282 | parser.add_argument('--distillation-tau', default=1.0, type=float, help="") 283 | 284 | # Adversarial 285 | parser.add_argument('--adv-epochs', default=200, type=int, metavar='N', 286 | help='number of epochs for performing adversarial training') 287 | parser.add_argument('--adv-iters', default=3, type=int, metavar='N', 288 | help='number of iterations for adversarial augmentation') 289 | parser.add_argument('--adv-eps', default=2.0/255, type=float, 290 | help='adversarial strength for adversarial augmentation') 291 | parser.add_argument('--adv-lr', default=1.0/255, type=float, 292 | help='learning rate for adversarial augmentation') 293 | parser.add_argument('--adv-kl-weight', default=0.01, type=float, 294 | help='weight of KL-divergence for adversarial augmentation') 295 | parser.add_argument('--adv-ce-weight', default=3.0, type=float, 296 | help='weight of ce-loss for adversarial augmentation') 297 | 298 | 299 | def _parse_args(): 300 | # Do we have a config file to parse? 301 | args_config, remaining = config_parser.parse_known_args() 302 | if args_config.config: 303 | with open(args_config.config, 'r') as f: 304 | cfg = yaml.safe_load(f) 305 | parser.set_defaults(**cfg) 306 | 307 | # The main arg parser parses the rest of the args, the usual 308 | # defaults will have been overridden if config file specified. 309 | args = parser.parse_args(remaining) 310 | 311 | # Cache the args as a text string to save them in the output dir later 312 | args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) 313 | return args, args_text 314 | 315 | 316 | def main(): 317 | setup_default_logging() 318 | args, args_text = _parse_args() 319 | 320 | args.prefetcher = not args.no_prefetcher 321 | args.distributed = False 322 | if 'WORLD_SIZE' in os.environ: 323 | args.distributed = int(os.environ['WORLD_SIZE']) > 1 324 | args.device = 'cuda:0' 325 | args.world_size = 1 326 | args.rank = 0 # global rank 327 | if args.distributed: 328 | args.device = 'cuda:%d' % args.local_rank 329 | torch.cuda.set_device(args.local_rank) 330 | torch.distributed.init_process_group(backend='nccl', 331 | init_method='env://') 332 | args.world_size = torch.distributed.get_world_size() 333 | args.rank = torch.distributed.get_rank() 334 | _logger.info( 335 | 'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' 336 | % (args.rank, args.world_size)) 337 | else: 338 | _logger.info('Training with a single process on 1 GPUs.') 339 | assert args.rank >= 0 340 | 341 | # resolve AMP arguments based on PyTorch / Apex availability 342 | use_amp = None 343 | if args.amp: 344 | # `--amp` chooses native amp before apex (APEX ver not actively maintained) 345 | if has_native_amp: 346 | args.native_amp = True 347 | elif has_apex: 348 | args.apex_amp = True 349 | if args.apex_amp and has_apex: 350 | use_amp = 'apex' 351 | elif args.native_amp and has_native_amp: 352 | use_amp = 'native' 353 | elif args.apex_amp or args.native_amp: 354 | _logger.warning( 355 | "Neither APEX or native Torch AMP is available, using float32. " 356 | "Install NVIDA apex or upgrade to PyTorch 1.6") 357 | 358 | torch.manual_seed(args.seed + args.rank) 359 | np.random.seed(args.seed + args.rank) 360 | model = create_model( 361 | args.model, 362 | pretrained=args.pretrained, 363 | num_classes=args.num_classes, 364 | drop_rate=args.drop, 365 | drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path 366 | drop_path_rate=args.drop_path, 367 | drop_block_rate=args.drop_block, 368 | global_pool=args.gp, 369 | bn_tf=args.bn_tf, 370 | bn_momentum=args.bn_momentum, 371 | bn_eps=args.bn_eps, 372 | scriptable=args.torchscript, 373 | checkpoint_path=args.initial_checkpoint, 374 | img_size=args.img_size) 375 | if args.num_classes is None: 376 | assert hasattr( 377 | model, 'num_classes' 378 | ), 'Model must have `num_classes` attr if not set on cmd line/config.' 379 | args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly 380 | 381 | if args.finetune: 382 | load_pretrained_weights(model=model, 383 | checkpoint_path=args.finetune, 384 | use_ema=args.model_ema, 385 | strict=False, 386 | num_classes=args.num_classes) 387 | 388 | if args.rank == 0: 389 | _logger.info('Model %s created, param count: %d' % 390 | (args.model, sum([m.numel() 391 | for m in model.parameters()]))) 392 | 393 | data_config = resolve_data_config(vars(args), 394 | model=model, 395 | verbose=args.local_rank == 0) 396 | 397 | # setup augmentation batch splits for contrastive loss or split bn 398 | num_aug_splits = 0 399 | if args.aug_splits > 0: 400 | assert args.aug_splits > 1, 'A split of 1 makes no sense' 401 | num_aug_splits = args.aug_splits 402 | 403 | # enable split bn (separate bn stats per batch-portion) 404 | if args.split_bn: 405 | assert num_aug_splits > 1 or args.resplit 406 | model = convert_splitbn_model(model, max(num_aug_splits, 2)) 407 | 408 | # move model to GPU, enable channels last layout if set 409 | model.cuda() 410 | if args.channels_last: 411 | model = model.to(memory_format=torch.channels_last) 412 | 413 | # setup synchronized BatchNorm for distributed training 414 | if args.distributed and args.sync_bn: 415 | assert not args.split_bn 416 | if has_apex and use_amp != 'native': 417 | # Apex SyncBN preferred unless native amp is activated 418 | model = convert_syncbn_model(model) 419 | else: 420 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 421 | if args.rank == 0: 422 | _logger.info( 423 | 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 424 | 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.' 425 | ) 426 | 427 | if args.torchscript: 428 | assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' 429 | assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' 430 | model = torch.jit.script(model) 431 | 432 | linear_scaled_lr = args.lr * args.batch_size * args.world_size / 1024.0 433 | args.lr = linear_scaled_lr 434 | if args.rank == 0: 435 | print("learning rate:", args.lr) 436 | optimizer = create_optimizer(args, model) 437 | 438 | # setup automatic mixed-precision (AMP) loss scaling and op casting 439 | amp_autocast = suppress # do nothing 440 | loss_scaler = None 441 | optimizers = None 442 | if use_amp == 'apex': 443 | model, optimizer = amp.initialize(model, optimizer, opt_level='O1') 444 | loss_scaler = ApexScaler() 445 | if args.rank == 0: 446 | _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') 447 | elif use_amp == 'native': 448 | amp_autocast = torch.cuda.amp.autocast 449 | loss_scaler = NativeScaler() 450 | if args.rank == 0: 451 | _logger.info( 452 | 'Using native Torch AMP. Training in mixed precision.') 453 | else: 454 | if args.rank == 0: 455 | _logger.info('AMP not enabled. Training in float32.') 456 | 457 | # optionally resume from a checkpoint 458 | resume_epoch = None 459 | if args.resume and os.path.isfile(args.resume): 460 | resume_epoch = resume_checkpoint( 461 | model, 462 | args.resume, 463 | optimizer=None if args.no_resume_opt else optimizer, 464 | loss_scaler=None if args.no_resume_opt else loss_scaler, 465 | log_info=args.rank == 0) 466 | 467 | # setup exponential moving average of model weights, SWA could be used here too 468 | model_ema = None 469 | if args.model_ema: 470 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper 471 | model_ema = ModelEmaV2( 472 | model, 473 | decay=args.model_ema_decay, 474 | device='cpu' if args.model_ema_force_cpu else None) 475 | if args.resume and os.path.isfile(args.resume): 476 | load_checkpoint(model_ema.module, args.resume, use_ema=True) 477 | 478 | # setup distributed training 479 | if args.distributed: 480 | if has_apex and use_amp != 'native': 481 | # Apex DDP preferred unless native amp is activated 482 | if args.rank == 0: 483 | _logger.info("Using NVIDIA APEX DistributedDataParallel.") 484 | model = ApexDDP(model, delay_allreduce=True) 485 | else: 486 | if args.rank == 0: 487 | _logger.info("Using native Torch DistributedDataParallel.") 488 | model = NativeDDP(model, device_ids=[ 489 | args.local_rank 490 | ]) # can use device str in Torch >= 1.1 491 | # NOTE: EMA model does not need to be wrapped by DDP 492 | 493 | # setup learning rate schedule and starting epoch 494 | lr_scheduler, num_epochs = create_scheduler(args, optimizer) 495 | start_epoch = 0 496 | if args.start_epoch is not None: 497 | # a specified start_epoch will always override the resume epoch 498 | start_epoch = args.start_epoch 499 | elif resume_epoch is not None: 500 | start_epoch = resume_epoch 501 | if lr_scheduler is not None and start_epoch > 0: 502 | lr_scheduler.step(start_epoch) 503 | 504 | if args.rank == 0: 505 | _logger.info('Scheduled epochs: {}'.format(num_epochs)) 506 | 507 | # create the train and eval datasets 508 | 509 | # create the train and eval datasets 510 | dataset_train = create_dataset(args.dataset, 511 | root=args.data_dir, 512 | split=args.train_split, 513 | is_training=True, 514 | batch_size=args.batch_size) 515 | dataset_eval = create_dataset(args.dataset, 516 | root=args.data_dir, 517 | split=args.val_split, 518 | is_training=False, 519 | batch_size=args.batch_size) 520 | 521 | # setup mixup / cutmix 522 | collate_fn = None 523 | mixup_fn = None 524 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 525 | if mixup_active: 526 | mixup_args = dict(mixup_alpha=args.mixup, 527 | cutmix_alpha=args.cutmix, 528 | cutmix_minmax=args.cutmix_minmax, 529 | prob=args.mixup_prob, 530 | switch_prob=args.mixup_switch_prob, 531 | mode=args.mixup_mode, 532 | label_smoothing=args.smoothing, 533 | num_classes=args.num_classes) 534 | if args.prefetcher: 535 | assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) 536 | collate_fn = FastCollateMixup(**mixup_args) 537 | else: 538 | mixup_fn = Mixup(**mixup_args) 539 | 540 | # wrap dataset in AugMix helper 541 | if num_aug_splits > 1: 542 | dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) 543 | 544 | # create data loaders w/ augmentation pipeiine 545 | train_interpolation = args.train_interpolation 546 | if args.no_aug or not train_interpolation: 547 | train_interpolation = data_config['interpolation'] 548 | 549 | loader_train = create_loader( 550 | dataset_train, 551 | input_size=data_config['input_size'], 552 | batch_size=args.batch_size, 553 | is_training=True, 554 | use_prefetcher=args.prefetcher, 555 | no_aug=args.no_aug, 556 | re_prob=args.reprob, 557 | re_mode=args.remode, 558 | re_count=args.recount, 559 | re_split=args.resplit, 560 | scale=args.scale, 561 | ratio=args.ratio, 562 | hflip=args.hflip, 563 | vflip=args.vflip, 564 | color_jitter=args.color_jitter, 565 | auto_augment=args.aa, 566 | # num_aug_repeats=args.aug_repeats, 567 | num_aug_splits=num_aug_splits, 568 | interpolation=train_interpolation, 569 | mean=data_config['mean'], 570 | std=data_config['std'], 571 | num_workers=args.workers, 572 | distributed=args.distributed, 573 | collate_fn=collate_fn, 574 | pin_memory=args.pin_mem, 575 | use_multi_epochs_loader=args.use_multi_epochs_loader, 576 | # worker_seeding=args.worker_seeding, 577 | ) 578 | 579 | loader_eval = create_loader( 580 | dataset_eval, 581 | input_size=data_config['input_size'], 582 | batch_size=args.validation_batch_size_multiplier * args.batch_size, 583 | is_training=False, 584 | use_prefetcher=args.prefetcher, 585 | interpolation=data_config['interpolation'], 586 | mean=data_config['mean'], 587 | std=data_config['std'], 588 | num_workers=args.workers, 589 | distributed=args.distributed, 590 | crop_pct=data_config['crop_pct'], 591 | pin_memory=args.pin_mem, 592 | ) 593 | 594 | # setup loss function 595 | if mixup_active: 596 | train_loss_fn = SoftTargetCrossEntropy(reduce=True) 597 | elif args.smoothing: 598 | train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 599 | else: 600 | train_loss_fn = nn.CrossEntropyLoss() 601 | train_loss_fn = train_loss_fn.cuda() 602 | validate_loss_fn = nn.CrossEntropyLoss().cuda() 603 | 604 | teacher_model = None 605 | if args.distillation_type != 'none': 606 | assert args.teacher_path, 'need to specify teacher-path when using distillation' 607 | print(f"Creating teacher model: {args.teacher_model}") 608 | teacher_model = create_model( 609 | args.teacher_model, 610 | pretrained=False 611 | ) 612 | if args.teacher_path.startswith('https'): 613 | checkpoint = torch.hub.load_state_dict_from_url( 614 | args.teacher_path, map_location='cpu', check_hash=True) 615 | else: 616 | checkpoint = torch.load(args.teacher_path, map_location='cpu') 617 | if 'model' in checkpoint.keys(): 618 | teacher_model.load_state_dict(checkpoint['model']) 619 | elif 'state_dict' in checkpoint.keys(): 620 | teacher_model.load_state_dict(checkpoint['state_dict']) 621 | teacher_model.to(args.device) 622 | teacher_model.eval() 623 | 624 | # wrap the criterion in our custom DistillationLoss, which 625 | # just dispatches to the original criterion if args.distillation_type is 'none' 626 | train_loss_fn = DistillationLoss( 627 | train_loss_fn, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau 628 | ) 629 | 630 | # setup checkpoint saver and eval metric tracking 631 | eval_metric = args.eval_metric 632 | best_metric = None 633 | best_epoch = None 634 | saver = None 635 | output_dir = '' 636 | if args.rank == 0: 637 | output_base = args.output if args.output else './output' 638 | output_dir = get_outdir(output_base) 639 | decreasing = True if eval_metric == 'loss' else False 640 | saver = CheckpointSaver(model=model, 641 | optimizer=optimizer, 642 | args=args, 643 | model_ema=model_ema, 644 | amp_scaler=loss_scaler, 645 | checkpoint_dir=output_dir, 646 | recovery_dir=output_dir, 647 | decreasing=decreasing, 648 | max_history=args.checkpoint_hist) 649 | with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: 650 | f.write(args_text) 651 | 652 | try: 653 | if args.finetune: 654 | validate(model, 655 | loader_eval, 656 | validate_loss_fn, 657 | args, 658 | amp_autocast=amp_autocast) 659 | for epoch in range(start_epoch, num_epochs): 660 | 661 | if epoch >= args.adv_epochs: 662 | args.adv_iters = 1 663 | 664 | if args.distributed and hasattr(loader_train.sampler, 'set_epoch'): 665 | loader_train.sampler.set_epoch(epoch) 666 | 667 | train_metrics = train_one_epoch(epoch, 668 | model, 669 | loader_train, 670 | optimizer, 671 | train_loss_fn, 672 | args, 673 | lr_scheduler=lr_scheduler, 674 | saver=saver, 675 | output_dir=output_dir, 676 | amp_autocast=amp_autocast, 677 | loss_scaler=loss_scaler, 678 | model_ema=model_ema, 679 | mixup_fn=mixup_fn, 680 | optimizers=optimizers) 681 | 682 | if args.distributed and args.dist_bn in ('broadcast', 'reduce'): 683 | if args.rank == 0: 684 | _logger.info( 685 | "Distributing BatchNorm running means and vars") 686 | distribute_bn(model, args.world_size, args.dist_bn == 'reduce') 687 | 688 | eval_metrics = validate(model, 689 | loader_eval, 690 | validate_loss_fn, 691 | args, 692 | amp_autocast=amp_autocast) 693 | 694 | if model_ema is not None and not args.model_ema_force_cpu: 695 | if args.distributed and args.dist_bn in ('broadcast', 696 | 'reduce'): 697 | distribute_bn(model_ema, args.world_size, 698 | args.dist_bn == 'reduce') 699 | ema_eval_metrics = validate(model_ema.module, 700 | loader_eval, 701 | validate_loss_fn, 702 | args, 703 | amp_autocast=amp_autocast, 704 | log_suffix=' (EMA)') 705 | eval_metrics = ema_eval_metrics 706 | 707 | if lr_scheduler is not None: 708 | # step LR for next epoch 709 | lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) 710 | 711 | update_summary(epoch, 712 | train_metrics, 713 | eval_metrics, 714 | os.path.join(output_dir, 'summary.csv'), 715 | write_header=best_metric is None) 716 | 717 | if saver is not None: 718 | # save proper checkpoint with eval metric 719 | save_metric = eval_metrics[eval_metric] 720 | best_metric, best_epoch = saver.save_checkpoint( 721 | epoch, metric=save_metric) 722 | 723 | 724 | except KeyboardInterrupt: 725 | pass 726 | if best_metric is not None: 727 | _logger.info('*** Best metric: {0} (epoch {1})'.format( 728 | best_metric, best_epoch)) 729 | 730 | 731 | def train_one_epoch(epoch, 732 | model, 733 | loader, 734 | optimizer, 735 | loss_fn, 736 | args, 737 | lr_scheduler=None, 738 | saver=None, 739 | output_dir='', 740 | amp_autocast=suppress, 741 | loss_scaler=None, 742 | model_ema=None, 743 | mixup_fn=None, 744 | optimizers=None): 745 | assert isinstance(loss_scaler, ApexScaler) 746 | 747 | if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: 748 | if args.prefetcher and loader.mixup_enabled: 749 | loader.mixup_enabled = False 750 | elif mixup_fn is not None: 751 | mixup_fn.mixup_enabled = False 752 | 753 | batch_time_m = AverageMeter() 754 | data_time_m = AverageMeter() 755 | losses_m = AverageMeter() 756 | 757 | model.train() 758 | 759 | end = time.time() 760 | last_idx = len(loader) - 1 761 | num_updates = epoch * len(loader) 762 | for batch_idx, (input, target) in enumerate(loader): 763 | last_batch = batch_idx == last_idx 764 | data_time_m.update(time.time() - end) 765 | if not args.prefetcher: 766 | input, target = input.cuda(), target.cuda() 767 | if mixup_fn is not None: 768 | input, target = mixup_fn(input, target) 769 | if args.channels_last: 770 | input = input.contiguous(memory_format=torch.channels_last) 771 | 772 | delta = torch.zeros_like(input) 773 | 774 | output_cle, output_cle_kd = model(input) 775 | output_cle = output_cle.detach() 776 | output_cle_kd = output_cle_kd.detach() 777 | 778 | output_cle_prob = F.softmax(output_cle, dim=1) 779 | output_cle_logprob = F.log_softmax(output_cle, dim=1) 780 | 781 | for a_iter in range(args.adv_iters): 782 | delta.requires_grad_() 783 | 784 | with amp_autocast(): 785 | 786 | output_adv, output_adv_kd = model(input+delta) 787 | 788 | loss_ce = loss_fn(input+delta, [output_adv, output_adv_kd], target) 789 | 790 | output_adv_prob = F.softmax(output_adv, dim=1) 791 | output_adv_logprob = F.log_softmax(output_adv, dim=1) 792 | loss_kl = (F.kl_div(output_adv_logprob, output_cle_prob, reduce=False).sum(dim=1) + 793 | F.kl_div(output_cle_logprob, output_adv_prob, reduce=False).sum(dim=1)).mean() / 2 794 | 795 | if a_iter == 0: 796 | loss = (loss_ce + args.adv_kl_weight * loss_kl) 797 | else: 798 | loss = (args.adv_ce_weight * loss_ce + args.adv_kl_weight * loss_kl) / (args.adv_iters-1) 799 | # loss = loss.mean() 800 | # print(loss.item(), loss_ce.item(), loss_kl.item()) 801 | 802 | with amp.scale_loss(loss, optimizer) as scaled_loss: 803 | scaled_loss.backward(retain_graph=False) 804 | 805 | delta_grad = delta.grad.clone().detach().float() 806 | delta = (delta + args.adv_lr * torch.sign(delta_grad)).detach() 807 | delta = torch.clamp(delta, -args.adv_eps, args.adv_eps).detach() 808 | 809 | 810 | if not args.distributed: 811 | losses_m.update(loss.item(), input.size(0)) 812 | 813 | optimizer.step() 814 | optimizer.zero_grad() 815 | 816 | if model_ema is not None: 817 | model_ema.update(model) 818 | 819 | torch.cuda.synchronize() 820 | num_updates += 1 821 | batch_time_m.update(time.time() - end) 822 | if last_batch or batch_idx % args.log_interval == 0: 823 | lrl = [param_group['lr'] for param_group in optimizer.param_groups] 824 | lr = sum(lrl) / len(lrl) 825 | 826 | if args.distributed: 827 | reduced_loss = reduce_tensor(loss.data, args.world_size) 828 | losses_m.update(reduced_loss.item(), input.size(0)) 829 | 830 | if args.rank == 0: 831 | _logger.info( 832 | 'Train: {} [{:>4d}/{} ({:>3.0f}%)] ' 833 | 'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) ' 834 | 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s ' 835 | '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 836 | 'LR: {lr:.3e} ' 837 | 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format( 838 | epoch, 839 | batch_idx, 840 | len(loader), 841 | 100. * batch_idx / last_idx, 842 | loss=losses_m, 843 | batch_time=batch_time_m, 844 | rate=input.size(0) * args.world_size / 845 | batch_time_m.val, 846 | rate_avg=input.size(0) * args.world_size / 847 | batch_time_m.avg, 848 | lr=lr, 849 | data_time=data_time_m)) 850 | 851 | if args.save_images and output_dir: 852 | torchvision.utils.save_image( 853 | input, 854 | os.path.join(output_dir, 855 | 'train-batch-%d.jpg' % batch_idx), 856 | padding=0, 857 | normalize=True) 858 | 859 | if saver is not None and args.recovery_interval and ( 860 | last_batch or (batch_idx + 1) % args.recovery_interval == 0): 861 | saver.save_recovery(epoch, batch_idx=batch_idx) 862 | 863 | if lr_scheduler is not None: 864 | lr_scheduler.step_update(num_updates=num_updates, 865 | metric=losses_m.avg) 866 | 867 | end = time.time() 868 | # end for 869 | 870 | if hasattr(optimizer, 'sync_lookahead'): 871 | optimizer.sync_lookahead() 872 | 873 | return OrderedDict([('loss', losses_m.avg)]) 874 | 875 | 876 | def validate(model, 877 | loader, 878 | loss_fn, 879 | args, 880 | amp_autocast=suppress, 881 | log_suffix=''): 882 | batch_time_m = AverageMeter() 883 | losses_m = AverageMeter() 884 | top1_m = AverageMeter() 885 | top5_m = AverageMeter() 886 | 887 | model.eval() 888 | 889 | end = time.time() 890 | last_idx = len(loader) - 1 891 | with torch.no_grad(): 892 | for batch_idx, (input, target) in enumerate(loader): 893 | last_batch = batch_idx == last_idx 894 | if not args.prefetcher: 895 | input = input.cuda() 896 | target = target.cuda() 897 | if args.channels_last: 898 | input = input.contiguous(memory_format=torch.channels_last) 899 | 900 | with amp_autocast(): 901 | output = model(input) 902 | if isinstance(output, (tuple, list)): 903 | output = output[0] 904 | 905 | # augmentation reduction 906 | reduce_factor = args.tta 907 | if reduce_factor > 1: 908 | output = output.unfold(0, reduce_factor, 909 | reduce_factor).mean(dim=2) 910 | target = target[0:target.size(0):reduce_factor] 911 | 912 | loss = loss_fn(output, target) 913 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 914 | 915 | if args.distributed: 916 | reduced_loss = reduce_tensor(loss.data, args.world_size) 917 | acc1 = reduce_tensor(acc1, args.world_size) 918 | acc5 = reduce_tensor(acc5, args.world_size) 919 | else: 920 | reduced_loss = loss.data 921 | 922 | torch.cuda.synchronize() 923 | 924 | losses_m.update(reduced_loss.item(), input.size(0)) 925 | top1_m.update(acc1.item(), output.size(0)) 926 | top5_m.update(acc5.item(), output.size(0)) 927 | 928 | batch_time_m.update(time.time() - end) 929 | end = time.time() 930 | if args.rank == 0 and (last_batch or 931 | batch_idx % args.log_interval == 0): 932 | log_name = 'Test' + log_suffix 933 | _logger.info( 934 | '{0}: [{1:>4d}/{2}] ' 935 | 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 936 | 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 937 | 'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) ' 938 | 'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format( 939 | log_name, 940 | batch_idx, 941 | last_idx, 942 | batch_time=batch_time_m, 943 | loss=losses_m, 944 | top1=top1_m, 945 | top5=top5_m)) 946 | 947 | metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), 948 | ('top5', top5_m.avg)]) 949 | 950 | return metrics 951 | 952 | 953 | if __name__ == '__main__': 954 | main() 955 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings("ignore") 3 | 4 | import argparse 5 | import time 6 | import yaml 7 | import os 8 | import logging 9 | from collections import OrderedDict 10 | from contextlib import suppress 11 | from datetime import datetime 12 | import numpy as np 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torchvision.utils 17 | from torch.nn.parallel import DistributedDataParallel as NativeDDP 18 | 19 | from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset 20 | from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model, model_parameters 21 | from timm.utils import * 22 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 23 | from timm.optim import create_optimizer 24 | from timm.scheduler import create_scheduler 25 | from timm.utils import ApexScaler, NativeScaler 26 | 27 | from tlt.data import create_token_label_target, TokenLabelMixup, FastCollateTokenLabelMixup, \ 28 | create_token_label_loader, create_token_label_dataset 29 | from tlt.utils import load_pretrained_weights 30 | from loss import TokenLabelGTCrossEntropy, TokenLabelCrossEntropy, TokenLabelSoftTargetCrossEntropy 31 | import models 32 | import torch.nn.functional as F 33 | 34 | try: 35 | from apex import amp 36 | from apex.parallel import DistributedDataParallel as ApexDDP 37 | from apex.parallel import convert_syncbn_model 38 | 39 | has_apex = True 40 | except ImportError: 41 | has_apex = False 42 | 43 | has_native_amp = False 44 | try: 45 | if getattr(torch.cuda.amp, 'autocast') is not None: 46 | has_native_amp = True 47 | except AttributeError: 48 | pass 49 | 50 | torch.backends.cudnn.benchmark = True 51 | _logger = logging.getLogger('train') 52 | 53 | # The first arg parser parses out only the --config argument, this argument is used to 54 | # load a yaml file containing key-values that override the defaults for the main parser below 55 | config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False) 56 | parser.add_argument('-c', '--config', default='', type=str, metavar='FILE', 57 | help='YAML config file specifying default arguments') 58 | 59 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 60 | 61 | # Dataset / Model parameters 62 | parser.add_argument('--data_dir', metavar='DIR', 63 | help='path to dataset') 64 | parser.add_argument('--dataset', '-d', metavar='NAME', default='', 65 | help='dataset type (default: ImageFolder/ImageTar if empty)') 66 | parser.add_argument('--train-split', metavar='NAME', default='train', 67 | help='dataset train split (default: train)') 68 | parser.add_argument('--val-split', metavar='NAME', default='validation', 69 | help='dataset validation split (default: validation)') 70 | parser.add_argument('--model', default='volo_d1', type=str, metavar='MODEL', 71 | help='Name of model to train (default: "volo_d1"') 72 | parser.add_argument('--pretrained', action='store_true', default=False, 73 | help='Start with pretrained version of specified network (if avail)') 74 | parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH', 75 | help='Initialize model from this checkpoint (default: none)') 76 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 77 | help='Resume full model and optimizer state from checkpoint (default: none)') 78 | parser.add_argument('--no-resume-opt', action='store_true', default=False, 79 | help='prevent resume of optimizer state when resuming model') 80 | parser.add_argument('--num-classes', type=int, default=None, metavar='N', 81 | help='number of label classes (Model default if None)') 82 | parser.add_argument('--gp', default=None, type=str, metavar='POOL', 83 | help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') 84 | parser.add_argument('--img-size', type=int, default=None, metavar='N', 85 | help='Image patch size (default: None => model default)') 86 | parser.add_argument('--input-size', default=None, nargs=3, type=int, 87 | metavar='N N N', 88 | help='Input all image dimensions (d h w, e.g. --input-size 3 224 224),' 89 | ' uses model default if empty') 90 | parser.add_argument('--crop-pct', default=None, type=float, 91 | metavar='N', help='Input image center crop percent (for validation only)') 92 | parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', 93 | help='Override mean pixel value of dataset') 94 | parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', 95 | help='Override std deviation of of dataset') 96 | parser.add_argument('--interpolation', default='', type=str, metavar='NAME', 97 | help='Image resize interpolation type (overrides model)') 98 | parser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N', 99 | help='input batch size for training (default: 128)') 100 | parser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N', 101 | help='ratio of validation batch size to training batch size (default: 1)') 102 | 103 | # Optimizer parameters 104 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', 105 | help='Optimizer (default: "adamw"') 106 | parser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON', 107 | help='Optimizer Epsilon (default: None, use opt default)') 108 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', 109 | help='Optimizer Betas (default: None, use opt default)') 110 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 111 | help='Optimizer momentum (default: 0.9)') 112 | parser.add_argument('--weight-decay', type=float, default=0.05, 113 | help='weight decay (default: 0.05)') 114 | parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', 115 | help='Clip gradient norm (default: None, no clipping)') 116 | parser.add_argument('--clip-mode', type=str, default='norm', 117 | help='Gradient clipping mode. One of ("norm", "value", "agc")') 118 | 119 | # Learning rate schedule parameters 120 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', 121 | help='LR scheduler (default: "cosine"') 122 | parser.add_argument('--lr', type=float, default=1.6e-3, metavar='LR', 123 | help='learning rate (default: 1.6e-3)') 124 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', 125 | help='learning rate noise on/off epoch percentages') 126 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', 127 | help='learning rate noise limit percent (default: 0.67)') 128 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', 129 | help='learning rate noise std-dev (default: 1.0)') 130 | parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT', 131 | help='learning rate cycle len multiplier (default: 1.0)') 132 | parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N', 133 | help='learning rate cycle limit') 134 | parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', 135 | help='warmup learning rate (default: 0.0001)') 136 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', 137 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 138 | parser.add_argument('--epochs', type=int, default=300, metavar='N', 139 | help='number of epochs to train (default: 300)') 140 | parser.add_argument('--start-epoch', default=None, type=int, metavar='N', 141 | help='manual epoch number (useful on restarts)') 142 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', 143 | help='epoch interval to decay LR') 144 | parser.add_argument('--warmup-epochs', type=int, default=20, metavar='N', 145 | help='epochs to warmup LR, if scheduler supports') 146 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', 147 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends') 148 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', 149 | help='patience epochs for Plateau LR scheduler (default: 10') 150 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', 151 | help='LR decay rate (default: 0.1)') 152 | 153 | # Augmentation & regularization parameters 154 | parser.add_argument('--no-aug', action='store_true', default=False, 155 | help='Disable all training augmentation, override other train aug args') 156 | parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT', 157 | help='Random resize scale (default: 0.08 1.0)') 158 | parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO', 159 | help='Random resize aspect ratio (default: 0.75 1.33)') 160 | parser.add_argument('--hflip', type=float, default=0.5, 161 | help='Horizontal flip training aug probability') 162 | parser.add_argument('--vflip', type=float, default=0., 163 | help='Vertical flip training aug probability') 164 | parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', 165 | help='Color jitter factor (default: 0.4)') 166 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 167 | help='Use AutoAugment policy. "v0" or "original". (default: rand-m9-mstd0.5-inc1)'), 168 | parser.add_argument('--aug-splits', type=int, default=0, 169 | help='Number of augmentation splits (default: 0, valid: 0 or >=2)') 170 | parser.add_argument('--jsd', action='store_true', default=False, 171 | help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.') 172 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 173 | help='Random erase prob (default: 0.25)') 174 | parser.add_argument('--remode', type=str, default='pixel', 175 | help='Random erase mode (default: "pixel")') 176 | parser.add_argument('--recount', type=int, default=1, 177 | help='Random erase count (default: 1)') 178 | parser.add_argument('--resplit', action='store_true', default=False, 179 | help='Do not random erase first (clean) augmentation split') 180 | parser.add_argument('--mixup', type=float, default=0.8, 181 | help='mixup alpha, mixup enabled if > 0. (default: 0.)') 182 | parser.add_argument('--cutmix', type=float, default=1.0, 183 | help='cutmix alpha, cutmix enabled if > 0. (default: 0.)') 184 | parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, 185 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 186 | parser.add_argument('--mixup-prob', type=float, default=1.0, 187 | help='Probability of performing mixup or cutmix when either/both is enabled') 188 | parser.add_argument('--mixup-switch-prob', type=float, default=0.5, 189 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 190 | parser.add_argument('--mixup-mode', type=str, default='batch', 191 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 192 | parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N', 193 | help='Turn off mixup after this epoch, disabled if 0 (default: 0)') 194 | parser.add_argument('--smoothing', type=float, default=0.1, 195 | help='Label smoothing (default: 0.1)') 196 | parser.add_argument('--train-interpolation', type=str, default='random', 197 | help='Training interpolation (random, bilinear, bicubic default: "random")') 198 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', 199 | help='Dropout rate (default: 0.)') 200 | parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT', 201 | help='Drop connect rate, DEPRECATED, use drop-path (default: None)') 202 | parser.add_argument('--drop-path', type=float, default=None, metavar='PCT', 203 | help='Drop path rate (default: None)') 204 | parser.add_argument('--drop-block', type=float, default=None, metavar='PCT', 205 | help='Drop block rate (default: None)') 206 | 207 | # Batch norm parameters (only works with gen_efficientnet based models currently) 208 | parser.add_argument('--bn-tf', action='store_true', default=False, 209 | help='Use Tensorflow BatchNorm defaults for models that support it (default: False)') 210 | parser.add_argument('--bn-momentum', type=float, default=None, 211 | help='BatchNorm momentum override (if not None)') 212 | parser.add_argument('--bn-eps', type=float, default=None, 213 | help='BatchNorm epsilon override (if not None)') 214 | parser.add_argument('--sync-bn', action='store_true', 215 | help='Enable NVIDIA Apex or Torch synchronized BatchNorm.') 216 | parser.add_argument('--dist-bn', type=str, default='', 217 | help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")') 218 | parser.add_argument('--split-bn', action='store_true', 219 | help='Enable separate BN layers per augmentation split.') 220 | 221 | # Model Exponential Moving Average 222 | parser.add_argument('--model-ema', action='store_true', default=False, 223 | help='Enable tracking moving average of model weights') 224 | parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, 225 | help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.') 226 | parser.add_argument('--model-ema-decay', type=float, default=0.99992, 227 | help='decay factor for model weights moving average (default: 0.99992)') 228 | 229 | # Misc 230 | parser.add_argument('--seed', type=int, default=42, metavar='S', 231 | help='random seed (default: 42)') 232 | parser.add_argument('--log-interval', type=int, default=50, metavar='N', 233 | help='how many batches to wait before logging training status') 234 | parser.add_argument('--recovery-interval', type=int, default=0, metavar='N', 235 | help='how many batches to wait before writing recovery checkpoint') 236 | parser.add_argument('--checkpoint-hist', type=int, default=3, metavar='N', 237 | help='number of checkpoints to keep (default: 10)') 238 | parser.add_argument('-j', '--workers', type=int, default=8, metavar='N', 239 | help='how many training processes to use (default: 1)') 240 | parser.add_argument('--save-images', action='store_true', default=False, 241 | help='save images of input bathes every log interval for debugging') 242 | parser.add_argument('--amp', action='store_true', default=False, 243 | help='use NVIDIA Apex AMP or Native AMP for mixed precision training') 244 | parser.add_argument('--apex-amp', action='store_true', default=True, 245 | help='Use NVIDIA Apex AMP mixed precision') 246 | parser.add_argument('--native-amp', action='store_true', default=False, 247 | help='Use Native Torch AMP mixed precision') 248 | parser.add_argument('--channels-last', action='store_true', default=False, 249 | help='Use channels_last memory layout') 250 | parser.add_argument('--pin-mem', action='store_true', default=False, 251 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 252 | parser.add_argument('--no-prefetcher', action='store_true', default=False, 253 | help='disable fast prefetcher') 254 | parser.add_argument('--output', default='', type=str, metavar='PATH', 255 | help='path to output folder (default: none, current dir)') 256 | parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC', 257 | help='Best metric (default: "top1"') 258 | parser.add_argument('--tta', type=int, default=0, metavar='N', 259 | help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)') 260 | parser.add_argument("--local_rank", default=0, type=int) 261 | parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False, 262 | help='use the multi-epochs-loader to save time at the beginning of every epoch') 263 | parser.add_argument('--torchscript', dest='torchscript', action='store_true', 264 | help='convert model torchscript for inference') 265 | 266 | # Token labeling 267 | 268 | parser.add_argument('--token-label', action='store_true', default=False, 269 | help='Use dense token-level label map for training') 270 | parser.add_argument('--token-label-data', type=str, default='', metavar='DIR', 271 | help='path to token_label data') 272 | parser.add_argument('--token-label-size', type=int, default=1, metavar='N', 273 | help='size of result token label map') 274 | parser.add_argument('--dense-weight', type=float, default=0.5, 275 | help='Token labeling loss multiplier (default: 0.5)') 276 | parser.add_argument('--cls-weight', type=float, default=1.0, 277 | help='Cls token prediction loss multiplier (default: 1.0)') 278 | parser.add_argument('--ground-truth', action='store_true', default=False, 279 | help='mix ground truth when use token labeling') 280 | 281 | # Finetune 282 | parser.add_argument('--finetune', default='', type=str, metavar='PATH', 283 | help='path to checkpoint file (default: none)') 284 | 285 | # Adversarial 286 | parser.add_argument('--adv-epochs', default=200, type=int, metavar='N', 287 | help='number of epochs for performing adversarial training') 288 | parser.add_argument('--adv-iters', default=3, type=int, metavar='N', 289 | help='number of PGD steps') 290 | parser.add_argument('--adv-eps', default=2.0/255, type=float, 291 | help='adversarial strength for adversarial perturbations') 292 | parser.add_argument('--adv-lr', default=1.0/255, type=float, 293 | help='learning rate of PGD') 294 | parser.add_argument('--adv-kl-weight', default=0.01, type=float, 295 | help='weight of KL-divergence for adversarial training') 296 | parser.add_argument('--adv-ce-weight', default=3.0, type=float, 297 | help='weight of ce-loss for adversarial training') 298 | 299 | def _parse_args(): 300 | # Do we have a config file to parse? 301 | args_config, remaining = config_parser.parse_known_args() 302 | if args_config.config: 303 | with open(args_config.config, 'r') as f: 304 | cfg = yaml.safe_load(f) 305 | parser.set_defaults(**cfg) 306 | 307 | # The main arg parser parses the rest of the args, the usual 308 | # defaults will have been overridden if config file specified. 309 | args = parser.parse_args(remaining) 310 | 311 | # Cache the args as a text string to save them in the output dir later 312 | args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) 313 | return args, args_text 314 | 315 | 316 | def main(): 317 | setup_default_logging() 318 | args, args_text = _parse_args() 319 | 320 | args.prefetcher = not args.no_prefetcher 321 | args.distributed = False 322 | if 'WORLD_SIZE' in os.environ: 323 | args.distributed = int(os.environ['WORLD_SIZE']) > 1 324 | args.device = 'cuda:0' 325 | args.world_size = 1 326 | args.rank = 0 # global rank 327 | if args.distributed: 328 | args.device = 'cuda:%d' % args.local_rank 329 | torch.cuda.set_device(args.local_rank) 330 | torch.distributed.init_process_group(backend='nccl', 331 | init_method='env://') 332 | args.world_size = torch.distributed.get_world_size() 333 | args.rank = torch.distributed.get_rank() 334 | _logger.info( 335 | 'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' 336 | % (args.rank, args.world_size)) 337 | else: 338 | _logger.info('Training with a single process on 1 GPUs.') 339 | assert args.rank >= 0 340 | 341 | # resolve AMP arguments based on PyTorch / Apex availability 342 | use_amp = None 343 | if args.amp: 344 | # `--amp` chooses native amp before apex (APEX ver not actively maintained) 345 | if has_native_amp: 346 | args.native_amp = True 347 | elif has_apex: 348 | args.apex_amp = True 349 | if args.apex_amp and has_apex: 350 | use_amp = 'apex' 351 | elif args.native_amp and has_native_amp: 352 | use_amp = 'native' 353 | elif args.apex_amp or args.native_amp: 354 | _logger.warning( 355 | "Neither APEX or native Torch AMP is available, using float32. " 356 | "Install NVIDA apex or upgrade to PyTorch 1.6") 357 | 358 | torch.manual_seed(args.seed + args.rank) 359 | np.random.seed(args.seed + args.rank) 360 | model = create_model( 361 | args.model, 362 | pretrained=args.pretrained, 363 | num_classes=args.num_classes, 364 | drop_rate=args.drop, 365 | drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path 366 | drop_path_rate=args.drop_path, 367 | drop_block_rate=args.drop_block, 368 | global_pool=args.gp, 369 | bn_tf=args.bn_tf, 370 | bn_momentum=args.bn_momentum, 371 | bn_eps=args.bn_eps, 372 | scriptable=args.torchscript, 373 | checkpoint_path=args.initial_checkpoint, 374 | img_size=args.img_size) 375 | if args.num_classes is None: 376 | assert hasattr( 377 | model, 'num_classes' 378 | ), 'Model must have `num_classes` attr if not set on cmd line/config.' 379 | args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly 380 | 381 | if args.finetune: 382 | load_pretrained_weights(model=model, 383 | checkpoint_path=args.finetune, 384 | use_ema=args.model_ema, 385 | strict=False, 386 | num_classes=args.num_classes) 387 | 388 | if args.rank == 0: 389 | _logger.info('Model %s created, param count: %d' % 390 | (args.model, sum([m.numel() 391 | for m in model.parameters()]))) 392 | 393 | data_config = resolve_data_config(vars(args), 394 | model=model, 395 | verbose=args.local_rank == 0) 396 | 397 | # setup augmentation batch splits for contrastive loss or split bn 398 | num_aug_splits = 0 399 | if args.aug_splits > 0: 400 | assert args.aug_splits > 1, 'A split of 1 makes no sense' 401 | num_aug_splits = args.aug_splits 402 | 403 | # enable split bn (separate bn stats per batch-portion) 404 | if args.split_bn: 405 | assert num_aug_splits > 1 or args.resplit 406 | model = convert_splitbn_model(model, max(num_aug_splits, 2)) 407 | 408 | # move model to GPU, enable channels last layout if set 409 | model.cuda() 410 | if args.channels_last: 411 | model = model.to(memory_format=torch.channels_last) 412 | 413 | # setup synchronized BatchNorm for distributed training 414 | if args.distributed and args.sync_bn: 415 | assert not args.split_bn 416 | if has_apex and use_amp != 'native': 417 | # Apex SyncBN preferred unless native amp is activated 418 | model = convert_syncbn_model(model) 419 | else: 420 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 421 | if args.rank == 0: 422 | _logger.info( 423 | 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 424 | 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.' 425 | ) 426 | 427 | if args.torchscript: 428 | assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' 429 | assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' 430 | model = torch.jit.script(model) 431 | 432 | linear_scaled_lr = args.lr * args.batch_size * args.world_size / 1024.0 433 | args.lr = linear_scaled_lr 434 | if args.rank == 0: 435 | print("learning rate:", args.lr) 436 | optimizer = create_optimizer(args, model) 437 | 438 | # setup automatic mixed-precision (AMP) loss scaling and op casting 439 | amp_autocast = suppress # do nothing 440 | loss_scaler = None 441 | optimizers = None 442 | if use_amp == 'apex': 443 | model, optimizer = amp.initialize(model, optimizer, opt_level='O1') 444 | loss_scaler = ApexScaler() 445 | if args.rank == 0: 446 | _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') 447 | elif use_amp == 'native': 448 | amp_autocast = torch.cuda.amp.autocast 449 | loss_scaler = NativeScaler() 450 | if args.rank == 0: 451 | _logger.info( 452 | 'Using native Torch AMP. Training in mixed precision.') 453 | else: 454 | if args.rank == 0: 455 | _logger.info('AMP not enabled. Training in float32.') 456 | 457 | # optionally resume from a checkpoint 458 | resume_epoch = None 459 | if args.resume and os.path.isfile(args.resume): 460 | resume_epoch = resume_checkpoint( 461 | model, 462 | args.resume, 463 | optimizer=None if args.no_resume_opt else optimizer, 464 | loss_scaler=None if args.no_resume_opt else loss_scaler, 465 | log_info=args.rank == 0) 466 | 467 | # setup exponential moving average of model weights, SWA could be used here too 468 | model_ema = None 469 | if args.model_ema: 470 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper 471 | model_ema = ModelEmaV2( 472 | model, 473 | decay=args.model_ema_decay, 474 | device='cpu' if args.model_ema_force_cpu else None) 475 | if args.resume and os.path.isfile(args.resume): 476 | load_checkpoint(model_ema.module, args.resume, use_ema=True) 477 | 478 | # setup distributed training 479 | if args.distributed: 480 | if has_apex and use_amp != 'native': 481 | # Apex DDP preferred unless native amp is activated 482 | if args.rank == 0: 483 | _logger.info("Using NVIDIA APEX DistributedDataParallel.") 484 | model = ApexDDP(model, delay_allreduce=True) 485 | else: 486 | if args.rank == 0: 487 | _logger.info("Using native Torch DistributedDataParallel.") 488 | model = NativeDDP(model, device_ids=[ 489 | args.local_rank 490 | ]) # can use device str in Torch >= 1.1 491 | # NOTE: EMA model does not need to be wrapped by DDP 492 | 493 | # setup learning rate schedule and starting epoch 494 | lr_scheduler, num_epochs = create_scheduler(args, optimizer) 495 | start_epoch = 0 496 | if args.start_epoch is not None: 497 | # a specified start_epoch will always override the resume epoch 498 | start_epoch = args.start_epoch 499 | elif resume_epoch is not None: 500 | start_epoch = resume_epoch 501 | if lr_scheduler is not None and start_epoch > 0: 502 | lr_scheduler.step(start_epoch) 503 | 504 | if args.rank == 0: 505 | _logger.info('Scheduled epochs: {}'.format(num_epochs)) 506 | 507 | # create the train and eval datasets 508 | 509 | # create token_label dataset 510 | if args.token_label_data: 511 | dataset_train = create_token_label_dataset( 512 | args.dataset, root=args.data_dir, label_root=args.token_label_data) 513 | else: 514 | dataset_train = create_dataset(args.dataset, 515 | root=args.data_dir, 516 | split=args.train_split, 517 | is_training=True, 518 | batch_size=args.batch_size) 519 | dataset_eval = create_dataset(args.dataset, 520 | root=args.data_dir, 521 | split=args.val_split, 522 | is_training=False, 523 | batch_size=args.batch_size) 524 | 525 | # setup mixup / cutmix 526 | collate_fn = None 527 | mixup_fn = None 528 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 529 | if mixup_active: 530 | mixup_args = dict(mixup_alpha=args.mixup, 531 | cutmix_alpha=args.cutmix, 532 | cutmix_minmax=args.cutmix_minmax, 533 | prob=args.mixup_prob, 534 | switch_prob=args.mixup_switch_prob, 535 | mode=args.mixup_mode, 536 | label_smoothing=args.smoothing, 537 | num_classes=args.num_classes) 538 | # create token_label mixup 539 | if args.token_label_data: 540 | mixup_args['label_size'] = args.token_label_size 541 | if args.prefetcher: 542 | assert not num_aug_splits 543 | collate_fn = FastCollateTokenLabelMixup(**mixup_args) 544 | else: 545 | mixup_fn = TokenLabelMixup(**mixup_args) 546 | else: 547 | if args.prefetcher: 548 | assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) 549 | collate_fn = FastCollateMixup(**mixup_args) 550 | else: 551 | mixup_fn = Mixup(**mixup_args) 552 | 553 | # wrap dataset in AugMix helper 554 | if num_aug_splits > 1: 555 | assert not args.token_label 556 | dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) 557 | 558 | # create data loaders w/ augmentation pipeiine 559 | train_interpolation = args.train_interpolation 560 | if args.no_aug or not train_interpolation: 561 | train_interpolation = data_config['interpolation'] 562 | if args.token_label and args.token_label_data: 563 | use_token_label = True 564 | else: 565 | use_token_label = False 566 | loader_train = create_token_label_loader( 567 | dataset_train, 568 | input_size=data_config['input_size'], 569 | batch_size=args.batch_size, 570 | is_training=True, 571 | use_prefetcher=args.prefetcher, 572 | no_aug=args.no_aug, 573 | re_prob=args.reprob, 574 | re_mode=args.remode, 575 | re_count=args.recount, 576 | re_split=args.resplit, 577 | scale=args.scale, 578 | ratio=args.ratio, 579 | hflip=args.hflip, 580 | vflip=args.vflip, 581 | color_jitter=args.color_jitter, 582 | auto_augment=args.aa, 583 | num_aug_splits=num_aug_splits, 584 | interpolation=train_interpolation, 585 | mean=data_config['mean'], 586 | std=data_config['std'], 587 | num_workers=args.workers, 588 | distributed=args.distributed, 589 | collate_fn=collate_fn, 590 | pin_memory=args.pin_mem, 591 | use_multi_epochs_loader=args.use_multi_epochs_loader, 592 | use_token_label=use_token_label) 593 | 594 | loader_eval = create_loader( 595 | dataset_eval, 596 | input_size=data_config['input_size'], 597 | batch_size=args.validation_batch_size_multiplier * args.batch_size, 598 | is_training=False, 599 | use_prefetcher=args.prefetcher, 600 | interpolation=data_config['interpolation'], 601 | mean=data_config['mean'], 602 | std=data_config['std'], 603 | num_workers=args.workers, 604 | distributed=args.distributed, 605 | crop_pct=data_config['crop_pct'], 606 | pin_memory=args.pin_mem, 607 | ) 608 | 609 | # setup loss function 610 | # use token_label loss 611 | if args.token_label: 612 | if args.token_label_size == 1: 613 | # back to relabel/original ImageNet label 614 | train_loss_fn = TokenLabelSoftTargetCrossEntropy().cuda() 615 | else: 616 | if args.ground_truth: 617 | train_loss_fn = TokenLabelGTCrossEntropy(dense_weight=args.dense_weight,\ 618 | cls_weight = args.cls_weight, mixup_active = mixup_active).cuda() 619 | 620 | else: 621 | train_loss_fn = TokenLabelCrossEntropy(dense_weight=args.dense_weight, \ 622 | cls_weight=args.cls_weight, mixup_active=mixup_active).cuda() 623 | 624 | else: 625 | # smoothing is handled with mixup target transform or create_token_label_target function 626 | train_loss_fn = SoftTargetCrossEntropy().cuda() 627 | 628 | validate_loss_fn = nn.CrossEntropyLoss().cuda() 629 | 630 | # setup checkpoint saver and eval metric tracking 631 | eval_metric = args.eval_metric 632 | best_metric = None 633 | best_epoch = None 634 | saver = None 635 | output_dir = '' 636 | if args.rank == 0: 637 | output_base = args.output if args.output else './output' 638 | output_dir = get_outdir(output_base) 639 | decreasing = True if eval_metric == 'loss' else False 640 | saver = CheckpointSaver(model=model, 641 | optimizer=optimizer, 642 | args=args, 643 | model_ema=model_ema, 644 | amp_scaler=loss_scaler, 645 | checkpoint_dir=output_dir, 646 | recovery_dir=output_dir, 647 | decreasing=decreasing, 648 | max_history=args.checkpoint_hist) 649 | with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: 650 | f.write(args_text) 651 | 652 | try: 653 | if args.finetune: 654 | validate(model, 655 | loader_eval, 656 | validate_loss_fn, 657 | args, 658 | amp_autocast=amp_autocast) 659 | for epoch in range(start_epoch, num_epochs): 660 | 661 | if epoch >= args.adv_epochs: 662 | args.adv_iters = 1 663 | 664 | if args.distributed and hasattr(loader_train.sampler, 'set_epoch'): 665 | loader_train.sampler.set_epoch(epoch) 666 | 667 | train_metrics = train_one_epoch(epoch, 668 | model, 669 | loader_train, 670 | optimizer, 671 | train_loss_fn, 672 | args, 673 | lr_scheduler=lr_scheduler, 674 | saver=saver, 675 | output_dir=output_dir, 676 | amp_autocast=amp_autocast, 677 | loss_scaler=loss_scaler, 678 | model_ema=model_ema, 679 | mixup_fn=mixup_fn, 680 | optimizers=optimizers) 681 | 682 | if args.distributed and args.dist_bn in ('broadcast', 'reduce'): 683 | if args.rank == 0: 684 | _logger.info( 685 | "Distributing BatchNorm running means and vars") 686 | distribute_bn(model, args.world_size, args.dist_bn == 'reduce') 687 | 688 | eval_metrics = validate(model, 689 | loader_eval, 690 | validate_loss_fn, 691 | args, 692 | amp_autocast=amp_autocast) 693 | 694 | if model_ema is not None and not args.model_ema_force_cpu: 695 | if args.distributed and args.dist_bn in ('broadcast', 696 | 'reduce'): 697 | distribute_bn(model_ema, args.world_size, 698 | args.dist_bn == 'reduce') 699 | ema_eval_metrics = validate(model_ema.module, 700 | loader_eval, 701 | validate_loss_fn, 702 | args, 703 | amp_autocast=amp_autocast, 704 | log_suffix=' (EMA)') 705 | eval_metrics = ema_eval_metrics 706 | 707 | if lr_scheduler is not None: 708 | # step LR for next epoch 709 | lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) 710 | 711 | update_summary(epoch, 712 | train_metrics, 713 | eval_metrics, 714 | os.path.join(output_dir, 'summary.csv'), 715 | write_header=best_metric is None) 716 | 717 | if saver is not None: 718 | # save proper checkpoint with eval metric 719 | save_metric = eval_metrics[eval_metric] 720 | best_metric, best_epoch = saver.save_checkpoint( 721 | epoch, metric=save_metric) 722 | 723 | 724 | except KeyboardInterrupt: 725 | pass 726 | if best_metric is not None: 727 | _logger.info('*** Best metric: {0} (epoch {1})'.format( 728 | best_metric, best_epoch)) 729 | 730 | 731 | def train_one_epoch(epoch, 732 | model, 733 | loader, 734 | optimizer, 735 | loss_fn, 736 | args, 737 | lr_scheduler=None, 738 | saver=None, 739 | output_dir='', 740 | amp_autocast=suppress, 741 | loss_scaler=None, 742 | model_ema=None, 743 | mixup_fn=None, 744 | optimizers=None): 745 | 746 | if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: 747 | if args.prefetcher and loader.mixup_enabled: 748 | loader.mixup_enabled = False 749 | elif mixup_fn is not None: 750 | mixup_fn.mixup_enabled = False 751 | 752 | second_order = hasattr(optimizer, 753 | 'is_second_order') and optimizer.is_second_order 754 | batch_time_m = AverageMeter() 755 | data_time_m = AverageMeter() 756 | losses_m = AverageMeter() 757 | 758 | model.train() 759 | 760 | end = time.time() 761 | last_idx = len(loader) - 1 762 | num_updates = epoch * len(loader) 763 | for batch_idx, (input, target) in enumerate(loader): 764 | last_batch = batch_idx == last_idx 765 | data_time_m.update(time.time() - end) 766 | if not args.prefetcher: 767 | input, target = input.cuda(), target.cuda() 768 | if mixup_fn is not None: 769 | input, target = mixup_fn(input, target) 770 | else: 771 | # handle token_label without mixup 772 | if args.token_label and args.token_label_data: 773 | target = create_token_label_target( 774 | target, 775 | num_classes=args.num_classes, 776 | smoothing=args.smoothing, 777 | label_size=args.token_label_size) 778 | if len(target.shape) == 1: 779 | target = create_token_label_target( 780 | target, 781 | num_classes=args.num_classes, 782 | smoothing=args.smoothing) 783 | else: 784 | if args.token_label and args.token_label_data and not loader.mixup_enabled: 785 | target = create_token_label_target( 786 | target, 787 | num_classes=args.num_classes, 788 | smoothing=args.smoothing, 789 | label_size=args.token_label_size) 790 | if len(target.shape) == 1: 791 | target = create_token_label_target( 792 | target, 793 | num_classes=args.num_classes, 794 | smoothing=args.smoothing) 795 | if args.channels_last: 796 | input = input.contiguous(memory_format=torch.channels_last) 797 | 798 | 799 | delta = torch.zeros_like(input) 800 | 801 | if args.token_label: 802 | output_cle, output_aux, _ = model(input) 803 | output_cle = output_cle.detach() 804 | else: 805 | output_cle = model(input).detach() 806 | 807 | output_cle_prob = F.softmax(output_cle, dim=1) 808 | output_cle_logprob = F.log_softmax(output_cle, dim=1) 809 | 810 | for a_iter in range(args.adv_iters): 811 | delta.requires_grad_() 812 | 813 | with amp_autocast(): 814 | 815 | if args.token_label: 816 | output = model(input + delta) 817 | output_adv = output[0] 818 | loss_ce = loss_fn(output, target) 819 | else: 820 | output_adv = model(input + delta) 821 | loss_ce = loss_fn(output_adv, target) 822 | 823 | output_adv_prob = F.softmax(output_adv, dim=1) 824 | output_adv_logprob = F.log_softmax(output_adv, dim=1) 825 | loss_kl = (F.kl_div(output_adv_logprob, output_cle_prob, reduce=False).sum(dim=1) + 826 | F.kl_div(output_cle_logprob, output_adv_prob, reduce=False).sum(dim=1)).mean() / 2 827 | 828 | if a_iter == 0: 829 | loss = (loss_ce + args.adv_kl_weight * loss_kl) 830 | else: 831 | loss = (args.adv_ce_weight * loss_ce + args.adv_kl_weight * loss_kl) / (args.adv_iters - 1) 832 | 833 | 834 | with amp.scale_loss(loss, optimizer) as scaled_loss: 835 | scaled_loss.backward(retain_graph=False) 836 | 837 | delta_grad = delta.grad.clone().detach().float() 838 | delta = (delta + args.adv_lr * torch.sign(delta_grad)).detach() 839 | delta = torch.clamp(delta, -args.adv_eps, args.adv_eps).detach() 840 | 841 | if not args.distributed: 842 | losses_m.update(loss.item(), input.size(0)) 843 | 844 | 845 | optimizer.step() 846 | optimizer.zero_grad() 847 | 848 | 849 | if model_ema is not None: 850 | model_ema.update(model) 851 | 852 | torch.cuda.synchronize() 853 | num_updates += 1 854 | batch_time_m.update(time.time() - end) 855 | if last_batch or batch_idx % args.log_interval == 0: 856 | lrl = [param_group['lr'] for param_group in optimizer.param_groups] 857 | lr = sum(lrl) / len(lrl) 858 | 859 | if args.distributed: 860 | reduced_loss = reduce_tensor(loss.data, args.world_size) 861 | losses_m.update(reduced_loss.item(), input.size(0)) 862 | 863 | if args.rank == 0: 864 | _logger.info( 865 | 'Train: {} [{:>4d}/{} ({:>3.0f}%)] ' 866 | 'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) ' 867 | 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s ' 868 | '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 869 | 'LR: {lr:.3e} ' 870 | 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format( 871 | epoch, 872 | batch_idx, 873 | len(loader), 874 | 100. * batch_idx / last_idx, 875 | loss=losses_m, 876 | batch_time=batch_time_m, 877 | rate=input.size(0) * args.world_size / 878 | batch_time_m.val, 879 | rate_avg=input.size(0) * args.world_size / 880 | batch_time_m.avg, 881 | lr=lr, 882 | data_time=data_time_m)) 883 | 884 | if args.save_images and output_dir: 885 | torchvision.utils.save_image( 886 | input, 887 | os.path.join(output_dir, 888 | 'train-batch-%d.jpg' % batch_idx), 889 | padding=0, 890 | normalize=True) 891 | 892 | if saver is not None and args.recovery_interval and ( 893 | last_batch or (batch_idx + 1) % args.recovery_interval == 0): 894 | saver.save_recovery(epoch, batch_idx=batch_idx) 895 | 896 | if lr_scheduler is not None: 897 | lr_scheduler.step_update(num_updates=num_updates, 898 | metric=losses_m.avg) 899 | 900 | end = time.time() 901 | # end for 902 | 903 | if hasattr(optimizer, 'sync_lookahead'): 904 | optimizer.sync_lookahead() 905 | 906 | return OrderedDict([('loss', losses_m.avg)]) 907 | 908 | 909 | def validate(model, 910 | loader, 911 | loss_fn, 912 | args, 913 | amp_autocast=suppress, 914 | log_suffix=''): 915 | batch_time_m = AverageMeter() 916 | losses_m = AverageMeter() 917 | top1_m = AverageMeter() 918 | top5_m = AverageMeter() 919 | 920 | model.eval() 921 | 922 | end = time.time() 923 | last_idx = len(loader) - 1 924 | with torch.no_grad(): 925 | for batch_idx, (input, target) in enumerate(loader): 926 | last_batch = batch_idx == last_idx 927 | if not args.prefetcher: 928 | input = input.cuda() 929 | target = target.cuda() 930 | if args.channels_last: 931 | input = input.contiguous(memory_format=torch.channels_last) 932 | 933 | with amp_autocast(): 934 | output = model(input) 935 | if isinstance(output, (tuple, list)): 936 | output = output[0] 937 | if args.cls_weight == 0: 938 | output = output[1].mean(1) 939 | 940 | # augmentation reduction 941 | reduce_factor = args.tta 942 | if reduce_factor > 1: 943 | output = output.unfold(0, reduce_factor, 944 | reduce_factor).mean(dim=2) 945 | target = target[0:target.size(0):reduce_factor] 946 | 947 | loss = loss_fn(output, target) 948 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 949 | 950 | if args.distributed: 951 | reduced_loss = reduce_tensor(loss.data, args.world_size) 952 | acc1 = reduce_tensor(acc1, args.world_size) 953 | acc5 = reduce_tensor(acc5, args.world_size) 954 | else: 955 | reduced_loss = loss.data 956 | 957 | torch.cuda.synchronize() 958 | 959 | losses_m.update(reduced_loss.item(), input.size(0)) 960 | top1_m.update(acc1.item(), output.size(0)) 961 | top5_m.update(acc5.item(), output.size(0)) 962 | 963 | batch_time_m.update(time.time() - end) 964 | end = time.time() 965 | if args.rank == 0 and (last_batch or 966 | batch_idx % args.log_interval == 0): 967 | log_name = 'Test' + log_suffix 968 | _logger.info( 969 | '{0}: [{1:>4d}/{2}] ' 970 | 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 971 | 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 972 | 'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) ' 973 | 'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format( 974 | log_name, 975 | batch_idx, 976 | last_idx, 977 | batch_time=batch_time_m, 978 | loss=losses_m, 979 | top1=top1_m, 980 | top5=top5_m)) 981 | 982 | metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), 983 | ('top5', top5_m.avg)]) 984 | 985 | return metrics 986 | 987 | 988 | if __name__ == '__main__': 989 | main() 990 | --------------------------------------------------------------------------------