├── models
├── __init__.py
└── volo.py
├── misc
└── eccv.png
├── loss
├── __init__.py
└── cross_entropy.py
├── distributed_train.sh
├── LICENSE
├── loss_adv.py
├── README.md
├── validate.py
├── train_kd.py
└── train.py
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .volo import *
--------------------------------------------------------------------------------
/misc/eccv.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiawangbai/HAT/HEAD/misc/eccv.png
--------------------------------------------------------------------------------
/loss/__init__.py:
--------------------------------------------------------------------------------
1 | from .cross_entropy import TokenLabelGTCrossEntropy, TokenLabelSoftTargetCrossEntropy, TokenLabelCrossEntropy
--------------------------------------------------------------------------------
/distributed_train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | NUM_PROC=$1
3 | shift
4 | python3 -m torch.distributed.launch --nproc_per_node=$NUM_PROC train.py "$@"
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 jiawangbai
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/loss_adv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class LabelSmoothingCrossEntropy(nn.Module):
7 | """
8 | NLL loss with label smoothing.
9 | """
10 | def __init__(self, smoothing=0.1):
11 | """
12 | Constructor for the LabelSmoothing module.
13 | :param smoothing: label smoothing factor
14 | """
15 | super(LabelSmoothingCrossEntropy, self).__init__()
16 | assert smoothing < 1.0
17 | self.smoothing = smoothing
18 | self.confidence = 1. - smoothing
19 |
20 | def forward(self, x, target):
21 | logprobs = F.log_softmax(x, dim=-1)
22 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
23 | nll_loss = nll_loss.squeeze(1)
24 | smooth_loss = -logprobs.mean(dim=-1)
25 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss
26 | return loss.mean()
27 |
28 |
29 | class SoftTargetCrossEntropy(nn.Module):
30 |
31 | def __init__(self, reduce=False):
32 | super(SoftTargetCrossEntropy, self).__init__()
33 | self.reduce = reduce
34 |
35 | def forward(self, x, target):
36 | loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1)
37 | if self.reduce:
38 | return loss.mean()
39 | else:
40 | return loss
41 |
42 |
43 |
44 | class DistillationLoss(torch.nn.Module):
45 | """
46 | This module wraps a standard criterion and adds an extra knowledge distillation loss by
47 | taking a teacher model prediction and using it as additional supervision.
48 | """
49 | def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module,
50 | distillation_type: str, alpha: float, tau: float):
51 | super().__init__()
52 | self.base_criterion = base_criterion
53 | self.teacher_model = teacher_model
54 | assert distillation_type in ['none', 'soft', 'hard']
55 | self.distillation_type = distillation_type
56 | self.alpha = alpha
57 | self.tau = tau
58 |
59 | def forward(self, inputs, outputs, labels):
60 | """
61 | Args:
62 | inputs: The original inputs that are feed to the teacher model
63 | outputs: the outputs of the model to be trained. It is expected to be
64 | either a Tensor, or a Tuple[Tensor, Tensor], with the original output
65 | in the first position and the distillation predictions as the second output
66 | labels: the labels for the base criterion
67 | """
68 | outputs_kd = None
69 | if not isinstance(outputs, torch.Tensor):
70 | # assume that the model outputs a tuple of [outputs, outputs_kd]
71 | outputs, outputs_kd = outputs
72 | base_loss = self.base_criterion(outputs, labels)
73 | if self.distillation_type == 'none':
74 | return base_loss
75 |
76 | if outputs_kd is None:
77 | raise ValueError("When knowledge distillation is enabled, the model is "
78 | "expected to return a Tuple[Tensor, Tensor] with the output of the "
79 | "class_token and the dist_token")
80 | # don't backprop throught the teacher
81 | with torch.no_grad():
82 | teacher_outputs = self.teacher_model(inputs)
83 |
84 | if self.distillation_type == 'soft':
85 | T = self.tau
86 | # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
87 | # with slight modifications
88 | distillation_loss = F.kl_div(
89 | F.log_softmax(outputs_kd / T, dim=1),
90 | #We provide the teacher's targets in log probability because we use log_target=True
91 | #(as recommended in pytorch https://github.com/pytorch/pytorch/blob/9324181d0ac7b4f7949a574dbc3e8be30abe7041/torch/nn/functional.py#L2719)
92 | #but it is possible to give just the probabilities and set log_target=False. In our experiments we tried both.
93 | F.log_softmax(teacher_outputs / T, dim=1),
94 | reduction='sum',
95 | log_target=True
96 | ) * (T * T) / outputs_kd.numel()
97 | #We divide by outputs_kd.numel() to have the legacy PyTorch behavior.
98 | #But we also experiments output_kd.size(0)
99 | #see issue 61(https://github.com/facebookresearch/deit/issues/61) for more details
100 | elif self.distillation_type == 'hard':
101 | distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1))
102 |
103 | loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha
104 | return loss
--------------------------------------------------------------------------------
/loss/cross_entropy.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class SoftTargetCrossEntropy(nn.Module):
7 | """
8 | The native CE loss with soft target
9 | input: x is output of model, target is ground truth
10 | return: loss
11 | """
12 | def __init__(self):
13 | super(SoftTargetCrossEntropy, self).__init__()
14 |
15 | def forward(self, x, target):
16 | N_rep = x.shape[0]
17 | N = target.shape[0]
18 | if not N == N_rep:
19 | target = target.repeat(N_rep // N, 1)
20 | loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1)
21 | return loss.mean()
22 |
23 |
24 | class TokenLabelGTCrossEntropy(nn.Module):
25 | """
26 | Token labeling dense loss with ground gruth, see more from token labeling
27 | input: x is output of model, target is ground truth
28 | return: loss
29 | """
30 | def __init__(self,
31 | dense_weight=1.0,
32 | cls_weight=1.0,
33 | mixup_active=True,
34 | smoothing=0.1,
35 | classes=1000):
36 | super(TokenLabelGTCrossEntropy, self).__init__()
37 |
38 | self.CE = SoftTargetCrossEntropy()
39 |
40 | self.dense_weight = dense_weight
41 | self.smoothing = smoothing
42 | self.mixup_active = mixup_active
43 | self.classes = classes
44 | self.cls_weight = cls_weight
45 | assert dense_weight + cls_weight > 0
46 |
47 | def forward(self, x, target):
48 |
49 | output, aux_output, bb = x
50 | bbx1, bby1, bbx2, bby2 = bb
51 |
52 | B, N, C = aux_output.shape
53 | if len(target.shape) == 2:
54 | target_cls = target
55 | target_aux = target.repeat(1, N).reshape(B * N, C)
56 | else:
57 | ground_truth = target[:, :, 0]
58 | target_cls = target[:, :, 1]
59 | ratio = (0.9 - 0.4 *
60 | (ground_truth.max(-1)[1] == target_cls.max(-1)[1])
61 | ).unsqueeze(-1)
62 | target_cls = target_cls * ratio + ground_truth * (1 - ratio)
63 | target_aux = target[:, :, 2:]
64 | target_aux = target_aux.transpose(1, 2).reshape(-1, C)
65 | lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / N)
66 | if lam < 1:
67 | target_cls = lam * target_cls + (1 - lam) * target_cls.flip(0)
68 |
69 | aux_output = aux_output.reshape(-1, C)
70 |
71 | loss_cls = self.CE(output, target_cls)
72 | loss_aux = self.CE(aux_output, target_aux)
73 |
74 | return self.cls_weight * loss_cls + self.dense_weight * loss_aux
75 |
76 |
77 | class TokenLabelSoftTargetCrossEntropy(nn.Module):
78 | """
79 | Token labeling dense loss with soft target, see more from token labeling
80 | input: x is output of model, target is ground truth
81 | return: loss
82 | """
83 | def __init__(self):
84 | super(TokenLabelSoftTargetCrossEntropy, self).__init__()
85 |
86 | def forward(self, x, target):
87 | N_rep = x.shape[0]
88 | N = target.shape[0]
89 | if not N == N_rep:
90 | target = target.repeat(N_rep // N, 1)
91 | if len(target.shape) == 3 and target.shape[-1] == 2:
92 | target = target[:, :, 1]
93 | loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1)
94 | return loss.mean()
95 |
96 |
97 | class TokenLabelCrossEntropy(nn.Module):
98 | """
99 | Token labeling loss without ground truth
100 | input: x is output of model, target is ground truth
101 | return: loss
102 | """
103 | def __init__(self,
104 | dense_weight=1.0,
105 | cls_weight=1.0,
106 | mixup_active=True,
107 | classes=1000):
108 | """
109 | Constructor Token labeling loss.
110 | """
111 | super(TokenLabelCrossEntropy, self).__init__()
112 |
113 | self.CE = SoftTargetCrossEntropy()
114 |
115 | self.dense_weight = dense_weight
116 | self.mixup_active = mixup_active
117 | self.classes = classes
118 | self.cls_weight = cls_weight
119 | assert dense_weight + cls_weight > 0
120 |
121 | def forward(self, x, target):
122 |
123 | output, aux_output, bb = x
124 | bbx1, bby1, bbx2, bby2 = bb
125 |
126 | B, N, C = aux_output.shape
127 | if len(target.shape) == 2:
128 | target_cls = target
129 | target_aux = target.repeat(1, N).reshape(B * N, C)
130 | else:
131 | target_cls = target[:, :, 1]
132 | target_aux = target[:, :, 2:]
133 | target_aux = target_aux.transpose(1, 2).reshape(-1, C)
134 | lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / N)
135 | if lam < 1:
136 | target_cls = lam * target_cls + (1 - lam) * target_cls.flip(0)
137 |
138 | aux_output = aux_output.reshape(-1, C)
139 | loss_cls = self.CE(output, target_cls)
140 | loss_aux = self.CE(aux_output, target_aux)
141 | return self.cls_weight * loss_cls + self.dense_weight * loss_aux
142 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | 
3 | # HAT
4 |
5 | Implementation of HAT https://arxiv.org/pdf/2204.00993
6 | ```
7 | @inproceedings{bai2022improving,
8 | title={Improving Vision Transformers by Revisiting High-frequency Components},
9 | author={Bai, Jiawang and Yuan, Li and Xia, Shu-Tao and Yan, Shuicheng and Li, Zhifeng and Liu, Wei},
10 | booktitle={European Conference on Computer Vision},
11 | year={2022}
12 | }
13 | ```
14 |
15 |
16 |
17 |
18 | ## Requirements
19 | torch>=1.7.0
20 | torchvision>=0.8.0
21 | timm==0.4.5
22 | tlt==0.1.0
23 | pyyaml
24 | apex-amp
25 |
26 | ## ImageNet Classification
27 |
28 | ### Data Preparation
29 | We use the ImageNet-1K training and validation datasets by default.
30 | Please save them in [your_imagenet_path].
31 |
32 |
33 | ### Training
34 | Training ViT models with HAT using the default settings in our paper on 8 GPUs:
35 |
36 | ```shell
37 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 \
38 | --data_dir [your_imagenet_path] \
39 | --model [your_vit_model_name] \
40 | --adv-epochs 200 \
41 | --adv-iters 3 \
42 | --adv-eps 0.00784314 \
43 | --adv-kl-weight 0.01 \
44 | --adv-ce-weight 3.0 \
45 | --output [your_output_path] \
46 | and_other_parameters_specified_for_your_vit_models...
47 | ```
48 |
49 | For instance, we train Swin-T with the following command:
50 | ```shell
51 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 \
52 | --data_dir [your_imagenet_path] \
53 | --model swin_tiny_patch4_window7_224 \
54 | --adv-epochs 200 \
55 | --adv-iters 3 \
56 | --adv-eps 0.00784314 \
57 | --adv-kl-weight 0.01 \
58 | --adv-ce-weight 3.0 \
59 | --output [your_output_path] \
60 | --batch-size 256 \
61 | --drop-path 0.2 \
62 | --lr 1e-3 \
63 | --weight-decay 0.05 \
64 | --clip-grad 1.0
65 | ```
66 | For training variants of ViT, Swin Transformer, VOLO, we use the hyper-parameters in [3], [4], and [2], respectively.
67 |
68 |
69 | We also combine HAT with knowledge distillation in [5], using [train_kd.py]().
70 |
71 | ### Validation
72 |
73 | After training, we can use validate.py to evaluate the ViT model trained with HAT.
74 |
75 | For instance, we evaluate Swin-T with the following command:
76 | ```shell
77 | python3 -u validate.py \
78 | --data_dir [your_imagenet_path] \
79 | --model swin_tiny_patch4_window7_224 \
80 | --checkpoint [your_checkpoint_path] \
81 | --batch-size 128 \
82 | --num-gpu 8 \
83 | --apex-amp \
84 | --results-file [your_results_file_path]
85 | ```
86 |
87 |
88 | ### Results
89 | | Model | Params | FLOPs | Test Size | Top-1 | +HAT Top-1 | Download |
90 | |:-:|:-:|:-:|:-:|:-:|:-:|:-:|
91 | | ViT-T | 5.7M | 1.6G | 224 | 72.2 | **73.3** | [link](https://github.com/jiawangbai/HAT/releases/download/v0.0.1/hat_vit_tiny_patch16_224.pth.tar)|
92 | | ViT-S | 22.1M | 4.7G | 224 | 80.1 | **80.9** |[link](https://github.com/jiawangbai/HAT/releases/download/v0.0.1/hat_vit_small_patch16_224.pth.tar)|
93 | | ViT-B | 86.6M | 17.6G | 224 | 82.0 | **83.2** |[link](https://github.com/jiawangbai/HAT/releases/download/v0.0.1/hat_vit_base_patch16_224.pth.tar)|
94 | | Swin-T | 28.3M | 4.5G | 224 | 81.2 | **82.0** |[link](https://github.com/jiawangbai/HAT/releases/download/v0.0.1/hat_swin_tiny_patch4_window7_224.pth.tar)|
95 | | Swin-S | 49.6M | 8.7G | 224 | 83.0 | **83.3** |[link](https://github.com/jiawangbai/HAT/releases/download/v0.0.1/hat_swin_small_patch4_window7_224.pth.tar)|
96 | | Swin-B | 87.8M | 15.4G | 224 | 83.5 | **84.0** |[link](https://github.com/jiawangbai/HAT/releases/download/v0.0.1/hat_swin_base_patch4_window7_224.pth.tar)|
97 | | VOLO-D1 | 26.6M | 6.8G | 224 | 84.2 | **84.5** |[link](https://github.com/jiawangbai/HAT/releases/download/v0.0.1/hat_volo_d1_224.pth.tar)|
98 | | VOLO-D1 | 26.6M | 22.8G | 384 | 85.2 | **85.5** |[link](https://github.com/jiawangbai/HAT/releases/download/v0.0.1/hat_volo_d1_384.pth.tar)|
99 | | VOLO-D5 | 295.5M | 69.0G | 224 | 86.1 | **86.3** |[link](https://github.com/jiawangbai/HAT/releases/download/v0.0.1/hat_volo_d5_224.pth.tar)|
100 | | VOLO-D5 | 295.5M | 304G | 448 | 87.0 | **87.2** |[link](https://github.com/jiawangbai/HAT/releases/download/v0.0.1/hat_volo_d5_448.pth.tar)|
101 | | VOLO-D5 | 295.5M | 412G | 512 | 87.1 | **87.3** |[link](https://github.com/jiawangbai/HAT/releases/download/v0.0.1/hat_volo_d5_512.pth.tar)|
102 |
103 | The result of combining HAT with knowledge distillation in [5] is 84.3% for ViT-B, and it can be downloaded [here](https://github.com/jiawangbai/HAT/releases/download/v0.0.1/hat_deit_base_distilled_patch16_224.pth.tar).
104 |
105 | ## Downstream Tasks
106 |
107 | We first pretrain Swin-T/S/B on the ImageNet-1k dataset with our proposed HAT, and then transfer the models to the downstream tasks, including object detection, instance segmentation, and semantic segmentation.
108 |
109 | We use the codes in [Swin Transformer for Object Detection](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection) and [Swin Transformer for Semantic Segmentaion](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation), and follow their configurations.
110 |
111 | ### Cascade Mask R-CNN on COCO val 2017
112 | | Backbone | Params | FLOPs | Config| AP_box | +HAT AP_box | AP_mask | +HAT AP_mask |
113 | |:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|
114 | | Swin-T | 86M | 745G | [config](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/master/configs/swin/cascade_mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py) | 50.5 | **50.9** |43.7| **43.9** |
115 | | Swin-S | 107M | 838G | [config](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/master/configs/swin/cascade_mask_rcnn_swin_small_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py) | 51.8 | **52.5** |44.7| **45.4** |
116 | | Swin-B | 145M | 982G | [config](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/master/configs/swin/cascade_mask_rcnn_swin_base_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py) | 51.9 | **52.8** |45.0| **45.6** |
117 |
118 | ### UperNet on ADE20K
119 | | Backbone | Params | FLOPs | Config| mIoU(MS) | +HAT mIoU(MS) |
120 | |:-:|:-:|:-:|:-:|:-:|:-:|
121 | | Swin-T | 60M | 945G | [config](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation/blob/main/configs/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k.py) | 46.1 | **46.7** |
122 | | Swin-S | 81M | 1038G | [config](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation/blob/main/configs/swin/upernet_swin_small_patch4_window7_512x512_160k_ade20k.py) | 49.5 | **49.7** |
123 | | Swin-B | 121M | 1088G | [config](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation/blob/main/configs/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k.py) | 49.7 | **50.3** |
124 |
125 |
126 | [1] Wightman, R. Pytorch image models. https://github.com/rwightman/pytorch-image-models , 2019.
127 | [2] Yuan, L. et al. Volo: Vision outlooker for visual recognition. arXiv, 2021.
128 | [3] Dosovitskiy, A. et al. An image is worth 16x16 words: Transformers for image recognition at scale. ICLR, 2020.
129 | [4] Liu, Z. et al. Swin transformer: Hierarchical vision transformer using shifted windows. ICCV, 2021.
130 | [5] Touvron H. et al. Training data-efficient image transformers & distillation through attention. ICML, 2021.
131 |
--------------------------------------------------------------------------------
/validate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import csv
4 | import glob
5 | import time
6 | import logging
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.parallel
10 | from collections import OrderedDict
11 | from contextlib import suppress
12 |
13 | from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
14 | from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
15 | from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_legacy
16 | import models
17 |
18 | has_apex = False
19 | try:
20 | from apex import amp
21 | has_apex = True
22 | except ImportError:
23 | pass
24 |
25 | has_native_amp = False
26 | try:
27 | if getattr(torch.cuda.amp, 'autocast') is not None:
28 | has_native_amp = True
29 | except AttributeError:
30 | pass
31 |
32 | torch.backends.cudnn.benchmark = True
33 | _logger = logging.getLogger('validate')
34 |
35 |
36 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
37 | parser.add_argument('--data_dir', metavar='DIR',
38 | help='path to dataset')
39 | parser.add_argument('--dataset', '-d', metavar='NAME', default='',
40 | help='dataset type (default: ImageFolder/ImageTar if empty)')
41 | parser.add_argument('--split', metavar='NAME', default='validation',
42 | help='dataset split (default: validation)')
43 | parser.add_argument('--model', '-m', metavar='NAME', default='dpn92',
44 | help='model architecture (default: dpn92)')
45 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
46 | help='number of data loading workers (default: 2)')
47 | parser.add_argument('-b', '--batch-size', default=256, type=int,
48 | metavar='N', help='mini-batch size (default: 256)')
49 | parser.add_argument('--img-size', default=None, type=int,
50 | metavar='N', help='Input image dimension, uses model default if empty')
51 | parser.add_argument('--input-size', default=None, nargs=3, type=int,
52 | metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224),'
53 | ' uses model default if empty')
54 | parser.add_argument('--crop-pct', default=None, type=float,
55 | metavar='N', help='Input image center crop pct')
56 | parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
57 | help='Override mean pixel value of dataset')
58 | parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
59 | help='Override std deviation of of dataset')
60 | parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
61 | help='Image resize interpolation type (overrides model)')
62 | parser.add_argument('--num-classes', type=int, default=None,
63 | help='Number classes in dataset')
64 | parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
65 | help='path to class to idx mapping file (default: "")')
66 | parser.add_argument('--gp', default=None, type=str, metavar='POOL',
67 | help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
68 | parser.add_argument('--log-freq', default=50, type=int,
69 | metavar='N', help='batch logging frequency (default: 10)')
70 | parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
71 | help='path to latest checkpoint (default: none)')
72 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
73 | help='use pre-trained model')
74 | parser.add_argument('--num-gpu', type=int, default=1,
75 | help='Number of GPUS to use')
76 | parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true',
77 | help='disable test time pool')
78 | parser.add_argument('--no-prefetcher', action='store_true', default=False,
79 | help='disable fast prefetcher')
80 | parser.add_argument('--pin-mem', action='store_true', default=False,
81 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
82 | parser.add_argument('--channels-last', action='store_true', default=False,
83 | help='Use channels_last memory layout')
84 | parser.add_argument('--amp', action='store_true', default=False,
85 | help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.')
86 | parser.add_argument('--apex-amp', action='store_true', default=False,
87 | help='Use NVIDIA Apex AMP mixed precision')
88 | parser.add_argument('--native-amp', action='store_true', default=False,
89 | help='Use Native Torch AMP mixed precision')
90 | parser.add_argument('--tf-preprocessing', action='store_true', default=False,
91 | help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
92 | parser.add_argument('--use-ema', dest='use_ema', action='store_true',
93 | help='use ema version of weights if present')
94 | parser.add_argument('--torchscript', dest='torchscript', action='store_true',
95 | help='convert model torchscript for inference')
96 | parser.add_argument('--legacy-jit', dest='legacy_jit', action='store_true',
97 | help='use legacy jit mode for pytorch 1.5/1.5.1/1.6 to get back fusion performance')
98 | parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
99 | help='Output csv file for validation results (summary)')
100 | parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME',
101 | help='Real labels JSON file for imagenet evaluation')
102 | parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME',
103 | help='Valid label indices txt file for validation of partial label space')
104 |
105 |
106 | def validate(args):
107 | # might as well try to validate something
108 | args.pretrained = args.pretrained or not args.checkpoint
109 | args.prefetcher = not args.no_prefetcher
110 | amp_autocast = suppress # do nothing
111 | if args.amp:
112 | if has_native_amp:
113 | args.native_amp = True
114 | elif has_apex:
115 | args.apex_amp = True
116 | else:
117 | _logger.warning("Neither APEX or Native Torch AMP is available.")
118 | assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set."
119 | if args.native_amp:
120 | amp_autocast = torch.cuda.amp.autocast
121 | _logger.info('Validating in mixed precision with native PyTorch AMP.')
122 | elif args.apex_amp:
123 | _logger.info('Validating in mixed precision with NVIDIA APEX AMP.')
124 | else:
125 | _logger.info('Validating in float32. AMP not enabled.')
126 |
127 | if args.legacy_jit:
128 | set_jit_legacy()
129 |
130 | # create model
131 | model = create_model(
132 | args.model,
133 | pretrained=args.pretrained,
134 | num_classes=args.num_classes,
135 | in_chans=3,
136 | global_pool=args.gp,
137 | scriptable=args.torchscript,
138 | img_size=args.img_size)
139 | if args.num_classes is None:
140 | assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
141 | args.num_classes = model.num_classes
142 |
143 | if args.checkpoint:
144 | load_checkpoint(model, args.checkpoint, args.use_ema, strict=False)
145 |
146 | param_count = sum([m.numel() for m in model.parameters()])
147 | _logger.info('Model %s created, param count: %d' % (args.model, param_count))
148 |
149 | data_config = resolve_data_config(vars(args), model=model, use_test_size=True)
150 | test_time_pool = False
151 | if not args.no_test_pool:
152 | model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True)
153 |
154 | if args.torchscript:
155 | torch.jit.optimized_execution(True)
156 | model = torch.jit.script(model)
157 |
158 | model = model.cuda()
159 | if args.apex_amp:
160 | model = amp.initialize(model, opt_level='O1')
161 |
162 | if args.channels_last:
163 | model = model.to(memory_format=torch.channels_last)
164 |
165 | if args.num_gpu > 1:
166 | model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))
167 |
168 | criterion = nn.CrossEntropyLoss().cuda()
169 |
170 | dataset = create_dataset(
171 | root=args.data_dir, name=args.dataset, split=args.split,
172 | load_bytes=args.tf_preprocessing, class_map=args.class_map)
173 |
174 | if args.valid_labels:
175 | with open(args.valid_labels, 'r') as f:
176 | valid_labels = {int(line.rstrip()) for line in f}
177 | valid_labels = [i in valid_labels for i in range(args.num_classes)]
178 | else:
179 | valid_labels = None
180 |
181 | if args.real_labels:
182 | real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels)
183 | else:
184 | real_labels = None
185 |
186 | crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
187 | loader = create_loader(
188 | dataset,
189 | input_size=data_config['input_size'],
190 | batch_size=args.batch_size,
191 | use_prefetcher=args.prefetcher,
192 | interpolation=data_config['interpolation'],
193 | mean=data_config['mean'],
194 | std=data_config['std'],
195 | num_workers=args.workers,
196 | crop_pct=crop_pct,
197 | pin_memory=args.pin_mem,
198 | tf_preprocessing=args.tf_preprocessing)
199 |
200 | batch_time = AverageMeter()
201 | losses = AverageMeter()
202 | top1 = AverageMeter()
203 | top5 = AverageMeter()
204 |
205 | model.eval()
206 | with torch.no_grad():
207 | # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
208 | input = torch.randn((args.batch_size,) + data_config['input_size']).cuda()
209 | if args.channels_last:
210 | input = input.contiguous(memory_format=torch.channels_last)
211 | model(input)
212 | end = time.time()
213 | for batch_idx, (input, target) in enumerate(loader):
214 | if args.no_prefetcher:
215 | target = target.cuda()
216 | input = input.cuda()
217 | if args.channels_last:
218 | input = input.contiguous(memory_format=torch.channels_last)
219 |
220 | # compute output
221 | with amp_autocast():
222 | output = model(input)
223 | if isinstance(output, (tuple, list)):
224 | output = output[0]
225 | if valid_labels is not None:
226 | output = output[:, valid_labels]
227 | loss = criterion(output, target)
228 |
229 | if real_labels is not None:
230 | real_labels.add_result(output)
231 |
232 | # measure accuracy and record loss
233 | acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))
234 | losses.update(loss.item(), input.size(0))
235 | top1.update(acc1.item(), input.size(0))
236 | top5.update(acc5.item(), input.size(0))
237 |
238 | # measure elapsed time
239 | batch_time.update(time.time() - end)
240 | end = time.time()
241 |
242 | if batch_idx % args.log_freq == 0:
243 | _logger.info(
244 | 'Test: [{0:>4d}/{1}] '
245 | 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
246 | 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
247 | 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) '
248 | 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
249 | batch_idx, len(loader), batch_time=batch_time,
250 | rate_avg=input.size(0) / batch_time.avg,
251 | loss=losses, top1=top1, top5=top5))
252 |
253 | if real_labels is not None:
254 | # real labels mode replaces topk values at the end
255 | top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5)
256 | else:
257 | top1a, top5a = top1.avg, top5.avg
258 | results = OrderedDict(
259 | top1=round(top1a, 4), top1_err=round(100 - top1a, 4),
260 | top5=round(top5a, 4), top5_err=round(100 - top5a, 4),
261 | param_count=round(param_count / 1e6, 2),
262 | img_size=data_config['input_size'][-1],
263 | cropt_pct=crop_pct,
264 | interpolation=data_config['interpolation'])
265 |
266 | _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
267 | results['top1'], results['top1_err'], results['top5'], results['top5_err']))
268 |
269 | return results
270 |
271 |
272 | def main():
273 | setup_default_logging()
274 | args = parser.parse_args()
275 | model_cfgs = []
276 | model_names = []
277 | if os.path.isdir(args.checkpoint):
278 | # validate all checkpoints in a path with same model
279 | checkpoints = glob.glob(args.checkpoint + '/*.pth.tar')
280 | checkpoints += glob.glob(args.checkpoint + '/*.pth')
281 | model_names = list_models(args.model)
282 | model_cfgs = [(args.model, c) for c in sorted(checkpoints, key=natural_key)]
283 | else:
284 | if args.model == 'all':
285 | # validate all models in a list of names with pretrained checkpoints
286 | args.pretrained = True
287 | model_names = list_models(pretrained=True, exclude_filters=['*in21k'])
288 | model_cfgs = [(n, '') for n in model_names]
289 | elif not is_model(args.model):
290 | # model name doesn't exist, try as wildcard filter
291 | model_names = list_models(args.model)
292 | model_cfgs = [(n, '') for n in model_names]
293 |
294 | if len(model_cfgs):
295 | results_file = args.results_file or './results-all.csv'
296 | _logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names)))
297 | results = []
298 | try:
299 | start_batch_size = args.batch_size
300 | for m, c in model_cfgs:
301 | batch_size = start_batch_size
302 | args.model = m
303 | args.checkpoint = c
304 | result = OrderedDict(model=args.model)
305 | r = {}
306 | while not r and batch_size >= args.num_gpu:
307 | torch.cuda.empty_cache()
308 | try:
309 | args.batch_size = batch_size
310 | print('Validating with batch size: %d' % args.batch_size)
311 | r = validate(args)
312 | except RuntimeError as e:
313 | if batch_size <= args.num_gpu:
314 | print("Validation failed with no ability to reduce batch size. Exiting.")
315 | raise e
316 | batch_size = max(batch_size // 2, args.num_gpu)
317 | print("Validation failed, reducing batch size by 50%")
318 | result.update(r)
319 | if args.checkpoint:
320 | result['checkpoint'] = args.checkpoint
321 | results.append(result)
322 | except KeyboardInterrupt as e:
323 | pass
324 | results = sorted(results, key=lambda x: x['top1'], reverse=True)
325 | if len(results):
326 | write_results(results_file, results)
327 | else:
328 | validate(args)
329 |
330 | def write_results(results_file, results):
331 | with open(results_file, mode='w') as cf:
332 | dw = csv.DictWriter(cf, fieldnames=results[0].keys())
333 | dw.writeheader()
334 | for r in results:
335 | dw.writerow(r)
336 | cf.flush()
337 |
338 | if __name__ == '__main__':
339 | main()
--------------------------------------------------------------------------------
/models/volo.py:
--------------------------------------------------------------------------------
1 | """
2 | Vision OutLOoker (VOLO) implementation
3 | """
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
9 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
10 | from timm.models.registry import register_model
11 | import math
12 | import numpy as np
13 |
14 |
15 | def _cfg(url='', **kwargs):
16 | return {
17 | 'url': url,
18 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
19 | 'crop_pct': .96, 'interpolation': 'bicubic',
20 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
21 | 'first_conv': 'patch_embed.proj', 'classifier': 'head',
22 | **kwargs
23 | }
24 |
25 |
26 | default_cfgs = {
27 | 'volo': _cfg(crop_pct=0.96),
28 | 'volo_large': _cfg(crop_pct=1.15),
29 | }
30 |
31 |
32 | class OutlookAttention(nn.Module):
33 | """
34 | Implementation of outlook attention
35 | --dim: hidden dim
36 | --num_heads: number of heads
37 | --kernel_size: kernel size in each window for outlook attention
38 | return: token features after outlook attention
39 | """
40 |
41 | def __init__(self, dim, num_heads, kernel_size=3, padding=1, stride=1,
42 | qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
43 | super().__init__()
44 | head_dim = dim // num_heads
45 | self.num_heads = num_heads
46 | self.kernel_size = kernel_size
47 | self.padding = padding
48 | self.stride = stride
49 | self.scale = qk_scale or head_dim**-0.5
50 |
51 | self.v = nn.Linear(dim, dim, bias=qkv_bias)
52 | self.attn = nn.Linear(dim, kernel_size**4 * num_heads)
53 |
54 | self.attn_drop = nn.Dropout(attn_drop)
55 | self.proj = nn.Linear(dim, dim)
56 | self.proj_drop = nn.Dropout(proj_drop)
57 |
58 | self.unfold = nn.Unfold(kernel_size=kernel_size, padding=padding, stride=stride)
59 | self.pool = nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True)
60 |
61 | def forward(self, x):
62 | B, H, W, C = x.shape
63 |
64 | v = self.v(x).permute(0, 3, 1, 2) # B, C, H, W
65 |
66 | h, w = math.ceil(H / self.stride), math.ceil(W / self.stride)
67 | v = self.unfold(v).reshape(B, self.num_heads, C // self.num_heads,
68 | self.kernel_size * self.kernel_size,
69 | h * w).permute(0, 1, 4, 3, 2) # B,H,N,kxk,C/H
70 |
71 | attn = self.pool(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
72 | attn = self.attn(attn).reshape(
73 | B, h * w, self.num_heads, self.kernel_size * self.kernel_size,
74 | self.kernel_size * self.kernel_size).permute(0, 2, 1, 3, 4) # B,H,N,kxk,kxk
75 | attn = attn * self.scale
76 | attn = attn.softmax(dim=-1)
77 | attn = self.attn_drop(attn)
78 |
79 | x = (attn @ v).permute(0, 1, 4, 3, 2).reshape(
80 | B, C * self.kernel_size * self.kernel_size, h * w)
81 | x = F.fold(x, output_size=(H, W), kernel_size=self.kernel_size,
82 | padding=self.padding, stride=self.stride)
83 |
84 | x = self.proj(x.permute(0, 2, 3, 1))
85 | x = self.proj_drop(x)
86 |
87 | return x
88 |
89 |
90 | class Outlooker(nn.Module):
91 | """
92 | Implementation of outlooker layer: which includes outlook attention + MLP
93 | Outlooker is the first stage in our VOLO
94 | --dim: hidden dim
95 | --num_heads: number of heads
96 | --mlp_ratio: mlp ratio
97 | --kernel_size: kernel size in each window for outlook attention
98 | return: outlooker layer
99 | """
100 | def __init__(self, dim, kernel_size, padding, stride=1,
101 | num_heads=1,mlp_ratio=3., attn_drop=0.,
102 | drop_path=0., act_layer=nn.GELU,
103 | norm_layer=nn.LayerNorm, qkv_bias=False,
104 | qk_scale=None):
105 | super().__init__()
106 | self.norm1 = norm_layer(dim)
107 | self.attn = OutlookAttention(dim, num_heads, kernel_size=kernel_size,
108 | padding=padding, stride=stride,
109 | qkv_bias=qkv_bias, qk_scale=qk_scale,
110 | attn_drop=attn_drop)
111 |
112 | self.drop_path = DropPath(
113 | drop_path) if drop_path > 0. else nn.Identity()
114 |
115 | self.norm2 = norm_layer(dim)
116 | mlp_hidden_dim = int(dim * mlp_ratio)
117 | self.mlp = Mlp(in_features=dim,
118 | hidden_features=mlp_hidden_dim,
119 | act_layer=act_layer)
120 |
121 | def forward(self, x):
122 | x = x + self.drop_path(self.attn(self.norm1(x)))
123 | x = x + self.drop_path(self.mlp(self.norm2(x)))
124 | return x
125 |
126 |
127 | class Mlp(nn.Module):
128 | "Implementation of MLP"
129 |
130 | def __init__(self, in_features, hidden_features=None,
131 | out_features=None, act_layer=nn.GELU,
132 | drop=0.):
133 | super().__init__()
134 | out_features = out_features or in_features
135 | hidden_features = hidden_features or in_features
136 | self.fc1 = nn.Linear(in_features, hidden_features)
137 | self.act = act_layer()
138 | self.fc2 = nn.Linear(hidden_features, out_features)
139 | self.drop = nn.Dropout(drop)
140 |
141 | def forward(self, x):
142 | x = self.fc1(x)
143 | x = self.act(x)
144 | x = self.drop(x)
145 | x = self.fc2(x)
146 | x = self.drop(x)
147 | return x
148 |
149 |
150 | class Attention(nn.Module):
151 | "Implementation of self-attention"
152 |
153 | def __init__(self, dim, num_heads=8, qkv_bias=False,
154 | qk_scale=None, attn_drop=0., proj_drop=0.):
155 | super().__init__()
156 | self.num_heads = num_heads
157 | head_dim = dim // num_heads
158 | self.scale = qk_scale or head_dim**-0.5
159 |
160 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
161 | self.attn_drop = nn.Dropout(attn_drop)
162 | self.proj = nn.Linear(dim, dim)
163 | self.proj_drop = nn.Dropout(proj_drop)
164 |
165 | def forward(self, x):
166 | B, H, W, C = x.shape
167 |
168 | qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads,
169 | C // self.num_heads).permute(2, 0, 3, 1, 4)
170 | q, k, v = qkv[0], qkv[1], qkv[
171 | 2] # make torchscript happy (cannot use tensor as tuple)
172 |
173 | attn = (q @ k.transpose(-2, -1)) * self.scale
174 | attn = attn.softmax(dim=-1)
175 | attn = self.attn_drop(attn)
176 |
177 | x = (attn @ v).transpose(1, 2).reshape(B, H, W, C)
178 | x = self.proj(x)
179 | x = self.proj_drop(x)
180 |
181 | return x
182 |
183 |
184 | class Transformer(nn.Module):
185 | """
186 | Implementation of Transformer,
187 | Transformer is the second stage in our VOLO
188 | """
189 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False,
190 | qk_scale=None, attn_drop=0., drop_path=0.,
191 | act_layer=nn.GELU, norm_layer=nn.LayerNorm):
192 | super().__init__()
193 | self.norm1 = norm_layer(dim)
194 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias,
195 | qk_scale=qk_scale, attn_drop=attn_drop)
196 |
197 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
198 | self.drop_path = DropPath(
199 | drop_path) if drop_path > 0. else nn.Identity()
200 |
201 | self.norm2 = norm_layer(dim)
202 | mlp_hidden_dim = int(dim * mlp_ratio)
203 | self.mlp = Mlp(in_features=dim,
204 | hidden_features=mlp_hidden_dim,
205 | act_layer=act_layer)
206 |
207 | def forward(self, x):
208 | x = x + self.drop_path(self.attn(self.norm1(x)))
209 | x = x + self.drop_path(self.mlp(self.norm2(x)))
210 | return x
211 |
212 |
213 | class ClassAttention(nn.Module):
214 | """
215 | Class attention layer from CaiT, see details in CaiT
216 | Class attention is the post stage in our VOLO, which is optional.
217 | """
218 | def __init__(self, dim, num_heads=8, head_dim=None, qkv_bias=False,
219 | qk_scale=None, attn_drop=0., proj_drop=0.):
220 | super().__init__()
221 | self.num_heads = num_heads
222 | if head_dim is not None:
223 | self.head_dim = head_dim
224 | else:
225 | head_dim = dim // num_heads
226 | self.head_dim = head_dim
227 | self.scale = qk_scale or head_dim**-0.5
228 |
229 | self.kv = nn.Linear(dim,
230 | self.head_dim * self.num_heads * 2,
231 | bias=qkv_bias)
232 | self.q = nn.Linear(dim, self.head_dim * self.num_heads, bias=qkv_bias)
233 | self.attn_drop = nn.Dropout(attn_drop)
234 | self.proj = nn.Linear(self.head_dim * self.num_heads, dim)
235 | self.proj_drop = nn.Dropout(proj_drop)
236 |
237 | def forward(self, x):
238 | B, N, C = x.shape
239 |
240 | kv = self.kv(x).reshape(B, N, 2, self.num_heads,
241 | self.head_dim).permute(2, 0, 3, 1, 4)
242 | k, v = kv[0], kv[
243 | 1] # make torchscript happy (cannot use tensor as tuple)
244 | q = self.q(x[:, :1, :]).reshape(B, self.num_heads, 1, self.head_dim)
245 | attn = ((q * self.scale) @ k.transpose(-2, -1))
246 | attn = attn.softmax(dim=-1)
247 | attn = self.attn_drop(attn)
248 |
249 | cls_embed = (attn @ v).transpose(1, 2).reshape(
250 | B, 1, self.head_dim * self.num_heads)
251 | cls_embed = self.proj(cls_embed)
252 | cls_embed = self.proj_drop(cls_embed)
253 | return cls_embed
254 |
255 |
256 | class ClassBlock(nn.Module):
257 | """
258 | Class attention block from CaiT, see details in CaiT
259 | We use two-layers class attention in our VOLO, which is optional.
260 | """
261 |
262 | def __init__(self, dim, num_heads, head_dim=None, mlp_ratio=4.,
263 | qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
264 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
265 | super().__init__()
266 | self.norm1 = norm_layer(dim)
267 | self.attn = ClassAttention(
268 | dim, num_heads=num_heads, head_dim=head_dim, qkv_bias=qkv_bias,
269 | qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
270 | # NOTE: drop path for stochastic depth
271 | self.drop_path = DropPath(
272 | drop_path) if drop_path > 0. else nn.Identity()
273 | self.norm2 = norm_layer(dim)
274 | mlp_hidden_dim = int(dim * mlp_ratio)
275 | self.mlp = Mlp(in_features=dim,
276 | hidden_features=mlp_hidden_dim,
277 | act_layer=act_layer,
278 | drop=drop)
279 |
280 | def forward(self, x):
281 | cls_embed = x[:, :1]
282 | cls_embed = cls_embed + self.drop_path(self.attn(self.norm1(x)))
283 | cls_embed = cls_embed + self.drop_path(self.mlp(self.norm2(cls_embed)))
284 | return torch.cat([cls_embed, x[:, 1:]], dim=1)
285 |
286 |
287 | def get_block(block_type, **kargs):
288 | """
289 | get block by name, specifically for class attention block in here
290 | """
291 | if block_type == 'ca':
292 | return ClassBlock(**kargs)
293 |
294 |
295 | def rand_bbox(size, lam, scale=1):
296 | """
297 | get bounding box as token labeling (https://github.com/zihangJiang/TokenLabeling)
298 | return: bounding box
299 | """
300 | W = size[1] // scale
301 | H = size[2] // scale
302 | cut_rat = np.sqrt(1. - lam)
303 | cut_w = np.int(W * cut_rat)
304 | cut_h = np.int(H * cut_rat)
305 |
306 | # uniform
307 | cx = np.random.randint(W)
308 | cy = np.random.randint(H)
309 |
310 | bbx1 = np.clip(cx - cut_w // 2, 0, W)
311 | bby1 = np.clip(cy - cut_h // 2, 0, H)
312 | bbx2 = np.clip(cx + cut_w // 2, 0, W)
313 | bby2 = np.clip(cy + cut_h // 2, 0, H)
314 |
315 | return bbx1, bby1, bbx2, bby2
316 |
317 |
318 | class PatchEmbed(nn.Module):
319 | """
320 | Image to Patch Embedding.
321 | Different with ViT use 1 conv layer, we use 4 conv layers to do patch embedding
322 | """
323 |
324 | def __init__(self, img_size=224, stem_conv=False, stem_stride=1,
325 | patch_size=8, in_chans=3, hidden_dim=64, embed_dim=384):
326 | super().__init__()
327 | assert patch_size in [4, 8, 16]
328 |
329 | self.stem_conv = stem_conv
330 | if stem_conv:
331 | self.conv = nn.Sequential(
332 | nn.Conv2d(in_chans, hidden_dim, kernel_size=7, stride=stem_stride,
333 | padding=3, bias=False), # 112x112
334 | nn.BatchNorm2d(hidden_dim),
335 | nn.ReLU(inplace=True),
336 | nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1,
337 | padding=1, bias=False), # 112x112
338 | nn.BatchNorm2d(hidden_dim),
339 | nn.ReLU(inplace=True),
340 | nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1,
341 | padding=1, bias=False), # 112x112
342 | nn.BatchNorm2d(hidden_dim),
343 | nn.ReLU(inplace=True),
344 | )
345 |
346 | self.proj = nn.Conv2d(hidden_dim,
347 | embed_dim,
348 | kernel_size=patch_size // stem_stride,
349 | stride=patch_size // stem_stride)
350 | self.num_patches = (img_size // patch_size) * (img_size // patch_size)
351 |
352 | def forward(self, x):
353 | if self.stem_conv:
354 | x = self.conv(x)
355 | x = self.proj(x) # B, C, H, W
356 | return x
357 |
358 |
359 | class Downsample(nn.Module):
360 | """
361 | Image to Patch Embedding, downsampling between stage1 and stage2
362 | """
363 | def __init__(self, in_embed_dim, out_embed_dim, patch_size):
364 | super().__init__()
365 | self.proj = nn.Conv2d(in_embed_dim, out_embed_dim,
366 | kernel_size=patch_size, stride=patch_size)
367 |
368 | def forward(self, x):
369 | x = x.permute(0, 3, 1, 2)
370 | x = self.proj(x) # B, C, H, W
371 | x = x.permute(0, 2, 3, 1)
372 | return x
373 |
374 |
375 | def outlooker_blocks(block_fn, index, dim, layers, num_heads=1, kernel_size=3,
376 | padding=1,stride=1, mlp_ratio=3., qkv_bias=False, qk_scale=None,
377 | attn_drop=0, drop_path_rate=0., **kwargs):
378 | """
379 | generate outlooker layer in stage1
380 | return: outlooker layers
381 | """
382 | blocks = []
383 | for block_idx in range(layers[index]):
384 | block_dpr = drop_path_rate * (block_idx +
385 | sum(layers[:index])) / (sum(layers) - 1)
386 | blocks.append(block_fn(dim, kernel_size=kernel_size, padding=padding,
387 | stride=stride, num_heads=num_heads, mlp_ratio=mlp_ratio,
388 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
389 | drop_path=block_dpr))
390 |
391 | blocks = nn.Sequential(*blocks)
392 |
393 | return blocks
394 |
395 |
396 | def transformer_blocks(block_fn, index, dim, layers, num_heads, mlp_ratio=3.,
397 | qkv_bias=False, qk_scale=None, attn_drop=0,
398 | drop_path_rate=0., **kwargs):
399 | """
400 | generate transformer layers in stage2
401 | return: transformer layers
402 | """
403 | blocks = []
404 | for block_idx in range(layers[index]):
405 | block_dpr = drop_path_rate * (block_idx +
406 | sum(layers[:index])) / (sum(layers) - 1)
407 | blocks.append(
408 | block_fn(dim, num_heads,
409 | mlp_ratio=mlp_ratio,
410 | qkv_bias=qkv_bias,
411 | qk_scale=qk_scale,
412 | attn_drop=attn_drop,
413 | drop_path=block_dpr))
414 |
415 | blocks = nn.Sequential(*blocks)
416 |
417 | return blocks
418 |
419 |
420 | class VOLO(nn.Module):
421 | """
422 | Vision Outlooker, the main class of our model
423 | --layers: [x,x,x,x], four blocks in two stages, the first block is outlooker, the
424 | other three are transformer, we set four blocks, which are easily
425 | applied to downstream tasks
426 | --img_size, --in_chans, --num_classes: these three are very easy to understand
427 | --patch_size: patch_size in outlook attention
428 | --stem_hidden_dim: hidden dim of patch embedding, d1-d4 is 64, d5 is 128
429 | --embed_dims, --num_heads: embedding dim, number of heads in each block
430 | --downsamples: flags to apply downsampling or not
431 | --outlook_attention: flags to apply outlook attention or not
432 | --mlp_ratios, --qkv_bias, --qk_scale, --drop_rate: easy to undertand
433 | --attn_drop_rate, --drop_path_rate, --norm_layer: easy to undertand
434 | --post_layers: post layers like two class attention layers using [ca, ca],
435 | if yes, return_mean=False
436 | --return_mean: use mean of all feature tokens for classification, if yes, no class token
437 | --return_dense: use token labeling, details are here:
438 | https://github.com/zihangJiang/TokenLabeling
439 | --mix_token: mixing tokens as token labeling, details are here:
440 | https://github.com/zihangJiang/TokenLabeling
441 | --pooling_scale: pooling_scale=2 means we downsample 2x
442 | --out_kernel, --out_stride, --out_padding: kerner size,
443 | stride, and padding for outlook attention
444 | """
445 | def __init__(self, layers, img_size=224, in_chans=3, num_classes=1000, patch_size=8,
446 | stem_hidden_dim=64, embed_dims=None, num_heads=None, downsamples=None,
447 | outlook_attention=None, mlp_ratios=None, qkv_bias=False, qk_scale=None,
448 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
449 | post_layers=None, return_mean=False, return_dense=True, mix_token=True,
450 | pooling_scale=2, out_kernel=3, out_stride=2, out_padding=1):
451 |
452 | super().__init__()
453 | self.num_classes = num_classes
454 | self.patch_embed = PatchEmbed(stem_conv=True, stem_stride=2, patch_size=patch_size,
455 | in_chans=in_chans, hidden_dim=stem_hidden_dim,
456 | embed_dim=embed_dims[0])
457 |
458 | # inital positional encoding, we add positional encoding after outlooker blocks
459 | self.pos_embed = nn.Parameter(
460 | torch.zeros(1, img_size // patch_size // pooling_scale,
461 | img_size // patch_size // pooling_scale,
462 | embed_dims[-1]))
463 |
464 | self.pos_drop = nn.Dropout(p=drop_rate)
465 |
466 | # set the main block in network
467 | network = []
468 | for i in range(len(layers)):
469 | if outlook_attention[i]:
470 | # stage 1
471 | stage = outlooker_blocks(Outlooker, i, embed_dims[i], layers,
472 | downsample=downsamples[i], num_heads=num_heads[i],
473 | kernel_size=out_kernel, stride=out_stride,
474 | padding=out_padding, mlp_ratio=mlp_ratios[i],
475 | qkv_bias=qkv_bias, qk_scale=qk_scale,
476 | attn_drop=attn_drop_rate, norm_layer=norm_layer)
477 | network.append(stage)
478 | else:
479 | # stage 2
480 | stage = transformer_blocks(Transformer, i, embed_dims[i], layers,
481 | num_heads[i], mlp_ratio=mlp_ratios[i],
482 | qkv_bias=qkv_bias, qk_scale=qk_scale,
483 | drop_path_rate=drop_path_rate,
484 | attn_drop=attn_drop_rate,
485 | norm_layer=norm_layer)
486 | network.append(stage)
487 |
488 | if downsamples[i]:
489 | # downsampling between two stages
490 | network.append(Downsample(embed_dims[i], embed_dims[i + 1], 2))
491 |
492 | self.network = nn.ModuleList(network)
493 |
494 | # set post block, for example, class attention layers
495 | self.post_network = None
496 | if post_layers is not None:
497 | self.post_network = nn.ModuleList([
498 | get_block(post_layers[i],
499 | dim=embed_dims[-1],
500 | num_heads=num_heads[-1],
501 | mlp_ratio=mlp_ratios[-1],
502 | qkv_bias=qkv_bias,
503 | qk_scale=qk_scale,
504 | attn_drop=attn_drop_rate,
505 | drop_path=0.,
506 | norm_layer=norm_layer)
507 | for i in range(len(post_layers))
508 | ])
509 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims[-1]))
510 | trunc_normal_(self.cls_token, std=.02)
511 |
512 | # set output type
513 | self.return_mean = return_mean # if yes, return mean, not use class token
514 | self.return_dense = return_dense # if yes, return class token and all feature tokens
515 | if return_dense:
516 | assert not return_mean, "cannot return both mean and dense"
517 | self.mix_token = mix_token
518 | self.pooling_scale = pooling_scale
519 | if mix_token: # enable token mixing, see token labeling for details.
520 | self.beta = 1.0
521 | assert return_dense, "return all tokens if mix_token is enabled"
522 | if return_dense:
523 | self.aux_head = nn.Linear(
524 | embed_dims[-1],
525 | num_classes) if num_classes > 0 else nn.Identity()
526 | self.norm = norm_layer(embed_dims[-1])
527 |
528 | # Classifier head
529 | self.head = nn.Linear(
530 | embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
531 |
532 | trunc_normal_(self.pos_embed, std=.02)
533 | self.apply(self._init_weights)
534 |
535 | def _init_weights(self, m):
536 | if isinstance(m, nn.Linear):
537 | trunc_normal_(m.weight, std=.02)
538 | if isinstance(m, nn.Linear) and m.bias is not None:
539 | nn.init.constant_(m.bias, 0)
540 | elif isinstance(m, nn.LayerNorm):
541 | nn.init.constant_(m.bias, 0)
542 | nn.init.constant_(m.weight, 1.0)
543 |
544 | @torch.jit.ignore
545 | def no_weight_decay(self):
546 | return {'pos_embed', 'cls_token'}
547 |
548 | def get_classifier(self):
549 | return self.head
550 |
551 | def reset_classifier(self, num_classes):
552 | self.num_classes = num_classes
553 | self.head = nn.Linear(
554 | self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
555 |
556 | def forward_embeddings(self, x):
557 | # patch embedding
558 | x = self.patch_embed(x)
559 | # B,C,H,W-> B,H,W,C
560 | x = x.permute(0, 2, 3, 1)
561 | return x
562 |
563 | def forward_tokens(self, x):
564 | for idx, block in enumerate(self.network):
565 | if idx == 2: # add positional encoding after outlooker blocks
566 | x = x + self.pos_embed
567 | x = self.pos_drop(x)
568 | x = block(x)
569 |
570 | B, H, W, C = x.shape
571 | x = x.reshape(B, -1, C)
572 | return x
573 |
574 | def forward_cls(self, x):
575 | B, N, C = x.shape
576 | cls_tokens = self.cls_token.expand(B, -1, -1)
577 | x = torch.cat((cls_tokens, x), dim=1)
578 | for block in self.post_network:
579 | x = block(x)
580 | return x
581 |
582 | def forward(self, x):
583 | # step1: patch embedding
584 | x = self.forward_embeddings(x)
585 |
586 | # mix token, see token labeling for details.
587 | if self.mix_token and self.training:
588 | lam = np.random.beta(self.beta, self.beta)
589 | patch_h, patch_w = x.shape[1] // self.pooling_scale, x.shape[
590 | 2] // self.pooling_scale
591 | bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam, scale=self.pooling_scale)
592 | temp_x = x.clone()
593 | sbbx1,sbby1,sbbx2,sbby2=self.pooling_scale*bbx1,self.pooling_scale*bby1,\
594 | self.pooling_scale*bbx2,self.pooling_scale*bby2
595 | temp_x[:, sbbx1:sbbx2, sbby1:sbby2, :] = x.flip(0)[:, sbbx1:sbbx2, sbby1:sbby2, :]
596 | x = temp_x
597 | else:
598 | bbx1, bby1, bbx2, bby2 = 0, 0, 0, 0
599 |
600 | # step2: tokens learning in the two stages
601 | x = self.forward_tokens(x)
602 |
603 | # step3: post network, apply class attention or not
604 | if self.post_network is not None:
605 | x = self.forward_cls(x)
606 | x = self.norm(x)
607 |
608 | if self.return_mean: # if no class token, return mean
609 | return self.head(x.mean(1))
610 |
611 | x_cls = self.head(x[:, 0])
612 | if not self.return_dense:
613 | return x_cls
614 |
615 | x_aux = self.aux_head(
616 | x[:, 1:]
617 | ) # generate classes in all feature tokens, see token labeling
618 |
619 | if not self.training:
620 | return x_cls + 0.5 * x_aux.max(1)[0]
621 |
622 | if self.mix_token and self.training: # reverse "mix token", see token labeling for details.
623 | x_aux = x_aux.reshape(x_aux.shape[0], patch_h, patch_w, x_aux.shape[-1])
624 |
625 | temp_x = x_aux.clone()
626 | temp_x[:, bbx1:bbx2, bby1:bby2, :] = x_aux.flip(0)[:, bbx1:bbx2, bby1:bby2, :]
627 | x_aux = temp_x
628 |
629 | x_aux = x_aux.reshape(x_aux.shape[0], patch_h * patch_w, x_aux.shape[-1])
630 |
631 | # return these: 1. class token, 2. classes from all feature tokens, 3. bounding box
632 | return x_cls, x_aux, (bbx1, bby1, bbx2, bby2)
633 |
634 |
635 | @register_model
636 | def volo_d1(pretrained=False, **kwargs):
637 | """
638 | VOLO-D1 model, Params: 27M
639 | --layers: [x,x,x,x], four blocks in two stages, the first stage(block) is outlooker,
640 | the other three blocks are transformer, we set four blocks, which are easily
641 | applied to downstream tasks
642 | --embed_dims, --num_heads,: embedding dim, number of heads in each block
643 | --downsamples: flags to apply downsampling or not in four blocks
644 | --outlook_attention: flags to apply outlook attention or not
645 | --mlp_ratios: mlp ratio in four blocks
646 | --post_layers: post layers like two class attention layers using [ca, ca]
647 | See detail for all args in the class VOLO()
648 | """
649 | layers = [4, 4, 8, 2] # num of layers in the four blocks
650 | embed_dims = [192, 384, 384, 384]
651 | num_heads = [6, 12, 12, 12]
652 | mlp_ratios = [3, 3, 3, 3]
653 | downsamples = [True, False, False, False] # do downsampling after first block
654 | outlook_attention = [True, False, False, False ]
655 | # first block is outlooker (stage1), the other three are transformer (stage2)
656 | model = VOLO(layers,
657 | embed_dims=embed_dims,
658 | num_heads=num_heads,
659 | mlp_ratios=mlp_ratios,
660 | downsamples=downsamples,
661 | outlook_attention=outlook_attention,
662 | post_layers=['ca', 'ca'],
663 | **kwargs)
664 | model.default_cfg = default_cfgs['volo']
665 | return model
666 |
667 |
668 | @register_model
669 | def volo_d2(pretrained=False, **kwargs):
670 | """
671 | VOLO-D2 model, Params: 59M
672 | """
673 | layers = [6, 4, 10, 4]
674 | embed_dims = [256, 512, 512, 512]
675 | num_heads = [8, 16, 16, 16]
676 | mlp_ratios = [3, 3, 3, 3]
677 | downsamples = [True, False, False, False]
678 | outlook_attention = [True, False, False, False]
679 | model = VOLO(layers,
680 | embed_dims=embed_dims,
681 | num_heads=num_heads,
682 | mlp_ratios=mlp_ratios,
683 | downsamples=downsamples,
684 | outlook_attention=outlook_attention,
685 | post_layers=['ca', 'ca'],
686 | **kwargs)
687 | model.default_cfg = default_cfgs['volo']
688 | return model
689 |
690 |
691 | @register_model
692 | def volo_d3(pretrained=False, **kwargs):
693 | """
694 | VOLO-D3 model, Params: 86M
695 | """
696 | layers = [8, 8, 16, 4]
697 | embed_dims = [256, 512, 512, 512]
698 | num_heads = [8, 16, 16, 16]
699 | mlp_ratios = [3, 3, 3, 3]
700 | downsamples = [True, False, False, False]
701 | outlook_attention = [True, False, False, False]
702 | model = VOLO(layers,
703 | embed_dims=embed_dims,
704 | num_heads=num_heads,
705 | mlp_ratios=mlp_ratios,
706 | downsamples=downsamples,
707 | outlook_attention=outlook_attention,
708 | post_layers=['ca', 'ca'],
709 | **kwargs)
710 | model.default_cfg = default_cfgs['volo']
711 | return model
712 |
713 |
714 | @register_model
715 | def volo_d4(pretrained=False, **kwargs):
716 | """
717 | VOLO-D4 model, Params: 193M
718 | """
719 | layers = [8, 8, 16, 4]
720 | embed_dims = [384, 768, 768, 768]
721 | num_heads = [12, 16, 16, 16]
722 | mlp_ratios = [3, 3, 3, 3]
723 | downsamples = [True, False, False, False]
724 | outlook_attention = [True, False, False, False]
725 | model = VOLO(layers,
726 | embed_dims=embed_dims,
727 | num_heads=num_heads,
728 | mlp_ratios=mlp_ratios,
729 | downsamples=downsamples,
730 | outlook_attention=outlook_attention,
731 | post_layers=['ca', 'ca'],
732 | **kwargs)
733 | model.default_cfg = default_cfgs['volo_large']
734 | return model
735 |
736 |
737 | @register_model
738 | def volo_d5(pretrained=False, **kwargs):
739 | """
740 | VOLO-D5 model, Params: 296M
741 | stem_hidden_dim=128, the dim in patch embedding is 128 for VOLO-D5
742 | """
743 | layers = [12, 12, 20, 4]
744 | embed_dims = [384, 768, 768, 768]
745 | num_heads = [12, 16, 16, 16]
746 | mlp_ratios = [4, 4, 4, 4]
747 | downsamples = [True, False, False, False]
748 | outlook_attention = [True, False, False, False]
749 | model = VOLO(layers,
750 | embed_dims=embed_dims,
751 | num_heads=num_heads,
752 | mlp_ratios=mlp_ratios,
753 | downsamples=downsamples,
754 | outlook_attention=outlook_attention,
755 | post_layers=['ca', 'ca'],
756 | stem_hidden_dim=128,
757 | **kwargs)
758 | model.default_cfg = default_cfgs['volo_large']
759 | return model
760 |
--------------------------------------------------------------------------------
/train_kd.py:
--------------------------------------------------------------------------------
1 | """
2 | ImageNet Training Script
3 | This script is adapted from pytorch-image-models by Ross Wightman (https://github.com/rwightman/pytorch-image-models/)
4 | It was started from an early version of the PyTorch ImageNet example
5 | (https://github.com/pytorch/examples/tree/master/imagenet)
6 | """
7 | import warnings
8 | warnings.filterwarnings("ignore")
9 |
10 | import argparse
11 | import time
12 | import yaml
13 | import os
14 | import logging
15 | from collections import OrderedDict
16 | from contextlib import suppress
17 | from datetime import datetime
18 | import numpy as np
19 |
20 | import torch
21 | import torch.nn as nn
22 | import torchvision.utils
23 | from torch.nn.parallel import DistributedDataParallel as NativeDDP
24 |
25 | from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
26 | from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model, model_parameters
27 | from timm.utils import *
28 | from timm.loss import LabelSmoothingCrossEntropy
29 | from loss_adv import SoftTargetCrossEntropy, DistillationLoss
30 | from timm.optim import create_optimizer
31 | from timm.scheduler import create_scheduler
32 | from timm.utils import ApexScaler, NativeScaler
33 |
34 | from tlt.utils import load_pretrained_weights
35 | import models
36 | import torch.nn.functional as F
37 |
38 | try:
39 | from apex import amp
40 | from apex.parallel import DistributedDataParallel as ApexDDP
41 | from apex.parallel import convert_syncbn_model
42 |
43 | has_apex = True
44 | except ImportError:
45 | has_apex = False
46 |
47 | has_native_amp = False
48 | try:
49 | if getattr(torch.cuda.amp, 'autocast') is not None:
50 | has_native_amp = True
51 | except AttributeError:
52 | pass
53 |
54 | torch.backends.cudnn.benchmark = True
55 | _logger = logging.getLogger('train')
56 |
57 | # The first arg parser parses out only the --config argument, this argument is used to
58 | # load a yaml file containing key-values that override the defaults for the main parser below
59 | config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
60 | parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
61 | help='YAML config file specifying default arguments')
62 |
63 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
64 |
65 | # Dataset / Model parameters
66 | parser.add_argument('--data_dir', metavar='DIR',
67 | help='path to dataset')
68 | parser.add_argument('--dataset', '-d', metavar='NAME', default='',
69 | help='dataset type (default: ImageFolder/ImageTar if empty)')
70 | parser.add_argument('--train-split', metavar='NAME', default='train',
71 | help='dataset train split (default: train)')
72 | parser.add_argument('--val-split', metavar='NAME', default='validation',
73 | help='dataset validation split (default: validation)')
74 | parser.add_argument('--model', default='volo_d1', type=str, metavar='MODEL',
75 | help='Name of model to train (default: "volo_d1"')
76 | parser.add_argument('--pretrained', action='store_true', default=False,
77 | help='Start with pretrained version of specified network (if avail)')
78 | parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
79 | help='Initialize model from this checkpoint (default: none)')
80 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
81 | help='Resume full model and optimizer state from checkpoint (default: none)')
82 | parser.add_argument('--no-resume-opt', action='store_true', default=False,
83 | help='prevent resume of optimizer state when resuming model')
84 | parser.add_argument('--num-classes', type=int, default=None, metavar='N',
85 | help='number of label classes (Model default if None)')
86 | parser.add_argument('--gp', default=None, type=str, metavar='POOL',
87 | help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
88 | parser.add_argument('--img-size', type=int, default=None, metavar='N',
89 | help='Image patch size (default: None => model default)')
90 | parser.add_argument('--input-size', default=None, nargs=3, type=int,
91 | metavar='N N N',
92 | help='Input all image dimensions (d h w, e.g. --input-size 3 224 224),'
93 | ' uses model default if empty')
94 | parser.add_argument('--crop-pct', default=None, type=float,
95 | metavar='N', help='Input image center crop percent (for validation only)')
96 | parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
97 | help='Override mean pixel value of dataset')
98 | parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
99 | help='Override std deviation of of dataset')
100 | parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
101 | help='Image resize interpolation type (overrides model)')
102 | parser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',
103 | help='input batch size for training (default: 128)')
104 | parser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N',
105 | help='ratio of validation batch size to training batch size (default: 1)')
106 |
107 | # Optimizer parameters
108 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
109 | help='Optimizer (default: "adamw"')
110 | parser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON',
111 | help='Optimizer Epsilon (default: None, use opt default)')
112 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
113 | help='Optimizer Betas (default: None, use opt default)')
114 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
115 | help='Optimizer momentum (default: 0.9)')
116 | parser.add_argument('--weight-decay', type=float, default=0.05,
117 | help='weight decay (default: 0.05)')
118 | parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
119 | help='Clip gradient norm (default: None, no clipping)')
120 | parser.add_argument('--clip-mode', type=str, default='norm',
121 | help='Gradient clipping mode. One of ("norm", "value", "agc")')
122 |
123 | # Learning rate schedule parameters
124 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
125 | help='LR scheduler (default: "cosine"')
126 | parser.add_argument('--lr', type=float, default=1.6e-3, metavar='LR',
127 | help='learning rate (default: 1.6e-3)')
128 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
129 | help='learning rate noise on/off epoch percentages')
130 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
131 | help='learning rate noise limit percent (default: 0.67)')
132 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
133 | help='learning rate noise std-dev (default: 1.0)')
134 | parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
135 | help='learning rate cycle len multiplier (default: 1.0)')
136 | parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
137 | help='learning rate cycle limit')
138 | parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
139 | help='warmup learning rate (default: 0.0001)')
140 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
141 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
142 | parser.add_argument('--epochs', type=int, default=300, metavar='N',
143 | help='number of epochs to train (default: 300)')
144 | parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
145 | help='manual epoch number (useful on restarts)')
146 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
147 | help='epoch interval to decay LR')
148 | parser.add_argument('--warmup-epochs', type=int, default=20, metavar='N',
149 | help='epochs to warmup LR, if scheduler supports')
150 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
151 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
152 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
153 | help='patience epochs for Plateau LR scheduler (default: 10')
154 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
155 | help='LR decay rate (default: 0.1)')
156 |
157 | # Augmentation & regularization parameters
158 | parser.add_argument('--no-aug', action='store_true', default=False,
159 | help='Disable all training augmentation, override other train aug args')
160 | parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
161 | help='Random resize scale (default: 0.08 1.0)')
162 | parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
163 | help='Random resize aspect ratio (default: 0.75 1.33)')
164 | parser.add_argument('--hflip', type=float, default=0.5,
165 | help='Horizontal flip training aug probability')
166 | parser.add_argument('--vflip', type=float, default=0.,
167 | help='Vertical flip training aug probability')
168 | parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
169 | help='Color jitter factor (default: 0.4)')
170 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
171 | help='Use AutoAugment policy. "v0" or "original". (default: rand-m9-mstd0.5-inc1)'),
172 | parser.add_argument('--aug-splits', type=int, default=0,
173 | help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
174 | parser.add_argument('--jsd', action='store_true', default=False,
175 | help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
176 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
177 | help='Random erase prob (default: 0.25)')
178 | parser.add_argument('--remode', type=str, default='pixel',
179 | help='Random erase mode (default: "pixel")')
180 | parser.add_argument('--recount', type=int, default=1,
181 | help='Random erase count (default: 1)')
182 | parser.add_argument('--resplit', action='store_true', default=False,
183 | help='Do not random erase first (clean) augmentation split')
184 | parser.add_argument('--mixup', type=float, default=0.8,
185 | help='mixup alpha, mixup enabled if > 0. (default: 0.)')
186 | parser.add_argument('--cutmix', type=float, default=1.0,
187 | help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
188 | parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
189 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
190 | parser.add_argument('--mixup-prob', type=float, default=1.0,
191 | help='Probability of performing mixup or cutmix when either/both is enabled')
192 | parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
193 | help='Probability of switching to cutmix when both mixup and cutmix enabled')
194 | parser.add_argument('--mixup-mode', type=str, default='batch',
195 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
196 | parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
197 | help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
198 | parser.add_argument('--smoothing', type=float, default=0.1,
199 | help='Label smoothing (default: 0.1)')
200 | parser.add_argument('--train-interpolation', type=str, default='random',
201 | help='Training interpolation (random, bilinear, bicubic default: "random")')
202 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
203 | help='Dropout rate (default: 0.)')
204 | parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
205 | help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
206 | parser.add_argument('--drop-path', type=float, default=None, metavar='PCT',
207 | help='Drop path rate (default: None)')
208 | parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
209 | help='Drop block rate (default: None)')
210 |
211 | # Batch norm parameters (only works with gen_efficientnet based models currently)
212 | parser.add_argument('--bn-tf', action='store_true', default=False,
213 | help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')
214 | parser.add_argument('--bn-momentum', type=float, default=None,
215 | help='BatchNorm momentum override (if not None)')
216 | parser.add_argument('--bn-eps', type=float, default=None,
217 | help='BatchNorm epsilon override (if not None)')
218 | parser.add_argument('--sync-bn', action='store_true',
219 | help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
220 | parser.add_argument('--dist-bn', type=str, default='',
221 | help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
222 | parser.add_argument('--split-bn', action='store_true',
223 | help='Enable separate BN layers per augmentation split.')
224 |
225 | # Model Exponential Moving Average
226 | parser.add_argument('--model-ema', action='store_true', default=False,
227 | help='Enable tracking moving average of model weights')
228 | parser.add_argument('--model-ema-force-cpu', action='store_true', default=False,
229 | help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
230 | parser.add_argument('--model-ema-decay', type=float, default=0.99992,
231 | help='decay factor for model weights moving average (default: 0.99992)')
232 |
233 | # Misc
234 | parser.add_argument('--seed', type=int, default=42, metavar='S',
235 | help='random seed (default: 42)')
236 | parser.add_argument('--log-interval', type=int, default=50, metavar='N',
237 | help='how many batches to wait before logging training status')
238 | parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
239 | help='how many batches to wait before writing recovery checkpoint')
240 | parser.add_argument('--checkpoint-hist', type=int, default=3, metavar='N',
241 | help='number of checkpoints to keep (default: 10)')
242 | parser.add_argument('-j', '--workers', type=int, default=8, metavar='N',
243 | help='how many training processes to use (default: 1)')
244 | parser.add_argument('--save-images', action='store_true', default=False,
245 | help='save images of input bathes every log interval for debugging')
246 | parser.add_argument('--amp', action='store_true', default=False,
247 | help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
248 | parser.add_argument('--apex-amp', action='store_true', default=True,
249 | help='Use NVIDIA Apex AMP mixed precision')
250 | parser.add_argument('--native-amp', action='store_true', default=False,
251 | help='Use Native Torch AMP mixed precision')
252 | parser.add_argument('--channels-last', action='store_true', default=False,
253 | help='Use channels_last memory layout')
254 | parser.add_argument('--pin-mem', action='store_true', default=False,
255 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
256 | parser.add_argument('--no-prefetcher', action='store_true', default=False,
257 | help='disable fast prefetcher')
258 | parser.add_argument('--output', default='', type=str, metavar='PATH',
259 | help='path to output folder (default: none, current dir)')
260 | parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',
261 | help='Best metric (default: "top1"')
262 | parser.add_argument('--tta', type=int, default=0, metavar='N',
263 | help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
264 | parser.add_argument("--local_rank", default=0, type=int)
265 | parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
266 | help='use the multi-epochs-loader to save time at the beginning of every epoch')
267 | parser.add_argument('--torchscript', dest='torchscript', action='store_true',
268 | help='convert model torchscript for inference')
269 |
270 |
271 | # Finetune
272 | parser.add_argument('--finetune', default='', type=str, metavar='PATH',
273 | help='path to checkpoint file (default: none)')
274 |
275 |
276 | # Distillation parameters
277 | parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL',
278 | help='Name of teacher model to train (default: "regnety_160"')
279 | parser.add_argument('--teacher-path', type=str, default='')
280 | parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="")
281 | parser.add_argument('--distillation-alpha', default=0.5, type=float, help="")
282 | parser.add_argument('--distillation-tau', default=1.0, type=float, help="")
283 |
284 | # Adversarial
285 | parser.add_argument('--adv-epochs', default=200, type=int, metavar='N',
286 | help='number of epochs for performing adversarial training')
287 | parser.add_argument('--adv-iters', default=3, type=int, metavar='N',
288 | help='number of iterations for adversarial augmentation')
289 | parser.add_argument('--adv-eps', default=2.0/255, type=float,
290 | help='adversarial strength for adversarial augmentation')
291 | parser.add_argument('--adv-lr', default=1.0/255, type=float,
292 | help='learning rate for adversarial augmentation')
293 | parser.add_argument('--adv-kl-weight', default=0.01, type=float,
294 | help='weight of KL-divergence for adversarial augmentation')
295 | parser.add_argument('--adv-ce-weight', default=3.0, type=float,
296 | help='weight of ce-loss for adversarial augmentation')
297 |
298 |
299 | def _parse_args():
300 | # Do we have a config file to parse?
301 | args_config, remaining = config_parser.parse_known_args()
302 | if args_config.config:
303 | with open(args_config.config, 'r') as f:
304 | cfg = yaml.safe_load(f)
305 | parser.set_defaults(**cfg)
306 |
307 | # The main arg parser parses the rest of the args, the usual
308 | # defaults will have been overridden if config file specified.
309 | args = parser.parse_args(remaining)
310 |
311 | # Cache the args as a text string to save them in the output dir later
312 | args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
313 | return args, args_text
314 |
315 |
316 | def main():
317 | setup_default_logging()
318 | args, args_text = _parse_args()
319 |
320 | args.prefetcher = not args.no_prefetcher
321 | args.distributed = False
322 | if 'WORLD_SIZE' in os.environ:
323 | args.distributed = int(os.environ['WORLD_SIZE']) > 1
324 | args.device = 'cuda:0'
325 | args.world_size = 1
326 | args.rank = 0 # global rank
327 | if args.distributed:
328 | args.device = 'cuda:%d' % args.local_rank
329 | torch.cuda.set_device(args.local_rank)
330 | torch.distributed.init_process_group(backend='nccl',
331 | init_method='env://')
332 | args.world_size = torch.distributed.get_world_size()
333 | args.rank = torch.distributed.get_rank()
334 | _logger.info(
335 | 'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
336 | % (args.rank, args.world_size))
337 | else:
338 | _logger.info('Training with a single process on 1 GPUs.')
339 | assert args.rank >= 0
340 |
341 | # resolve AMP arguments based on PyTorch / Apex availability
342 | use_amp = None
343 | if args.amp:
344 | # `--amp` chooses native amp before apex (APEX ver not actively maintained)
345 | if has_native_amp:
346 | args.native_amp = True
347 | elif has_apex:
348 | args.apex_amp = True
349 | if args.apex_amp and has_apex:
350 | use_amp = 'apex'
351 | elif args.native_amp and has_native_amp:
352 | use_amp = 'native'
353 | elif args.apex_amp or args.native_amp:
354 | _logger.warning(
355 | "Neither APEX or native Torch AMP is available, using float32. "
356 | "Install NVIDA apex or upgrade to PyTorch 1.6")
357 |
358 | torch.manual_seed(args.seed + args.rank)
359 | np.random.seed(args.seed + args.rank)
360 | model = create_model(
361 | args.model,
362 | pretrained=args.pretrained,
363 | num_classes=args.num_classes,
364 | drop_rate=args.drop,
365 | drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path
366 | drop_path_rate=args.drop_path,
367 | drop_block_rate=args.drop_block,
368 | global_pool=args.gp,
369 | bn_tf=args.bn_tf,
370 | bn_momentum=args.bn_momentum,
371 | bn_eps=args.bn_eps,
372 | scriptable=args.torchscript,
373 | checkpoint_path=args.initial_checkpoint,
374 | img_size=args.img_size)
375 | if args.num_classes is None:
376 | assert hasattr(
377 | model, 'num_classes'
378 | ), 'Model must have `num_classes` attr if not set on cmd line/config.'
379 | args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly
380 |
381 | if args.finetune:
382 | load_pretrained_weights(model=model,
383 | checkpoint_path=args.finetune,
384 | use_ema=args.model_ema,
385 | strict=False,
386 | num_classes=args.num_classes)
387 |
388 | if args.rank == 0:
389 | _logger.info('Model %s created, param count: %d' %
390 | (args.model, sum([m.numel()
391 | for m in model.parameters()])))
392 |
393 | data_config = resolve_data_config(vars(args),
394 | model=model,
395 | verbose=args.local_rank == 0)
396 |
397 | # setup augmentation batch splits for contrastive loss or split bn
398 | num_aug_splits = 0
399 | if args.aug_splits > 0:
400 | assert args.aug_splits > 1, 'A split of 1 makes no sense'
401 | num_aug_splits = args.aug_splits
402 |
403 | # enable split bn (separate bn stats per batch-portion)
404 | if args.split_bn:
405 | assert num_aug_splits > 1 or args.resplit
406 | model = convert_splitbn_model(model, max(num_aug_splits, 2))
407 |
408 | # move model to GPU, enable channels last layout if set
409 | model.cuda()
410 | if args.channels_last:
411 | model = model.to(memory_format=torch.channels_last)
412 |
413 | # setup synchronized BatchNorm for distributed training
414 | if args.distributed and args.sync_bn:
415 | assert not args.split_bn
416 | if has_apex and use_amp != 'native':
417 | # Apex SyncBN preferred unless native amp is activated
418 | model = convert_syncbn_model(model)
419 | else:
420 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
421 | if args.rank == 0:
422 | _logger.info(
423 | 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
424 | 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.'
425 | )
426 |
427 | if args.torchscript:
428 | assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
429 | assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
430 | model = torch.jit.script(model)
431 |
432 | linear_scaled_lr = args.lr * args.batch_size * args.world_size / 1024.0
433 | args.lr = linear_scaled_lr
434 | if args.rank == 0:
435 | print("learning rate:", args.lr)
436 | optimizer = create_optimizer(args, model)
437 |
438 | # setup automatic mixed-precision (AMP) loss scaling and op casting
439 | amp_autocast = suppress # do nothing
440 | loss_scaler = None
441 | optimizers = None
442 | if use_amp == 'apex':
443 | model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
444 | loss_scaler = ApexScaler()
445 | if args.rank == 0:
446 | _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
447 | elif use_amp == 'native':
448 | amp_autocast = torch.cuda.amp.autocast
449 | loss_scaler = NativeScaler()
450 | if args.rank == 0:
451 | _logger.info(
452 | 'Using native Torch AMP. Training in mixed precision.')
453 | else:
454 | if args.rank == 0:
455 | _logger.info('AMP not enabled. Training in float32.')
456 |
457 | # optionally resume from a checkpoint
458 | resume_epoch = None
459 | if args.resume and os.path.isfile(args.resume):
460 | resume_epoch = resume_checkpoint(
461 | model,
462 | args.resume,
463 | optimizer=None if args.no_resume_opt else optimizer,
464 | loss_scaler=None if args.no_resume_opt else loss_scaler,
465 | log_info=args.rank == 0)
466 |
467 | # setup exponential moving average of model weights, SWA could be used here too
468 | model_ema = None
469 | if args.model_ema:
470 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
471 | model_ema = ModelEmaV2(
472 | model,
473 | decay=args.model_ema_decay,
474 | device='cpu' if args.model_ema_force_cpu else None)
475 | if args.resume and os.path.isfile(args.resume):
476 | load_checkpoint(model_ema.module, args.resume, use_ema=True)
477 |
478 | # setup distributed training
479 | if args.distributed:
480 | if has_apex and use_amp != 'native':
481 | # Apex DDP preferred unless native amp is activated
482 | if args.rank == 0:
483 | _logger.info("Using NVIDIA APEX DistributedDataParallel.")
484 | model = ApexDDP(model, delay_allreduce=True)
485 | else:
486 | if args.rank == 0:
487 | _logger.info("Using native Torch DistributedDataParallel.")
488 | model = NativeDDP(model, device_ids=[
489 | args.local_rank
490 | ]) # can use device str in Torch >= 1.1
491 | # NOTE: EMA model does not need to be wrapped by DDP
492 |
493 | # setup learning rate schedule and starting epoch
494 | lr_scheduler, num_epochs = create_scheduler(args, optimizer)
495 | start_epoch = 0
496 | if args.start_epoch is not None:
497 | # a specified start_epoch will always override the resume epoch
498 | start_epoch = args.start_epoch
499 | elif resume_epoch is not None:
500 | start_epoch = resume_epoch
501 | if lr_scheduler is not None and start_epoch > 0:
502 | lr_scheduler.step(start_epoch)
503 |
504 | if args.rank == 0:
505 | _logger.info('Scheduled epochs: {}'.format(num_epochs))
506 |
507 | # create the train and eval datasets
508 |
509 | # create the train and eval datasets
510 | dataset_train = create_dataset(args.dataset,
511 | root=args.data_dir,
512 | split=args.train_split,
513 | is_training=True,
514 | batch_size=args.batch_size)
515 | dataset_eval = create_dataset(args.dataset,
516 | root=args.data_dir,
517 | split=args.val_split,
518 | is_training=False,
519 | batch_size=args.batch_size)
520 |
521 | # setup mixup / cutmix
522 | collate_fn = None
523 | mixup_fn = None
524 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
525 | if mixup_active:
526 | mixup_args = dict(mixup_alpha=args.mixup,
527 | cutmix_alpha=args.cutmix,
528 | cutmix_minmax=args.cutmix_minmax,
529 | prob=args.mixup_prob,
530 | switch_prob=args.mixup_switch_prob,
531 | mode=args.mixup_mode,
532 | label_smoothing=args.smoothing,
533 | num_classes=args.num_classes)
534 | if args.prefetcher:
535 | assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup)
536 | collate_fn = FastCollateMixup(**mixup_args)
537 | else:
538 | mixup_fn = Mixup(**mixup_args)
539 |
540 | # wrap dataset in AugMix helper
541 | if num_aug_splits > 1:
542 | dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
543 |
544 | # create data loaders w/ augmentation pipeiine
545 | train_interpolation = args.train_interpolation
546 | if args.no_aug or not train_interpolation:
547 | train_interpolation = data_config['interpolation']
548 |
549 | loader_train = create_loader(
550 | dataset_train,
551 | input_size=data_config['input_size'],
552 | batch_size=args.batch_size,
553 | is_training=True,
554 | use_prefetcher=args.prefetcher,
555 | no_aug=args.no_aug,
556 | re_prob=args.reprob,
557 | re_mode=args.remode,
558 | re_count=args.recount,
559 | re_split=args.resplit,
560 | scale=args.scale,
561 | ratio=args.ratio,
562 | hflip=args.hflip,
563 | vflip=args.vflip,
564 | color_jitter=args.color_jitter,
565 | auto_augment=args.aa,
566 | # num_aug_repeats=args.aug_repeats,
567 | num_aug_splits=num_aug_splits,
568 | interpolation=train_interpolation,
569 | mean=data_config['mean'],
570 | std=data_config['std'],
571 | num_workers=args.workers,
572 | distributed=args.distributed,
573 | collate_fn=collate_fn,
574 | pin_memory=args.pin_mem,
575 | use_multi_epochs_loader=args.use_multi_epochs_loader,
576 | # worker_seeding=args.worker_seeding,
577 | )
578 |
579 | loader_eval = create_loader(
580 | dataset_eval,
581 | input_size=data_config['input_size'],
582 | batch_size=args.validation_batch_size_multiplier * args.batch_size,
583 | is_training=False,
584 | use_prefetcher=args.prefetcher,
585 | interpolation=data_config['interpolation'],
586 | mean=data_config['mean'],
587 | std=data_config['std'],
588 | num_workers=args.workers,
589 | distributed=args.distributed,
590 | crop_pct=data_config['crop_pct'],
591 | pin_memory=args.pin_mem,
592 | )
593 |
594 | # setup loss function
595 | if mixup_active:
596 | train_loss_fn = SoftTargetCrossEntropy(reduce=True)
597 | elif args.smoothing:
598 | train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
599 | else:
600 | train_loss_fn = nn.CrossEntropyLoss()
601 | train_loss_fn = train_loss_fn.cuda()
602 | validate_loss_fn = nn.CrossEntropyLoss().cuda()
603 |
604 | teacher_model = None
605 | if args.distillation_type != 'none':
606 | assert args.teacher_path, 'need to specify teacher-path when using distillation'
607 | print(f"Creating teacher model: {args.teacher_model}")
608 | teacher_model = create_model(
609 | args.teacher_model,
610 | pretrained=False
611 | )
612 | if args.teacher_path.startswith('https'):
613 | checkpoint = torch.hub.load_state_dict_from_url(
614 | args.teacher_path, map_location='cpu', check_hash=True)
615 | else:
616 | checkpoint = torch.load(args.teacher_path, map_location='cpu')
617 | if 'model' in checkpoint.keys():
618 | teacher_model.load_state_dict(checkpoint['model'])
619 | elif 'state_dict' in checkpoint.keys():
620 | teacher_model.load_state_dict(checkpoint['state_dict'])
621 | teacher_model.to(args.device)
622 | teacher_model.eval()
623 |
624 | # wrap the criterion in our custom DistillationLoss, which
625 | # just dispatches to the original criterion if args.distillation_type is 'none'
626 | train_loss_fn = DistillationLoss(
627 | train_loss_fn, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau
628 | )
629 |
630 | # setup checkpoint saver and eval metric tracking
631 | eval_metric = args.eval_metric
632 | best_metric = None
633 | best_epoch = None
634 | saver = None
635 | output_dir = ''
636 | if args.rank == 0:
637 | output_base = args.output if args.output else './output'
638 | output_dir = get_outdir(output_base)
639 | decreasing = True if eval_metric == 'loss' else False
640 | saver = CheckpointSaver(model=model,
641 | optimizer=optimizer,
642 | args=args,
643 | model_ema=model_ema,
644 | amp_scaler=loss_scaler,
645 | checkpoint_dir=output_dir,
646 | recovery_dir=output_dir,
647 | decreasing=decreasing,
648 | max_history=args.checkpoint_hist)
649 | with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
650 | f.write(args_text)
651 |
652 | try:
653 | if args.finetune:
654 | validate(model,
655 | loader_eval,
656 | validate_loss_fn,
657 | args,
658 | amp_autocast=amp_autocast)
659 | for epoch in range(start_epoch, num_epochs):
660 |
661 | if epoch >= args.adv_epochs:
662 | args.adv_iters = 1
663 |
664 | if args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
665 | loader_train.sampler.set_epoch(epoch)
666 |
667 | train_metrics = train_one_epoch(epoch,
668 | model,
669 | loader_train,
670 | optimizer,
671 | train_loss_fn,
672 | args,
673 | lr_scheduler=lr_scheduler,
674 | saver=saver,
675 | output_dir=output_dir,
676 | amp_autocast=amp_autocast,
677 | loss_scaler=loss_scaler,
678 | model_ema=model_ema,
679 | mixup_fn=mixup_fn,
680 | optimizers=optimizers)
681 |
682 | if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
683 | if args.rank == 0:
684 | _logger.info(
685 | "Distributing BatchNorm running means and vars")
686 | distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
687 |
688 | eval_metrics = validate(model,
689 | loader_eval,
690 | validate_loss_fn,
691 | args,
692 | amp_autocast=amp_autocast)
693 |
694 | if model_ema is not None and not args.model_ema_force_cpu:
695 | if args.distributed and args.dist_bn in ('broadcast',
696 | 'reduce'):
697 | distribute_bn(model_ema, args.world_size,
698 | args.dist_bn == 'reduce')
699 | ema_eval_metrics = validate(model_ema.module,
700 | loader_eval,
701 | validate_loss_fn,
702 | args,
703 | amp_autocast=amp_autocast,
704 | log_suffix=' (EMA)')
705 | eval_metrics = ema_eval_metrics
706 |
707 | if lr_scheduler is not None:
708 | # step LR for next epoch
709 | lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
710 |
711 | update_summary(epoch,
712 | train_metrics,
713 | eval_metrics,
714 | os.path.join(output_dir, 'summary.csv'),
715 | write_header=best_metric is None)
716 |
717 | if saver is not None:
718 | # save proper checkpoint with eval metric
719 | save_metric = eval_metrics[eval_metric]
720 | best_metric, best_epoch = saver.save_checkpoint(
721 | epoch, metric=save_metric)
722 |
723 |
724 | except KeyboardInterrupt:
725 | pass
726 | if best_metric is not None:
727 | _logger.info('*** Best metric: {0} (epoch {1})'.format(
728 | best_metric, best_epoch))
729 |
730 |
731 | def train_one_epoch(epoch,
732 | model,
733 | loader,
734 | optimizer,
735 | loss_fn,
736 | args,
737 | lr_scheduler=None,
738 | saver=None,
739 | output_dir='',
740 | amp_autocast=suppress,
741 | loss_scaler=None,
742 | model_ema=None,
743 | mixup_fn=None,
744 | optimizers=None):
745 | assert isinstance(loss_scaler, ApexScaler)
746 |
747 | if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
748 | if args.prefetcher and loader.mixup_enabled:
749 | loader.mixup_enabled = False
750 | elif mixup_fn is not None:
751 | mixup_fn.mixup_enabled = False
752 |
753 | batch_time_m = AverageMeter()
754 | data_time_m = AverageMeter()
755 | losses_m = AverageMeter()
756 |
757 | model.train()
758 |
759 | end = time.time()
760 | last_idx = len(loader) - 1
761 | num_updates = epoch * len(loader)
762 | for batch_idx, (input, target) in enumerate(loader):
763 | last_batch = batch_idx == last_idx
764 | data_time_m.update(time.time() - end)
765 | if not args.prefetcher:
766 | input, target = input.cuda(), target.cuda()
767 | if mixup_fn is not None:
768 | input, target = mixup_fn(input, target)
769 | if args.channels_last:
770 | input = input.contiguous(memory_format=torch.channels_last)
771 |
772 | delta = torch.zeros_like(input)
773 |
774 | output_cle, output_cle_kd = model(input)
775 | output_cle = output_cle.detach()
776 | output_cle_kd = output_cle_kd.detach()
777 |
778 | output_cle_prob = F.softmax(output_cle, dim=1)
779 | output_cle_logprob = F.log_softmax(output_cle, dim=1)
780 |
781 | for a_iter in range(args.adv_iters):
782 | delta.requires_grad_()
783 |
784 | with amp_autocast():
785 |
786 | output_adv, output_adv_kd = model(input+delta)
787 |
788 | loss_ce = loss_fn(input+delta, [output_adv, output_adv_kd], target)
789 |
790 | output_adv_prob = F.softmax(output_adv, dim=1)
791 | output_adv_logprob = F.log_softmax(output_adv, dim=1)
792 | loss_kl = (F.kl_div(output_adv_logprob, output_cle_prob, reduce=False).sum(dim=1) +
793 | F.kl_div(output_cle_logprob, output_adv_prob, reduce=False).sum(dim=1)).mean() / 2
794 |
795 | if a_iter == 0:
796 | loss = (loss_ce + args.adv_kl_weight * loss_kl)
797 | else:
798 | loss = (args.adv_ce_weight * loss_ce + args.adv_kl_weight * loss_kl) / (args.adv_iters-1)
799 | # loss = loss.mean()
800 | # print(loss.item(), loss_ce.item(), loss_kl.item())
801 |
802 | with amp.scale_loss(loss, optimizer) as scaled_loss:
803 | scaled_loss.backward(retain_graph=False)
804 |
805 | delta_grad = delta.grad.clone().detach().float()
806 | delta = (delta + args.adv_lr * torch.sign(delta_grad)).detach()
807 | delta = torch.clamp(delta, -args.adv_eps, args.adv_eps).detach()
808 |
809 |
810 | if not args.distributed:
811 | losses_m.update(loss.item(), input.size(0))
812 |
813 | optimizer.step()
814 | optimizer.zero_grad()
815 |
816 | if model_ema is not None:
817 | model_ema.update(model)
818 |
819 | torch.cuda.synchronize()
820 | num_updates += 1
821 | batch_time_m.update(time.time() - end)
822 | if last_batch or batch_idx % args.log_interval == 0:
823 | lrl = [param_group['lr'] for param_group in optimizer.param_groups]
824 | lr = sum(lrl) / len(lrl)
825 |
826 | if args.distributed:
827 | reduced_loss = reduce_tensor(loss.data, args.world_size)
828 | losses_m.update(reduced_loss.item(), input.size(0))
829 |
830 | if args.rank == 0:
831 | _logger.info(
832 | 'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
833 | 'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) '
834 | 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
835 | '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
836 | 'LR: {lr:.3e} '
837 | 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
838 | epoch,
839 | batch_idx,
840 | len(loader),
841 | 100. * batch_idx / last_idx,
842 | loss=losses_m,
843 | batch_time=batch_time_m,
844 | rate=input.size(0) * args.world_size /
845 | batch_time_m.val,
846 | rate_avg=input.size(0) * args.world_size /
847 | batch_time_m.avg,
848 | lr=lr,
849 | data_time=data_time_m))
850 |
851 | if args.save_images and output_dir:
852 | torchvision.utils.save_image(
853 | input,
854 | os.path.join(output_dir,
855 | 'train-batch-%d.jpg' % batch_idx),
856 | padding=0,
857 | normalize=True)
858 |
859 | if saver is not None and args.recovery_interval and (
860 | last_batch or (batch_idx + 1) % args.recovery_interval == 0):
861 | saver.save_recovery(epoch, batch_idx=batch_idx)
862 |
863 | if lr_scheduler is not None:
864 | lr_scheduler.step_update(num_updates=num_updates,
865 | metric=losses_m.avg)
866 |
867 | end = time.time()
868 | # end for
869 |
870 | if hasattr(optimizer, 'sync_lookahead'):
871 | optimizer.sync_lookahead()
872 |
873 | return OrderedDict([('loss', losses_m.avg)])
874 |
875 |
876 | def validate(model,
877 | loader,
878 | loss_fn,
879 | args,
880 | amp_autocast=suppress,
881 | log_suffix=''):
882 | batch_time_m = AverageMeter()
883 | losses_m = AverageMeter()
884 | top1_m = AverageMeter()
885 | top5_m = AverageMeter()
886 |
887 | model.eval()
888 |
889 | end = time.time()
890 | last_idx = len(loader) - 1
891 | with torch.no_grad():
892 | for batch_idx, (input, target) in enumerate(loader):
893 | last_batch = batch_idx == last_idx
894 | if not args.prefetcher:
895 | input = input.cuda()
896 | target = target.cuda()
897 | if args.channels_last:
898 | input = input.contiguous(memory_format=torch.channels_last)
899 |
900 | with amp_autocast():
901 | output = model(input)
902 | if isinstance(output, (tuple, list)):
903 | output = output[0]
904 |
905 | # augmentation reduction
906 | reduce_factor = args.tta
907 | if reduce_factor > 1:
908 | output = output.unfold(0, reduce_factor,
909 | reduce_factor).mean(dim=2)
910 | target = target[0:target.size(0):reduce_factor]
911 |
912 | loss = loss_fn(output, target)
913 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
914 |
915 | if args.distributed:
916 | reduced_loss = reduce_tensor(loss.data, args.world_size)
917 | acc1 = reduce_tensor(acc1, args.world_size)
918 | acc5 = reduce_tensor(acc5, args.world_size)
919 | else:
920 | reduced_loss = loss.data
921 |
922 | torch.cuda.synchronize()
923 |
924 | losses_m.update(reduced_loss.item(), input.size(0))
925 | top1_m.update(acc1.item(), output.size(0))
926 | top5_m.update(acc5.item(), output.size(0))
927 |
928 | batch_time_m.update(time.time() - end)
929 | end = time.time()
930 | if args.rank == 0 and (last_batch or
931 | batch_idx % args.log_interval == 0):
932 | log_name = 'Test' + log_suffix
933 | _logger.info(
934 | '{0}: [{1:>4d}/{2}] '
935 | 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
936 | 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
937 | 'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
938 | 'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
939 | log_name,
940 | batch_idx,
941 | last_idx,
942 | batch_time=batch_time_m,
943 | loss=losses_m,
944 | top1=top1_m,
945 | top5=top5_m))
946 |
947 | metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg),
948 | ('top5', top5_m.avg)])
949 |
950 | return metrics
951 |
952 |
953 | if __name__ == '__main__':
954 | main()
955 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | warnings.filterwarnings("ignore")
3 |
4 | import argparse
5 | import time
6 | import yaml
7 | import os
8 | import logging
9 | from collections import OrderedDict
10 | from contextlib import suppress
11 | from datetime import datetime
12 | import numpy as np
13 |
14 | import torch
15 | import torch.nn as nn
16 | import torchvision.utils
17 | from torch.nn.parallel import DistributedDataParallel as NativeDDP
18 |
19 | from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
20 | from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model, model_parameters
21 | from timm.utils import *
22 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
23 | from timm.optim import create_optimizer
24 | from timm.scheduler import create_scheduler
25 | from timm.utils import ApexScaler, NativeScaler
26 |
27 | from tlt.data import create_token_label_target, TokenLabelMixup, FastCollateTokenLabelMixup, \
28 | create_token_label_loader, create_token_label_dataset
29 | from tlt.utils import load_pretrained_weights
30 | from loss import TokenLabelGTCrossEntropy, TokenLabelCrossEntropy, TokenLabelSoftTargetCrossEntropy
31 | import models
32 | import torch.nn.functional as F
33 |
34 | try:
35 | from apex import amp
36 | from apex.parallel import DistributedDataParallel as ApexDDP
37 | from apex.parallel import convert_syncbn_model
38 |
39 | has_apex = True
40 | except ImportError:
41 | has_apex = False
42 |
43 | has_native_amp = False
44 | try:
45 | if getattr(torch.cuda.amp, 'autocast') is not None:
46 | has_native_amp = True
47 | except AttributeError:
48 | pass
49 |
50 | torch.backends.cudnn.benchmark = True
51 | _logger = logging.getLogger('train')
52 |
53 | # The first arg parser parses out only the --config argument, this argument is used to
54 | # load a yaml file containing key-values that override the defaults for the main parser below
55 | config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
56 | parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
57 | help='YAML config file specifying default arguments')
58 |
59 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
60 |
61 | # Dataset / Model parameters
62 | parser.add_argument('--data_dir', metavar='DIR',
63 | help='path to dataset')
64 | parser.add_argument('--dataset', '-d', metavar='NAME', default='',
65 | help='dataset type (default: ImageFolder/ImageTar if empty)')
66 | parser.add_argument('--train-split', metavar='NAME', default='train',
67 | help='dataset train split (default: train)')
68 | parser.add_argument('--val-split', metavar='NAME', default='validation',
69 | help='dataset validation split (default: validation)')
70 | parser.add_argument('--model', default='volo_d1', type=str, metavar='MODEL',
71 | help='Name of model to train (default: "volo_d1"')
72 | parser.add_argument('--pretrained', action='store_true', default=False,
73 | help='Start with pretrained version of specified network (if avail)')
74 | parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
75 | help='Initialize model from this checkpoint (default: none)')
76 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
77 | help='Resume full model and optimizer state from checkpoint (default: none)')
78 | parser.add_argument('--no-resume-opt', action='store_true', default=False,
79 | help='prevent resume of optimizer state when resuming model')
80 | parser.add_argument('--num-classes', type=int, default=None, metavar='N',
81 | help='number of label classes (Model default if None)')
82 | parser.add_argument('--gp', default=None, type=str, metavar='POOL',
83 | help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
84 | parser.add_argument('--img-size', type=int, default=None, metavar='N',
85 | help='Image patch size (default: None => model default)')
86 | parser.add_argument('--input-size', default=None, nargs=3, type=int,
87 | metavar='N N N',
88 | help='Input all image dimensions (d h w, e.g. --input-size 3 224 224),'
89 | ' uses model default if empty')
90 | parser.add_argument('--crop-pct', default=None, type=float,
91 | metavar='N', help='Input image center crop percent (for validation only)')
92 | parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
93 | help='Override mean pixel value of dataset')
94 | parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
95 | help='Override std deviation of of dataset')
96 | parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
97 | help='Image resize interpolation type (overrides model)')
98 | parser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',
99 | help='input batch size for training (default: 128)')
100 | parser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N',
101 | help='ratio of validation batch size to training batch size (default: 1)')
102 |
103 | # Optimizer parameters
104 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
105 | help='Optimizer (default: "adamw"')
106 | parser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON',
107 | help='Optimizer Epsilon (default: None, use opt default)')
108 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
109 | help='Optimizer Betas (default: None, use opt default)')
110 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
111 | help='Optimizer momentum (default: 0.9)')
112 | parser.add_argument('--weight-decay', type=float, default=0.05,
113 | help='weight decay (default: 0.05)')
114 | parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
115 | help='Clip gradient norm (default: None, no clipping)')
116 | parser.add_argument('--clip-mode', type=str, default='norm',
117 | help='Gradient clipping mode. One of ("norm", "value", "agc")')
118 |
119 | # Learning rate schedule parameters
120 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
121 | help='LR scheduler (default: "cosine"')
122 | parser.add_argument('--lr', type=float, default=1.6e-3, metavar='LR',
123 | help='learning rate (default: 1.6e-3)')
124 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
125 | help='learning rate noise on/off epoch percentages')
126 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
127 | help='learning rate noise limit percent (default: 0.67)')
128 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
129 | help='learning rate noise std-dev (default: 1.0)')
130 | parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
131 | help='learning rate cycle len multiplier (default: 1.0)')
132 | parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
133 | help='learning rate cycle limit')
134 | parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
135 | help='warmup learning rate (default: 0.0001)')
136 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
137 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
138 | parser.add_argument('--epochs', type=int, default=300, metavar='N',
139 | help='number of epochs to train (default: 300)')
140 | parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
141 | help='manual epoch number (useful on restarts)')
142 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
143 | help='epoch interval to decay LR')
144 | parser.add_argument('--warmup-epochs', type=int, default=20, metavar='N',
145 | help='epochs to warmup LR, if scheduler supports')
146 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
147 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
148 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
149 | help='patience epochs for Plateau LR scheduler (default: 10')
150 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
151 | help='LR decay rate (default: 0.1)')
152 |
153 | # Augmentation & regularization parameters
154 | parser.add_argument('--no-aug', action='store_true', default=False,
155 | help='Disable all training augmentation, override other train aug args')
156 | parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
157 | help='Random resize scale (default: 0.08 1.0)')
158 | parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
159 | help='Random resize aspect ratio (default: 0.75 1.33)')
160 | parser.add_argument('--hflip', type=float, default=0.5,
161 | help='Horizontal flip training aug probability')
162 | parser.add_argument('--vflip', type=float, default=0.,
163 | help='Vertical flip training aug probability')
164 | parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
165 | help='Color jitter factor (default: 0.4)')
166 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
167 | help='Use AutoAugment policy. "v0" or "original". (default: rand-m9-mstd0.5-inc1)'),
168 | parser.add_argument('--aug-splits', type=int, default=0,
169 | help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
170 | parser.add_argument('--jsd', action='store_true', default=False,
171 | help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
172 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
173 | help='Random erase prob (default: 0.25)')
174 | parser.add_argument('--remode', type=str, default='pixel',
175 | help='Random erase mode (default: "pixel")')
176 | parser.add_argument('--recount', type=int, default=1,
177 | help='Random erase count (default: 1)')
178 | parser.add_argument('--resplit', action='store_true', default=False,
179 | help='Do not random erase first (clean) augmentation split')
180 | parser.add_argument('--mixup', type=float, default=0.8,
181 | help='mixup alpha, mixup enabled if > 0. (default: 0.)')
182 | parser.add_argument('--cutmix', type=float, default=1.0,
183 | help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
184 | parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
185 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
186 | parser.add_argument('--mixup-prob', type=float, default=1.0,
187 | help='Probability of performing mixup or cutmix when either/both is enabled')
188 | parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
189 | help='Probability of switching to cutmix when both mixup and cutmix enabled')
190 | parser.add_argument('--mixup-mode', type=str, default='batch',
191 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
192 | parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
193 | help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
194 | parser.add_argument('--smoothing', type=float, default=0.1,
195 | help='Label smoothing (default: 0.1)')
196 | parser.add_argument('--train-interpolation', type=str, default='random',
197 | help='Training interpolation (random, bilinear, bicubic default: "random")')
198 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
199 | help='Dropout rate (default: 0.)')
200 | parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
201 | help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
202 | parser.add_argument('--drop-path', type=float, default=None, metavar='PCT',
203 | help='Drop path rate (default: None)')
204 | parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
205 | help='Drop block rate (default: None)')
206 |
207 | # Batch norm parameters (only works with gen_efficientnet based models currently)
208 | parser.add_argument('--bn-tf', action='store_true', default=False,
209 | help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')
210 | parser.add_argument('--bn-momentum', type=float, default=None,
211 | help='BatchNorm momentum override (if not None)')
212 | parser.add_argument('--bn-eps', type=float, default=None,
213 | help='BatchNorm epsilon override (if not None)')
214 | parser.add_argument('--sync-bn', action='store_true',
215 | help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
216 | parser.add_argument('--dist-bn', type=str, default='',
217 | help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
218 | parser.add_argument('--split-bn', action='store_true',
219 | help='Enable separate BN layers per augmentation split.')
220 |
221 | # Model Exponential Moving Average
222 | parser.add_argument('--model-ema', action='store_true', default=False,
223 | help='Enable tracking moving average of model weights')
224 | parser.add_argument('--model-ema-force-cpu', action='store_true', default=False,
225 | help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
226 | parser.add_argument('--model-ema-decay', type=float, default=0.99992,
227 | help='decay factor for model weights moving average (default: 0.99992)')
228 |
229 | # Misc
230 | parser.add_argument('--seed', type=int, default=42, metavar='S',
231 | help='random seed (default: 42)')
232 | parser.add_argument('--log-interval', type=int, default=50, metavar='N',
233 | help='how many batches to wait before logging training status')
234 | parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
235 | help='how many batches to wait before writing recovery checkpoint')
236 | parser.add_argument('--checkpoint-hist', type=int, default=3, metavar='N',
237 | help='number of checkpoints to keep (default: 10)')
238 | parser.add_argument('-j', '--workers', type=int, default=8, metavar='N',
239 | help='how many training processes to use (default: 1)')
240 | parser.add_argument('--save-images', action='store_true', default=False,
241 | help='save images of input bathes every log interval for debugging')
242 | parser.add_argument('--amp', action='store_true', default=False,
243 | help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
244 | parser.add_argument('--apex-amp', action='store_true', default=True,
245 | help='Use NVIDIA Apex AMP mixed precision')
246 | parser.add_argument('--native-amp', action='store_true', default=False,
247 | help='Use Native Torch AMP mixed precision')
248 | parser.add_argument('--channels-last', action='store_true', default=False,
249 | help='Use channels_last memory layout')
250 | parser.add_argument('--pin-mem', action='store_true', default=False,
251 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
252 | parser.add_argument('--no-prefetcher', action='store_true', default=False,
253 | help='disable fast prefetcher')
254 | parser.add_argument('--output', default='', type=str, metavar='PATH',
255 | help='path to output folder (default: none, current dir)')
256 | parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',
257 | help='Best metric (default: "top1"')
258 | parser.add_argument('--tta', type=int, default=0, metavar='N',
259 | help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
260 | parser.add_argument("--local_rank", default=0, type=int)
261 | parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
262 | help='use the multi-epochs-loader to save time at the beginning of every epoch')
263 | parser.add_argument('--torchscript', dest='torchscript', action='store_true',
264 | help='convert model torchscript for inference')
265 |
266 | # Token labeling
267 |
268 | parser.add_argument('--token-label', action='store_true', default=False,
269 | help='Use dense token-level label map for training')
270 | parser.add_argument('--token-label-data', type=str, default='', metavar='DIR',
271 | help='path to token_label data')
272 | parser.add_argument('--token-label-size', type=int, default=1, metavar='N',
273 | help='size of result token label map')
274 | parser.add_argument('--dense-weight', type=float, default=0.5,
275 | help='Token labeling loss multiplier (default: 0.5)')
276 | parser.add_argument('--cls-weight', type=float, default=1.0,
277 | help='Cls token prediction loss multiplier (default: 1.0)')
278 | parser.add_argument('--ground-truth', action='store_true', default=False,
279 | help='mix ground truth when use token labeling')
280 |
281 | # Finetune
282 | parser.add_argument('--finetune', default='', type=str, metavar='PATH',
283 | help='path to checkpoint file (default: none)')
284 |
285 | # Adversarial
286 | parser.add_argument('--adv-epochs', default=200, type=int, metavar='N',
287 | help='number of epochs for performing adversarial training')
288 | parser.add_argument('--adv-iters', default=3, type=int, metavar='N',
289 | help='number of PGD steps')
290 | parser.add_argument('--adv-eps', default=2.0/255, type=float,
291 | help='adversarial strength for adversarial perturbations')
292 | parser.add_argument('--adv-lr', default=1.0/255, type=float,
293 | help='learning rate of PGD')
294 | parser.add_argument('--adv-kl-weight', default=0.01, type=float,
295 | help='weight of KL-divergence for adversarial training')
296 | parser.add_argument('--adv-ce-weight', default=3.0, type=float,
297 | help='weight of ce-loss for adversarial training')
298 |
299 | def _parse_args():
300 | # Do we have a config file to parse?
301 | args_config, remaining = config_parser.parse_known_args()
302 | if args_config.config:
303 | with open(args_config.config, 'r') as f:
304 | cfg = yaml.safe_load(f)
305 | parser.set_defaults(**cfg)
306 |
307 | # The main arg parser parses the rest of the args, the usual
308 | # defaults will have been overridden if config file specified.
309 | args = parser.parse_args(remaining)
310 |
311 | # Cache the args as a text string to save them in the output dir later
312 | args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
313 | return args, args_text
314 |
315 |
316 | def main():
317 | setup_default_logging()
318 | args, args_text = _parse_args()
319 |
320 | args.prefetcher = not args.no_prefetcher
321 | args.distributed = False
322 | if 'WORLD_SIZE' in os.environ:
323 | args.distributed = int(os.environ['WORLD_SIZE']) > 1
324 | args.device = 'cuda:0'
325 | args.world_size = 1
326 | args.rank = 0 # global rank
327 | if args.distributed:
328 | args.device = 'cuda:%d' % args.local_rank
329 | torch.cuda.set_device(args.local_rank)
330 | torch.distributed.init_process_group(backend='nccl',
331 | init_method='env://')
332 | args.world_size = torch.distributed.get_world_size()
333 | args.rank = torch.distributed.get_rank()
334 | _logger.info(
335 | 'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
336 | % (args.rank, args.world_size))
337 | else:
338 | _logger.info('Training with a single process on 1 GPUs.')
339 | assert args.rank >= 0
340 |
341 | # resolve AMP arguments based on PyTorch / Apex availability
342 | use_amp = None
343 | if args.amp:
344 | # `--amp` chooses native amp before apex (APEX ver not actively maintained)
345 | if has_native_amp:
346 | args.native_amp = True
347 | elif has_apex:
348 | args.apex_amp = True
349 | if args.apex_amp and has_apex:
350 | use_amp = 'apex'
351 | elif args.native_amp and has_native_amp:
352 | use_amp = 'native'
353 | elif args.apex_amp or args.native_amp:
354 | _logger.warning(
355 | "Neither APEX or native Torch AMP is available, using float32. "
356 | "Install NVIDA apex or upgrade to PyTorch 1.6")
357 |
358 | torch.manual_seed(args.seed + args.rank)
359 | np.random.seed(args.seed + args.rank)
360 | model = create_model(
361 | args.model,
362 | pretrained=args.pretrained,
363 | num_classes=args.num_classes,
364 | drop_rate=args.drop,
365 | drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path
366 | drop_path_rate=args.drop_path,
367 | drop_block_rate=args.drop_block,
368 | global_pool=args.gp,
369 | bn_tf=args.bn_tf,
370 | bn_momentum=args.bn_momentum,
371 | bn_eps=args.bn_eps,
372 | scriptable=args.torchscript,
373 | checkpoint_path=args.initial_checkpoint,
374 | img_size=args.img_size)
375 | if args.num_classes is None:
376 | assert hasattr(
377 | model, 'num_classes'
378 | ), 'Model must have `num_classes` attr if not set on cmd line/config.'
379 | args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly
380 |
381 | if args.finetune:
382 | load_pretrained_weights(model=model,
383 | checkpoint_path=args.finetune,
384 | use_ema=args.model_ema,
385 | strict=False,
386 | num_classes=args.num_classes)
387 |
388 | if args.rank == 0:
389 | _logger.info('Model %s created, param count: %d' %
390 | (args.model, sum([m.numel()
391 | for m in model.parameters()])))
392 |
393 | data_config = resolve_data_config(vars(args),
394 | model=model,
395 | verbose=args.local_rank == 0)
396 |
397 | # setup augmentation batch splits for contrastive loss or split bn
398 | num_aug_splits = 0
399 | if args.aug_splits > 0:
400 | assert args.aug_splits > 1, 'A split of 1 makes no sense'
401 | num_aug_splits = args.aug_splits
402 |
403 | # enable split bn (separate bn stats per batch-portion)
404 | if args.split_bn:
405 | assert num_aug_splits > 1 or args.resplit
406 | model = convert_splitbn_model(model, max(num_aug_splits, 2))
407 |
408 | # move model to GPU, enable channels last layout if set
409 | model.cuda()
410 | if args.channels_last:
411 | model = model.to(memory_format=torch.channels_last)
412 |
413 | # setup synchronized BatchNorm for distributed training
414 | if args.distributed and args.sync_bn:
415 | assert not args.split_bn
416 | if has_apex and use_amp != 'native':
417 | # Apex SyncBN preferred unless native amp is activated
418 | model = convert_syncbn_model(model)
419 | else:
420 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
421 | if args.rank == 0:
422 | _logger.info(
423 | 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
424 | 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.'
425 | )
426 |
427 | if args.torchscript:
428 | assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
429 | assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
430 | model = torch.jit.script(model)
431 |
432 | linear_scaled_lr = args.lr * args.batch_size * args.world_size / 1024.0
433 | args.lr = linear_scaled_lr
434 | if args.rank == 0:
435 | print("learning rate:", args.lr)
436 | optimizer = create_optimizer(args, model)
437 |
438 | # setup automatic mixed-precision (AMP) loss scaling and op casting
439 | amp_autocast = suppress # do nothing
440 | loss_scaler = None
441 | optimizers = None
442 | if use_amp == 'apex':
443 | model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
444 | loss_scaler = ApexScaler()
445 | if args.rank == 0:
446 | _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
447 | elif use_amp == 'native':
448 | amp_autocast = torch.cuda.amp.autocast
449 | loss_scaler = NativeScaler()
450 | if args.rank == 0:
451 | _logger.info(
452 | 'Using native Torch AMP. Training in mixed precision.')
453 | else:
454 | if args.rank == 0:
455 | _logger.info('AMP not enabled. Training in float32.')
456 |
457 | # optionally resume from a checkpoint
458 | resume_epoch = None
459 | if args.resume and os.path.isfile(args.resume):
460 | resume_epoch = resume_checkpoint(
461 | model,
462 | args.resume,
463 | optimizer=None if args.no_resume_opt else optimizer,
464 | loss_scaler=None if args.no_resume_opt else loss_scaler,
465 | log_info=args.rank == 0)
466 |
467 | # setup exponential moving average of model weights, SWA could be used here too
468 | model_ema = None
469 | if args.model_ema:
470 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
471 | model_ema = ModelEmaV2(
472 | model,
473 | decay=args.model_ema_decay,
474 | device='cpu' if args.model_ema_force_cpu else None)
475 | if args.resume and os.path.isfile(args.resume):
476 | load_checkpoint(model_ema.module, args.resume, use_ema=True)
477 |
478 | # setup distributed training
479 | if args.distributed:
480 | if has_apex and use_amp != 'native':
481 | # Apex DDP preferred unless native amp is activated
482 | if args.rank == 0:
483 | _logger.info("Using NVIDIA APEX DistributedDataParallel.")
484 | model = ApexDDP(model, delay_allreduce=True)
485 | else:
486 | if args.rank == 0:
487 | _logger.info("Using native Torch DistributedDataParallel.")
488 | model = NativeDDP(model, device_ids=[
489 | args.local_rank
490 | ]) # can use device str in Torch >= 1.1
491 | # NOTE: EMA model does not need to be wrapped by DDP
492 |
493 | # setup learning rate schedule and starting epoch
494 | lr_scheduler, num_epochs = create_scheduler(args, optimizer)
495 | start_epoch = 0
496 | if args.start_epoch is not None:
497 | # a specified start_epoch will always override the resume epoch
498 | start_epoch = args.start_epoch
499 | elif resume_epoch is not None:
500 | start_epoch = resume_epoch
501 | if lr_scheduler is not None and start_epoch > 0:
502 | lr_scheduler.step(start_epoch)
503 |
504 | if args.rank == 0:
505 | _logger.info('Scheduled epochs: {}'.format(num_epochs))
506 |
507 | # create the train and eval datasets
508 |
509 | # create token_label dataset
510 | if args.token_label_data:
511 | dataset_train = create_token_label_dataset(
512 | args.dataset, root=args.data_dir, label_root=args.token_label_data)
513 | else:
514 | dataset_train = create_dataset(args.dataset,
515 | root=args.data_dir,
516 | split=args.train_split,
517 | is_training=True,
518 | batch_size=args.batch_size)
519 | dataset_eval = create_dataset(args.dataset,
520 | root=args.data_dir,
521 | split=args.val_split,
522 | is_training=False,
523 | batch_size=args.batch_size)
524 |
525 | # setup mixup / cutmix
526 | collate_fn = None
527 | mixup_fn = None
528 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
529 | if mixup_active:
530 | mixup_args = dict(mixup_alpha=args.mixup,
531 | cutmix_alpha=args.cutmix,
532 | cutmix_minmax=args.cutmix_minmax,
533 | prob=args.mixup_prob,
534 | switch_prob=args.mixup_switch_prob,
535 | mode=args.mixup_mode,
536 | label_smoothing=args.smoothing,
537 | num_classes=args.num_classes)
538 | # create token_label mixup
539 | if args.token_label_data:
540 | mixup_args['label_size'] = args.token_label_size
541 | if args.prefetcher:
542 | assert not num_aug_splits
543 | collate_fn = FastCollateTokenLabelMixup(**mixup_args)
544 | else:
545 | mixup_fn = TokenLabelMixup(**mixup_args)
546 | else:
547 | if args.prefetcher:
548 | assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup)
549 | collate_fn = FastCollateMixup(**mixup_args)
550 | else:
551 | mixup_fn = Mixup(**mixup_args)
552 |
553 | # wrap dataset in AugMix helper
554 | if num_aug_splits > 1:
555 | assert not args.token_label
556 | dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
557 |
558 | # create data loaders w/ augmentation pipeiine
559 | train_interpolation = args.train_interpolation
560 | if args.no_aug or not train_interpolation:
561 | train_interpolation = data_config['interpolation']
562 | if args.token_label and args.token_label_data:
563 | use_token_label = True
564 | else:
565 | use_token_label = False
566 | loader_train = create_token_label_loader(
567 | dataset_train,
568 | input_size=data_config['input_size'],
569 | batch_size=args.batch_size,
570 | is_training=True,
571 | use_prefetcher=args.prefetcher,
572 | no_aug=args.no_aug,
573 | re_prob=args.reprob,
574 | re_mode=args.remode,
575 | re_count=args.recount,
576 | re_split=args.resplit,
577 | scale=args.scale,
578 | ratio=args.ratio,
579 | hflip=args.hflip,
580 | vflip=args.vflip,
581 | color_jitter=args.color_jitter,
582 | auto_augment=args.aa,
583 | num_aug_splits=num_aug_splits,
584 | interpolation=train_interpolation,
585 | mean=data_config['mean'],
586 | std=data_config['std'],
587 | num_workers=args.workers,
588 | distributed=args.distributed,
589 | collate_fn=collate_fn,
590 | pin_memory=args.pin_mem,
591 | use_multi_epochs_loader=args.use_multi_epochs_loader,
592 | use_token_label=use_token_label)
593 |
594 | loader_eval = create_loader(
595 | dataset_eval,
596 | input_size=data_config['input_size'],
597 | batch_size=args.validation_batch_size_multiplier * args.batch_size,
598 | is_training=False,
599 | use_prefetcher=args.prefetcher,
600 | interpolation=data_config['interpolation'],
601 | mean=data_config['mean'],
602 | std=data_config['std'],
603 | num_workers=args.workers,
604 | distributed=args.distributed,
605 | crop_pct=data_config['crop_pct'],
606 | pin_memory=args.pin_mem,
607 | )
608 |
609 | # setup loss function
610 | # use token_label loss
611 | if args.token_label:
612 | if args.token_label_size == 1:
613 | # back to relabel/original ImageNet label
614 | train_loss_fn = TokenLabelSoftTargetCrossEntropy().cuda()
615 | else:
616 | if args.ground_truth:
617 | train_loss_fn = TokenLabelGTCrossEntropy(dense_weight=args.dense_weight,\
618 | cls_weight = args.cls_weight, mixup_active = mixup_active).cuda()
619 |
620 | else:
621 | train_loss_fn = TokenLabelCrossEntropy(dense_weight=args.dense_weight, \
622 | cls_weight=args.cls_weight, mixup_active=mixup_active).cuda()
623 |
624 | else:
625 | # smoothing is handled with mixup target transform or create_token_label_target function
626 | train_loss_fn = SoftTargetCrossEntropy().cuda()
627 |
628 | validate_loss_fn = nn.CrossEntropyLoss().cuda()
629 |
630 | # setup checkpoint saver and eval metric tracking
631 | eval_metric = args.eval_metric
632 | best_metric = None
633 | best_epoch = None
634 | saver = None
635 | output_dir = ''
636 | if args.rank == 0:
637 | output_base = args.output if args.output else './output'
638 | output_dir = get_outdir(output_base)
639 | decreasing = True if eval_metric == 'loss' else False
640 | saver = CheckpointSaver(model=model,
641 | optimizer=optimizer,
642 | args=args,
643 | model_ema=model_ema,
644 | amp_scaler=loss_scaler,
645 | checkpoint_dir=output_dir,
646 | recovery_dir=output_dir,
647 | decreasing=decreasing,
648 | max_history=args.checkpoint_hist)
649 | with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
650 | f.write(args_text)
651 |
652 | try:
653 | if args.finetune:
654 | validate(model,
655 | loader_eval,
656 | validate_loss_fn,
657 | args,
658 | amp_autocast=amp_autocast)
659 | for epoch in range(start_epoch, num_epochs):
660 |
661 | if epoch >= args.adv_epochs:
662 | args.adv_iters = 1
663 |
664 | if args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
665 | loader_train.sampler.set_epoch(epoch)
666 |
667 | train_metrics = train_one_epoch(epoch,
668 | model,
669 | loader_train,
670 | optimizer,
671 | train_loss_fn,
672 | args,
673 | lr_scheduler=lr_scheduler,
674 | saver=saver,
675 | output_dir=output_dir,
676 | amp_autocast=amp_autocast,
677 | loss_scaler=loss_scaler,
678 | model_ema=model_ema,
679 | mixup_fn=mixup_fn,
680 | optimizers=optimizers)
681 |
682 | if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
683 | if args.rank == 0:
684 | _logger.info(
685 | "Distributing BatchNorm running means and vars")
686 | distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
687 |
688 | eval_metrics = validate(model,
689 | loader_eval,
690 | validate_loss_fn,
691 | args,
692 | amp_autocast=amp_autocast)
693 |
694 | if model_ema is not None and not args.model_ema_force_cpu:
695 | if args.distributed and args.dist_bn in ('broadcast',
696 | 'reduce'):
697 | distribute_bn(model_ema, args.world_size,
698 | args.dist_bn == 'reduce')
699 | ema_eval_metrics = validate(model_ema.module,
700 | loader_eval,
701 | validate_loss_fn,
702 | args,
703 | amp_autocast=amp_autocast,
704 | log_suffix=' (EMA)')
705 | eval_metrics = ema_eval_metrics
706 |
707 | if lr_scheduler is not None:
708 | # step LR for next epoch
709 | lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
710 |
711 | update_summary(epoch,
712 | train_metrics,
713 | eval_metrics,
714 | os.path.join(output_dir, 'summary.csv'),
715 | write_header=best_metric is None)
716 |
717 | if saver is not None:
718 | # save proper checkpoint with eval metric
719 | save_metric = eval_metrics[eval_metric]
720 | best_metric, best_epoch = saver.save_checkpoint(
721 | epoch, metric=save_metric)
722 |
723 |
724 | except KeyboardInterrupt:
725 | pass
726 | if best_metric is not None:
727 | _logger.info('*** Best metric: {0} (epoch {1})'.format(
728 | best_metric, best_epoch))
729 |
730 |
731 | def train_one_epoch(epoch,
732 | model,
733 | loader,
734 | optimizer,
735 | loss_fn,
736 | args,
737 | lr_scheduler=None,
738 | saver=None,
739 | output_dir='',
740 | amp_autocast=suppress,
741 | loss_scaler=None,
742 | model_ema=None,
743 | mixup_fn=None,
744 | optimizers=None):
745 |
746 | if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
747 | if args.prefetcher and loader.mixup_enabled:
748 | loader.mixup_enabled = False
749 | elif mixup_fn is not None:
750 | mixup_fn.mixup_enabled = False
751 |
752 | second_order = hasattr(optimizer,
753 | 'is_second_order') and optimizer.is_second_order
754 | batch_time_m = AverageMeter()
755 | data_time_m = AverageMeter()
756 | losses_m = AverageMeter()
757 |
758 | model.train()
759 |
760 | end = time.time()
761 | last_idx = len(loader) - 1
762 | num_updates = epoch * len(loader)
763 | for batch_idx, (input, target) in enumerate(loader):
764 | last_batch = batch_idx == last_idx
765 | data_time_m.update(time.time() - end)
766 | if not args.prefetcher:
767 | input, target = input.cuda(), target.cuda()
768 | if mixup_fn is not None:
769 | input, target = mixup_fn(input, target)
770 | else:
771 | # handle token_label without mixup
772 | if args.token_label and args.token_label_data:
773 | target = create_token_label_target(
774 | target,
775 | num_classes=args.num_classes,
776 | smoothing=args.smoothing,
777 | label_size=args.token_label_size)
778 | if len(target.shape) == 1:
779 | target = create_token_label_target(
780 | target,
781 | num_classes=args.num_classes,
782 | smoothing=args.smoothing)
783 | else:
784 | if args.token_label and args.token_label_data and not loader.mixup_enabled:
785 | target = create_token_label_target(
786 | target,
787 | num_classes=args.num_classes,
788 | smoothing=args.smoothing,
789 | label_size=args.token_label_size)
790 | if len(target.shape) == 1:
791 | target = create_token_label_target(
792 | target,
793 | num_classes=args.num_classes,
794 | smoothing=args.smoothing)
795 | if args.channels_last:
796 | input = input.contiguous(memory_format=torch.channels_last)
797 |
798 |
799 | delta = torch.zeros_like(input)
800 |
801 | if args.token_label:
802 | output_cle, output_aux, _ = model(input)
803 | output_cle = output_cle.detach()
804 | else:
805 | output_cle = model(input).detach()
806 |
807 | output_cle_prob = F.softmax(output_cle, dim=1)
808 | output_cle_logprob = F.log_softmax(output_cle, dim=1)
809 |
810 | for a_iter in range(args.adv_iters):
811 | delta.requires_grad_()
812 |
813 | with amp_autocast():
814 |
815 | if args.token_label:
816 | output = model(input + delta)
817 | output_adv = output[0]
818 | loss_ce = loss_fn(output, target)
819 | else:
820 | output_adv = model(input + delta)
821 | loss_ce = loss_fn(output_adv, target)
822 |
823 | output_adv_prob = F.softmax(output_adv, dim=1)
824 | output_adv_logprob = F.log_softmax(output_adv, dim=1)
825 | loss_kl = (F.kl_div(output_adv_logprob, output_cle_prob, reduce=False).sum(dim=1) +
826 | F.kl_div(output_cle_logprob, output_adv_prob, reduce=False).sum(dim=1)).mean() / 2
827 |
828 | if a_iter == 0:
829 | loss = (loss_ce + args.adv_kl_weight * loss_kl)
830 | else:
831 | loss = (args.adv_ce_weight * loss_ce + args.adv_kl_weight * loss_kl) / (args.adv_iters - 1)
832 |
833 |
834 | with amp.scale_loss(loss, optimizer) as scaled_loss:
835 | scaled_loss.backward(retain_graph=False)
836 |
837 | delta_grad = delta.grad.clone().detach().float()
838 | delta = (delta + args.adv_lr * torch.sign(delta_grad)).detach()
839 | delta = torch.clamp(delta, -args.adv_eps, args.adv_eps).detach()
840 |
841 | if not args.distributed:
842 | losses_m.update(loss.item(), input.size(0))
843 |
844 |
845 | optimizer.step()
846 | optimizer.zero_grad()
847 |
848 |
849 | if model_ema is not None:
850 | model_ema.update(model)
851 |
852 | torch.cuda.synchronize()
853 | num_updates += 1
854 | batch_time_m.update(time.time() - end)
855 | if last_batch or batch_idx % args.log_interval == 0:
856 | lrl = [param_group['lr'] for param_group in optimizer.param_groups]
857 | lr = sum(lrl) / len(lrl)
858 |
859 | if args.distributed:
860 | reduced_loss = reduce_tensor(loss.data, args.world_size)
861 | losses_m.update(reduced_loss.item(), input.size(0))
862 |
863 | if args.rank == 0:
864 | _logger.info(
865 | 'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
866 | 'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) '
867 | 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
868 | '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
869 | 'LR: {lr:.3e} '
870 | 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
871 | epoch,
872 | batch_idx,
873 | len(loader),
874 | 100. * batch_idx / last_idx,
875 | loss=losses_m,
876 | batch_time=batch_time_m,
877 | rate=input.size(0) * args.world_size /
878 | batch_time_m.val,
879 | rate_avg=input.size(0) * args.world_size /
880 | batch_time_m.avg,
881 | lr=lr,
882 | data_time=data_time_m))
883 |
884 | if args.save_images and output_dir:
885 | torchvision.utils.save_image(
886 | input,
887 | os.path.join(output_dir,
888 | 'train-batch-%d.jpg' % batch_idx),
889 | padding=0,
890 | normalize=True)
891 |
892 | if saver is not None and args.recovery_interval and (
893 | last_batch or (batch_idx + 1) % args.recovery_interval == 0):
894 | saver.save_recovery(epoch, batch_idx=batch_idx)
895 |
896 | if lr_scheduler is not None:
897 | lr_scheduler.step_update(num_updates=num_updates,
898 | metric=losses_m.avg)
899 |
900 | end = time.time()
901 | # end for
902 |
903 | if hasattr(optimizer, 'sync_lookahead'):
904 | optimizer.sync_lookahead()
905 |
906 | return OrderedDict([('loss', losses_m.avg)])
907 |
908 |
909 | def validate(model,
910 | loader,
911 | loss_fn,
912 | args,
913 | amp_autocast=suppress,
914 | log_suffix=''):
915 | batch_time_m = AverageMeter()
916 | losses_m = AverageMeter()
917 | top1_m = AverageMeter()
918 | top5_m = AverageMeter()
919 |
920 | model.eval()
921 |
922 | end = time.time()
923 | last_idx = len(loader) - 1
924 | with torch.no_grad():
925 | for batch_idx, (input, target) in enumerate(loader):
926 | last_batch = batch_idx == last_idx
927 | if not args.prefetcher:
928 | input = input.cuda()
929 | target = target.cuda()
930 | if args.channels_last:
931 | input = input.contiguous(memory_format=torch.channels_last)
932 |
933 | with amp_autocast():
934 | output = model(input)
935 | if isinstance(output, (tuple, list)):
936 | output = output[0]
937 | if args.cls_weight == 0:
938 | output = output[1].mean(1)
939 |
940 | # augmentation reduction
941 | reduce_factor = args.tta
942 | if reduce_factor > 1:
943 | output = output.unfold(0, reduce_factor,
944 | reduce_factor).mean(dim=2)
945 | target = target[0:target.size(0):reduce_factor]
946 |
947 | loss = loss_fn(output, target)
948 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
949 |
950 | if args.distributed:
951 | reduced_loss = reduce_tensor(loss.data, args.world_size)
952 | acc1 = reduce_tensor(acc1, args.world_size)
953 | acc5 = reduce_tensor(acc5, args.world_size)
954 | else:
955 | reduced_loss = loss.data
956 |
957 | torch.cuda.synchronize()
958 |
959 | losses_m.update(reduced_loss.item(), input.size(0))
960 | top1_m.update(acc1.item(), output.size(0))
961 | top5_m.update(acc5.item(), output.size(0))
962 |
963 | batch_time_m.update(time.time() - end)
964 | end = time.time()
965 | if args.rank == 0 and (last_batch or
966 | batch_idx % args.log_interval == 0):
967 | log_name = 'Test' + log_suffix
968 | _logger.info(
969 | '{0}: [{1:>4d}/{2}] '
970 | 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
971 | 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
972 | 'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
973 | 'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
974 | log_name,
975 | batch_idx,
976 | last_idx,
977 | batch_time=batch_time_m,
978 | loss=losses_m,
979 | top1=top1_m,
980 | top5=top5_m))
981 |
982 | metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg),
983 | ('top5', top5_m.avg)])
984 |
985 | return metrics
986 |
987 |
988 | if __name__ == '__main__':
989 | main()
990 |
--------------------------------------------------------------------------------