├── FKD
├── FKD.png
├── FKD_SSL
│ └── README.md
├── LICENSE
├── FKD_ViT
│ ├── README.md
│ ├── utils_FKD.py
│ ├── SReT.py
│ └── train_ViT_FKD.py
├── FKD_SLG
│ ├── README.md
│ ├── utils.py
│ └── generate_soft_label.py
├── utils_FKD.py
├── README.md
└── train_FKD.py
├── FerKD
├── assets
│ ├── FerKD.png
│ └── converge.png
├── LICENSE
├── soft_label_zoo
│ └── README.md
└── README.md
└── README.md
/FKD/FKD.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/szq0214/FKD/HEAD/FKD/FKD.png
--------------------------------------------------------------------------------
/FerKD/assets/FerKD.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/szq0214/FKD/HEAD/FerKD/assets/FerKD.png
--------------------------------------------------------------------------------
/FerKD/assets/converge.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/szq0214/FKD/HEAD/FerKD/assets/converge.png
--------------------------------------------------------------------------------
/FKD/FKD_SSL/README.md:
--------------------------------------------------------------------------------
1 | ## Self-supervised Representation Learning Using FKD
2 |
3 |
4 |
5 | ### Preparation
6 |
7 | - Install PyTorch and ImageNet dataset following the [official PyTorch ImageNet training code](https://github.com/pytorch/examples/tree/master/imagenet).
8 |
9 | - Download our [soft label](http://zhiqiangshen.com/projects/FKD/index.html) for SSL.
10 |
11 |
12 | ### FKD Training on ReActNet & ResNet-50
13 |
14 | The training code is similar to the supervised scheme, will be available soon.
15 |
16 |
17 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Evolving Knowledge Distillation: The Role of Pre-generated Soft Labels
2 |
3 | ***This is a collection of our works targeted at utilizing pre-generated soft labels for stronger, faster and more efficient knowledge distillation.***
4 |
5 | > [**FerKD**](./FerKD) (```@ICCV'23```): **FerKD: Surgical Label Adaptation for Efficient Distillation**
6 |
7 |
8 |

9 |
10 |
11 |
12 |
13 |
14 | > [**FKD**](./FKD) (```@ECCV'22```): **A Fast Knowledge Distillation Framework for Visual Recognition**
15 |
16 |
17 |

18 |
19 |
20 |
21 | ## Bibtex
22 | ```bibtex
23 | @inproceedings{shen2023ferkd,
24 | title={FerKD: Surgical Label Adaptation for Efficient Distillation},
25 | author={Zhiqiang Shen},
26 | booktitle={ICCV},
27 | year={2023}
28 | }
29 |
30 | @inproceedings{shen2021afast,
31 | title={A Fast Knowledge Distillation Framework for Visual Recognition},
32 | author={Zhiqiang Shen and Eric Xing},
33 | booktitle={ECCV},
34 | year={2022}
35 | }
36 |
--------------------------------------------------------------------------------
/FKD/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Zhiqiang Shen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/FerKD/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Zhiqiang Shen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/FerKD/soft_label_zoo/README.md:
--------------------------------------------------------------------------------
1 | # FerKD Soft Label Zoo and Baselines
2 |
3 | ## Introduction
4 |
5 | This file documents a large collection of soft labels and corresponding baselines trained with these soft labels.
6 |
7 |
8 |
9 | | Teacher | Top-1 | Student | Top-1 | Student | Top-1 | Soft Label |
10 | |:-------:|:--------:|:--------:|:--------:|:--------:|:--------:|:--------:|
11 | | `Effi-L2-475` | 88.14 | ResNet-50 | 80.23 | ViT-S/16 | 81.16 | [Download]() |
12 | | `Effi-L2-800 ` | 88.39 | ResNet-50 | 80.16 | ViT-S/16 | 81.30 | [Download]() |
13 | | `RegY-128GF-384` | 88.24 | ResNet-50 | 80.34 | ViT-S/16 | 81.42 | [Download]() |
14 | | `ViT-L16-512` | 88.07 | ResNet-50 | 80.29 | ViT-S/16 | 81.43 | [Download]() |
15 | | `ViT-H14-518` | 88.55 | ResNet-50 | 80.18 | ViT-S/16 | 81.41 | [Download]() |
16 | | `BEIT-L-224` | 87.52 | ResNet-50 | 80.03 | ViT-S/16 | 81.16 | [Download]() |
17 | | `BEIT-L-384` | 88.40 | ResNet-50 | 80.06 | ViT-S/16 | 81.11 | [Download]() |
18 | | `BEIT-L-512` | 88.60 | ResNet-50 | 80.09 | ViT-S/16 | 81.07 | [Download]() |
19 | | `ViT-G14-336-30M` | 89.59 | ResNet-50 | 79.03 | ViT-S/16 | 79.62 | [Download]() |
20 | | `ViT-G14-336-CLIP ` | 89.38 | ResNet-50 | 79.59 | ViT-S/16 | 79.48 | [Download]() |
21 |
22 |
--------------------------------------------------------------------------------
/FKD/FKD_ViT/README.md:
--------------------------------------------------------------------------------
1 |
2 | ## Preparation
3 |
4 | - Install PyTorch and ImageNet dataset following the [official PyTorch ImageNet training code](https://github.com/pytorch/examples/tree/master/imagenet). This repo has minimal modifications on that code.
5 |
6 | - Download our soft label and unzip it. We provide multiple types of [soft labels](http://zhiqiangshen.com/projects/FKD/index.html), and we recommend to use [Marginal Smoothing Top-5 (500-crop)](https://drive.google.com/file/d/14leI6xGfnyxHPsBxo0PpCmOq71gWt008/view?usp=sharing).
7 |
8 |
9 | ## FKD Training on ViT/DeiT and SReT
10 |
11 | To train a ViT model, run `train_ViT_FKD.py` with the desired model architecture and the path to the soft label and ImageNet dataset:
12 |
13 | ```
14 | python train_ViT_FKD.py \
15 | --dist-url 'tcp://127.0.0.1:10001' \
16 | --dist-backend 'nccl' \
17 | --multiprocessing-distributed --world-size 1 --rank 0 \
18 | -a SReT_LT --lr 0.002 --wd 0.05 \
19 | --num_crops 4 -b 1024 --cos \
20 | --temp 1.0 --mixup_cutmix \
21 | --softlabel_path [soft label path, e.g., ./FKD_soft_label_500_crops_marginal_smoothing_k_5/imagenet] \
22 | [imagenet-folder with train and val folders]
23 | ```
24 |
25 | Add `--mixup_cutmix` to enable Mixup and Cutmix augmentations. For the instructions of `SReT_LT` model, please refer to [SReT](https://github.com/szq0214/SReT) for details.
26 |
27 | ## Evaluation
28 |
29 | ```
30 | python train_ViT_FKD.py -a SReT_LT -e --resume [model path] [imagenet-folder with train and val folders]
31 | ```
32 |
33 | ### Trained Models
34 |
35 | | Model | FLOPs| #params | accuracy (Top-1) |weights |configurations |
36 | |:-------:|:--------:|:--------:|:--------:|:--------:|:--------:|
37 | | [`DeiT-T-distill`](https://github.com/facebookresearch/deit) | 1.3B | 5.7M | 74.5 |-- | -- |
38 | | `FKD ViT/DeiT-T` | 1.3B | 5.7M | **75.2** |[link](https://drive.google.com/file/d/1m33c1wHdCV7ePETO_HvWNaboSd_W4nfC/view?usp=sharing) | [Table 13 of paper](http://zhiqiangshen.com/projects/FKD/FKD_camera-ready.pdf) |
39 | | [`SReT-LT-distill`](https://github.com/szq0214/SReT) | 1.2B | 5.0M | 77.7 |-- | -- |
40 | | `FKD SReT-LT` | 1.2B | 5.0M | **78.7** |[link](https://drive.google.com/file/d/1mmdPXKutHM9Li8xo5nGG6TB0aAXA9PFR/view?usp=sharing) | [Table 13 of paper](http://zhiqiangshen.com/projects/FKD/FKD_camera-ready.pdf) |
41 |
42 |
43 |
--------------------------------------------------------------------------------
/FerKD/README.md:
--------------------------------------------------------------------------------
1 | ## 🚀🚀 FerKD: Surgical Label Adaptation for Efficient Distillation
2 |
3 | [**FerKD: Surgical Label Adaptation for Efficient Distillation**](https://openaccess.thecvf.com/content/ICCV2023/papers/Shen_FerKD_Surgical_Label_Adaptation_for_Efficient_Distillation_ICCV_2023_paper.pdf), Zhiqiang Shen, ICCV 2023.
4 |
5 |
6 |

7 |
8 |
9 | ### Abstract
10 |
11 | **🚀🚀 FerKD (Faster Knowledge Distillation)** is a novel efficient knowledge distillation framework that incorporates *partial soft-hard label adaptation coupled with a region-calibration mechanism*. Our approach stems from the observation and intuition that standard data augmentations, such as RandomResizedCrop, tend to transform inputs into diverse conditions: easy positives, hard positives, or hard negatives. In traditional distillation frameworks, these transformed samples are utilized equally through their predictive probabilities derived from pretrained teacher models. However, merely relying on prediction values from a pretrained teacher neglects the reliability of these soft label predictions. To address this, we propose a new scheme that calibrates the less-confident regions to be the context using softened hard groundtruth labels. The proposed approach involves the processes of *hard regions mining* + *calibration*.
12 |
13 | ## Citation
14 |
15 | @inproceedings{shen2023ferkd,
16 | title={FerKD: Surgical Label Adaptation for Efficient Distillation},
17 | author={Zhiqiang Shen},
18 | year={2023},
19 | booktitle={ICCV}
20 | }
21 |
22 |
23 |
24 | ## Soft Label Zoo
25 |
26 | Please check the [soft labels](./soft_label_zoo) generated from different giant teacher models.
27 |
28 | ## Fast Convergence of FerKD
29 |
30 |
31 |
32 |

33 |
34 |
35 | ## Training
36 |
37 | FerKD follows [FKD](https://github.com/szq0214/FKD/tree/main/FKD) training code and procedure while using different preprocessed soft labels, please download the soft label for FerKD at [link](./soft_label_zoo).
38 |
39 | ## Trained Models
40 |
41 | | Method | Network | accuracy (Top-1) |weights |
42 | |:-------:|:--------:|:--------:|:--------:|
43 | | `FerKD` | ResNet-50 | 81.2 | [Download](https://drive.google.com/file/d/1wrs2-v8Dg8ghaJBDEmYLuGnk4mRbJek_/view?usp=sharing) |
44 | | `FerKD*` | ResNet-50 | 81.4 | [Download](https://drive.google.com/file/d/1HW9scG0OlKVa64C-sNxhmmsB0ZPTruIp/view?usp=sharing) |
45 |
46 | ## Acknowledgements
47 |
48 | We thank the [High-Flyer AI](https://www.high-flyer.cn/en/) for providing the deep learning platform and computational resources for this work. We'd like to especially thank Yanhong Xu, Xiaowen Sun, Wenjie Wu, and Le Su from High-Flyer AI for helping us organize computing resources.
49 |
50 | ## Contact
51 |
52 | Zhiqiang Shen (zhiqiangshen0214 at gmail.com)
53 |
54 |
--------------------------------------------------------------------------------
/FKD/FKD_SLG/README.md:
--------------------------------------------------------------------------------
1 |
2 | ## Soft Label Generation (SLG)
3 |
4 | ### Preparation
5 |
6 | - Install PyTorch and ImageNet dataset following the [official PyTorch ImageNet training code](https://github.com/pytorch/examples/tree/master/imagenet).
7 | - Install `timm` using:
8 |
9 | ```
10 | pip install git+https://github.com/rwightman/pytorch-image-models.git
11 | ```
12 |
13 |
14 | ### Generating Soft Labels from Supervised Teachers
15 |
16 | FKD flags:
17 |
18 | - `--num_crops `: number of crops in each image to generate soft labels. Default: 500.
19 | - `--num_seg `: true number of batch-size on GPUs during generating. Make sure `--num_crops` is divisible by `--num_seg`. Default: 50.
20 | - `--label_type `: type of generated soft labels. Default: `marginal_smoothing_k5`.
21 | - `--use_fp16`: save soft labels as `fp16` to decrease storage. Default: `False`.
22 | - `--temp`: temperature on soft label. Default: `1.0`.
23 |
24 | Path flags:
25 |
26 | - `--save_path `: specify the folder to save soft labels.
27 | - `--reference_path `: specify the path to existing soft labels as the reference of crop locations. This is used for soft label ensemble in [FKD MEAL V2](https://github.com/szq0214/MEAL-V2).
28 | - [imagenet-folder with train and val folders]: ImageNet data folder.
29 |
30 | Model flags:
31 |
32 | - `--arch `: specify which model to use as the teacher network.
33 | - `--input_size `: input size of teacher network.
34 | - `--teacher_source `: source of teachers. Currently, it supports models from (1) `pytorch`; (2) `timm`; and (3) private pre-trained models.
35 |
36 | Some important notes:
37 |
38 | - Modify `normalize` values according to the training settings of teacher networks.
39 |
40 | ```
41 | # EfficientNet_V2_L, BEIT, etc.
42 | normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
43 | std=[0.5, 0.5, 0.5])
44 | # ResNet, efficientnet_l2_ns_475, etc.
45 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
46 | std=[0.229, 0.224, 0.225])
47 | ```
48 |
49 | - Modify `--min_scale_crops` and `--max_scale_crops` according to the training setting of teacher networks. For example, `tf_efficientnet_l2_ns_475` in `timm` has the scale=(0.08, 0.936).
50 | - `--num_seg` is the true number of batch-size, thus `-b` can be set to a relatively small value. It will not slow down the training.
51 | - `Resume` is supported by simply restart. You can also launch multiple experiments parallelly to speed up generating, it will automatically skip the existing files.
52 |
53 | **Important:** Test your teacher models using `--evaluate` to check whether the accuracy is correct before starting to generate soft labels.
54 |
55 | An example of the command line for generating soft labels from `tf_efficientnet_l2_ns_475`:
56 |
57 | ```
58 | python generate_soft_label.py \
59 | -a tf_efficientnet_l2_ns_475 \
60 | --input_size 475 \
61 | --min_scale_crops 0.08 \
62 | --max_scale_crops 0.936 \
63 | --num_crops 500 \
64 | --num_seg 50 \
65 | -b 4 \
66 | --temp 1.0 \
67 | --label_type marginal_smoothing_k5 \
68 | --save_path FKD_efficientnet_l2_ns_475_marginal_smoothing_k5 \
69 | [imagenet-folder with train and val folders]
70 | ```
71 |
72 | Soft label generation from self-supervised teachers will be available soon.
73 |
74 |
--------------------------------------------------------------------------------
/FKD/FKD_SLG/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | import torch.distributed
5 | import torch.nn as nn
6 | import torchvision
7 | from torchvision.transforms import functional as t_F
8 | from torchvision.datasets.folder import ImageFolder
9 |
10 |
11 | class RandomResizedCropWithCoords(torchvision.transforms.RandomResizedCrop):
12 | def __init__(self, **kwargs):
13 | super(RandomResizedCropWithCoords, self).__init__(**kwargs)
14 |
15 | def __call__(self, img, coords):
16 | try:
17 | reference = (coords.any())
18 | except:
19 | reference = False
20 | if not reference:
21 | i, j, h, w = self.get_params(img, self.scale, self.ratio)
22 | coords = (i / img.size[1],
23 | j / img.size[0],
24 | h / img.size[1],
25 | w / img.size[0])
26 | coords = torch.FloatTensor(coords)
27 | else:
28 | i = coords[0].item() * img.size[1]
29 | j = coords[1].item() * img.size[0]
30 | h = coords[2].item() * img.size[1]
31 | w = coords[3].item() * img.size[0]
32 | return t_F.resized_crop(img, i, j, h, w, self.size,
33 | self.interpolation), coords
34 |
35 |
36 | class ComposeWithCoords(torchvision.transforms.Compose):
37 | def __init__(self, **kwargs):
38 | super(ComposeWithCoords, self).__init__(**kwargs)
39 |
40 | def __call__(self, img, coords, status):
41 | for t in self.transforms:
42 | if type(t).__name__ == 'RandomResizedCropWithCoords':
43 | img, coords = t(img, coords)
44 | elif type(t).__name__ == 'RandomCropWithCoords':
45 | img, coords = t(img, coords)
46 | elif type(t).__name__ == 'RandomHorizontalFlipWithRes':
47 | img, status = t(img, status)
48 | else:
49 | img = t(img)
50 | return img, status, coords
51 |
52 |
53 | class RandomHorizontalFlipWithRes(torch.nn.Module):
54 | """Horizontally flip the given image randomly with a given probability.
55 | If the image is torch Tensor, it is expected
56 | to have [..., H, W] shape, where ... means an arbitrary number of leading
57 | dimensions
58 |
59 | Args:
60 | p (float): probability of the image being flipped. Default value is 0.5
61 | """
62 |
63 | def __init__(self, p=0.5):
64 | super().__init__()
65 | self.p = p
66 |
67 | def forward(self, img, status):
68 | """
69 | Args:
70 | img (PIL Image or Tensor): Image to be flipped.
71 |
72 | Returns:
73 | PIL Image or Tensor: Randomly flipped image.
74 | """
75 |
76 | if status is not None:
77 | if status == True:
78 | return t_F.hflip(img), status
79 | else:
80 | return img, status
81 | else:
82 | status = False
83 | if torch.rand(1) < self.p:
84 | status = True
85 | return t_F.hflip(img), status
86 | return img, status
87 |
88 |
89 | def __repr__(self):
90 | return self.__class__.__name__ + '(p={})'.format(self.p)
91 |
92 |
93 | class ImageFolder_FKD_GSL(torchvision.datasets.ImageFolder):
94 | def __init__(self, **kwargs):
95 | self.num_crops = kwargs['num_crops']
96 | self.save_path = kwargs['save_path']
97 | self.reference_path = kwargs['reference_path']
98 | kwargs.pop('num_crops')
99 | kwargs.pop('save_path')
100 | kwargs.pop('reference_path')
101 | super(ImageFolder_FKD_GSL, self).__init__(**kwargs)
102 |
103 | def __getitem__(self, index):
104 | path, target = self.samples[index]
105 |
106 | if self.reference_path is not None:
107 | ref_path = os.path.join(self.reference_path,'/'.join(path.split('/')[-4:-1]))
108 | ref_filename = os.path.join(ref_path,'/'.join(path.split('/')[-1:]).split('.')[0] + '.tar')
109 | label = torch.load(ref_filename, map_location=torch.device('cpu'))
110 | coords_ref, flip_ref, _ = label
111 | else:
112 | coords_ref = None
113 |
114 | sample = self.loader(path)
115 | sample_all = []
116 | flip_status_all = []
117 | coords_all = []
118 | for i in range(self.num_crops):
119 | if self.transform is not None:
120 | if coords_ref is not None:
121 | coords_ = coords_ref[i]
122 | flip_ = flip_ref[i]
123 | else:
124 | coords_ = None
125 | flip_ = None
126 | sample_new, flip_status, coords_single = self.transform(sample, coords_, flip_)
127 | sample_all.append(sample_new)
128 | flip_status_all.append(flip_status)
129 | coords_all.append(coords_single)
130 | else:
131 | coords = None
132 | flip_status = None
133 | if self.target_transform is not None:
134 | target = self.target_transform(target)
135 |
136 | return sample_all, target, flip_status_all, coords_all, path
--------------------------------------------------------------------------------
/FKD/utils_FKD.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.distributed
4 | import torch.nn as nn
5 | import torchvision
6 | from torchvision.transforms import functional as t_F
7 | from torch.nn import functional as F
8 | from torchvision.datasets.folder import ImageFolder
9 | from torch.nn.modules import loss
10 | from torchvision.transforms import InterpolationMode
11 | import random
12 | import numpy as np
13 |
14 |
15 | class Soft_CrossEntropy(loss._Loss):
16 | def forward(self, model_output, soft_output):
17 |
18 | size_average = True
19 |
20 | model_output_log_prob = F.log_softmax(model_output, dim=1)
21 |
22 | soft_output = soft_output.unsqueeze(1)
23 | model_output_log_prob = model_output_log_prob.unsqueeze(2)
24 |
25 | cross_entropy_loss = -torch.bmm(soft_output, model_output_log_prob)
26 | if size_average:
27 | cross_entropy_loss = cross_entropy_loss.mean()
28 | else:
29 | cross_entropy_loss = cross_entropy_loss.sum()
30 |
31 | return cross_entropy_loss
32 |
33 |
34 | class RandomResizedCrop_FKD(torchvision.transforms.RandomResizedCrop):
35 | def __init__(self, **kwargs):
36 | super(RandomResizedCrop_FKD, self).__init__(**kwargs)
37 |
38 | def __call__(self, img, coords, status):
39 | i = coords[0].item() * img.size[1]
40 | j = coords[1].item() * img.size[0]
41 | h = coords[2].item() * img.size[1]
42 | w = coords[3].item() * img.size[0]
43 |
44 | if self.interpolation == 'bilinear':
45 | inter = InterpolationMode.BILINEAR
46 | elif self.interpolation == 'bicubic':
47 | inter = InterpolationMode.BICUBIC
48 | return t_F.resized_crop(img, i, j, h, w, self.size, inter)
49 |
50 |
51 | class RandomHorizontalFlip_FKD(torch.nn.Module):
52 | def __init__(self, p=0.5):
53 | super().__init__()
54 | self.p = p
55 |
56 | def forward(self, img, coords, status):
57 |
58 | if status == True:
59 | return t_F.hflip(img)
60 | else:
61 | return img
62 |
63 | def __repr__(self):
64 | return self.__class__.__name__ + '(p={})'.format(self.p)
65 |
66 |
67 | class Compose_FKD(torchvision.transforms.Compose):
68 | def __init__(self, **kwargs):
69 | super(Compose_FKD, self).__init__(**kwargs)
70 |
71 | def __call__(self, img, coords, status):
72 | for t in self.transforms:
73 | if type(t).__name__ == 'RandomResizedCrop_FKD':
74 | img = t(img, coords, status)
75 | elif type(t).__name__ == 'RandomCrop_FKD':
76 | img, coords = t(img)
77 | elif type(t).__name__ == 'RandomHorizontalFlip_FKD':
78 | img = t(img, coords, status)
79 | else:
80 | img = t(img)
81 | return img
82 |
83 |
84 | class ImageFolder_FKD(torchvision.datasets.ImageFolder):
85 | def __init__(self, **kwargs):
86 | self.num_crops = kwargs['num_crops']
87 | self.softlabel_path = kwargs['softlabel_path']
88 | kwargs.pop('num_crops')
89 | kwargs.pop('softlabel_path')
90 | super(ImageFolder_FKD, self).__init__(**kwargs)
91 |
92 | def __getitem__(self, index):
93 | path, target = self.samples[index]
94 |
95 | label_path = os.path.join(self.softlabel_path, '/'.join(path.split('/')[-3:]).split('.')[0] + '.tar')
96 |
97 | label = torch.load(label_path, map_location=torch.device('cpu'))
98 |
99 | coords, flip_status, output = label
100 |
101 | rand_index = torch.randperm(len(output))
102 | soft_target = []
103 |
104 | sample = self.loader(path)
105 | sample_all = []
106 | hard_target = []
107 |
108 | for i in range(self.num_crops):
109 | if self.transform is not None:
110 | soft_target.append(output[rand_index[i]])
111 | sample_trans = self.transform(sample, coords[rand_index[i]], flip_status[rand_index[i]])
112 | sample_all.append(sample_trans)
113 | hard_target.append(target)
114 | else:
115 | coords = None
116 | flip_status = None
117 | if self.target_transform is not None:
118 | target = self.target_transform(target)
119 |
120 | return sample_all, hard_target, soft_target
121 |
122 |
123 | def Recover_soft_label(label, label_type, n_classes):
124 | # recover quantized soft label to n_classes dimension.
125 | if label_type == 'hard':
126 |
127 | return torch.zeros(label.size(0), n_classes).scatter_(1, label.view(-1, 1), 1)
128 |
129 | elif label_type == 'smoothing':
130 | index = label[:,0].to(dtype=int)
131 | value = label[:,1]
132 | minor_value = (torch.ones_like(value) - value)/(n_classes-1)
133 | minor_value = minor_value.reshape(-1,1).repeat_interleave(n_classes, dim=1)
134 | soft_label = (minor_value * torch.ones(index.size(0), n_classes)).scatter_(1, index.view(-1, 1), value.view(-1, 1))
135 |
136 | return soft_label
137 |
138 | elif label_type == 'marginal_smoothing_k5':
139 | index = label[:,0,:].to(dtype=int)
140 | value = label[:,1,:]
141 | minor_value = (torch.ones(label.size(0),1) - torch.sum(value, dim=1, keepdim=True))/(n_classes-5)
142 | minor_value = minor_value.reshape(-1,1).repeat_interleave(n_classes, dim=1)
143 | soft_label = (minor_value * torch.ones(index.size(0), n_classes)).scatter_(1, index, value)
144 |
145 | return soft_label
146 |
147 | elif label_type == 'marginal_renorm':
148 | index = label[:,0,:].to(dtype=int)
149 | value = label[:,1,:]
150 | soft_label = torch.zeros(index.size(0), n_classes).scatter_(1, index, value)
151 | soft_label = F.normalize(soft_label, p=1.0, dim=1, eps=1e-12)
152 |
153 | return soft_label
154 |
155 | elif label_type == 'marginal_smoothing_k10':
156 | index = label[:,0,:].to(dtype=int)
157 | value = label[:,1,:]
158 | minor_value = (torch.ones(label.size(0),1) - torch.sum(value, dim=1, keepdim=True))/(n_classes-10)
159 | minor_value = minor_value.reshape(-1,1).repeat_interleave(n_classes, dim=1)
160 | soft_label = (minor_value * torch.ones(index.size(0), n_classes)).scatter_(1, index, value)
161 |
162 | return soft_label
163 |
164 |
165 | def rand_bbox(size, lam):
166 | W = size[2]
167 | H = size[3]
168 | cut_rat = np.sqrt(1. - lam)
169 | cut_w = np.int(W * cut_rat)
170 | cut_h = np.int(H * cut_rat)
171 |
172 | # uniform
173 | cx = np.random.randint(W)
174 | cy = np.random.randint(H)
175 |
176 | bbx1 = np.clip(cx - cut_w // 2, 0, W)
177 | bby1 = np.clip(cy - cut_h // 2, 0, H)
178 | bbx2 = np.clip(cx + cut_w // 2, 0, W)
179 | bby2 = np.clip(cy + cut_h // 2, 0, H)
180 |
181 | return bbx1, bby1, bbx2, bby2
182 |
183 |
184 | def mixup_cutmix(images, soft_label, args):
185 | enable_p = np.random.rand(1)
186 | if enable_p < args.mixup_cutmix_prob:
187 | switch_p = np.random.rand(1)
188 | if switch_p < args.mixup_switch_prob:
189 | lam = np.random.beta(args.mixup, args.mixup)
190 | rand_index = torch.randperm(images.size()[0]).cuda()
191 | target_a = soft_label
192 | target_b = soft_label[rand_index]
193 | mixed_x = lam * images + (1 - lam) * images[rand_index]
194 | target_mix = target_a * lam + target_b * (1 - lam)
195 | return mixed_x, target_mix
196 | else:
197 | lam = np.random.beta(args.cutmix, args.cutmix)
198 | rand_index = torch.randperm(images.size()[0]).cuda()
199 | target_a = soft_label
200 | target_b = soft_label[rand_index]
201 | bbx1, bby1, bbx2, bby2 = rand_bbox(images.size(), lam)
202 | images[:, :, bbx1:bbx2, bby1:bby2] = images[rand_index, :, bbx1:bbx2, bby1:bby2]
203 | # adjust lambda to exactly match pixel ratio
204 | lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (images.size()[-1] * images.size()[-2]))
205 | target_mix = target_a * lam + target_b * (1 - lam)
206 | else:
207 | target_mix = soft_label
208 |
209 | return images, target_mix
210 |
--------------------------------------------------------------------------------
/FKD/FKD_ViT/utils_FKD.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.distributed
4 | import torch.nn as nn
5 | import torchvision
6 | from torchvision.transforms import functional as t_F
7 | from torch.nn import functional as F
8 | from torchvision.datasets.folder import ImageFolder
9 | from torch.nn.modules import loss
10 | from torchvision.transforms import InterpolationMode
11 | import random
12 | import numpy as np
13 |
14 |
15 | class Soft_CrossEntropy(loss._Loss):
16 | def forward(self, model_output, soft_output):
17 |
18 | size_average = True
19 |
20 | model_output_log_prob = F.log_softmax(model_output, dim=1)
21 |
22 | soft_output = soft_output.unsqueeze(1)
23 | model_output_log_prob = model_output_log_prob.unsqueeze(2)
24 |
25 | cross_entropy_loss = -torch.bmm(soft_output, model_output_log_prob)
26 | if size_average:
27 | cross_entropy_loss = cross_entropy_loss.mean()
28 | else:
29 | cross_entropy_loss = cross_entropy_loss.sum()
30 |
31 | return cross_entropy_loss
32 |
33 |
34 | class RandomResizedCrop_FKD(torchvision.transforms.RandomResizedCrop):
35 | def __init__(self, **kwargs):
36 | super(RandomResizedCrop_FKD, self).__init__(**kwargs)
37 |
38 | def __call__(self, img, coords, status):
39 | i = coords[0].item() * img.size[1]
40 | j = coords[1].item() * img.size[0]
41 | h = coords[2].item() * img.size[1]
42 | w = coords[3].item() * img.size[0]
43 |
44 | if self.interpolation == 'bilinear':
45 | inter = InterpolationMode.BILINEAR
46 | elif self.interpolation == 'bicubic':
47 | inter = InterpolationMode.BICUBIC
48 | return t_F.resized_crop(img, i, j, h, w, self.size, inter)
49 |
50 |
51 | class RandomHorizontalFlip_FKD(torch.nn.Module):
52 | def __init__(self, p=0.5):
53 | super().__init__()
54 | self.p = p
55 |
56 | def forward(self, img, coords, status):
57 | if status == True:
58 | return t_F.hflip(img)
59 | else:
60 | return img
61 |
62 | def __repr__(self):
63 | return self.__class__.__name__ + '(p={})'.format(self.p)
64 |
65 |
66 | class Compose_FKD(torchvision.transforms.Compose):
67 | def __init__(self, **kwargs):
68 | super(Compose_FKD, self).__init__(**kwargs)
69 |
70 | def __call__(self, img, coords, status):
71 | for t in self.transforms:
72 | if type(t).__name__ == 'RandomResizedCrop_FKD':
73 | img = t(img, coords, status)
74 | elif type(t).__name__ == 'RandomCrop_FKD':
75 | img, coords = t(img)
76 | elif type(t).__name__ == 'RandomHorizontalFlip_FKD':
77 | img = t(img, coords, status)
78 | else:
79 | img = t(img)
80 | return img
81 |
82 |
83 | class ImageFolder_FKD(torchvision.datasets.ImageFolder):
84 | def __init__(self, **kwargs):
85 | self.num_crops = kwargs['num_crops']
86 | self.softlabel_path = kwargs['softlabel_path']
87 | kwargs.pop('num_crops')
88 | kwargs.pop('softlabel_path')
89 | super(ImageFolder_FKD, self).__init__(**kwargs)
90 |
91 | def __getitem__(self, index):
92 | path, target = self.samples[index]
93 |
94 | label_path = os.path.join(self.softlabel_path, '/'.join(path.split('/')[-3:]).split('.')[0] + '.tar')
95 |
96 | label = torch.load(label_path, map_location=torch.device('cpu'))
97 |
98 | coords, flip_status, output = label
99 |
100 | rand_index = torch.randperm(len(output))
101 | soft_target = []
102 |
103 | sample = self.loader(path)
104 | sample_all = []
105 | hard_target = []
106 |
107 | for i in range(self.num_crops):
108 | if self.transform is not None:
109 | soft_target.append(output[rand_index[i]])
110 | sample_trans = self.transform(sample, coords[rand_index[i]], flip_status[rand_index[i]])
111 | sample_all.append(sample_trans)
112 | hard_target.append(target)
113 | else:
114 | coords = None
115 | flip_status = None
116 | if self.target_transform is not None:
117 | target = self.target_transform(target)
118 |
119 | return sample_all, hard_target, soft_target
120 |
121 |
122 | def Recover_soft_label(label, label_type, n_classes):
123 | # recover quantized soft label to n_classes dimension.
124 | if label_type == 'hard':
125 |
126 | return torch.zeros(label.size(0), n_classes).scatter_(1, label.view(-1, 1), 1)
127 |
128 | elif label_type == 'smoothing':
129 | index = label[:,0].to(dtype=int)
130 | value = label[:,1]
131 | minor_value = (torch.ones_like(value) - value)/(n_classes-1)
132 | minor_value = minor_value.reshape(-1,1).repeat_interleave(n_classes, dim=1)
133 | soft_label = (minor_value * torch.ones(index.size(0), n_classes)).scatter_(1, index.view(-1, 1), value.view(-1, 1))
134 |
135 | return soft_label
136 |
137 | elif label_type == 'marginal_smoothing_k5':
138 | index = label[:,0,:].to(dtype=int)
139 | value = label[:,1,:]
140 | minor_value = (torch.ones(label.size(0),1) - torch.sum(value, dim=1, keepdim=True))/(n_classes-5)
141 | minor_value = minor_value.reshape(-1,1).repeat_interleave(n_classes, dim=1)
142 | soft_label = (minor_value * torch.ones(index.size(0), n_classes)).scatter_(1, index, value)
143 |
144 | return soft_label
145 |
146 | elif label_type == 'marginal_renorm':
147 | index = label[:,0,:].to(dtype=int)
148 | value = label[:,1,:]
149 | soft_label = torch.zeros(index.size(0), n_classes).scatter_(1, index, value)
150 | soft_label = F.normalize(soft_label, p=1.0, dim=1, eps=1e-12)
151 |
152 | return soft_label
153 |
154 | elif label_type == 'marginal_smoothing_k10':
155 | index = label[:,0,:].to(dtype=int)
156 | value = label[:,1,:]
157 | minor_value = (torch.ones(label.size(0),1) - torch.sum(value, dim=1, keepdim=True))/(n_classes-10)
158 | minor_value = minor_value.reshape(-1,1).repeat_interleave(n_classes, dim=1)
159 | soft_label = (minor_value * torch.ones(index.size(0), n_classes)).scatter_(1, index, value)
160 |
161 | return soft_label
162 |
163 |
164 | def rand_bbox(size, lam):
165 | W = size[2]
166 | H = size[3]
167 | cut_rat = np.sqrt(1. - lam)
168 | cut_w = np.int(W * cut_rat)
169 | cut_h = np.int(H * cut_rat)
170 |
171 | # uniform
172 | cx = np.random.randint(W)
173 | cy = np.random.randint(H)
174 |
175 | bbx1 = np.clip(cx - cut_w // 2, 0, W)
176 | bby1 = np.clip(cy - cut_h // 2, 0, H)
177 | bbx2 = np.clip(cx + cut_w // 2, 0, W)
178 | bby2 = np.clip(cy + cut_h // 2, 0, H)
179 |
180 | return bbx1, bby1, bbx2, bby2
181 |
182 |
183 | def mixup_cutmix(images, soft_label, args):
184 | enable_p = np.random.rand(1)
185 | if enable_p < args.mixup_cutmix_prob:
186 | switch_p = np.random.rand(1)
187 | if switch_p < args.mixup_switch_prob:
188 | lam = np.random.beta(args.mixup, args.mixup)
189 | rand_index = torch.randperm(images.size()[0]).cuda()
190 | target_a = soft_label
191 | target_b = soft_label[rand_index]
192 | mixed_x = lam * images + (1 - lam) * images[rand_index]
193 | target_mix = target_a * lam + target_b * (1 - lam)
194 | return mixed_x, target_mix
195 | else:
196 | lam = np.random.beta(args.cutmix, args.cutmix)
197 | rand_index = torch.randperm(images.size()[0]).cuda()
198 | target_a = soft_label
199 | target_b = soft_label[rand_index]
200 | bbx1, bby1, bbx2, bby2 = rand_bbox(images.size(), lam)
201 | images[:, :, bbx1:bbx2, bby1:bby2] = images[rand_index, :, bbx1:bbx2, bby1:bby2]
202 | # adjust lambda to exactly match pixel ratio
203 | lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (images.size()[-1] * images.size()[-2]))
204 | target_mix = target_a * lam + target_b * (1 - lam)
205 | else:
206 | target_mix = soft_label
207 |
208 | return images, target_mix
209 |
--------------------------------------------------------------------------------
/FKD/README.md:
--------------------------------------------------------------------------------
1 | ## 🚀 FKD: A Fast Knowledge Distillation Framework for Visual Recognition
2 |
3 | Official PyTorch implementation of paper [**A Fast Knowledge Distillation Framework for Visual Recognition**](http://zhiqiangshen.com/projects/FKD/FKD_camera-ready.pdf) (ECCV 2022, [ECCV paper](http://zhiqiangshen.com/projects/FKD/FKD_camera-ready.pdf), [arXiv](https://arxiv.org/abs/2112.01528)), Zhiqiang Shen and Eric Xing.
4 |
5 |
6 |
7 |

8 |
9 |
10 | ### Abstract
11 |
12 | Knowledge Distillation (KD) has been recognized as a useful tool in many visual tasks, such as the **supervised classification** and **self-supervised representation learning**. While the main drawback of a vanilla KD framework lies in its mechanism that most of the computational overhead is consumed on forwarding through the giant teacher networks, which makes the whole learning procedure a low-efficient and costly manner.
13 |
14 | **🚀 Fast Knowledge Distillation (FKD)** is a novel framework that addresses the low-efficiency drawback, simulates the distillation training phase, and generates soft labels following the multi-crop KD procedure, meanwhile enjoying a faster training speed than other methods. FKD is even more efficient than the conventional classification framework when employing multi-crop in the same image for data loading. It achieves **80.1%** (SGD) and **80.5%** (AdamW) using ResNet-50 on ImageNet-1K with plain training settings. This work also demonstrates the efficiency advantage of FKD on the self-supervised learning task.
15 |
16 | ## Citation
17 |
18 | @article{shen2021afast,
19 | title={A Fast Knowledge Distillation Framework for Visual Recognition},
20 | author={Zhiqiang Shen and Eric Xing},
21 | year={2021},
22 | journal={arXiv preprint arXiv:2112.01528}
23 | }
24 |
25 | ## What's New
26 | * Please refer to our work [here](https://github.com/VILA-Lab/SRe2L/tree/main/SRe2L/relabel#make-fkd-compatible-with-mixup-and-cutmix) if you would like to utilize mixture-based data augmentations (Mixup, CutMix, etc.) during the soft label generation and model training.
27 | * Includes [code of soft label generation](FKD_SLG) for customization. We will also set up a [soft label zoo and baselines](FKD_SLG) with multiple soft labels from various teachers.
28 | * FKD with AdamW on ResNet-50 achieves **80.5%** using a plain training scheme. Pre-trained model is available [here](https://drive.google.com/file/d/14HgpE-9SMOFUN3cb7gT9OjURqq7s7q2_/view?usp=sharing).
29 |
30 |
31 |
32 | ## Supervised Training
33 |
34 | ### Preparation
35 |
36 | - Install PyTorch and ImageNet dataset following the [official PyTorch ImageNet training code](https://github.com/pytorch/examples/tree/master/imagenet). This repo has minimal modifications on that code.
37 |
38 | - Download our soft label and unzip it. We provide multiple types of [soft labels](http://zhiqiangshen.com/projects/FKD/index.html), and we recommend to use [Marginal Smoothing Top-5 (500-crop)](https://drive.google.com/file/d/14leI6xGfnyxHPsBxo0PpCmOq71gWt008/view?usp=sharing).
39 |
40 | - [Optional] Generate customized soft labels using [./FKD_SLG](FKD_SLG).
41 |
42 |
43 | ### FKD Training on CNNs
44 |
45 | To train a model, run `train_FKD.py` with the desired model architecture and the path to the soft label and ImageNet dataset:
46 |
47 |
48 | ```
49 | python train_FKD.py -a resnet50 --lr 0.1 --num_crops 4 -b 1024 --cos --temp 1.0 --softlabel_path [soft label path] [imagenet-folder with train and val folders]
50 | ```
51 |
52 | Add `--mixup_cutmix` to enable Mixup and Cutmix augmentations. For `--softlabel_path`, use format as `./FKD_soft_label_500_crops_marginal_smoothing_k_5/imagenet`.
53 |
54 | Multi-processing distributed training on a single node with multiple GPUs:
55 |
56 | ```
57 | python train_FKD.py \
58 | --dist-url 'tcp://127.0.0.1:10001' \
59 | --dist-backend 'nccl' \
60 | --multiprocessing-distributed --world-size 1 --rank 0 \
61 | -a resnet50 --lr 0.1 --num_crops 4 -b 1024 \
62 | --temp 1.0 --cos -j 32 \
63 | --save_checkpoint_path ./FKD_nc_4_res50_plain \
64 | --softlabel_path [soft label path, e.g., ./FKD_soft_label_500_crops_marginal_smoothing_k_5/imagenet] \
65 | [imagenet-folder with train and val folders]
66 | ```
67 |
68 |
69 | **For multiple nodes multi-processing distributed training, please refer to [official PyTorch ImageNet training code](https://github.com/pytorch/examples/tree/master/imagenet) for details.**
70 |
71 |
72 | ### Evaluation
73 |
74 | ```
75 | python train_FKD.py -a resnet50 -e --resume [model path] [imagenet-folder with train and val folders]
76 | ```
77 |
78 | ### Training Speed Comparison
79 |
80 | The training speed of each epoch is tested on HPC/CIAI cluster at MBZUAI with 8 NVIDIA V100 GPUs. The batch size is 1024 for all three methods: **(i)** regular/vanilla classification framework, **(ii)** Relabel and **(iii)** FKD. For `Vanilla` and `ReLabel`, we use the average of 10 epochs after the speed is stable. For FKD, we perform `num_crops = 4` to calculate the average of (4 $\times$ 10) epochs, note that using 8 will give faster training speed. All other settings are the same for the comparison.
81 |
82 | | Method | Network | Training time per-epoch |
83 | |:-------:|:--------:|:--------:|
84 | | Vanilla | ResNet-50 | 579.36 sec/epoch |
85 | | ReLabel | ResNet-50 | 762.11 sec/epoch |
86 | | FKD (Ours) | ResNet-50 | 486.77 sec/epoch |
87 |
88 | ### Trained Models
89 |
90 | | Method | Network | accuracy (Top-1) |weights |configurations |
91 | |:-------:|:--------:|:--------:|:--------:|:--------:|
92 | | [`ReLabel`](https://github.com/naver-ai/relabel_imagenet) | ResNet-50 | 78.9 | -- | -- |
93 | | `FKD`| ResNet-50 | **80.1+1.2%** | [link](https://drive.google.com/file/d/1qQK3kae4pXBZOldegnZqw7j_aJWtbPgV/view?usp=sharing) | same as ReLabel while initial lr = 0.1 $\times$ $batch size \over 512$ |
94 | | | | |
95 | | `FKD`(Plain)| ResNet-50 | **79.8** | [link](https://drive.google.com/file/d/1s6Tx5xmXnAseMZJBwaa4bnuvzZZGjMdk/view?usp=sharing) | [Table 12 in paper](http://zhiqiangshen.com/projects/FKD/FKD_camera-ready.pdf)
(w/o warmup&colorJ ) |
96 | | `FKD`(AdamW) | ResNet-50 | **80.5** | [link](https://drive.google.com/file/d/14HgpE-9SMOFUN3cb7gT9OjURqq7s7q2_/view?usp=sharing) | [Table 13 in paper](http://zhiqiangshen.com/projects/FKD/FKD_camera-ready.pdf)
(same as our settings on ViT and SReT) |
97 | | | | |
98 | | [`ReLabel`](https://github.com/naver-ai/relabel_imagenet) | ResNet-101 | 80.7 | -- | -- |
99 | | `FKD` | ResNet-101 | **81.9+1.2%** | [link](TBA) | [Table 12 in paper](http://zhiqiangshen.com/projects/FKD/FKD_camera-ready.pdf) |
100 | | | | |
101 | | `FKD`(Plain)| ResNet-101 | **81.7** | [link](https://drive.google.com/file/d/13bVpHpTykCaYYXIAbWHa2W2C2tSxZlW5/view?usp=sharing) | [Table 12 in paper](http://zhiqiangshen.com/projects/FKD/FKD_camera-ready.pdf)
(w/o warmup&colorJ ) |
102 |
103 | ### Mobile-level Efficient Networks
104 |
105 | | Method | Network | FLOPs | accuracy (Top-1) |weights |
106 | |:-------:|:--------:|:--------:|:--------:|:--------:|
107 | | [`FBNet`](https://arxiv.org/abs/1812.03443)| FBNet-c100 | 375M | 75.12% | -- |
108 | | `FKD`| FBNet-c100 | 375M | **77.13%+2.01%** | [link](https://drive.google.com/file/d/1s2pnIedXgwYAPpY2GBT3OC24ZP-0vfWe/view?usp=sharing) |
109 | | | | |
110 | | [`EfficientNetv2`](https://arxiv.org/abs/2104.00298)| EfficientNetv2-B0 | 700M | 78.35% | -- |
111 | | `FKD`| EfficientNetv2-B0 | 700M | **79.94%+1.59%** | [link](https://drive.google.com/file/d/1qL21XOnTRWt6CvZLvUY5IpULISESEfZm/view?usp=sharing) |
112 |
113 | The training protocol is the same as we used for ViT/SReT:
114 |
115 | ```
116 | # Use the same settings as on ViT and SReT
117 | cd train_ViT
118 | # Train the model
119 | python -u train_ViT_FKD.py \
120 | --dist-url 'tcp://127.0.0.1:10001' \
121 | --dist-backend 'nccl' \
122 | --multiprocessing-distributed --world-size 1 --rank 0 \
123 | -a tf_efficientnetv2_b0 \
124 | --lr 0.002 --wd 0.05 \
125 | --epochs 300 --cos -j 32 \
126 | --num_classes 1000 --temp 1.0 \
127 | -b 1024 --num_crops 4 \
128 | --save_checkpoint_path ./FKD_nc_4_224_efficientnetv2_b0 \
129 | --soft_label_type marginal_smoothing_k5 \
130 | --softlabel_path [soft label path] \
131 | [imagenet-folder with train and val folders]
132 | ```
133 |
134 | ### FKD Training on ViT/DeiT and SReT
135 |
136 | To train a ViT model, run `train_ViT_FKD.py` with the desired model architecture and the path to the soft label and ImageNet dataset:
137 |
138 | ```
139 | cd train_ViT
140 | python train_ViT_FKD.py \
141 | --dist-url 'tcp://127.0.0.1:10001' \
142 | --dist-backend 'nccl' \
143 | --multiprocessing-distributed --world-size 1 --rank 0 \
144 | -a SReT_LT --lr 0.002 --wd 0.05 --num_crops 4 \
145 | --temp 1.0 -b 1024 --cos \
146 | --softlabel_path [soft label path] \
147 | [imagenet-folder with train and val folders]
148 | ```
149 |
150 | For the instructions of `SReT_LT` model, please refer to [SReT](https://github.com/szq0214/SReT) for details.
151 |
152 | ### Evaluation
153 |
154 | ```
155 | python train_ViT_FKD.py -a SReT_LT -e --resume [model path] [imagenet-folder with train and val folders]
156 | ```
157 |
158 | ### Trained Models
159 |
160 | | Model | FLOPs| #params | accuracy (Top-1) |weights |configurations |
161 | |:-------:|:--------:|:--------:|:--------:|:--------:|:--------:|
162 | | [`DeiT-T-distill`](https://github.com/facebookresearch/deit) | 1.3B | 5.7M | 74.5 |-- | -- |
163 | | `FKD ViT/DeiT-T` | 1.3B | 5.7M | **75.2** |[link](https://drive.google.com/file/d/1m33c1wHdCV7ePETO_HvWNaboSd_W4nfC/view?usp=sharing) | [Table 13 in paper](http://zhiqiangshen.com/projects/FKD/FKD_camera-ready.pdf) |
164 | | [`SReT-LT-distill`](https://github.com/szq0214/SReT) | 1.2B | 5.0M | 77.7 |-- | -- |
165 | | `FKD SReT-LT` | 1.2B | 5.0M | **78.7** |[link](https://drive.google.com/file/d/1mmdPXKutHM9Li8xo5nGG6TB0aAXA9PFR/view?usp=sharing) | [Table 13 in paper](http://zhiqiangshen.com/projects/FKD/FKD_camera-ready.pdf) |
166 |
167 | ## Fast MEAL V2
168 |
169 | Please see [MEAL V2](https://github.com/szq0214/MEAL-V2) for the instructions to run FKD with MEAL V2.
170 |
171 | ## Self-supervised Representation Learning Using FKD
172 |
173 | Please see [FKD-SSL](https://github.com/szq0214/FKD/tree/main/FKD_SSL) for the instructions to run FKD for SSL task.
174 |
175 |
176 | ## Contact
177 |
178 | Zhiqiang Shen (zhiqiangshen0214 at gmail.com or zhiqians at andrew.cmu.edu)
179 |
180 |
--------------------------------------------------------------------------------
/FKD/FKD_ViT/SReT.py:
--------------------------------------------------------------------------------
1 | # SReT (Sliced Recursive Transformer: https://arxiv.org/abs/2111.05297)
2 | # Zhiqiang Shen
3 | # CMU & MBZUAI
4 |
5 | # PiT (Rethinking Spatial Dimensions of Vision Transformers)
6 | # Copyright 2021-present NAVER Corp.
7 | # Apache License v2.0
8 |
9 | # Timm (https://github.com/rwightman/pytorch-image-models)
10 | # Ross Wightman
11 | # Apache License v2.0
12 |
13 | import torch
14 | from einops import rearrange
15 | from torch import nn
16 | import math
17 |
18 | from functools import partial
19 | from timm.models.layers import trunc_normal_
20 | from timm.models.layers import DropPath, to_2tuple, lecun_normal_
21 | from timm.models.registry import register_model
22 |
23 |
24 | __all__ = [
25 | "SReT",
26 | "SReT_T",
27 | "SReT_LT",
28 | "SReT_S",
29 | ]
30 |
31 |
32 | class LearnableCoefficient(nn.Module):
33 | def __init__(self):
34 | super(LearnableCoefficient, self).__init__()
35 | self.bias = nn.Parameter(torch.ones(1), requires_grad=True)
36 |
37 | def forward(self, x):
38 | out = x * self.bias
39 | return out
40 |
41 | class Mlp(nn.Module):
42 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
43 | super().__init__()
44 | out_features = out_features or in_features
45 | hidden_features = hidden_features or in_features
46 | self.fc1 = nn.Linear(in_features, hidden_features)
47 | self.act = act_layer()
48 | self.fc2 = nn.Linear(hidden_features, out_features)
49 | self.drop = nn.Dropout(drop)
50 |
51 | def forward(self, x):
52 | x = self.fc1(x)
53 | x = self.act(x)
54 | x = self.drop(x)
55 | x = self.fc2(x)
56 | x = self.drop(x)
57 | return x
58 |
59 |
60 | class Non_proj(nn.Module):
61 |
62 | def __init__(self, dim, num_heads, mlp_ratio=1., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
63 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
64 | super().__init__()
65 | self.norm1 = norm_layer(dim)
66 | mlp_hidden_dim = int(dim * mlp_ratio)
67 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
68 | self.coefficient1 = LearnableCoefficient()
69 | self.coefficient2 = LearnableCoefficient()
70 |
71 | def forward(self, x, recursive_index):
72 | x = self.coefficient1(x) + self.coefficient2(self.mlp(self.norm1(x)))
73 | return x
74 |
75 | class Attention(nn.Module):
76 | def __init__(self, dim, num_groups1=8, num_groups2=4, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
77 | super().__init__()
78 | self.num_heads = num_heads
79 | self.num_groups1 = num_groups1
80 | self.num_groups2 = num_groups2
81 | head_dim = dim // num_heads
82 | self.scale = qk_scale or head_dim ** -0.5
83 |
84 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
85 | self.attn_drop = nn.Dropout(attn_drop)
86 | self.proj = nn.Linear(dim, dim)
87 | self.proj_drop = nn.Dropout(proj_drop)
88 |
89 | def forward(self, x, recursive_index):
90 | B, N, C = x.shape
91 | if recursive_index == False:
92 | num_groups = self.num_groups1
93 | else:
94 | num_groups = self.num_groups2
95 | if num_groups != 1:
96 | idx = torch.randperm(N)
97 | x = x[:,idx,:]
98 | inverse = torch.argsort(idx)
99 | qkv = self.qkv(x).reshape(B, num_groups, N // num_groups, 3, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5)
100 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
101 |
102 | attn = (q @ k.transpose(-2, -1)) * self.scale
103 | attn = attn.softmax(dim=-1)
104 | attn = self.attn_drop(attn)
105 |
106 | x = (attn @ v).transpose(2, 3).reshape(B, num_groups, N // num_groups, C)
107 | x = x.permute(0, 3, 1, 2).reshape(B, C, N).transpose(1, 2)
108 | if recursive_index == True and num_groups != 1:
109 | x = x[:,inverse,:]
110 | x = self.proj(x)
111 | x = self.proj_drop(x)
112 | return x
113 |
114 | class Transformer_Block(nn.Module):
115 |
116 | def __init__(self, dim, num_groups1, num_groups2, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
117 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
118 | super().__init__()
119 | self.norm1 = norm_layer(dim)
120 | self.attn = Attention(
121 | dim, num_groups1=num_groups1, num_groups2=num_groups2, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
122 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
123 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
124 | self.norm2 = norm_layer(dim)
125 | mlp_hidden_dim = int(dim * mlp_ratio)
126 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
127 | self.coefficient1 = LearnableCoefficient()
128 | self.coefficient2 = LearnableCoefficient()
129 | self.coefficient3 = LearnableCoefficient()
130 | self.coefficient4 = LearnableCoefficient()
131 |
132 | def forward(self, x, recursive_index):
133 | x = self.coefficient1(x) + self.coefficient2(self.drop_path(self.attn(self.norm1(x),recursive_index)))
134 |
135 | x = self.coefficient3(x) + self.coefficient4(self.drop_path(self.mlp(self.norm2(x))))
136 | return x
137 |
138 | class Transformer(nn.Module):
139 | def __init__(self, base_dim, depth, recursive_num, groups1, groups2, heads, mlp_ratio, np_mlp_ratio,
140 | drop_rate=.0, attn_drop_rate=.0, drop_path_prob=None):
141 | super(Transformer, self).__init__()
142 | self.layers = nn.ModuleList([])
143 | embed_dim = base_dim * heads
144 |
145 | if drop_path_prob is None:
146 | drop_path_prob = [0.0 for _ in range(depth)]
147 |
148 | blocks = [
149 | Transformer_Block(
150 | dim=embed_dim,
151 | num_groups1=groups1,
152 | num_groups2=groups2,
153 | num_heads=heads,
154 | mlp_ratio=mlp_ratio,
155 | qkv_bias=True,
156 | drop=drop_rate,
157 | attn_drop=attn_drop_rate,
158 | drop_path=drop_path_prob[i],
159 | act_layer=nn.GELU,
160 | norm_layer=partial(nn.LayerNorm, eps=1e-6)
161 | )
162 | for i in range(recursive_num)]
163 |
164 | recursive_loops = int(depth/recursive_num)
165 | non_projs = [
166 | Non_proj(
167 | dim=embed_dim, num_heads=heads, mlp_ratio=np_mlp_ratio, drop=drop_rate, attn_drop=attn_drop_rate,
168 | drop_path=drop_path_prob[i], norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU)
169 | for i in range(depth)]
170 | RT = []
171 | for rn in range(recursive_num):
172 | for rl in range(recursive_loops):
173 | RT.append(blocks[rn])
174 | RT.append(non_projs[rn*recursive_loops+rl])
175 |
176 | self.blocks = nn.ModuleList(RT)
177 |
178 | def forward(self, x):
179 | h, w = x.shape[2:4]
180 | x = rearrange(x, 'b c h w -> b (h w) c')
181 |
182 | for i, blk in enumerate(self.blocks):
183 | if (i+2)%4 == 0: # mark the recursive layers
184 | recursive_index = True
185 | else:
186 | recursive_index = False
187 | x = blk(x, recursive_index)
188 |
189 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
190 |
191 | return x
192 |
193 |
194 | class conv_head_pooling(nn.Module):
195 | def __init__(self, in_feature, out_feature, stride,
196 | padding_mode='zeros'):
197 | super(conv_head_pooling, self).__init__()
198 |
199 | self.conv = nn.Conv2d(in_feature, out_feature, kernel_size=stride + 1,
200 | padding=stride // 2, stride=stride,
201 | padding_mode=padding_mode, groups=in_feature)
202 |
203 | def forward(self, x):
204 |
205 | x = self.conv(x)
206 |
207 | return x
208 |
209 |
210 | class conv_embedding(nn.Module):
211 | def __init__(self, in_channels, out_channels, patch_size,
212 | stride, padding):
213 | super(conv_embedding, self).__init__()
214 | norm_layer = nn.BatchNorm2d
215 | self.conv1 = nn.Conv2d(in_channels, int(out_channels/2), kernel_size=3,
216 | stride=2, padding=1, bias=True)
217 | self.bn1 = norm_layer(int(out_channels/2))
218 | self.relu1 = nn.ReLU(inplace=True)
219 | self.conv2 = nn.Conv2d(int(out_channels/2), out_channels, kernel_size=3,
220 | stride=2, padding=1, bias=True)
221 | self.bn2 = norm_layer(out_channels)
222 | self.relu2 = nn.ReLU(inplace=True)
223 | self.conv3 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
224 | stride=2, padding=1, bias=True)
225 | self.bn3 = norm_layer(out_channels)
226 | self.relu3 = nn.ReLU(inplace=True)
227 |
228 | def forward(self, x):
229 | x = self.conv1(x)
230 | x = self.bn1(x)
231 | x = self.relu1(x)
232 | x = self.conv2(x)
233 | x = self.bn2(x)
234 | x = self.relu2(x)
235 | x = self.conv3(x)
236 | x = self.bn3(x)
237 | x = self.relu3(x)
238 | return x
239 |
240 |
241 | class SReT(nn.Module):
242 | def __init__(self, image_size, patch_size, stride, base_dims, depth, recursive_num, groups1, groups2, heads,
243 | mlp_ratio, np_mlp_ratio, num_classes=1000, in_chans=3,
244 | attn_drop_rate=.0, drop_rate=.0, drop_path_rate=.0):
245 | super(SReT, self).__init__()
246 |
247 | total_block = sum(depth)
248 | padding = 0
249 | block_idx = 0
250 |
251 | width = int(image_size/8)
252 |
253 | self.base_dims = base_dims
254 | self.heads = heads
255 | self.num_classes = num_classes
256 |
257 | self.patch_size = patch_size
258 | self.pos_embed = nn.Parameter(
259 | torch.randn(1, base_dims[0] * heads[0], width, width),
260 | requires_grad=True
261 | )
262 | self.patch_embed = conv_embedding(in_chans, base_dims[0] * heads[0],
263 | patch_size, stride, padding)
264 |
265 | self.pos_drop = nn.Dropout(p=drop_rate)
266 |
267 | self.transformers = nn.ModuleList([])
268 | self.pools = nn.ModuleList([])
269 |
270 | for stage in range(len(depth)):
271 | drop_path_prob = [drop_path_rate * i / total_block
272 | for i in range(block_idx, block_idx + depth[stage])]
273 | block_idx += depth[stage]
274 |
275 | self.transformers.append(
276 | Transformer(base_dims[stage], depth[stage], recursive_num[stage], groups1[stage], groups2[stage], heads[stage],
277 | mlp_ratio, np_mlp_ratio,
278 | drop_rate, attn_drop_rate, drop_path_prob)
279 | )
280 | if stage < len(heads) - 1:
281 | self.pools.append(
282 | conv_head_pooling(base_dims[stage] * heads[stage],
283 | base_dims[stage + 1] * heads[stage + 1],
284 | stride=2
285 | )
286 | )
287 |
288 | self.norm = nn.LayerNorm(base_dims[-1] * heads[-1], eps=1e-6)
289 | self.embed_dim = base_dims[-1] * heads[-1]
290 |
291 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
292 |
293 | # Classifier head
294 | if num_classes > 0:
295 | self.head = nn.Linear(base_dims[-1] * heads[-1], num_classes)
296 | else:
297 | self.head = nn.Identity()
298 |
299 | trunc_normal_(self.pos_embed, std=.02)
300 | self.apply(self._init_weights)
301 |
302 | def _init_weights(self, m):
303 | if isinstance(m, nn.LayerNorm):
304 | nn.init.constant_(m.bias, 0)
305 | nn.init.constant_(m.weight, 1.0)
306 |
307 | @torch.jit.ignore
308 | def no_weight_decay(self):
309 | return {'pos_embed', 'cls_token'}
310 |
311 | def get_classifier(self):
312 | return self.head
313 |
314 | def reset_classifier(self, num_classes, global_pool=''):
315 | self.num_classes = num_classes
316 | if num_classes > 0:
317 | self.head = nn.Linear(self.embed_dim, num_classes)
318 | else:
319 | self.head = nn.Identity()
320 |
321 | def forward_features(self, x):
322 | x = self.patch_embed(x)
323 |
324 | pos_embed = self.pos_embed
325 | x = self.pos_drop(x + pos_embed)
326 |
327 | for stage in range(len(self.pools)):
328 | x = self.transformers[stage](x)
329 | x = self.pools[stage](x)
330 | x = self.transformers[-1](x)
331 |
332 | x = self.avgpool(x)
333 | x = torch.flatten(x, 1)
334 |
335 | x = self.norm(x)
336 |
337 | return x
338 |
339 | def forward(self, x):
340 | x = self.forward_features(x)
341 | x = self.head(x)
342 | return x
343 |
344 |
345 | class Distilled_SReT(SReT):
346 | def __init__(self, *args, **kwargs):
347 | super().__init__(*args, **kwargs)
348 |
349 | def forward(self, x):
350 | x = self.forward_features(x)
351 | x_cls = self.head(x)
352 | # `x_cls, x_cls` is used to make it compatible with DeiT codebase, while SReT uses global_average pooling, and soft label only for knowledge distillation
353 | # so `x_cls` is enough
354 | if self.training:
355 | # return x_cls, x_cls
356 | return x_cls
357 | else:
358 | return x_cls
359 |
360 |
361 | @register_model
362 | def SReT_T(pretrained=False, **kwargs):
363 | model = SReT(
364 | image_size=224,
365 | patch_size=16,
366 | stride=8,
367 | base_dims=[32, 32, 32],
368 | depth=[4, 10, 6],
369 | recursive_num=[2,5,3],
370 | heads=[2, 4, 8],
371 | groups1=[8, 4, 1],
372 | groups2=[2, 1, 1],
373 | mlp_ratio=3.6,
374 | np_mlp_ratio=1,
375 | drop_path_rate=0.1,
376 | **kwargs
377 | )
378 | if pretrained:
379 | state_dict = \
380 | torch.load('SReT_T.pth', map_location='cpu')
381 | model.load_state_dict(state_dict['model'])
382 | return model
383 |
384 | @register_model
385 | def SReT_LT(pretrained=False, **kwargs):
386 | model = SReT(
387 | image_size=224,
388 | patch_size=16,
389 | stride=8,
390 | base_dims=[32, 32, 32],
391 | depth=[4, 10, 6],
392 | recursive_num=[2, 5, 3],
393 | heads=[2, 4, 8],
394 | groups1=[8, 4, 1],
395 | groups2=[2, 1, 1],
396 | mlp_ratio=4.0,
397 | np_mlp_ratio=1,
398 | drop_path_rate=0.1,
399 | **kwargs
400 | )
401 | if pretrained:
402 | state_dict = \
403 | torch.load('SReT_LT.pth', map_location='cpu')
404 | model.load_state_dict(state_dict['model'])
405 | return model
406 |
407 | def SReT_S(pretrained=False, **kwargs):
408 | model = SReT(
409 | image_size=224,
410 | patch_size=16,
411 | stride=8,
412 | base_dims=[42, 42, 42],
413 | depth=[4, 10, 6],
414 | recursive_num=[2, 5, 3],
415 | heads=[3, 6, 12],
416 | groups1=[8, 4, 1],
417 | groups2=[2, 1, 1],
418 | mlp_ratio=3.0,
419 | np_mlp_ratio=2,
420 | drop_path_rate=0.2,
421 | **kwargs
422 | )
423 | if pretrained:
424 | state_dict = \
425 | torch.load('SReT_S.pth', map_location='cpu')
426 | model.load_state_dict(state_dict['model'])
427 | return model
428 |
429 | # Knowledge Distillation
430 | @register_model
431 | def SReT_T_distill(pretrained=False, **kwargs):
432 | model = Distilled_SReT(
433 | image_size=224,
434 | patch_size=16,
435 | stride=8,
436 | base_dims=[32, 32, 32],
437 | depth=[4, 10, 6],
438 | recursive_num=[2, 5, 3],
439 | heads=[2, 4, 8],
440 | groups1=[8, 4, 1],
441 | groups2=[2, 1, 1],
442 | mlp_ratio=3.6,
443 | np_mlp_ratio=1,
444 | drop_path_rate=0.1,
445 | **kwargs
446 | )
447 | if pretrained:
448 | state_dict = \
449 | torch.load('SReT_T_distill.pth', map_location='cpu')
450 | model.load_state_dict(state_dict['model'])
451 | return model
452 |
453 | @register_model
454 | def SReT_LT_distill(pretrained=False, **kwargs):
455 | model = Distilled_SReT(
456 | image_size=224,
457 | patch_size=16,
458 | stride=8,
459 | base_dims=[32, 32, 32],
460 | depth=[4, 10, 6],
461 | recursive_num=[2, 5, 3],
462 | heads=[2, 4, 8],
463 | groups1=[8, 4, 1],
464 | groups2=[2, 1, 1],
465 | mlp_ratio=4.0,
466 | np_mlp_ratio=1,
467 | drop_path_rate=0.1,
468 | **kwargs
469 | )
470 | if pretrained:
471 | state_dict = \
472 | torch.load('SReT_LT_distill.pth', map_location='cpu')
473 | model.load_state_dict(state_dict['model'])
474 | return model
475 |
476 | def SReT_S_distill(pretrained=False, **kwargs):
477 | model = Distilled_SReT(
478 | image_size=224,
479 | patch_size=16,
480 | stride=8,
481 | base_dims=[42, 42, 42],
482 | depth=[4, 10, 6],
483 | recursive_num=[2, 5, 3],
484 | heads=[3, 6, 12],
485 | groups1=[8, 4, 1],
486 | groups2=[2, 1, 1],
487 | mlp_ratio=3.0,
488 | np_mlp_ratio=2,
489 | drop_path_rate=0.2,
490 | **kwargs
491 | )
492 | if pretrained:
493 | state_dict = \
494 | torch.load('SReT_S_distill.pth', map_location='cpu')
495 | model.load_state_dict(state_dict['model'])
496 | return model
--------------------------------------------------------------------------------
/FKD/FKD_SLG/generate_soft_label.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 | import shutil
5 | import time
6 | import warnings
7 | import timm
8 | import numpy as np
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.parallel
13 | import torch.backends.cudnn as cudnn
14 | import torch.distributed as dist
15 | import torch.optim
16 | import torch.multiprocessing as mp
17 | import torch.utils.data
18 | import torch.utils.data.distributed
19 | import torchvision.transforms as transforms
20 | import torchvision.datasets as datasets
21 | import torchvision.models as models
22 |
23 | from utils import RandomResizedCropWithCoords
24 | from utils import RandomHorizontalFlipWithRes
25 | from utils import ImageFolder_FKD_GSL
26 | from utils import ComposeWithCoords
27 |
28 | from torchvision.transforms import InterpolationMode
29 |
30 | import torch.multiprocessing
31 |
32 |
33 | parser = argparse.ArgumentParser(description='FKD Soft Label Generation on ImageNet-1K')
34 | parser.add_argument('data', metavar='DIR',
35 | help='path to dataset')
36 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18')
37 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
38 | help='number of data loading workers (default: 4)')
39 | parser.add_argument('-b', '--batch-size', default=256, type=int,
40 | metavar='N',
41 | help='mini-batch size (default: 256), this is the total '
42 | 'batch size of all GPUs on the current node when '
43 | 'using Data Parallel or Distributed Data Parallel')
44 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
45 | help='evaluate model on validation set')
46 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
47 | help='use pre-trained model')
48 | parser.add_argument('--world-size', default=-1, type=int,
49 | help='number of nodes for distributed training')
50 | parser.add_argument('--rank', default=-1, type=int,
51 | help='node rank for distributed training')
52 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
53 | help='url used to set up distributed training')
54 | parser.add_argument('--dist-backend', default='nccl', type=str,
55 | help='distributed backend')
56 | parser.add_argument('--seed', default=None, type=int,
57 | help='seed for initializing training. ')
58 | parser.add_argument('--gpu', default=None, type=int,
59 | help='GPU id to use.')
60 | parser.add_argument('--multiprocessing-distributed', action='store_true',
61 | help='Use multi-processing distributed training to launch '
62 | 'N processes per node, which has N GPUs. This is the '
63 | 'fastest way to use PyTorch for either single node or '
64 | 'multi node data parallel training')
65 | # FKD soft label generation args
66 | parser.add_argument('--num_crops', default=500, type=int,
67 | help='total number of crops in each image')
68 | parser.add_argument('--num_seg', default=50, type=int,
69 | help='number of crops on each GPU during generating')
70 | parser.add_argument("--min_scale_crops", type=float, default=0.08,
71 | help="argument in RandomResizedCrop")
72 | parser.add_argument("--max_scale_crops", type=float, default=0.936,
73 | help="argument in RandomResizedCrop")
74 | parser.add_argument("--temp", type=float, default=1.0,
75 | help="temperature on soft label")
76 | parser.add_argument('--save_path', default='./FKD_effL2_475_soft_label', type=str, metavar='PATH',
77 | help='path to save soft labels')
78 | parser.add_argument('--reference_path', default=None, type=str, metavar='PATH',
79 | help='path to existing soft labels files, we can use existing crop locations to generate new soft labels')
80 | parser.add_argument('--input_size', default=475, type=int, metavar='S',
81 | help='input size of teacher model')
82 | parser.add_argument('--teacher_path', default='', type=str, metavar='TEACHER',
83 | help='path of pre-trained teacher')
84 | parser.add_argument('--teacher_source', default='timm', type=str, metavar='SOURCE',
85 | help='source of pre-trained teacher models: pytorch or timm')
86 | parser.add_argument('--label_type', default='marginal_smoothing_k5', type=str, metavar='TYPE',
87 | help='type of generated soft labels')
88 | parser.add_argument('--use_fp16', dest='use_fp16', action='store_true',
89 | help='save soft labels as `fp16`')
90 |
91 |
92 | sharing_strategy = "file_system"
93 | torch.multiprocessing.set_sharing_strategy(sharing_strategy)
94 |
95 | def set_worker_sharing_strategy(worker_id: int) -> None:
96 | torch.multiprocessing.set_sharing_strategy(sharing_strategy)
97 |
98 |
99 | def main():
100 | args = parser.parse_args()
101 |
102 | if not os.path.exists(args.save_path):
103 | os.makedirs(args.save_path, exist_ok=True)
104 |
105 | if args.seed is not None:
106 | random.seed(args.seed)
107 | torch.manual_seed(args.seed)
108 | cudnn.deterministic = True
109 | warnings.warn('You have chosen to seed training. '
110 | 'This will turn on the CUDNN deterministic setting, '
111 | 'which can slow down your training considerably! '
112 | 'You may see unexpected behavior when restarting '
113 | 'from checkpoints.')
114 |
115 | if args.gpu is not None:
116 | warnings.warn('You have chosen a specific GPU. This will completely '
117 | 'disable data parallelism.')
118 |
119 | if args.dist_url == "env://" and args.world_size == -1:
120 | args.world_size = int(os.environ["WORLD_SIZE"])
121 |
122 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed
123 |
124 | ngpus_per_node = torch.cuda.device_count()
125 | if args.multiprocessing_distributed:
126 | # Since we have ngpus_per_node processes per node, the total world_size
127 | # needs to be adjusted accordingly
128 | args.world_size = ngpus_per_node * args.world_size
129 | # Use torch.multiprocessing.spawn to launch distributed processes: the
130 | # main_worker process function
131 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
132 | else:
133 | # Simply call main_worker function
134 | main_worker(args.gpu, ngpus_per_node, args)
135 |
136 |
137 | def main_worker(gpu, ngpus_per_node, args):
138 | args.gpu = gpu
139 |
140 | if args.gpu is not None:
141 | print("Use GPU: {} for training".format(args.gpu))
142 |
143 | if args.distributed:
144 | if args.dist_url == "env://" and args.rank == -1:
145 | args.rank = int(os.environ["RANK"])
146 | if args.multiprocessing_distributed:
147 | # For multiprocessing distributed training, rank needs to be the
148 | # global rank among all the processes
149 | args.rank = args.rank * ngpus_per_node + gpu
150 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
151 | world_size=args.world_size, rank=args.rank)
152 | # create model
153 | if os.path.isfile(args.teacher_path):
154 | print("=> using pre-trained model from '{}'".format(args.teacher_path))
155 | model = models.__dict__[args.arch](pretrained=False)
156 | checkpoint = torch.load(args.teacher_path, map_location=torch.device('cpu'))
157 | model.load_state_dict(checkpoint)
158 | elif args.teacher_source == 'timm':
159 | # Timm pre-trained models
160 | print("=> using pre-trained model '{}'".format(args.arch))
161 | model = timm.create_model(args.arch, pretrained=True, num_classes=1000)
162 | elif args.teacher_source == 'pytorch':
163 | # PyTorch pre-trained models
164 | print("=> using pre-trained model '{}'".format(args.arch))
165 | model = models.__dict__[args.arch](pretrained=True)
166 | else:
167 | print("'{}' currently is not supported. Please use pytorch, timm or your own pre-trained models as teachers.".format(args.teacher_source))
168 | # add your code of loading teacher here.
169 | return
170 |
171 | if not torch.cuda.is_available():
172 | print('using CPU, this will be slow')
173 | elif args.distributed:
174 | # For multiprocessing distributed, DistributedDataParallel constructor
175 | # should always set the single device scope, otherwise,
176 | # DistributedDataParallel will use all available devices.
177 | if args.gpu is not None:
178 | torch.cuda.set_device(args.gpu)
179 | model.cuda(args.gpu)
180 | # When using a single GPU per process and per
181 | # DistributedDataParallel, we need to divide the batch size
182 | # ourselves based on the total number of GPUs we have
183 | args.batch_size = int(args.batch_size / ngpus_per_node)
184 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
185 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
186 | else:
187 | model.cuda()
188 | # DistributedDataParallel will divide and allocate batch_size to all
189 | # available GPUs if device_ids are not set
190 | model = torch.nn.parallel.DistributedDataParallel(model)
191 | elif args.gpu is not None:
192 | torch.cuda.set_device(args.gpu)
193 | model = model.cuda(args.gpu)
194 | else:
195 | # DataParallel will divide and allocate batch_size to all available GPUs
196 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
197 | model.features = torch.nn.DataParallel(model.features)
198 | model.cuda()
199 | else:
200 | model = torch.nn.DataParallel(model).cuda()
201 |
202 | # freeze all layers
203 | for name, param in model.named_parameters():
204 | param.requires_grad = False
205 |
206 | # define loss function (criterion) and optimizer
207 | criterion = nn.CrossEntropyLoss().cuda(args.gpu)
208 |
209 | cudnn.benchmark = True
210 |
211 | # Data loading code
212 | traindir = os.path.join(args.data, 'train')
213 | valdir = os.path.join(args.data, 'val')
214 |
215 | # BEIT, etc.
216 | # normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
217 | # std=[0.5, 0.5, 0.5])
218 | # ResNet, efficientnet_l2_ns_475, etc.
219 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
220 | std=[0.229, 0.224, 0.225])
221 |
222 | train_dataset = ImageFolder_FKD_GSL(
223 | num_crops=args.num_crops,
224 | save_path=args.save_path,
225 | reference_path=args.reference_path,
226 | root=traindir,
227 | transform=ComposeWithCoords(transforms=[
228 | RandomResizedCropWithCoords(size=args.input_size,
229 | scale=(args.min_scale_crops, args.max_scale_crops),
230 | interpolation=InterpolationMode.BICUBIC),
231 | RandomHorizontalFlipWithRes(),
232 | transforms.ToTensor(),
233 | normalize,
234 | ]))
235 |
236 | if args.distributed:
237 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
238 | else:
239 | train_sampler = None
240 |
241 | #(train_sampler is None)
242 | train_loader = torch.utils.data.DataLoader(
243 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
244 | num_workers=args.workers, pin_memory=True, sampler=train_sampler, worker_init_fn=set_worker_sharing_strategy)
245 |
246 | val_loader = torch.utils.data.DataLoader(
247 | datasets.ImageFolder(valdir, transforms.Compose([
248 | transforms.Resize(int(256/224*args.input_size)),
249 | transforms.CenterCrop(args.input_size),
250 | transforms.ToTensor(),
251 | normalize,
252 | ])),
253 | batch_size=args.batch_size, shuffle=False,
254 | num_workers=args.workers, pin_memory=True)
255 |
256 | # test the accuracy of teacher model
257 | if args.evaluate:
258 | validate(val_loader, model, criterion, args)
259 | return
260 |
261 | # generate soft labels
262 | generate_soft_labels(train_loader, model, args)
263 |
264 |
265 | def generate_soft_labels(train_loader, model, args):
266 | batch_time = AverageMeter('Time', ':6.3f')
267 | data_time = AverageMeter('Data', ':6.3f')
268 |
269 | # switch to eval mode
270 | model.eval()
271 |
272 | end = time.time()
273 | for i, (images, target, flip_status, coords, path) in enumerate(train_loader):
274 | # measure data loading time
275 | data_time.update(time.time() - end)
276 |
277 | images = torch.stack(images, dim=0).permute(1,0,2,3,4)
278 | flip_status = torch.stack(flip_status, dim=0).permute(1,0)
279 | coords = torch.stack(coords, dim=0).permute(1,0,2)
280 |
281 | for k in range(images.size()[0]):
282 | save_path = os.path.join(args.save_path,'/'.join(path[k].split('/')[-4:-1]))
283 | if not os.path.exists(save_path):
284 | os.makedirs(save_path, exist_ok=True)
285 | new_filename = os.path.join(save_path,'/'.join(path[k].split('/')[-1:]).split('.')[0] + '.tar')
286 | if not os.path.exists(new_filename):
287 | if args.num_crops <= args.num_seg:
288 | if args.gpu is not None:
289 | images[k] = images[k].cuda(args.gpu, non_blocking=True)
290 | if torch.cuda.is_available():
291 | target = target.cuda(args.gpu, non_blocking=True)
292 |
293 | # compute output
294 | output = model(images[k])
295 |
296 | output = nn.functional.softmax(output / args.temp, dim=1)
297 | images[k] = images[k].detach()#.cpu()
298 | output = output.detach().cpu()
299 |
300 | output = label_quantization(output, args.label_type)
301 |
302 | state = [coords[k].detach().numpy(), flip_status[k].detach().numpy(), output]
303 | torch.save(state, new_filename)
304 | else:
305 | output_all = []
306 | for split in range(int(args.num_crops / args.num_seg)):
307 | if args.gpu is not None:
308 | images[k][split*args.num_seg:(split+1)*args.num_seg] = images[k][split*args.num_seg:(split+1)*args.num_seg].cuda(args.gpu, non_blocking=True)
309 | if torch.cuda.is_available():
310 | target = target.cuda(args.gpu, non_blocking=True)
311 |
312 | # compute output
313 | output = model(images[k][split*args.num_seg:(split+1)*args.num_seg])
314 |
315 | output = nn.functional.softmax(output / args.temp, dim=1)
316 | images[k][split*args.num_seg:(split+1)*args.num_seg] = images[k][split*args.num_seg:(split+1)*args.num_seg].detach()#.cpu()
317 | output = output.detach().cpu()
318 | output_all.append(output)
319 |
320 | output_all = torch.cat(output_all, dim=0)
321 | output_quan = label_quantization(output_all, args.label_type)
322 |
323 | if args.use_fp16:
324 | state = [np.float16(coords[k].detach().numpy()), flip_status[k].detach().numpy(), np.float16(output_quan)]
325 | else:
326 | state = [coords[k].detach().numpy(), flip_status[k].detach().numpy(), output_quan]
327 |
328 | torch.save(state, new_filename)
329 |
330 | print(i,'/', len(train_loader), i/len(train_loader)*100, "%")
331 |
332 |
333 | def validate(val_loader, model, criterion, args):
334 | batch_time = AverageMeter('Time', ':6.3f')
335 | losses = AverageMeter('Loss', ':.4e')
336 | top1 = AverageMeter('Acc@1', ':6.2f')
337 | top5 = AverageMeter('Acc@5', ':6.2f')
338 | progress = ProgressMeter(
339 | len(val_loader),
340 | [batch_time, losses, top1, top5],
341 | prefix='Test: ')
342 |
343 | # switch to evaluate mode
344 | model.eval()
345 |
346 | with torch.no_grad():
347 | end = time.time()
348 | for i, (images, target) in enumerate(val_loader):
349 | if args.gpu is not None:
350 | images = images.cuda(args.gpu, non_blocking=True)
351 | if torch.cuda.is_available():
352 | target = target.cuda(args.gpu, non_blocking=True)
353 |
354 | # compute output
355 | output = model(images)
356 | loss = criterion(output, target)
357 |
358 | # measure accuracy and record loss
359 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
360 | losses.update(loss.item(), images.size(0))
361 | top1.update(acc1[0], images.size(0))
362 | top5.update(acc5[0], images.size(0))
363 |
364 | # measure elapsed time
365 | batch_time.update(time.time() - end)
366 | end = time.time()
367 |
368 | if i % args.print_freq == 0:
369 | progress.display(i)
370 |
371 | # TODO: this should also be done with the ProgressMeter
372 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
373 | .format(top1=top1, top5=top5))
374 |
375 | return top1.avg
376 |
377 | def label_quantization(full_label, label_type):
378 | # '(1) hard; (2) smoothing; (3) marginal_smoothing_k5 and marginal_renorm_k5; (4) marginal_smoothing_k10'
379 | output_quantized = []
380 | for kk,p in enumerate(full_label):
381 | # 1 hard
382 | if label_type == 'hard':
383 | output = torch.argmax(p)
384 | output_quantized.append(output)
385 | # 2 smoothing
386 | elif label_type == 'smoothing':
387 | output = torch.argmax(p)
388 | value = p[output]#.item()
389 | output_quantized.append(torch.stack([output, value], dim=0))
390 | # 3 marginal_smoothing_k5 and marginal_renorm_k5
391 | elif label_type == 'marginal_smoothing_k5' or label_type == 'marginal_renorm_k5':
392 | output = torch.argsort(p, descending=True)
393 | value = p[output[:5]]
394 | output_quantized.append(torch.stack([output[:5], value], dim=0))
395 | # 4 marginal_smoothing_k10
396 | elif label_type == 'marginal_smoothing_k10':
397 | output = torch.argsort(p, descending=True)
398 | value = p[output[:10]]
399 | output_quantized.append(torch.stack([output[:10], value], dim=0))
400 |
401 | output_quantized = torch.stack(output_quantized, dim=0)
402 |
403 | return output_quantized.detach().numpy()
404 |
405 |
406 | class AverageMeter(object):
407 | """Computes and stores the average and current value"""
408 | def __init__(self, name, fmt=':f'):
409 | self.name = name
410 | self.fmt = fmt
411 | self.reset()
412 |
413 | def reset(self):
414 | self.val = 0
415 | self.avg = 0
416 | self.sum = 0
417 | self.count = 0
418 |
419 | def update(self, val, n=1):
420 | self.val = val
421 | self.sum += val * n
422 | self.count += n
423 | self.avg = self.sum / self.count
424 |
425 | def __str__(self):
426 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
427 | return fmtstr.format(**self.__dict__)
428 |
429 |
430 | class ProgressMeter(object):
431 | def __init__(self, num_batches, meters, prefix=""):
432 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
433 | self.meters = meters
434 | self.prefix = prefix
435 |
436 | def display(self, batch):
437 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
438 | entries += [str(meter) for meter in self.meters]
439 | print('\t'.join(entries))
440 |
441 | def _get_batch_fmtstr(self, num_batches):
442 | num_digits = len(str(num_batches // 1))
443 | fmt = '{:' + str(num_digits) + 'd}'
444 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
445 |
446 |
447 | def accuracy(output, target, topk=(1,)):
448 | """Computes the accuracy over the k top predictions for the specified values of k"""
449 | with torch.no_grad():
450 | maxk = max(topk)
451 | batch_size = target.size(0)
452 |
453 | _, pred = output.topk(maxk, 1, True, True)
454 | pred = pred.t()
455 | correct = pred.eq(target.view(1, -1).expand_as(pred))
456 |
457 | res = []
458 | for k in topk:
459 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
460 | res.append(correct_k.mul_(100.0 / batch_size))
461 | return res
462 |
463 |
464 | if __name__ == '__main__':
465 | main()
466 |
--------------------------------------------------------------------------------
/FKD/train_FKD.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 | import shutil
5 | import time
6 | import math
7 | import warnings
8 | import numpy as np
9 | import builtins
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.parallel
14 | import torch.backends.cudnn as cudnn
15 | import torch.distributed as dist
16 | import torch.optim
17 | import torch.multiprocessing as mp
18 | import torch.utils.data
19 | import torch.utils.data.distributed
20 | import torchvision.transforms as transforms
21 | import torchvision.datasets as datasets
22 | import torchvision.models as models
23 | from torchvision.transforms import InterpolationMode
24 |
25 | from utils_FKD import RandomResizedCrop_FKD, RandomHorizontalFlip_FKD
26 | from utils_FKD import ImageFolder_FKD, Compose_FKD
27 | from utils_FKD import Soft_CrossEntropy, Recover_soft_label
28 | from utils_FKD import mixup_cutmix
29 |
30 |
31 | model_names = sorted(name for name in models.__dict__
32 | if name.islower() and not name.startswith("__")
33 | and callable(models.__dict__[name]))
34 |
35 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training with FKD Scheme')
36 | parser.add_argument('data', metavar='DIR',
37 | help='path to dataset')
38 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
39 | choices=model_names,
40 | help='model architecture: ' +
41 | ' | '.join(model_names) +
42 | ' (default: resnet18)')
43 | parser.add_argument('-j', '--workers', default=24, type=int, metavar='N',
44 | help='number of data loading workers (default: 4)')
45 | parser.add_argument('--epochs', default=300, type=int, metavar='N',
46 | help='number of total epochs to run')
47 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
48 | help='manual epoch number (useful on restarts)')
49 | parser.add_argument('-b', '--batch-size', default=1024, type=int,
50 | metavar='N',
51 | help='mini-batch size (default: 1024), this is the total '
52 | 'batch size of all GPUs on the current node when '
53 | 'using Data Parallel or Distributed Data Parallel')
54 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
55 | metavar='LR', help='initial learning rate', dest='lr')
56 | parser.add_argument('--schedule', default=[120, 240], nargs='*', type=int,
57 | help='learning rate schedule (when to drop lr by 10x)')
58 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
59 | help='momentum')
60 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
61 | metavar='W', help='weight decay (default: 1e-4)',
62 | dest='weight_decay')
63 | parser.add_argument('-p', '--print-freq', default=10, type=int,
64 | metavar='N', help='print frequency (default: 10)')
65 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
66 | help='path to latest checkpoint (default: none)')
67 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
68 | help='evaluate model on validation set')
69 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
70 | help='use pre-trained model')
71 | parser.add_argument('--world-size', default=-1, type=int,
72 | help='number of nodes for distributed training')
73 | parser.add_argument('--rank', default=-1, type=int,
74 | help='node rank for distributed training')
75 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
76 | help='url used to set up distributed training')
77 | parser.add_argument('--dist-backend', default='nccl', type=str,
78 | help='distributed backend')
79 | parser.add_argument('--seed', default=None, type=int,
80 | help='seed for initializing training. ')
81 | parser.add_argument('--gpu', default=None, type=int,
82 | help='GPU id to use.')
83 | parser.add_argument('--multiprocessing-distributed', action='store_true',
84 | help='Use multi-processing distributed training to launch '
85 | 'N processes per node, which has N GPUs. This is the '
86 | 'fastest way to use PyTorch for either single node or '
87 | 'multi node data parallel training')
88 | parser.add_argument('--num_crops', default=4, type=int,
89 | help='number of crops in each image, 1 is the standard training')
90 | parser.add_argument('--softlabel_path', default='./soft_label', type=str, metavar='PATH',
91 | help='path to soft label files (default: none)')
92 | parser.add_argument("--temp", type=float, default=1.0,
93 | help="temperature on student during training (defautl: 1.0)")
94 | parser.add_argument('--cos', action='store_true',
95 | help='use cosine lr schedule')
96 | parser.add_argument('--save_checkpoint_path', default='./FKD_checkpoints_output', type=str, metavar='PATH',
97 | help='path to latest checkpoint (default: none)')
98 | parser.add_argument('--soft_label_type', default='marginal_smoothing_k5', type=str, metavar='TYPE',
99 | help='(1) ori; (2) hard; (3) smoothing; (4) marginal_smoothing_k5; (5) marginal_smoothing_k10; (6) marginal_renorm_k5')
100 | parser.add_argument('--num_classes', default=1000, type=int,
101 | help='number of classes.')
102 |
103 | # mixup and cutmix parameters
104 | parser.add_argument('--mixup_cutmix', default=False, action='store_true',
105 | help='use mixup and cutmix data augmentation')
106 | parser.add_argument('--mixup', type=float, default=0.8,
107 | help='mixup alpha, mixup enabled if > 0. (default: 0.8)')
108 | parser.add_argument('--cutmix', type=float, default=1.0,
109 | help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')
110 | parser.add_argument('--mixup_cutmix_prob', type=float, default=1.0,
111 | help='Probability of performing mixup or cutmix when either/both is enabled')
112 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
113 | help='Probability of switching to cutmix when both mixup and cutmix enabled. Mixup only: set to 1.0, Cutmix only: set to 0.0')
114 |
115 | best_acc1 = 0
116 |
117 |
118 | def main():
119 | args = parser.parse_args()
120 |
121 | if not os.path.exists(args.save_checkpoint_path):
122 | os.makedirs(args.save_checkpoint_path)
123 |
124 | # convert to TRUE number of loading-images since we use multiple crops from the same image within a minbatch
125 | args.batch_size = math.ceil(args.batch_size / args.num_crops)
126 |
127 | if args.seed is not None:
128 | random.seed(args.seed)
129 | torch.manual_seed(args.seed)
130 | cudnn.deterministic = True
131 | warnings.warn('You have chosen to seed training. '
132 | 'This will turn on the CUDNN deterministic setting, '
133 | 'which can slow down your training considerably! '
134 | 'You may see unexpected behavior when restarting '
135 | 'from checkpoints.')
136 |
137 | if args.gpu is not None:
138 | warnings.warn('You have chosen a specific GPU. This will completely '
139 | 'disable data parallelism.')
140 |
141 | if args.dist_url == "env://" and args.world_size == -1:
142 | args.world_size = int(os.environ["WORLD_SIZE"])
143 |
144 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed
145 |
146 | ngpus_per_node = torch.cuda.device_count()
147 | if args.multiprocessing_distributed:
148 | # Since we have ngpus_per_node processes per node, the total world_size
149 | # needs to be adjusted accordingly
150 | args.world_size = ngpus_per_node * args.world_size
151 | # Use torch.multiprocessing.spawn to launch distributed processes: the
152 | # main_worker process function
153 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
154 | else:
155 | # Simply call main_worker function
156 | main_worker(args.gpu, ngpus_per_node, args)
157 |
158 |
159 | def main_worker(gpu, ngpus_per_node, args):
160 | global best_acc1
161 | args.gpu = gpu
162 |
163 | # suppress printing if not master
164 | if args.multiprocessing_distributed and args.gpu != 0:
165 | def print_pass(*args):
166 | pass
167 | builtins.print = print_pass
168 |
169 | if args.gpu is not None:
170 | print("Use GPU: {} for training".format(args.gpu))
171 |
172 | if args.distributed:
173 | if args.dist_url == "env://" and args.rank == -1:
174 | args.rank = int(os.environ["RANK"])
175 | if args.multiprocessing_distributed:
176 | # For multiprocessing distributed training, rank needs to be the
177 | # global rank among all the processes
178 | args.rank = args.rank * ngpus_per_node + gpu
179 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
180 | world_size=args.world_size, rank=args.rank)
181 | # create model
182 | if args.pretrained:
183 | print("=> using pre-trained model '{}'".format(args.arch))
184 | model = models.__dict__[args.arch](pretrained=True)
185 | else:
186 | print("=> creating model '{}'".format(args.arch))
187 | model = models.__dict__[args.arch](pretrained=False, num_classes=args.num_classes)
188 |
189 | if not torch.cuda.is_available():
190 | print('using CPU, this will be slow')
191 | elif args.distributed:
192 | # For multiprocessing distributed, DistributedDataParallel constructor
193 | # should always set the single device scope, otherwise,
194 | # DistributedDataParallel will use all available devices.
195 | if args.gpu is not None:
196 | torch.cuda.set_device(args.gpu)
197 | model.cuda(args.gpu)
198 | # When using a single GPU per process and per
199 | # DistributedDataParallel, we need to divide the batch size
200 | # ourselves based on the total number of GPUs we have
201 | args.batch_size = int(args.batch_size / ngpus_per_node)
202 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
203 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
204 | else:
205 | model.cuda()
206 | # DistributedDataParallel will divide and allocate batch_size to all
207 | # available GPUs if device_ids are not set
208 | model = torch.nn.parallel.DistributedDataParallel(model)
209 | elif args.gpu is not None:
210 | torch.cuda.set_device(args.gpu)
211 | model = model.cuda(args.gpu)
212 | else:
213 | # DataParallel will divide and allocate batch_size to all available GPUs
214 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
215 | model.features = torch.nn.DataParallel(model.features)
216 | model.cuda()
217 | else:
218 | model = torch.nn.DataParallel(model).cuda()
219 |
220 | # define loss function (criterion) and optimizer
221 | criterion_sce = Soft_CrossEntropy()
222 | criterion_ce = nn.CrossEntropyLoss().cuda(args.gpu)
223 |
224 | optimizer = torch.optim.SGD(model.parameters(), args.lr,
225 | momentum=args.momentum,
226 | weight_decay=args.weight_decay)
227 |
228 | # optionally resume from a checkpoint
229 | if args.resume:
230 | if os.path.isfile(args.resume):
231 | print("=> loading checkpoint '{}'".format(args.resume))
232 | if args.gpu is None:
233 | checkpoint = torch.load(args.resume)
234 | else:
235 | # Map model to be loaded to specified single gpu.
236 | loc = 'cuda:{}'.format(args.gpu)
237 | checkpoint = torch.load(args.resume, map_location=loc)
238 | args.start_epoch = checkpoint['epoch']
239 | best_acc1 = checkpoint['best_acc1']
240 | if args.gpu is not None:
241 | # best_acc1 may be from a checkpoint from a different GPU
242 | best_acc1 = best_acc1.to(args.gpu)
243 | model.load_state_dict(checkpoint['state_dict'])
244 | optimizer.load_state_dict(checkpoint['optimizer'])
245 | print("=> loaded checkpoint '{}' (epoch {})"
246 | .format(args.resume, checkpoint['epoch']))
247 | else:
248 | print("=> no checkpoint found at '{}'".format(args.resume))
249 |
250 | cudnn.benchmark = True
251 |
252 | # Data loading code
253 | traindir = os.path.join(args.data, 'train')
254 | valdir = os.path.join(args.data, 'val')
255 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
256 | std=[0.229, 0.224, 0.225])
257 |
258 | train_dataset = ImageFolder_FKD(
259 | num_crops=args.num_crops,
260 | softlabel_path=args.softlabel_path,
261 | root=traindir,
262 | transform=Compose_FKD(transforms=[
263 | RandomResizedCrop_FKD(size=224,
264 | interpolation='bilinear'),
265 | RandomHorizontalFlip_FKD(),
266 | transforms.ToTensor(),
267 | normalize,
268 | ]))
269 | train_dataset_single_crop = ImageFolder_FKD(
270 | num_crops=1,
271 | softlabel_path=args.softlabel_path,
272 | root=traindir,
273 | transform=Compose_FKD(transforms=[
274 | RandomResizedCrop_FKD(size=224,
275 | interpolation='bilinear'),
276 | RandomHorizontalFlip_FKD(),
277 | transforms.ToTensor(),
278 | normalize,
279 | ]))
280 |
281 | if args.distributed:
282 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
283 | else:
284 | train_sampler = None
285 |
286 | train_loader = torch.utils.data.DataLoader(
287 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
288 | num_workers=args.workers, pin_memory=True, sampler=train_sampler)
289 |
290 | train_loader_single_crop = torch.utils.data.DataLoader(
291 | train_dataset_single_crop, batch_size=args.batch_size*args.num_crops, shuffle=(train_sampler is None),
292 | num_workers=args.workers, pin_memory=True, sampler=train_sampler)
293 |
294 | val_loader = torch.utils.data.DataLoader(
295 | datasets.ImageFolder(valdir, transforms.Compose([
296 | transforms.Resize(256),
297 | transforms.CenterCrop(224),
298 | transforms.ToTensor(),
299 | normalize,
300 | ])),
301 | batch_size=args.batch_size, shuffle=False,
302 | num_workers=args.workers, pin_memory=True)
303 |
304 | if args.evaluate:
305 | validate(val_loader, model, criterion_ce, args)
306 | return
307 |
308 | # for resume
309 | if args.start_epoch !=0 and args.start_epoch < (args.epochs-args.num_crops):
310 | args.start_epoch = args.start_epoch + args.num_crops - 1
311 |
312 | for epoch in range(args.start_epoch, args.epochs, args.num_crops):
313 | if args.distributed:
314 | train_sampler.set_epoch(epoch)
315 | adjust_learning_rate(optimizer, epoch, args)
316 | # for fine-grained evaluation at last a few epochs
317 | if epoch >= (args.epochs-args.num_crops):
318 | start_epoch = epoch
319 | for epoch in range(start_epoch, args.epochs):
320 | if args.distributed:
321 | train_sampler.set_epoch(epoch)
322 | adjust_learning_rate(optimizer, epoch, args)
323 |
324 | # train for one epoch
325 | train(train_loader_single_crop, model, criterion_sce, optimizer, epoch, args)
326 |
327 | # evaluate on validation set
328 | acc1 = validate(val_loader, model, criterion_ce, args)
329 |
330 | # remember best acc@1 and save checkpoint
331 | is_best = acc1 > best_acc1
332 | best_acc1 = max(acc1, best_acc1)
333 |
334 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed
335 | and args.rank % ngpus_per_node == 0):
336 | save_checkpoint({
337 | 'epoch': epoch + 1,
338 | 'arch': args.arch,
339 | 'state_dict': model.state_dict(),
340 | 'best_acc1': best_acc1,
341 | 'optimizer' : optimizer.state_dict(),
342 | }, is_best, filename=args.save_checkpoint_path+'/checkpoint.pth.tar')
343 | return
344 | else:
345 | # train for one epoch
346 | train(train_loader, model, criterion_sce, optimizer, epoch, args)
347 |
348 | # evaluate on validation set
349 | acc1 = validate(val_loader, model, criterion_ce, args)
350 |
351 | # remember best acc@1 and save checkpoint
352 | is_best = acc1 > best_acc1
353 | best_acc1 = max(acc1, best_acc1)
354 |
355 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed
356 | and args.rank % ngpus_per_node == 0):
357 | save_checkpoint({
358 | 'epoch': epoch + 1,
359 | 'arch': args.arch,
360 | 'state_dict': model.state_dict(),
361 | 'best_acc1': best_acc1,
362 | 'optimizer' : optimizer.state_dict(),
363 | }, is_best, filename=args.save_checkpoint_path+'/checkpoint.pth.tar')
364 |
365 |
366 | def train(train_loader, model, criterion, optimizer, epoch, args):
367 | batch_time = AverageMeter('Time', ':6.3f')
368 | data_time = AverageMeter('Data', ':6.3f')
369 | losses = AverageMeter('Loss', ':.4e')
370 | top1 = AverageMeter('Acc@1', ':6.2f')
371 | top5 = AverageMeter('Acc@5', ':6.2f')
372 | progress = ProgressMeter(
373 | len(train_loader),
374 | [batch_time, data_time, losses, top1, top5, 'LR {lr:.5f}'.format(lr=_get_learning_rate(optimizer))],
375 | prefix="Epoch: [{}]".format(epoch))
376 |
377 | # switch to train mode
378 | model.train()
379 |
380 | end = time.time()
381 | for i, (images, target, soft_label) in enumerate(train_loader):
382 | # measure data loading time
383 | data_time.update(time.time() - end)
384 |
385 | # reshape images and soft label
386 | images = torch.cat(images, dim=0)
387 | soft_label = torch.cat(soft_label, dim=0)
388 | target = torch.cat(target, dim=0)
389 |
390 | if args.soft_label_type != 'ori':
391 | soft_label = Recover_soft_label(soft_label, args.soft_label_type, args.num_classes)
392 |
393 | if args.gpu is not None:
394 | images = images.cuda(args.gpu, non_blocking=True)
395 | if torch.cuda.is_available():
396 | target = target.cuda(args.gpu, non_blocking=True)
397 | soft_label = soft_label.cuda(args.gpu, non_blocking=True)
398 |
399 | if args.mixup_cutmix:
400 | images, soft_label = mixup_cutmix(images, soft_label, args)
401 |
402 | # compute output
403 | output = model(images)
404 | loss = criterion(output / args.temp, soft_label)
405 |
406 | # measure accuracy and record loss
407 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
408 | losses.update(loss.item(), images.size(0))
409 | top1.update(acc1[0], images.size(0))
410 | top5.update(acc5[0], images.size(0))
411 |
412 | # compute gradient and do SGD step
413 | optimizer.zero_grad()
414 | loss.backward()
415 | optimizer.step()
416 |
417 | # measure elapsed time
418 | batch_time.update(time.time() - end)
419 | end = time.time()
420 |
421 | if i % args.print_freq == 0:
422 | progress.display(i)
423 |
424 |
425 | def validate(val_loader, model, criterion, args):
426 | batch_time = AverageMeter('Time', ':6.3f')
427 | losses = AverageMeter('Loss', ':.4e')
428 | top1 = AverageMeter('Acc@1', ':6.2f')
429 | top5 = AverageMeter('Acc@5', ':6.2f')
430 | progress = ProgressMeter(
431 | len(val_loader),
432 | [batch_time, losses, top1, top5],
433 | prefix='Test: ')
434 |
435 | # switch to evaluate mode
436 | model.eval()
437 |
438 | with torch.no_grad():
439 | end = time.time()
440 | for i, (images, target) in enumerate(val_loader):
441 | if args.gpu is not None:
442 | images = images.cuda(args.gpu, non_blocking=True)
443 | if torch.cuda.is_available():
444 | target = target.cuda(args.gpu, non_blocking=True)
445 |
446 | # compute output
447 | output = model(images)
448 | loss = criterion(output, target)
449 |
450 | # measure accuracy and record loss
451 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
452 | losses.update(loss.item(), images.size(0))
453 | top1.update(acc1[0], images.size(0))
454 | top5.update(acc5[0], images.size(0))
455 |
456 | # measure elapsed time
457 | batch_time.update(time.time() - end)
458 | end = time.time()
459 |
460 | if i % args.print_freq == 0:
461 | progress.display(i)
462 |
463 | # TODO: this should also be done with the ProgressMeter
464 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
465 | .format(top1=top1, top5=top5))
466 |
467 | return top1.avg
468 |
469 |
470 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
471 | torch.save(state, filename)
472 | if is_best:
473 | shutil.copyfile(filename, filename[:-19]+'/model_best.pth.tar')
474 |
475 |
476 | class AverageMeter(object):
477 | """Computes and stores the average and current value"""
478 | def __init__(self, name, fmt=':f'):
479 | self.name = name
480 | self.fmt = fmt
481 | self.reset()
482 |
483 | def reset(self):
484 | self.val = 0
485 | self.avg = 0
486 | self.sum = 0
487 | self.count = 0
488 |
489 | def update(self, val, n=1):
490 | self.val = val
491 | self.sum += val * n
492 | self.count += n
493 | self.avg = self.sum / self.count
494 |
495 | def __str__(self):
496 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
497 | return fmtstr.format(**self.__dict__)
498 |
499 |
500 | class ProgressMeter(object):
501 | def __init__(self, num_batches, meters, prefix=""):
502 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
503 | self.meters = meters
504 | self.prefix = prefix
505 |
506 | def display(self, batch):
507 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
508 | entries += [str(meter) for meter in self.meters]
509 | print('\t'.join(entries))
510 |
511 | def _get_batch_fmtstr(self, num_batches):
512 | num_digits = len(str(num_batches // 1))
513 | fmt = '{:' + str(num_digits) + 'd}'
514 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
515 |
516 |
517 | def adjust_learning_rate(optimizer, epoch, args):
518 | """Decay the learning rate based on schedule"""
519 | lr = args.lr
520 | if args.cos: # cosine lr schedule
521 | lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs))
522 | else: # stepwise lr schedule
523 | for milestone in args.schedule:
524 | lr *= 0.1 if epoch >= milestone else 1.
525 | for param_group in optimizer.param_groups:
526 | param_group['lr'] = lr
527 |
528 | def _get_learning_rate(optimizer):
529 | return max(param_group['lr'] for param_group in optimizer.param_groups)
530 |
531 | def accuracy(output, target, topk=(1,)):
532 | """Computes the accuracy over the k top predictions for the specified values of k"""
533 | with torch.no_grad():
534 | maxk = max(topk)
535 | batch_size = target.size(0)
536 |
537 | _, pred = output.topk(maxk, 1, True, True)
538 | pred = pred.t()
539 | correct = pred.eq(target.view(1, -1).expand_as(pred))
540 |
541 | res = []
542 | for k in topk:
543 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
544 | res.append(correct_k.mul_(100.0 / batch_size))
545 | return res
546 |
547 |
548 | if __name__ == '__main__':
549 | main()
--------------------------------------------------------------------------------
/FKD/FKD_ViT/train_ViT_FKD.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 | import shutil
5 | import time
6 | import math
7 | import warnings
8 | import numpy as np
9 | import builtins
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.parallel
14 | import torch.backends.cudnn as cudnn
15 | import torch.distributed as dist
16 | import torch.optim
17 | import torch.multiprocessing as mp
18 | import torch.utils.data
19 | import torch.utils.data.distributed
20 | import torchvision.transforms as transforms
21 | import torchvision.datasets as datasets
22 | import torchvision.models as models
23 | from torchvision.transforms import InterpolationMode
24 |
25 | from utils_FKD import RandomResizedCrop_FKD, RandomHorizontalFlip_FKD
26 | from utils_FKD import ImageFolder_FKD, Compose_FKD
27 | from utils_FKD import Soft_CrossEntropy, Recover_soft_label
28 | from utils_FKD import mixup_cutmix
29 |
30 | import timm
31 | from timm.scheduler import create_scheduler
32 | from timm.optim import create_optimizer
33 | import SReT
34 |
35 | # timm is used to build the optimizer and learning rate scheduler (https://github.com/rwightman/pytorch-image-models)
36 |
37 |
38 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training with FKD Scheme')
39 | parser.add_argument('data', metavar='DIR',
40 | help='path to dataset')
41 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50')
42 | parser.add_argument('-j', '--workers', default=24, type=int, metavar='N',
43 | help='number of data loading workers (default: 24)')
44 | parser.add_argument('--epochs', default=300, type=int, metavar='N',
45 | help='number of total epochs to run')
46 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
47 | help='manual epoch number (useful on restarts)')
48 | parser.add_argument('-b', '--batch-size', default=256, type=int,
49 | metavar='N',
50 | help='mini-batch size (default: 256), this is the total '
51 | 'batch size of all GPUs on the current node when '
52 | 'using Data Parallel or Distributed Data Parallel')
53 | parser.add_argument('--lr', '--learning-rate', default=0.002, type=float,
54 | metavar='LR', help='initial learning rate', dest='lr')
55 | parser.add_argument('--schedule', default=[120, 160], nargs='*', type=int,
56 | help='learning rate schedule (when to drop lr by 10x)')
57 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
58 | help='momentum')
59 | parser.add_argument('--wd', '--weight-decay', default=0.05, type=float,
60 | metavar='W', help='weight decay (default: 1e-4)',
61 | dest='weight_decay')
62 | parser.add_argument('-p', '--print-freq', default=10, type=int,
63 | metavar='N', help='print frequency (default: 10)')
64 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
65 | help='path to latest checkpoint (default: none)')
66 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
67 | help='evaluate model on validation set')
68 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
69 | help='use pre-trained model')
70 | parser.add_argument('--world-size', default=-1, type=int,
71 | help='number of nodes for distributed training')
72 | parser.add_argument('--rank', default=-1, type=int,
73 | help='node rank for distributed training')
74 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
75 | help='url used to set up distributed training')
76 | parser.add_argument('--dist-backend', default='nccl', type=str,
77 | help='distributed backend')
78 | parser.add_argument('--seed', default=None, type=int,
79 | help='seed for initializing training. ')
80 | parser.add_argument('--gpu', default=None, type=int,
81 | help='GPU id to use.')
82 | parser.add_argument('--multiprocessing-distributed', action='store_true',
83 | help='Use multi-processing distributed training to launch '
84 | 'N processes per node, which has N GPUs. This is the '
85 | 'fastest way to use PyTorch for either single node or '
86 | 'multi node data parallel training')
87 | parser.add_argument('--num_crops', default=4, type=int,
88 | help='number of crops in each image, 1 is the standard training')
89 | parser.add_argument('--softlabel_path', default='../soft_label', type=str, metavar='PATH',
90 | help='path to soft label files (default: none)')
91 | parser.add_argument("--temp", type=float, default=1.0,
92 | help="temperature on student during training (defautl: 1.0)")
93 | parser.add_argument('--save_checkpoint_path', default='./FKD_checkpoints_output', type=str, metavar='PATH',
94 | help='path to latest checkpoint (default: none)')
95 | parser.add_argument('--soft_label_type', default='marginal_smoothing_k5', type=str, metavar='TYPE',
96 | help='(1) ori; (2) hard; (3) smoothing; (4) marginal_smoothing_k5; (5) marginal_smoothing_k10; (6) marginal_renorm_k5')
97 | parser.add_argument('--num_classes', default=1000, type=int,
98 | help='number of classes.')
99 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
100 | help='Dropout rate (default: 0.)')
101 | parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',
102 | help='Drop path rate (default: 0.1)')
103 |
104 | # Optimizer parameters
105 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
106 | help='Optimizer (default: "adamw"')
107 | parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
108 | help='Optimizer Epsilon (default: 1e-8)')
109 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
110 | help='Optimizer Betas (default: None, use opt default)')
111 | parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
112 | help='Clip gradient norm (default: None, no clipping)')
113 |
114 | # Learning rate schedule parameters
115 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
116 | help='LR scheduler (default: "cosine"')
117 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
118 | help='learning rate noise on/off epoch percentages')
119 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
120 | help='learning rate noise limit percent (default: 0.67)')
121 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
122 | help='learning rate noise std-dev (default: 1.0)')
123 | parser.add_argument('--warmup-lr', type=float, default=1e-5, metavar='LR',
124 | help='warmup learning rate (default: 1e-5)')
125 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
126 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
127 | parser.add_argument('--cos', action='store_true',
128 | help='use cosine lr schedule')
129 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
130 | help='epoch interval to decay LR')
131 | parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
132 | help='epochs to warmup LR, if scheduler supports')
133 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
134 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
135 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
136 | help='patience epochs for Plateau LR scheduler (default: 10')
137 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
138 | help='LR decay rate (default: 0.1)')
139 |
140 | # mixup and cutmix parameters
141 | parser.add_argument('--mixup_cutmix', default=False, action='store_true',
142 | help='use mixup and cutmix data augmentation')
143 | parser.add_argument('--mixup', type=float, default=0.8,
144 | help='mixup alpha, mixup enabled if > 0. (default: 0.8)')
145 | parser.add_argument('--cutmix', type=float, default=1.0,
146 | help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')
147 | parser.add_argument('--mixup_cutmix_prob', type=float, default=1.0,
148 | help='Probability of performing mixup or cutmix when either/both is enabled')
149 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
150 | help='Probability of switching to cutmix when both mixup and cutmix enabled. Mixup only: set to 1.0, Cutmix only: set to 0.0')
151 |
152 | best_acc1 = 0
153 |
154 |
155 | def main():
156 | args = parser.parse_args()
157 |
158 | if not os.path.exists(args.save_checkpoint_path):
159 | os.makedirs(args.save_checkpoint_path)
160 |
161 | # convert to TRUE number of loading-images and #epochs since we use multiple crops from the same image within a minbatch
162 | args.batch_size = math.ceil(args.batch_size / args.num_crops)
163 |
164 | if args.seed is not None:
165 | random.seed(args.seed)
166 | torch.manual_seed(args.seed)
167 | cudnn.deterministic = True
168 | warnings.warn('You have chosen to seed training. '
169 | 'This will turn on the CUDNN deterministic setting, '
170 | 'which can slow down your training considerably! '
171 | 'You may see unexpected behavior when restarting '
172 | 'from checkpoints.')
173 |
174 | if args.gpu is not None:
175 | warnings.warn('You have chosen a specific GPU. This will completely '
176 | 'disable data parallelism.')
177 |
178 | if args.dist_url == "env://" and args.world_size == -1:
179 | args.world_size = int(os.environ["WORLD_SIZE"])
180 |
181 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed
182 |
183 | ngpus_per_node = torch.cuda.device_count()
184 | if args.multiprocessing_distributed:
185 | # Since we have ngpus_per_node processes per node, the total world_size
186 | # needs to be adjusted accordingly
187 | args.world_size = ngpus_per_node * args.world_size
188 | # Use torch.multiprocessing.spawn to launch distributed processes: the
189 | # main_worker process function
190 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
191 | else:
192 | # Simply call main_worker function
193 | main_worker(args.gpu, ngpus_per_node, args)
194 |
195 |
196 | def main_worker(gpu, ngpus_per_node, args):
197 | global best_acc1
198 | args.gpu = gpu
199 |
200 | # suppress printing if not master
201 | if args.multiprocessing_distributed and args.gpu != 0:
202 | def print_pass(*args):
203 | pass
204 | builtins.print = print_pass
205 |
206 | if args.gpu is not None:
207 | print("Use GPU: {} for training".format(args.gpu))
208 |
209 | if args.distributed:
210 | if args.dist_url == "env://" and args.rank == -1:
211 | args.rank = int(os.environ["RANK"])
212 | if args.multiprocessing_distributed:
213 | # For multiprocessing distributed training, rank needs to be the
214 | # global rank among all the processes
215 | args.rank = args.rank * ngpus_per_node + gpu
216 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
217 | world_size=args.world_size, rank=args.rank)
218 | # create model
219 | if args.pretrained:
220 | print("=> using pre-trained model '{}'".format(args.arch))
221 | model = timm.create_model(args.arch, pretrained=True, num_classes=args.num_classes)
222 | else:
223 | print("=> creating model '{}'".format(args.arch))
224 | if args.arch.split('_')[0] == 'SReT':
225 | model = SReT.__dict__[args.arch](pretrained=False, num_classes=args.num_classes)
226 | else:
227 | model = timm.create_model(args.arch, pretrained=False, num_classes=args.num_classes)
228 |
229 | if not torch.cuda.is_available():
230 | print('using CPU, this will be slow')
231 | elif args.distributed:
232 | # For multiprocessing distributed, DistributedDataParallel constructor
233 | # should always set the single device scope, otherwise,
234 | # DistributedDataParallel will use all available devices.
235 | if args.gpu is not None:
236 | torch.cuda.set_device(args.gpu)
237 | model.cuda(args.gpu)
238 | # When using a single GPU per process and per
239 | # DistributedDataParallel, we need to divide the batch size
240 | # ourselves based on the total number of GPUs we have
241 | args.batch_size = int(args.batch_size / ngpus_per_node)
242 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
243 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
244 | else:
245 | model.cuda()
246 | # DistributedDataParallel will divide and allocate batch_size to all
247 | # available GPUs if device_ids are not set
248 | model = torch.nn.parallel.DistributedDataParallel(model)
249 | elif args.gpu is not None:
250 | torch.cuda.set_device(args.gpu)
251 | model = model.cuda(args.gpu)
252 | else:
253 | # DataParallel will divide and allocate batch_size to all available GPUs
254 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
255 | model.features = torch.nn.DataParallel(model.features)
256 | model.cuda()
257 | else:
258 | model = torch.nn.DataParallel(model).cuda()
259 |
260 | # define loss function (criterion) and optimizer
261 | criterion_sce = Soft_CrossEntropy()
262 | criterion_ce = nn.CrossEntropyLoss().cuda(args.gpu)
263 |
264 | optimizer = create_optimizer(args, model)
265 |
266 | lr_scheduler, _ = create_scheduler(args, optimizer)
267 |
268 | # optionally resume from a checkpoint
269 | if args.resume:
270 | if os.path.isfile(args.resume):
271 | print("=> loading checkpoint '{}'".format(args.resume))
272 | if args.gpu is None:
273 | checkpoint = torch.load(args.resume)
274 | else:
275 | # Map model to be loaded to specified single gpu.
276 | loc = 'cuda:{}'.format(args.gpu)
277 | checkpoint = torch.load(args.resume, map_location=loc)
278 | args.start_epoch = checkpoint['epoch']
279 | best_acc1 = checkpoint['best_acc1']
280 | if args.gpu is not None:
281 | # best_acc1 may be from a checkpoint from a different GPU
282 | best_acc1 = best_acc1.to(args.gpu)
283 | model.load_state_dict(checkpoint['state_dict'])
284 | optimizer.load_state_dict(checkpoint['optimizer'])
285 | print("=> loaded checkpoint '{}' (epoch {})"
286 | .format(args.resume, checkpoint['epoch']))
287 | else:
288 | print("=> no checkpoint found at '{}'".format(args.resume))
289 |
290 | cudnn.benchmark = True
291 |
292 | # Data loading code
293 | traindir = os.path.join(args.data, 'train')
294 | valdir = os.path.join(args.data, 'val')
295 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
296 | std=[0.229, 0.224, 0.225])
297 |
298 | train_dataset = ImageFolder_FKD(
299 | num_crops=args.num_crops,
300 | softlabel_path=args.softlabel_path,
301 | root=traindir,
302 | transform=Compose_FKD(transforms=[
303 | RandomResizedCrop_FKD(size=224,
304 | interpolation='bilinear'),
305 | RandomHorizontalFlip_FKD(),
306 | transforms.ToTensor(),
307 | normalize,
308 | ]))
309 | train_dataset_single_crop = ImageFolder_FKD(
310 | num_crops=1,
311 | softlabel_path=args.softlabel_path,
312 | root=traindir,
313 | transform=Compose_FKD(transforms=[
314 | RandomResizedCrop_FKD(size=224,
315 | interpolation='bilinear'),
316 | RandomHorizontalFlip_FKD(),
317 | transforms.ToTensor(),
318 | normalize,
319 | ]))
320 |
321 | if args.distributed:
322 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
323 | else:
324 | train_sampler = None
325 |
326 | train_loader = torch.utils.data.DataLoader(
327 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
328 | num_workers=args.workers, pin_memory=True, sampler=train_sampler)
329 |
330 | train_loader_single_crop = torch.utils.data.DataLoader(
331 | train_dataset_single_crop, batch_size=args.batch_size*args.num_crops, shuffle=(train_sampler is None),
332 | num_workers=args.workers, pin_memory=True, sampler=train_sampler)
333 |
334 | val_loader = torch.utils.data.DataLoader(
335 | datasets.ImageFolder(valdir, transforms.Compose([
336 | transforms.Resize(256),
337 | transforms.CenterCrop(224),
338 | transforms.ToTensor(),
339 | normalize,
340 | ])),
341 | batch_size=args.batch_size, shuffle=False,
342 | num_workers=args.workers, pin_memory=True)
343 |
344 | if args.evaluate:
345 | validate(val_loader, model, criterion_ce, args)
346 | return
347 |
348 | # warmup with single crop, "=" is used to let start_epoch to be 0 for the corner case.
349 | if args.start_epoch <= args.warmup_epochs:
350 | for epoch in range(args.start_epoch, args.warmup_epochs):
351 | if args.distributed:
352 | train_sampler.set_epoch(epoch)
353 | # train for one epoch
354 | train(train_loader_single_crop, model, criterion_sce, optimizer, epoch, args)
355 | lr_scheduler.step(epoch + 1)
356 |
357 | # evaluate on validation set
358 | acc1 = validate(val_loader, model, criterion_ce, args)
359 |
360 | # remember best acc@1 and save checkpoint
361 | is_best = acc1 > best_acc1
362 | best_acc1 = max(acc1, best_acc1)
363 |
364 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed
365 | and args.rank % ngpus_per_node == 0):
366 | save_checkpoint({
367 | 'epoch': epoch + 1,
368 | 'arch': args.arch,
369 | 'state_dict': model.state_dict(),
370 | 'best_acc1': best_acc1,
371 | 'optimizer' : optimizer.state_dict(),
372 | }, is_best, filename=args.save_checkpoint_path+'/checkpoint.pth.tar')
373 | args.start_epoch = 0 # for resume
374 | else:
375 | args.warmup_epochs = args.num_crops - 1 # for resume
376 |
377 | for epoch in range(args.start_epoch+args.warmup_epochs, args.epochs, args.num_crops):
378 | if args.distributed:
379 | train_sampler.set_epoch(epoch)
380 |
381 | # for fine-grained evaluation at last a few epochs
382 | if epoch >= (args.epochs-args.num_crops):
383 | start_epoch = epoch
384 | for epoch in range(start_epoch, args.epochs):
385 | if args.distributed:
386 | train_sampler.set_epoch(epoch)
387 | lr_scheduler.step(epoch+1)
388 | # train for one epoch
389 | train(train_loader_single_crop, model, criterion_sce, optimizer, epoch, args)
390 |
391 | # evaluate on validation set
392 | acc1 = validate(val_loader, model, criterion_ce, args)
393 |
394 | # remember best acc@1 and save checkpoint
395 | is_best = acc1 > best_acc1
396 | best_acc1 = max(acc1, best_acc1)
397 |
398 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed
399 | and args.rank % ngpus_per_node == 0):
400 | save_checkpoint({
401 | 'epoch': epoch + 1,
402 | 'arch': args.arch,
403 | 'state_dict': model.state_dict(),
404 | 'best_acc1': best_acc1,
405 | 'optimizer' : optimizer.state_dict(),
406 | }, is_best, filename=args.save_checkpoint_path+'/checkpoint.pth.tar')
407 | return
408 | else:
409 | # train for one epoch
410 | train(train_loader, model, criterion_sce, optimizer, epoch, args)
411 | lr_scheduler.step(epoch+args.num_crops)
412 |
413 | # evaluate on validation set
414 | acc1 = validate(val_loader, model, criterion_ce, args)
415 |
416 | # remember best acc@1 and save checkpoint
417 | is_best = acc1 > best_acc1
418 | best_acc1 = max(acc1, best_acc1)
419 |
420 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed
421 | and args.rank % ngpus_per_node == 0):
422 | save_checkpoint({
423 | 'epoch': epoch + 1,
424 | 'arch': args.arch,
425 | 'state_dict': model.state_dict(),
426 | 'best_acc1': best_acc1,
427 | 'optimizer' : optimizer.state_dict(),
428 | }, is_best, filename=args.save_checkpoint_path+'/checkpoint.pth.tar')
429 |
430 |
431 | def train(train_loader, model, criterion, optimizer, epoch, args):
432 | batch_time = AverageMeter('Time', ':6.3f')
433 | data_time = AverageMeter('Data', ':6.3f')
434 | losses = AverageMeter('Loss', ':.4e')
435 | top1 = AverageMeter('Acc@1', ':6.2f')
436 | top5 = AverageMeter('Acc@5', ':6.2f')
437 | progress = ProgressMeter(
438 | len(train_loader),
439 | [batch_time, data_time, losses, top1, top5, 'LR {lr:.5f}'.format(lr=_get_learning_rate(optimizer))],
440 | prefix="Epoch: [{}]".format(epoch))
441 |
442 | # switch to train mode
443 | model.train()
444 |
445 | end = time.time()
446 | for i, (images, target, soft_label) in enumerate(train_loader):
447 | # measure data loading time
448 | data_time.update(time.time() - end)
449 |
450 | images = torch.cat(images, dim=0)
451 | soft_label = torch.cat(soft_label, dim=0)
452 | target = torch.cat(target, dim=0)
453 |
454 | if args.soft_label_type != 'ori':
455 | soft_label = Recover_soft_label(soft_label, args.soft_label_type, args.num_classes)
456 |
457 | if args.gpu is not None:
458 | images = images.cuda(args.gpu, non_blocking=True)
459 | if torch.cuda.is_available():
460 | target = target.cuda(args.gpu, non_blocking=True)
461 | soft_label = soft_label.cuda(args.gpu, non_blocking=True)
462 |
463 | if args.mixup_cutmix:
464 | images, soft_label = mixup_cutmix(images, soft_label, args)
465 |
466 | # compute output
467 | output = model(images)
468 | loss = criterion(output / args.temp, soft_label)
469 |
470 | # measure accuracy and record loss
471 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
472 | losses.update(loss.item(), images.size(0))
473 | top1.update(acc1[0], images.size(0))
474 | top5.update(acc5[0], images.size(0))
475 |
476 | # compute gradient and update
477 | optimizer.zero_grad()
478 | loss.backward()
479 | optimizer.step()
480 |
481 | # measure elapsed time
482 | batch_time.update(time.time() - end)
483 | end = time.time()
484 |
485 | if i % args.print_freq == 0:
486 | t = time.localtime()
487 | current_time = time.strftime("%H:%M:%S", t)
488 | print(current_time)
489 | progress.display(i)
490 |
491 |
492 | def validate(val_loader, model, criterion, args):
493 | batch_time = AverageMeter('Time', ':6.3f')
494 | losses = AverageMeter('Loss', ':.4e')
495 | top1 = AverageMeter('Acc@1', ':6.2f')
496 | top5 = AverageMeter('Acc@5', ':6.2f')
497 | progress = ProgressMeter(
498 | len(val_loader),
499 | [batch_time, losses, top1, top5],
500 | prefix='Test: ')
501 |
502 | # switch to evaluate mode
503 | model.eval()
504 |
505 | with torch.no_grad():
506 | end = time.time()
507 | for i, (images, target) in enumerate(val_loader):
508 | if args.gpu is not None:
509 | images = images.cuda(args.gpu, non_blocking=True)
510 | if torch.cuda.is_available():
511 | target = target.cuda(args.gpu, non_blocking=True)
512 |
513 | # compute output
514 | output = model(images)
515 | loss = criterion(output, target)
516 |
517 | # measure accuracy and record loss
518 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
519 | losses.update(loss.item(), images.size(0))
520 | top1.update(acc1[0], images.size(0))
521 | top5.update(acc5[0], images.size(0))
522 |
523 | # measure elapsed time
524 | batch_time.update(time.time() - end)
525 | end = time.time()
526 |
527 | if i % args.print_freq == 0:
528 | progress.display(i)
529 |
530 | # TODO: this should also be done with the ProgressMeter
531 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
532 | .format(top1=top1, top5=top5))
533 |
534 | return top1.avg
535 |
536 |
537 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
538 | torch.save(state, filename)
539 | if is_best:
540 | shutil.copyfile(filename, filename[:-19]+'/model_best.pth.tar')
541 |
542 |
543 | class AverageMeter(object):
544 | """Computes and stores the average and current value"""
545 | def __init__(self, name, fmt=':f'):
546 | self.name = name
547 | self.fmt = fmt
548 | self.reset()
549 |
550 | def reset(self):
551 | self.val = 0
552 | self.avg = 0
553 | self.sum = 0
554 | self.count = 0
555 |
556 | def update(self, val, n=1):
557 | self.val = val
558 | self.sum += val * n
559 | self.count += n
560 | self.avg = self.sum / self.count
561 |
562 | def __str__(self):
563 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
564 | return fmtstr.format(**self.__dict__)
565 |
566 |
567 | class ProgressMeter(object):
568 | def __init__(self, num_batches, meters, prefix=""):
569 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
570 | self.meters = meters
571 | self.prefix = prefix
572 |
573 | def display(self, batch):
574 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
575 | entries += [str(meter) for meter in self.meters]
576 | print('\t'.join(entries))
577 |
578 | def _get_batch_fmtstr(self, num_batches):
579 | num_digits = len(str(num_batches // 1))
580 | fmt = '{:' + str(num_digits) + 'd}'
581 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
582 |
583 |
584 | def adjust_learning_rate(optimizer, epoch, args):
585 | """Decay the learning rate based on schedule"""
586 | lr = args.lr
587 | if args.cos: # cosine lr schedule
588 | lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs))
589 | else: # stepwise lr schedule
590 | for milestone in args.schedule:
591 | lr *= 0.1 if epoch >= milestone else 1.
592 | for param_group in optimizer.param_groups:
593 | param_group['lr'] = lr
594 |
595 | def _get_learning_rate(optimizer):
596 | return max(param_group['lr'] for param_group in optimizer.param_groups)
597 |
598 | def accuracy(output, target, topk=(1,)):
599 | """Computes the accuracy over the k top predictions for the specified values of k"""
600 | with torch.no_grad():
601 | maxk = max(topk)
602 | batch_size = target.size(0)
603 |
604 | _, pred = output.topk(maxk, 1, True, True)
605 | pred = pred.t()
606 | correct = pred.eq(target.view(1, -1).expand_as(pred))
607 |
608 | res = []
609 | for k in topk:
610 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
611 | res.append(correct_k.mul_(100.0 / batch_size))
612 | return res
613 |
614 |
615 | if __name__ == '__main__':
616 | main()
--------------------------------------------------------------------------------