├── FKD ├── FKD.png ├── FKD_SSL │ └── README.md ├── LICENSE ├── FKD_ViT │ ├── README.md │ ├── utils_FKD.py │ ├── SReT.py │ └── train_ViT_FKD.py ├── FKD_SLG │ ├── README.md │ ├── utils.py │ └── generate_soft_label.py ├── utils_FKD.py ├── README.md └── train_FKD.py ├── FerKD ├── assets │ ├── FerKD.png │ └── converge.png ├── LICENSE ├── soft_label_zoo │ └── README.md └── README.md └── README.md /FKD/FKD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/szq0214/FKD/HEAD/FKD/FKD.png -------------------------------------------------------------------------------- /FerKD/assets/FerKD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/szq0214/FKD/HEAD/FerKD/assets/FerKD.png -------------------------------------------------------------------------------- /FerKD/assets/converge.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/szq0214/FKD/HEAD/FerKD/assets/converge.png -------------------------------------------------------------------------------- /FKD/FKD_SSL/README.md: -------------------------------------------------------------------------------- 1 | ## Self-supervised Representation Learning Using FKD 2 | 3 | 4 | 5 | ### Preparation 6 | 7 | - Install PyTorch and ImageNet dataset following the [official PyTorch ImageNet training code](https://github.com/pytorch/examples/tree/master/imagenet). 8 | 9 | - Download our [soft label](http://zhiqiangshen.com/projects/FKD/index.html) for SSL. 10 | 11 | 12 | ### FKD Training on ReActNet & ResNet-50 13 | 14 | The training code is similar to the supervised scheme, will be available soon. 15 | 16 | 17 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Evolving Knowledge Distillation: The Role of Pre-generated Soft Labels 2 | 3 | ***This is a collection of our works targeted at utilizing pre-generated soft labels for stronger, faster and more efficient knowledge distillation.*** 4 | 5 | > [**FerKD**](./FerKD) (```@ICCV'23```): **FerKD: Surgical Label Adaptation for Efficient Distillation** 6 | 7 |
8 | 9 |
10 | 11 | 12 | 13 | 14 | > [**FKD**](./FKD) (```@ECCV'22```): **A Fast Knowledge Distillation Framework for Visual Recognition** 15 | 16 |
17 | 18 |
19 | 20 | 21 | ## Bibtex 22 | ```bibtex 23 | @inproceedings{shen2023ferkd, 24 | title={FerKD: Surgical Label Adaptation for Efficient Distillation}, 25 | author={Zhiqiang Shen}, 26 | booktitle={ICCV}, 27 | year={2023} 28 | } 29 | 30 | @inproceedings{shen2021afast, 31 | title={A Fast Knowledge Distillation Framework for Visual Recognition}, 32 | author={Zhiqiang Shen and Eric Xing}, 33 | booktitle={ECCV}, 34 | year={2022} 35 | } 36 | -------------------------------------------------------------------------------- /FKD/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Zhiqiang Shen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /FerKD/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Zhiqiang Shen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /FerKD/soft_label_zoo/README.md: -------------------------------------------------------------------------------- 1 | # FerKD Soft Label Zoo and Baselines 2 | 3 | ## Introduction 4 | 5 | This file documents a large collection of soft labels and corresponding baselines trained with these soft labels. 6 | 7 | 8 | 9 | | Teacher | Top-1 | Student | Top-1 | Student | Top-1 | Soft Label | 10 | |:-------:|:--------:|:--------:|:--------:|:--------:|:--------:|:--------:| 11 | | `Effi-L2-475` | 88.14 | ResNet-50 | 80.23 | ViT-S/16 | 81.16 | [Download]() | 12 | | `Effi-L2-800 ` | 88.39 | ResNet-50 | 80.16 | ViT-S/16 | 81.30 | [Download]() | 13 | | `RegY-128GF-384` | 88.24 | ResNet-50 | 80.34 | ViT-S/16 | 81.42 | [Download]() | 14 | | `ViT-L16-512` | 88.07 | ResNet-50 | 80.29 | ViT-S/16 | 81.43 | [Download]() | 15 | | `ViT-H14-518` | 88.55 | ResNet-50 | 80.18 | ViT-S/16 | 81.41 | [Download]() | 16 | | `BEIT-L-224` | 87.52 | ResNet-50 | 80.03 | ViT-S/16 | 81.16 | [Download]() | 17 | | `BEIT-L-384` | 88.40 | ResNet-50 | 80.06 | ViT-S/16 | 81.11 | [Download]() | 18 | | `BEIT-L-512` | 88.60 | ResNet-50 | 80.09 | ViT-S/16 | 81.07 | [Download]() | 19 | | `ViT-G14-336-30M` | 89.59 | ResNet-50 | 79.03 | ViT-S/16 | 79.62 | [Download]() | 20 | | `ViT-G14-336-CLIP ` | 89.38 | ResNet-50 | 79.59 | ViT-S/16 | 79.48 | [Download]() | 21 | 22 | -------------------------------------------------------------------------------- /FKD/FKD_ViT/README.md: -------------------------------------------------------------------------------- 1 | 2 | ## Preparation 3 | 4 | - Install PyTorch and ImageNet dataset following the [official PyTorch ImageNet training code](https://github.com/pytorch/examples/tree/master/imagenet). This repo has minimal modifications on that code. 5 | 6 | - Download our soft label and unzip it. We provide multiple types of [soft labels](http://zhiqiangshen.com/projects/FKD/index.html), and we recommend to use [Marginal Smoothing Top-5 (500-crop)](https://drive.google.com/file/d/14leI6xGfnyxHPsBxo0PpCmOq71gWt008/view?usp=sharing). 7 | 8 | 9 | ## FKD Training on ViT/DeiT and SReT 10 | 11 | To train a ViT model, run `train_ViT_FKD.py` with the desired model architecture and the path to the soft label and ImageNet dataset: 12 | 13 | ``` 14 | python train_ViT_FKD.py \ 15 | --dist-url 'tcp://127.0.0.1:10001' \ 16 | --dist-backend 'nccl' \ 17 | --multiprocessing-distributed --world-size 1 --rank 0 \ 18 | -a SReT_LT --lr 0.002 --wd 0.05 \ 19 | --num_crops 4 -b 1024 --cos \ 20 | --temp 1.0 --mixup_cutmix \ 21 | --softlabel_path [soft label path, e.g., ./FKD_soft_label_500_crops_marginal_smoothing_k_5/imagenet] \ 22 | [imagenet-folder with train and val folders] 23 | ``` 24 | 25 | Add `--mixup_cutmix` to enable Mixup and Cutmix augmentations. For the instructions of `SReT_LT` model, please refer to [SReT](https://github.com/szq0214/SReT) for details. 26 | 27 | ## Evaluation 28 | 29 | ``` 30 | python train_ViT_FKD.py -a SReT_LT -e --resume [model path] [imagenet-folder with train and val folders] 31 | ``` 32 | 33 | ### Trained Models 34 | 35 | | Model | FLOPs| #params | accuracy (Top-1) |weights |configurations | 36 | |:-------:|:--------:|:--------:|:--------:|:--------:|:--------:| 37 | | [`DeiT-T-distill`](https://github.com/facebookresearch/deit) | 1.3B | 5.7M | 74.5 |-- | -- | 38 | | `FKD ViT/DeiT-T` | 1.3B | 5.7M | **75.2** |[link](https://drive.google.com/file/d/1m33c1wHdCV7ePETO_HvWNaboSd_W4nfC/view?usp=sharing) | [Table 13 of paper](http://zhiqiangshen.com/projects/FKD/FKD_camera-ready.pdf) | 39 | | [`SReT-LT-distill`](https://github.com/szq0214/SReT) | 1.2B | 5.0M | 77.7 |-- | -- | 40 | | `FKD SReT-LT` | 1.2B | 5.0M | **78.7** |[link](https://drive.google.com/file/d/1mmdPXKutHM9Li8xo5nGG6TB0aAXA9PFR/view?usp=sharing) | [Table 13 of paper](http://zhiqiangshen.com/projects/FKD/FKD_camera-ready.pdf) | 41 | 42 | 43 | -------------------------------------------------------------------------------- /FerKD/README.md: -------------------------------------------------------------------------------- 1 | ## 🚀🚀 FerKD: Surgical Label Adaptation for Efficient Distillation 2 | 3 | [**FerKD: Surgical Label Adaptation for Efficient Distillation**](https://openaccess.thecvf.com/content/ICCV2023/papers/Shen_FerKD_Surgical_Label_Adaptation_for_Efficient_Distillation_ICCV_2023_paper.pdf), Zhiqiang Shen, ICCV 2023. 4 | 5 |
6 | 7 |
8 | 9 | ### Abstract 10 | 11 | **🚀🚀 FerKD (Faster Knowledge Distillation)** is a novel efficient knowledge distillation framework that incorporates *partial soft-hard label adaptation coupled with a region-calibration mechanism*. Our approach stems from the observation and intuition that standard data augmentations, such as RandomResizedCrop, tend to transform inputs into diverse conditions: easy positives, hard positives, or hard negatives. In traditional distillation frameworks, these transformed samples are utilized equally through their predictive probabilities derived from pretrained teacher models. However, merely relying on prediction values from a pretrained teacher neglects the reliability of these soft label predictions. To address this, we propose a new scheme that calibrates the less-confident regions to be the context using softened hard groundtruth labels. The proposed approach involves the processes of *hard regions mining* + *calibration*. 12 | 13 | ## Citation 14 | 15 | @inproceedings{shen2023ferkd, 16 | title={FerKD: Surgical Label Adaptation for Efficient Distillation}, 17 | author={Zhiqiang Shen}, 18 | year={2023}, 19 | booktitle={ICCV} 20 | } 21 | 22 | 23 | 24 | ## Soft Label Zoo 25 | 26 | Please check the [soft labels](./soft_label_zoo) generated from different giant teacher models. 27 | 28 | ## Fast Convergence of FerKD 29 | 30 | 31 |
32 | 33 |
34 | 35 | ## Training 36 | 37 | FerKD follows [FKD](https://github.com/szq0214/FKD/tree/main/FKD) training code and procedure while using different preprocessed soft labels, please download the soft label for FerKD at [link](./soft_label_zoo). 38 | 39 | ## Trained Models 40 | 41 | | Method | Network | accuracy (Top-1) |weights | 42 | |:-------:|:--------:|:--------:|:--------:| 43 | | `FerKD` | ResNet-50 | 81.2 | [Download](https://drive.google.com/file/d/1wrs2-v8Dg8ghaJBDEmYLuGnk4mRbJek_/view?usp=sharing) | 44 | |   `FerKD*` | ResNet-50 | 81.4 | [Download](https://drive.google.com/file/d/1HW9scG0OlKVa64C-sNxhmmsB0ZPTruIp/view?usp=sharing) | 45 | 46 | ## Acknowledgements 47 | 48 | We thank the [High-Flyer AI](https://www.high-flyer.cn/en/) for providing the deep learning platform and computational resources for this work. We'd like to especially thank Yanhong Xu, Xiaowen Sun, Wenjie Wu, and Le Su from High-Flyer AI for helping us organize computing resources. 49 | 50 | ## Contact 51 | 52 | Zhiqiang Shen (zhiqiangshen0214 at gmail.com) 53 | 54 | -------------------------------------------------------------------------------- /FKD/FKD_SLG/README.md: -------------------------------------------------------------------------------- 1 | 2 | ## Soft Label Generation (SLG) 3 | 4 | ### Preparation 5 | 6 | - Install PyTorch and ImageNet dataset following the [official PyTorch ImageNet training code](https://github.com/pytorch/examples/tree/master/imagenet). 7 | - Install `timm` using: 8 | 9 | ``` 10 | pip install git+https://github.com/rwightman/pytorch-image-models.git 11 | ``` 12 | 13 | 14 | ### Generating Soft Labels from Supervised Teachers 15 | 16 | FKD flags: 17 | 18 | - `--num_crops `: number of crops in each image to generate soft labels. Default: 500. 19 | - `--num_seg `: true number of batch-size on GPUs during generating. Make sure `--num_crops` is divisible by `--num_seg`. Default: 50. 20 | - `--label_type `: type of generated soft labels. Default: `marginal_smoothing_k5`. 21 | - `--use_fp16`: save soft labels as `fp16` to decrease storage. Default: `False`. 22 | - `--temp`: temperature on soft label. Default: `1.0`. 23 | 24 | Path flags: 25 | 26 | - `--save_path `: specify the folder to save soft labels. 27 | - `--reference_path `: specify the path to existing soft labels as the reference of crop locations. This is used for soft label ensemble in [FKD MEAL V2](https://github.com/szq0214/MEAL-V2). 28 | - [imagenet-folder with train and val folders]: ImageNet data folder. 29 | 30 | Model flags: 31 | 32 | - `--arch `: specify which model to use as the teacher network. 33 | - `--input_size `: input size of teacher network. 34 | - `--teacher_source `: source of teachers. Currently, it supports models from (1) `pytorch`; (2) `timm`; and (3) private pre-trained models. 35 | 36 | Some important notes: 37 | 38 | - Modify `normalize` values according to the training settings of teacher networks. 39 | 40 | ``` 41 | # EfficientNet_V2_L, BEIT, etc. 42 | normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], 43 | std=[0.5, 0.5, 0.5]) 44 | # ResNet, efficientnet_l2_ns_475, etc. 45 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 46 | std=[0.229, 0.224, 0.225]) 47 | ``` 48 | 49 | - Modify `--min_scale_crops` and `--max_scale_crops` according to the training setting of teacher networks. For example, `tf_efficientnet_l2_ns_475` in `timm` has the scale=(0.08, 0.936). 50 | - `--num_seg` is the true number of batch-size, thus `-b` can be set to a relatively small value. It will not slow down the training. 51 | - `Resume` is supported by simply restart. You can also launch multiple experiments parallelly to speed up generating, it will automatically skip the existing files. 52 | 53 | **Important:** Test your teacher models using `--evaluate` to check whether the accuracy is correct before starting to generate soft labels. 54 | 55 | An example of the command line for generating soft labels from `tf_efficientnet_l2_ns_475`: 56 | 57 | ``` 58 | python generate_soft_label.py \ 59 | -a tf_efficientnet_l2_ns_475 \ 60 | --input_size 475 \ 61 | --min_scale_crops 0.08 \ 62 | --max_scale_crops 0.936 \ 63 | --num_crops 500 \ 64 | --num_seg 50 \ 65 | -b 4 \ 66 | --temp 1.0 \ 67 | --label_type marginal_smoothing_k5 \ 68 | --save_path FKD_efficientnet_l2_ns_475_marginal_smoothing_k5 \ 69 | [imagenet-folder with train and val folders] 70 | ``` 71 | 72 | Soft label generation from self-supervised teachers will be available soon. 73 | 74 | -------------------------------------------------------------------------------- /FKD/FKD_SLG/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.distributed 5 | import torch.nn as nn 6 | import torchvision 7 | from torchvision.transforms import functional as t_F 8 | from torchvision.datasets.folder import ImageFolder 9 | 10 | 11 | class RandomResizedCropWithCoords(torchvision.transforms.RandomResizedCrop): 12 | def __init__(self, **kwargs): 13 | super(RandomResizedCropWithCoords, self).__init__(**kwargs) 14 | 15 | def __call__(self, img, coords): 16 | try: 17 | reference = (coords.any()) 18 | except: 19 | reference = False 20 | if not reference: 21 | i, j, h, w = self.get_params(img, self.scale, self.ratio) 22 | coords = (i / img.size[1], 23 | j / img.size[0], 24 | h / img.size[1], 25 | w / img.size[0]) 26 | coords = torch.FloatTensor(coords) 27 | else: 28 | i = coords[0].item() * img.size[1] 29 | j = coords[1].item() * img.size[0] 30 | h = coords[2].item() * img.size[1] 31 | w = coords[3].item() * img.size[0] 32 | return t_F.resized_crop(img, i, j, h, w, self.size, 33 | self.interpolation), coords 34 | 35 | 36 | class ComposeWithCoords(torchvision.transforms.Compose): 37 | def __init__(self, **kwargs): 38 | super(ComposeWithCoords, self).__init__(**kwargs) 39 | 40 | def __call__(self, img, coords, status): 41 | for t in self.transforms: 42 | if type(t).__name__ == 'RandomResizedCropWithCoords': 43 | img, coords = t(img, coords) 44 | elif type(t).__name__ == 'RandomCropWithCoords': 45 | img, coords = t(img, coords) 46 | elif type(t).__name__ == 'RandomHorizontalFlipWithRes': 47 | img, status = t(img, status) 48 | else: 49 | img = t(img) 50 | return img, status, coords 51 | 52 | 53 | class RandomHorizontalFlipWithRes(torch.nn.Module): 54 | """Horizontally flip the given image randomly with a given probability. 55 | If the image is torch Tensor, it is expected 56 | to have [..., H, W] shape, where ... means an arbitrary number of leading 57 | dimensions 58 | 59 | Args: 60 | p (float): probability of the image being flipped. Default value is 0.5 61 | """ 62 | 63 | def __init__(self, p=0.5): 64 | super().__init__() 65 | self.p = p 66 | 67 | def forward(self, img, status): 68 | """ 69 | Args: 70 | img (PIL Image or Tensor): Image to be flipped. 71 | 72 | Returns: 73 | PIL Image or Tensor: Randomly flipped image. 74 | """ 75 | 76 | if status is not None: 77 | if status == True: 78 | return t_F.hflip(img), status 79 | else: 80 | return img, status 81 | else: 82 | status = False 83 | if torch.rand(1) < self.p: 84 | status = True 85 | return t_F.hflip(img), status 86 | return img, status 87 | 88 | 89 | def __repr__(self): 90 | return self.__class__.__name__ + '(p={})'.format(self.p) 91 | 92 | 93 | class ImageFolder_FKD_GSL(torchvision.datasets.ImageFolder): 94 | def __init__(self, **kwargs): 95 | self.num_crops = kwargs['num_crops'] 96 | self.save_path = kwargs['save_path'] 97 | self.reference_path = kwargs['reference_path'] 98 | kwargs.pop('num_crops') 99 | kwargs.pop('save_path') 100 | kwargs.pop('reference_path') 101 | super(ImageFolder_FKD_GSL, self).__init__(**kwargs) 102 | 103 | def __getitem__(self, index): 104 | path, target = self.samples[index] 105 | 106 | if self.reference_path is not None: 107 | ref_path = os.path.join(self.reference_path,'/'.join(path.split('/')[-4:-1])) 108 | ref_filename = os.path.join(ref_path,'/'.join(path.split('/')[-1:]).split('.')[0] + '.tar') 109 | label = torch.load(ref_filename, map_location=torch.device('cpu')) 110 | coords_ref, flip_ref, _ = label 111 | else: 112 | coords_ref = None 113 | 114 | sample = self.loader(path) 115 | sample_all = [] 116 | flip_status_all = [] 117 | coords_all = [] 118 | for i in range(self.num_crops): 119 | if self.transform is not None: 120 | if coords_ref is not None: 121 | coords_ = coords_ref[i] 122 | flip_ = flip_ref[i] 123 | else: 124 | coords_ = None 125 | flip_ = None 126 | sample_new, flip_status, coords_single = self.transform(sample, coords_, flip_) 127 | sample_all.append(sample_new) 128 | flip_status_all.append(flip_status) 129 | coords_all.append(coords_single) 130 | else: 131 | coords = None 132 | flip_status = None 133 | if self.target_transform is not None: 134 | target = self.target_transform(target) 135 | 136 | return sample_all, target, flip_status_all, coords_all, path -------------------------------------------------------------------------------- /FKD/utils_FKD.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed 4 | import torch.nn as nn 5 | import torchvision 6 | from torchvision.transforms import functional as t_F 7 | from torch.nn import functional as F 8 | from torchvision.datasets.folder import ImageFolder 9 | from torch.nn.modules import loss 10 | from torchvision.transforms import InterpolationMode 11 | import random 12 | import numpy as np 13 | 14 | 15 | class Soft_CrossEntropy(loss._Loss): 16 | def forward(self, model_output, soft_output): 17 | 18 | size_average = True 19 | 20 | model_output_log_prob = F.log_softmax(model_output, dim=1) 21 | 22 | soft_output = soft_output.unsqueeze(1) 23 | model_output_log_prob = model_output_log_prob.unsqueeze(2) 24 | 25 | cross_entropy_loss = -torch.bmm(soft_output, model_output_log_prob) 26 | if size_average: 27 | cross_entropy_loss = cross_entropy_loss.mean() 28 | else: 29 | cross_entropy_loss = cross_entropy_loss.sum() 30 | 31 | return cross_entropy_loss 32 | 33 | 34 | class RandomResizedCrop_FKD(torchvision.transforms.RandomResizedCrop): 35 | def __init__(self, **kwargs): 36 | super(RandomResizedCrop_FKD, self).__init__(**kwargs) 37 | 38 | def __call__(self, img, coords, status): 39 | i = coords[0].item() * img.size[1] 40 | j = coords[1].item() * img.size[0] 41 | h = coords[2].item() * img.size[1] 42 | w = coords[3].item() * img.size[0] 43 | 44 | if self.interpolation == 'bilinear': 45 | inter = InterpolationMode.BILINEAR 46 | elif self.interpolation == 'bicubic': 47 | inter = InterpolationMode.BICUBIC 48 | return t_F.resized_crop(img, i, j, h, w, self.size, inter) 49 | 50 | 51 | class RandomHorizontalFlip_FKD(torch.nn.Module): 52 | def __init__(self, p=0.5): 53 | super().__init__() 54 | self.p = p 55 | 56 | def forward(self, img, coords, status): 57 | 58 | if status == True: 59 | return t_F.hflip(img) 60 | else: 61 | return img 62 | 63 | def __repr__(self): 64 | return self.__class__.__name__ + '(p={})'.format(self.p) 65 | 66 | 67 | class Compose_FKD(torchvision.transforms.Compose): 68 | def __init__(self, **kwargs): 69 | super(Compose_FKD, self).__init__(**kwargs) 70 | 71 | def __call__(self, img, coords, status): 72 | for t in self.transforms: 73 | if type(t).__name__ == 'RandomResizedCrop_FKD': 74 | img = t(img, coords, status) 75 | elif type(t).__name__ == 'RandomCrop_FKD': 76 | img, coords = t(img) 77 | elif type(t).__name__ == 'RandomHorizontalFlip_FKD': 78 | img = t(img, coords, status) 79 | else: 80 | img = t(img) 81 | return img 82 | 83 | 84 | class ImageFolder_FKD(torchvision.datasets.ImageFolder): 85 | def __init__(self, **kwargs): 86 | self.num_crops = kwargs['num_crops'] 87 | self.softlabel_path = kwargs['softlabel_path'] 88 | kwargs.pop('num_crops') 89 | kwargs.pop('softlabel_path') 90 | super(ImageFolder_FKD, self).__init__(**kwargs) 91 | 92 | def __getitem__(self, index): 93 | path, target = self.samples[index] 94 | 95 | label_path = os.path.join(self.softlabel_path, '/'.join(path.split('/')[-3:]).split('.')[0] + '.tar') 96 | 97 | label = torch.load(label_path, map_location=torch.device('cpu')) 98 | 99 | coords, flip_status, output = label 100 | 101 | rand_index = torch.randperm(len(output)) 102 | soft_target = [] 103 | 104 | sample = self.loader(path) 105 | sample_all = [] 106 | hard_target = [] 107 | 108 | for i in range(self.num_crops): 109 | if self.transform is not None: 110 | soft_target.append(output[rand_index[i]]) 111 | sample_trans = self.transform(sample, coords[rand_index[i]], flip_status[rand_index[i]]) 112 | sample_all.append(sample_trans) 113 | hard_target.append(target) 114 | else: 115 | coords = None 116 | flip_status = None 117 | if self.target_transform is not None: 118 | target = self.target_transform(target) 119 | 120 | return sample_all, hard_target, soft_target 121 | 122 | 123 | def Recover_soft_label(label, label_type, n_classes): 124 | # recover quantized soft label to n_classes dimension. 125 | if label_type == 'hard': 126 | 127 | return torch.zeros(label.size(0), n_classes).scatter_(1, label.view(-1, 1), 1) 128 | 129 | elif label_type == 'smoothing': 130 | index = label[:,0].to(dtype=int) 131 | value = label[:,1] 132 | minor_value = (torch.ones_like(value) - value)/(n_classes-1) 133 | minor_value = minor_value.reshape(-1,1).repeat_interleave(n_classes, dim=1) 134 | soft_label = (minor_value * torch.ones(index.size(0), n_classes)).scatter_(1, index.view(-1, 1), value.view(-1, 1)) 135 | 136 | return soft_label 137 | 138 | elif label_type == 'marginal_smoothing_k5': 139 | index = label[:,0,:].to(dtype=int) 140 | value = label[:,1,:] 141 | minor_value = (torch.ones(label.size(0),1) - torch.sum(value, dim=1, keepdim=True))/(n_classes-5) 142 | minor_value = minor_value.reshape(-1,1).repeat_interleave(n_classes, dim=1) 143 | soft_label = (minor_value * torch.ones(index.size(0), n_classes)).scatter_(1, index, value) 144 | 145 | return soft_label 146 | 147 | elif label_type == 'marginal_renorm': 148 | index = label[:,0,:].to(dtype=int) 149 | value = label[:,1,:] 150 | soft_label = torch.zeros(index.size(0), n_classes).scatter_(1, index, value) 151 | soft_label = F.normalize(soft_label, p=1.0, dim=1, eps=1e-12) 152 | 153 | return soft_label 154 | 155 | elif label_type == 'marginal_smoothing_k10': 156 | index = label[:,0,:].to(dtype=int) 157 | value = label[:,1,:] 158 | minor_value = (torch.ones(label.size(0),1) - torch.sum(value, dim=1, keepdim=True))/(n_classes-10) 159 | minor_value = minor_value.reshape(-1,1).repeat_interleave(n_classes, dim=1) 160 | soft_label = (minor_value * torch.ones(index.size(0), n_classes)).scatter_(1, index, value) 161 | 162 | return soft_label 163 | 164 | 165 | def rand_bbox(size, lam): 166 | W = size[2] 167 | H = size[3] 168 | cut_rat = np.sqrt(1. - lam) 169 | cut_w = np.int(W * cut_rat) 170 | cut_h = np.int(H * cut_rat) 171 | 172 | # uniform 173 | cx = np.random.randint(W) 174 | cy = np.random.randint(H) 175 | 176 | bbx1 = np.clip(cx - cut_w // 2, 0, W) 177 | bby1 = np.clip(cy - cut_h // 2, 0, H) 178 | bbx2 = np.clip(cx + cut_w // 2, 0, W) 179 | bby2 = np.clip(cy + cut_h // 2, 0, H) 180 | 181 | return bbx1, bby1, bbx2, bby2 182 | 183 | 184 | def mixup_cutmix(images, soft_label, args): 185 | enable_p = np.random.rand(1) 186 | if enable_p < args.mixup_cutmix_prob: 187 | switch_p = np.random.rand(1) 188 | if switch_p < args.mixup_switch_prob: 189 | lam = np.random.beta(args.mixup, args.mixup) 190 | rand_index = torch.randperm(images.size()[0]).cuda() 191 | target_a = soft_label 192 | target_b = soft_label[rand_index] 193 | mixed_x = lam * images + (1 - lam) * images[rand_index] 194 | target_mix = target_a * lam + target_b * (1 - lam) 195 | return mixed_x, target_mix 196 | else: 197 | lam = np.random.beta(args.cutmix, args.cutmix) 198 | rand_index = torch.randperm(images.size()[0]).cuda() 199 | target_a = soft_label 200 | target_b = soft_label[rand_index] 201 | bbx1, bby1, bbx2, bby2 = rand_bbox(images.size(), lam) 202 | images[:, :, bbx1:bbx2, bby1:bby2] = images[rand_index, :, bbx1:bbx2, bby1:bby2] 203 | # adjust lambda to exactly match pixel ratio 204 | lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (images.size()[-1] * images.size()[-2])) 205 | target_mix = target_a * lam + target_b * (1 - lam) 206 | else: 207 | target_mix = soft_label 208 | 209 | return images, target_mix 210 | -------------------------------------------------------------------------------- /FKD/FKD_ViT/utils_FKD.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed 4 | import torch.nn as nn 5 | import torchvision 6 | from torchvision.transforms import functional as t_F 7 | from torch.nn import functional as F 8 | from torchvision.datasets.folder import ImageFolder 9 | from torch.nn.modules import loss 10 | from torchvision.transforms import InterpolationMode 11 | import random 12 | import numpy as np 13 | 14 | 15 | class Soft_CrossEntropy(loss._Loss): 16 | def forward(self, model_output, soft_output): 17 | 18 | size_average = True 19 | 20 | model_output_log_prob = F.log_softmax(model_output, dim=1) 21 | 22 | soft_output = soft_output.unsqueeze(1) 23 | model_output_log_prob = model_output_log_prob.unsqueeze(2) 24 | 25 | cross_entropy_loss = -torch.bmm(soft_output, model_output_log_prob) 26 | if size_average: 27 | cross_entropy_loss = cross_entropy_loss.mean() 28 | else: 29 | cross_entropy_loss = cross_entropy_loss.sum() 30 | 31 | return cross_entropy_loss 32 | 33 | 34 | class RandomResizedCrop_FKD(torchvision.transforms.RandomResizedCrop): 35 | def __init__(self, **kwargs): 36 | super(RandomResizedCrop_FKD, self).__init__(**kwargs) 37 | 38 | def __call__(self, img, coords, status): 39 | i = coords[0].item() * img.size[1] 40 | j = coords[1].item() * img.size[0] 41 | h = coords[2].item() * img.size[1] 42 | w = coords[3].item() * img.size[0] 43 | 44 | if self.interpolation == 'bilinear': 45 | inter = InterpolationMode.BILINEAR 46 | elif self.interpolation == 'bicubic': 47 | inter = InterpolationMode.BICUBIC 48 | return t_F.resized_crop(img, i, j, h, w, self.size, inter) 49 | 50 | 51 | class RandomHorizontalFlip_FKD(torch.nn.Module): 52 | def __init__(self, p=0.5): 53 | super().__init__() 54 | self.p = p 55 | 56 | def forward(self, img, coords, status): 57 | if status == True: 58 | return t_F.hflip(img) 59 | else: 60 | return img 61 | 62 | def __repr__(self): 63 | return self.__class__.__name__ + '(p={})'.format(self.p) 64 | 65 | 66 | class Compose_FKD(torchvision.transforms.Compose): 67 | def __init__(self, **kwargs): 68 | super(Compose_FKD, self).__init__(**kwargs) 69 | 70 | def __call__(self, img, coords, status): 71 | for t in self.transforms: 72 | if type(t).__name__ == 'RandomResizedCrop_FKD': 73 | img = t(img, coords, status) 74 | elif type(t).__name__ == 'RandomCrop_FKD': 75 | img, coords = t(img) 76 | elif type(t).__name__ == 'RandomHorizontalFlip_FKD': 77 | img = t(img, coords, status) 78 | else: 79 | img = t(img) 80 | return img 81 | 82 | 83 | class ImageFolder_FKD(torchvision.datasets.ImageFolder): 84 | def __init__(self, **kwargs): 85 | self.num_crops = kwargs['num_crops'] 86 | self.softlabel_path = kwargs['softlabel_path'] 87 | kwargs.pop('num_crops') 88 | kwargs.pop('softlabel_path') 89 | super(ImageFolder_FKD, self).__init__(**kwargs) 90 | 91 | def __getitem__(self, index): 92 | path, target = self.samples[index] 93 | 94 | label_path = os.path.join(self.softlabel_path, '/'.join(path.split('/')[-3:]).split('.')[0] + '.tar') 95 | 96 | label = torch.load(label_path, map_location=torch.device('cpu')) 97 | 98 | coords, flip_status, output = label 99 | 100 | rand_index = torch.randperm(len(output)) 101 | soft_target = [] 102 | 103 | sample = self.loader(path) 104 | sample_all = [] 105 | hard_target = [] 106 | 107 | for i in range(self.num_crops): 108 | if self.transform is not None: 109 | soft_target.append(output[rand_index[i]]) 110 | sample_trans = self.transform(sample, coords[rand_index[i]], flip_status[rand_index[i]]) 111 | sample_all.append(sample_trans) 112 | hard_target.append(target) 113 | else: 114 | coords = None 115 | flip_status = None 116 | if self.target_transform is not None: 117 | target = self.target_transform(target) 118 | 119 | return sample_all, hard_target, soft_target 120 | 121 | 122 | def Recover_soft_label(label, label_type, n_classes): 123 | # recover quantized soft label to n_classes dimension. 124 | if label_type == 'hard': 125 | 126 | return torch.zeros(label.size(0), n_classes).scatter_(1, label.view(-1, 1), 1) 127 | 128 | elif label_type == 'smoothing': 129 | index = label[:,0].to(dtype=int) 130 | value = label[:,1] 131 | minor_value = (torch.ones_like(value) - value)/(n_classes-1) 132 | minor_value = minor_value.reshape(-1,1).repeat_interleave(n_classes, dim=1) 133 | soft_label = (minor_value * torch.ones(index.size(0), n_classes)).scatter_(1, index.view(-1, 1), value.view(-1, 1)) 134 | 135 | return soft_label 136 | 137 | elif label_type == 'marginal_smoothing_k5': 138 | index = label[:,0,:].to(dtype=int) 139 | value = label[:,1,:] 140 | minor_value = (torch.ones(label.size(0),1) - torch.sum(value, dim=1, keepdim=True))/(n_classes-5) 141 | minor_value = minor_value.reshape(-1,1).repeat_interleave(n_classes, dim=1) 142 | soft_label = (minor_value * torch.ones(index.size(0), n_classes)).scatter_(1, index, value) 143 | 144 | return soft_label 145 | 146 | elif label_type == 'marginal_renorm': 147 | index = label[:,0,:].to(dtype=int) 148 | value = label[:,1,:] 149 | soft_label = torch.zeros(index.size(0), n_classes).scatter_(1, index, value) 150 | soft_label = F.normalize(soft_label, p=1.0, dim=1, eps=1e-12) 151 | 152 | return soft_label 153 | 154 | elif label_type == 'marginal_smoothing_k10': 155 | index = label[:,0,:].to(dtype=int) 156 | value = label[:,1,:] 157 | minor_value = (torch.ones(label.size(0),1) - torch.sum(value, dim=1, keepdim=True))/(n_classes-10) 158 | minor_value = minor_value.reshape(-1,1).repeat_interleave(n_classes, dim=1) 159 | soft_label = (minor_value * torch.ones(index.size(0), n_classes)).scatter_(1, index, value) 160 | 161 | return soft_label 162 | 163 | 164 | def rand_bbox(size, lam): 165 | W = size[2] 166 | H = size[3] 167 | cut_rat = np.sqrt(1. - lam) 168 | cut_w = np.int(W * cut_rat) 169 | cut_h = np.int(H * cut_rat) 170 | 171 | # uniform 172 | cx = np.random.randint(W) 173 | cy = np.random.randint(H) 174 | 175 | bbx1 = np.clip(cx - cut_w // 2, 0, W) 176 | bby1 = np.clip(cy - cut_h // 2, 0, H) 177 | bbx2 = np.clip(cx + cut_w // 2, 0, W) 178 | bby2 = np.clip(cy + cut_h // 2, 0, H) 179 | 180 | return bbx1, bby1, bbx2, bby2 181 | 182 | 183 | def mixup_cutmix(images, soft_label, args): 184 | enable_p = np.random.rand(1) 185 | if enable_p < args.mixup_cutmix_prob: 186 | switch_p = np.random.rand(1) 187 | if switch_p < args.mixup_switch_prob: 188 | lam = np.random.beta(args.mixup, args.mixup) 189 | rand_index = torch.randperm(images.size()[0]).cuda() 190 | target_a = soft_label 191 | target_b = soft_label[rand_index] 192 | mixed_x = lam * images + (1 - lam) * images[rand_index] 193 | target_mix = target_a * lam + target_b * (1 - lam) 194 | return mixed_x, target_mix 195 | else: 196 | lam = np.random.beta(args.cutmix, args.cutmix) 197 | rand_index = torch.randperm(images.size()[0]).cuda() 198 | target_a = soft_label 199 | target_b = soft_label[rand_index] 200 | bbx1, bby1, bbx2, bby2 = rand_bbox(images.size(), lam) 201 | images[:, :, bbx1:bbx2, bby1:bby2] = images[rand_index, :, bbx1:bbx2, bby1:bby2] 202 | # adjust lambda to exactly match pixel ratio 203 | lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (images.size()[-1] * images.size()[-2])) 204 | target_mix = target_a * lam + target_b * (1 - lam) 205 | else: 206 | target_mix = soft_label 207 | 208 | return images, target_mix 209 | -------------------------------------------------------------------------------- /FKD/README.md: -------------------------------------------------------------------------------- 1 | ## 🚀 FKD: A Fast Knowledge Distillation Framework for Visual Recognition 2 | 3 | Official PyTorch implementation of paper [**A Fast Knowledge Distillation Framework for Visual Recognition**](http://zhiqiangshen.com/projects/FKD/FKD_camera-ready.pdf) (ECCV 2022, [ECCV paper](http://zhiqiangshen.com/projects/FKD/FKD_camera-ready.pdf), [arXiv](https://arxiv.org/abs/2112.01528)), Zhiqiang Shen and Eric Xing. 4 | 5 | 6 |
7 | 8 |
9 | 10 | ### Abstract 11 | 12 | Knowledge Distillation (KD) has been recognized as a useful tool in many visual tasks, such as the **supervised classification** and **self-supervised representation learning**. While the main drawback of a vanilla KD framework lies in its mechanism that most of the computational overhead is consumed on forwarding through the giant teacher networks, which makes the whole learning procedure a low-efficient and costly manner. 13 | 14 | **🚀 Fast Knowledge Distillation (FKD)** is a novel framework that addresses the low-efficiency drawback, simulates the distillation training phase, and generates soft labels following the multi-crop KD procedure, meanwhile enjoying a faster training speed than other methods. FKD is even more efficient than the conventional classification framework when employing multi-crop in the same image for data loading. It achieves **80.1%** (SGD) and **80.5%** (AdamW) using ResNet-50 on ImageNet-1K with plain training settings. This work also demonstrates the efficiency advantage of FKD on the self-supervised learning task. 15 | 16 | ## Citation 17 | 18 | @article{shen2021afast, 19 | title={A Fast Knowledge Distillation Framework for Visual Recognition}, 20 | author={Zhiqiang Shen and Eric Xing}, 21 | year={2021}, 22 | journal={arXiv preprint arXiv:2112.01528} 23 | } 24 | 25 | ## What's New 26 | * Please refer to our work [here](https://github.com/VILA-Lab/SRe2L/tree/main/SRe2L/relabel#make-fkd-compatible-with-mixup-and-cutmix) if you would like to utilize mixture-based data augmentations (Mixup, CutMix, etc.) during the soft label generation and model training. 27 | * Includes [code of soft label generation](FKD_SLG) for customization. We will also set up a [soft label zoo and baselines](FKD_SLG) with multiple soft labels from various teachers. 28 | * FKD with AdamW on ResNet-50 achieves **80.5%** using a plain training scheme. Pre-trained model is available [here](https://drive.google.com/file/d/14HgpE-9SMOFUN3cb7gT9OjURqq7s7q2_/view?usp=sharing). 29 | 30 | 31 | 32 | ## Supervised Training 33 | 34 | ### Preparation 35 | 36 | - Install PyTorch and ImageNet dataset following the [official PyTorch ImageNet training code](https://github.com/pytorch/examples/tree/master/imagenet). This repo has minimal modifications on that code. 37 | 38 | - Download our soft label and unzip it. We provide multiple types of [soft labels](http://zhiqiangshen.com/projects/FKD/index.html), and we recommend to use [Marginal Smoothing Top-5 (500-crop)](https://drive.google.com/file/d/14leI6xGfnyxHPsBxo0PpCmOq71gWt008/view?usp=sharing). 39 | 40 | - [Optional] Generate customized soft labels using [./FKD_SLG](FKD_SLG). 41 | 42 | 43 | ### FKD Training on CNNs 44 | 45 | To train a model, run `train_FKD.py` with the desired model architecture and the path to the soft label and ImageNet dataset: 46 | 47 | 48 | ``` 49 | python train_FKD.py -a resnet50 --lr 0.1 --num_crops 4 -b 1024 --cos --temp 1.0 --softlabel_path [soft label path] [imagenet-folder with train and val folders] 50 | ``` 51 | 52 | Add `--mixup_cutmix` to enable Mixup and Cutmix augmentations. For `--softlabel_path`, use format as `./FKD_soft_label_500_crops_marginal_smoothing_k_5/imagenet`. 53 | 54 | Multi-processing distributed training on a single node with multiple GPUs: 55 | 56 | ``` 57 | python train_FKD.py \ 58 | --dist-url 'tcp://127.0.0.1:10001' \ 59 | --dist-backend 'nccl' \ 60 | --multiprocessing-distributed --world-size 1 --rank 0 \ 61 | -a resnet50 --lr 0.1 --num_crops 4 -b 1024 \ 62 | --temp 1.0 --cos -j 32 \ 63 | --save_checkpoint_path ./FKD_nc_4_res50_plain \ 64 | --softlabel_path [soft label path, e.g., ./FKD_soft_label_500_crops_marginal_smoothing_k_5/imagenet] \ 65 | [imagenet-folder with train and val folders] 66 | ``` 67 | 68 | 69 | **For multiple nodes multi-processing distributed training, please refer to [official PyTorch ImageNet training code](https://github.com/pytorch/examples/tree/master/imagenet) for details.** 70 | 71 | 72 | ### Evaluation 73 | 74 | ``` 75 | python train_FKD.py -a resnet50 -e --resume [model path] [imagenet-folder with train and val folders] 76 | ``` 77 | 78 | ### Training Speed Comparison 79 | 80 | The training speed of each epoch is tested on HPC/CIAI cluster at MBZUAI with 8 NVIDIA V100 GPUs. The batch size is 1024 for all three methods: **(i)** regular/vanilla classification framework, **(ii)** Relabel and **(iii)** FKD. For `Vanilla` and `ReLabel`, we use the average of 10 epochs after the speed is stable. For FKD, we perform `num_crops = 4` to calculate the average of (4 $\times$ 10) epochs, note that using 8 will give faster training speed. All other settings are the same for the comparison. 81 | 82 | | Method | Network | Training time per-epoch | 83 | |:-------:|:--------:|:--------:| 84 | | Vanilla | ResNet-50 | 579.36 sec/epoch | 85 | | ReLabel | ResNet-50 | 762.11 sec/epoch | 86 | | FKD (Ours) | ResNet-50 | 486.77 sec/epoch | 87 | 88 | ### Trained Models 89 | 90 | | Method | Network | accuracy (Top-1) |weights |configurations | 91 | |:-------:|:--------:|:--------:|:--------:|:--------:| 92 | | [`ReLabel`](https://github.com/naver-ai/relabel_imagenet) | ResNet-50 | 78.9 | -- | -- | 93 | | `FKD`| ResNet-50 |     **80.1+1.2%** | [link](https://drive.google.com/file/d/1qQK3kae4pXBZOldegnZqw7j_aJWtbPgV/view?usp=sharing) | same as ReLabel while initial lr = 0.1 $\times$ $batch size \over 512$ | 94 | | | | | 95 | | `FKD`(Plain)| ResNet-50 | **79.8** | [link](https://drive.google.com/file/d/1s6Tx5xmXnAseMZJBwaa4bnuvzZZGjMdk/view?usp=sharing) | [Table 12 in paper](http://zhiqiangshen.com/projects/FKD/FKD_camera-ready.pdf)
(w/o warmup&colorJ ) | 96 | | `FKD`(AdamW) | ResNet-50 | **80.5** | [link](https://drive.google.com/file/d/14HgpE-9SMOFUN3cb7gT9OjURqq7s7q2_/view?usp=sharing) | [Table 13 in paper](http://zhiqiangshen.com/projects/FKD/FKD_camera-ready.pdf)
(same as our settings on ViT and SReT) | 97 | | | | | 98 | | [`ReLabel`](https://github.com/naver-ai/relabel_imagenet) | ResNet-101 | 80.7 | -- | -- | 99 | | `FKD` | ResNet-101 |     **81.9+1.2%** | [link](TBA) | [Table 12 in paper](http://zhiqiangshen.com/projects/FKD/FKD_camera-ready.pdf) | 100 | | | | | 101 | | `FKD`(Plain)| ResNet-101 | **81.7** | [link](https://drive.google.com/file/d/13bVpHpTykCaYYXIAbWHa2W2C2tSxZlW5/view?usp=sharing) | [Table 12 in paper](http://zhiqiangshen.com/projects/FKD/FKD_camera-ready.pdf)
(w/o warmup&colorJ ) | 102 | 103 | ### Mobile-level Efficient Networks 104 | 105 | | Method | Network | FLOPs | accuracy (Top-1) |weights | 106 | |:-------:|:--------:|:--------:|:--------:|:--------:| 107 | | [`FBNet`](https://arxiv.org/abs/1812.03443)| FBNet-c100 | 375M | 75.12% | -- | 108 | | `FKD`| FBNet-c100 | 375M |     **77.13%+2.01%** | [link](https://drive.google.com/file/d/1s2pnIedXgwYAPpY2GBT3OC24ZP-0vfWe/view?usp=sharing) | 109 | | | | | 110 | | [`EfficientNetv2`](https://arxiv.org/abs/2104.00298)| EfficientNetv2-B0 | 700M | 78.35% | -- | 111 | | `FKD`| EfficientNetv2-B0 | 700M |     **79.94%+1.59%** | [link](https://drive.google.com/file/d/1qL21XOnTRWt6CvZLvUY5IpULISESEfZm/view?usp=sharing) | 112 | 113 | The training protocol is the same as we used for ViT/SReT: 114 | 115 | ``` 116 | # Use the same settings as on ViT and SReT 117 | cd train_ViT 118 | # Train the model 119 | python -u train_ViT_FKD.py \ 120 | --dist-url 'tcp://127.0.0.1:10001' \ 121 | --dist-backend 'nccl' \ 122 | --multiprocessing-distributed --world-size 1 --rank 0 \ 123 | -a tf_efficientnetv2_b0 \ 124 | --lr 0.002 --wd 0.05 \ 125 | --epochs 300 --cos -j 32 \ 126 | --num_classes 1000 --temp 1.0 \ 127 | -b 1024 --num_crops 4 \ 128 | --save_checkpoint_path ./FKD_nc_4_224_efficientnetv2_b0 \ 129 | --soft_label_type marginal_smoothing_k5 \ 130 | --softlabel_path [soft label path] \ 131 | [imagenet-folder with train and val folders] 132 | ``` 133 | 134 | ### FKD Training on ViT/DeiT and SReT 135 | 136 | To train a ViT model, run `train_ViT_FKD.py` with the desired model architecture and the path to the soft label and ImageNet dataset: 137 | 138 | ``` 139 | cd train_ViT 140 | python train_ViT_FKD.py \ 141 | --dist-url 'tcp://127.0.0.1:10001' \ 142 | --dist-backend 'nccl' \ 143 | --multiprocessing-distributed --world-size 1 --rank 0 \ 144 | -a SReT_LT --lr 0.002 --wd 0.05 --num_crops 4 \ 145 | --temp 1.0 -b 1024 --cos \ 146 | --softlabel_path [soft label path] \ 147 | [imagenet-folder with train and val folders] 148 | ``` 149 | 150 | For the instructions of `SReT_LT` model, please refer to [SReT](https://github.com/szq0214/SReT) for details. 151 | 152 | ### Evaluation 153 | 154 | ``` 155 | python train_ViT_FKD.py -a SReT_LT -e --resume [model path] [imagenet-folder with train and val folders] 156 | ``` 157 | 158 | ### Trained Models 159 | 160 | | Model | FLOPs| #params | accuracy (Top-1) |weights |configurations | 161 | |:-------:|:--------:|:--------:|:--------:|:--------:|:--------:| 162 | | [`DeiT-T-distill`](https://github.com/facebookresearch/deit) | 1.3B | 5.7M | 74.5 |-- | -- | 163 | | `FKD ViT/DeiT-T` | 1.3B | 5.7M | **75.2** |[link](https://drive.google.com/file/d/1m33c1wHdCV7ePETO_HvWNaboSd_W4nfC/view?usp=sharing) | [Table 13 in paper](http://zhiqiangshen.com/projects/FKD/FKD_camera-ready.pdf) | 164 | | [`SReT-LT-distill`](https://github.com/szq0214/SReT) | 1.2B | 5.0M | 77.7 |-- | -- | 165 | | `FKD SReT-LT` | 1.2B | 5.0M | **78.7** |[link](https://drive.google.com/file/d/1mmdPXKutHM9Li8xo5nGG6TB0aAXA9PFR/view?usp=sharing) | [Table 13 in paper](http://zhiqiangshen.com/projects/FKD/FKD_camera-ready.pdf) | 166 | 167 | ## Fast MEAL V2 168 | 169 | Please see [MEAL V2](https://github.com/szq0214/MEAL-V2) for the instructions to run FKD with MEAL V2. 170 | 171 | ## Self-supervised Representation Learning Using FKD 172 | 173 | Please see [FKD-SSL](https://github.com/szq0214/FKD/tree/main/FKD_SSL) for the instructions to run FKD for SSL task. 174 | 175 | 176 | ## Contact 177 | 178 | Zhiqiang Shen (zhiqiangshen0214 at gmail.com or zhiqians at andrew.cmu.edu) 179 | 180 | -------------------------------------------------------------------------------- /FKD/FKD_ViT/SReT.py: -------------------------------------------------------------------------------- 1 | # SReT (Sliced Recursive Transformer: https://arxiv.org/abs/2111.05297) 2 | # Zhiqiang Shen 3 | # CMU & MBZUAI 4 | 5 | # PiT (Rethinking Spatial Dimensions of Vision Transformers) 6 | # Copyright 2021-present NAVER Corp. 7 | # Apache License v2.0 8 | 9 | # Timm (https://github.com/rwightman/pytorch-image-models) 10 | # Ross Wightman 11 | # Apache License v2.0 12 | 13 | import torch 14 | from einops import rearrange 15 | from torch import nn 16 | import math 17 | 18 | from functools import partial 19 | from timm.models.layers import trunc_normal_ 20 | from timm.models.layers import DropPath, to_2tuple, lecun_normal_ 21 | from timm.models.registry import register_model 22 | 23 | 24 | __all__ = [ 25 | "SReT", 26 | "SReT_T", 27 | "SReT_LT", 28 | "SReT_S", 29 | ] 30 | 31 | 32 | class LearnableCoefficient(nn.Module): 33 | def __init__(self): 34 | super(LearnableCoefficient, self).__init__() 35 | self.bias = nn.Parameter(torch.ones(1), requires_grad=True) 36 | 37 | def forward(self, x): 38 | out = x * self.bias 39 | return out 40 | 41 | class Mlp(nn.Module): 42 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 43 | super().__init__() 44 | out_features = out_features or in_features 45 | hidden_features = hidden_features or in_features 46 | self.fc1 = nn.Linear(in_features, hidden_features) 47 | self.act = act_layer() 48 | self.fc2 = nn.Linear(hidden_features, out_features) 49 | self.drop = nn.Dropout(drop) 50 | 51 | def forward(self, x): 52 | x = self.fc1(x) 53 | x = self.act(x) 54 | x = self.drop(x) 55 | x = self.fc2(x) 56 | x = self.drop(x) 57 | return x 58 | 59 | 60 | class Non_proj(nn.Module): 61 | 62 | def __init__(self, dim, num_heads, mlp_ratio=1., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 63 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 64 | super().__init__() 65 | self.norm1 = norm_layer(dim) 66 | mlp_hidden_dim = int(dim * mlp_ratio) 67 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 68 | self.coefficient1 = LearnableCoefficient() 69 | self.coefficient2 = LearnableCoefficient() 70 | 71 | def forward(self, x, recursive_index): 72 | x = self.coefficient1(x) + self.coefficient2(self.mlp(self.norm1(x))) 73 | return x 74 | 75 | class Attention(nn.Module): 76 | def __init__(self, dim, num_groups1=8, num_groups2=4, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 77 | super().__init__() 78 | self.num_heads = num_heads 79 | self.num_groups1 = num_groups1 80 | self.num_groups2 = num_groups2 81 | head_dim = dim // num_heads 82 | self.scale = qk_scale or head_dim ** -0.5 83 | 84 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 85 | self.attn_drop = nn.Dropout(attn_drop) 86 | self.proj = nn.Linear(dim, dim) 87 | self.proj_drop = nn.Dropout(proj_drop) 88 | 89 | def forward(self, x, recursive_index): 90 | B, N, C = x.shape 91 | if recursive_index == False: 92 | num_groups = self.num_groups1 93 | else: 94 | num_groups = self.num_groups2 95 | if num_groups != 1: 96 | idx = torch.randperm(N) 97 | x = x[:,idx,:] 98 | inverse = torch.argsort(idx) 99 | qkv = self.qkv(x).reshape(B, num_groups, N // num_groups, 3, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5) 100 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 101 | 102 | attn = (q @ k.transpose(-2, -1)) * self.scale 103 | attn = attn.softmax(dim=-1) 104 | attn = self.attn_drop(attn) 105 | 106 | x = (attn @ v).transpose(2, 3).reshape(B, num_groups, N // num_groups, C) 107 | x = x.permute(0, 3, 1, 2).reshape(B, C, N).transpose(1, 2) 108 | if recursive_index == True and num_groups != 1: 109 | x = x[:,inverse,:] 110 | x = self.proj(x) 111 | x = self.proj_drop(x) 112 | return x 113 | 114 | class Transformer_Block(nn.Module): 115 | 116 | def __init__(self, dim, num_groups1, num_groups2, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 117 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 118 | super().__init__() 119 | self.norm1 = norm_layer(dim) 120 | self.attn = Attention( 121 | dim, num_groups1=num_groups1, num_groups2=num_groups2, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 122 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 123 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 124 | self.norm2 = norm_layer(dim) 125 | mlp_hidden_dim = int(dim * mlp_ratio) 126 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 127 | self.coefficient1 = LearnableCoefficient() 128 | self.coefficient2 = LearnableCoefficient() 129 | self.coefficient3 = LearnableCoefficient() 130 | self.coefficient4 = LearnableCoefficient() 131 | 132 | def forward(self, x, recursive_index): 133 | x = self.coefficient1(x) + self.coefficient2(self.drop_path(self.attn(self.norm1(x),recursive_index))) 134 | 135 | x = self.coefficient3(x) + self.coefficient4(self.drop_path(self.mlp(self.norm2(x)))) 136 | return x 137 | 138 | class Transformer(nn.Module): 139 | def __init__(self, base_dim, depth, recursive_num, groups1, groups2, heads, mlp_ratio, np_mlp_ratio, 140 | drop_rate=.0, attn_drop_rate=.0, drop_path_prob=None): 141 | super(Transformer, self).__init__() 142 | self.layers = nn.ModuleList([]) 143 | embed_dim = base_dim * heads 144 | 145 | if drop_path_prob is None: 146 | drop_path_prob = [0.0 for _ in range(depth)] 147 | 148 | blocks = [ 149 | Transformer_Block( 150 | dim=embed_dim, 151 | num_groups1=groups1, 152 | num_groups2=groups2, 153 | num_heads=heads, 154 | mlp_ratio=mlp_ratio, 155 | qkv_bias=True, 156 | drop=drop_rate, 157 | attn_drop=attn_drop_rate, 158 | drop_path=drop_path_prob[i], 159 | act_layer=nn.GELU, 160 | norm_layer=partial(nn.LayerNorm, eps=1e-6) 161 | ) 162 | for i in range(recursive_num)] 163 | 164 | recursive_loops = int(depth/recursive_num) 165 | non_projs = [ 166 | Non_proj( 167 | dim=embed_dim, num_heads=heads, mlp_ratio=np_mlp_ratio, drop=drop_rate, attn_drop=attn_drop_rate, 168 | drop_path=drop_path_prob[i], norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU) 169 | for i in range(depth)] 170 | RT = [] 171 | for rn in range(recursive_num): 172 | for rl in range(recursive_loops): 173 | RT.append(blocks[rn]) 174 | RT.append(non_projs[rn*recursive_loops+rl]) 175 | 176 | self.blocks = nn.ModuleList(RT) 177 | 178 | def forward(self, x): 179 | h, w = x.shape[2:4] 180 | x = rearrange(x, 'b c h w -> b (h w) c') 181 | 182 | for i, blk in enumerate(self.blocks): 183 | if (i+2)%4 == 0: # mark the recursive layers 184 | recursive_index = True 185 | else: 186 | recursive_index = False 187 | x = blk(x, recursive_index) 188 | 189 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 190 | 191 | return x 192 | 193 | 194 | class conv_head_pooling(nn.Module): 195 | def __init__(self, in_feature, out_feature, stride, 196 | padding_mode='zeros'): 197 | super(conv_head_pooling, self).__init__() 198 | 199 | self.conv = nn.Conv2d(in_feature, out_feature, kernel_size=stride + 1, 200 | padding=stride // 2, stride=stride, 201 | padding_mode=padding_mode, groups=in_feature) 202 | 203 | def forward(self, x): 204 | 205 | x = self.conv(x) 206 | 207 | return x 208 | 209 | 210 | class conv_embedding(nn.Module): 211 | def __init__(self, in_channels, out_channels, patch_size, 212 | stride, padding): 213 | super(conv_embedding, self).__init__() 214 | norm_layer = nn.BatchNorm2d 215 | self.conv1 = nn.Conv2d(in_channels, int(out_channels/2), kernel_size=3, 216 | stride=2, padding=1, bias=True) 217 | self.bn1 = norm_layer(int(out_channels/2)) 218 | self.relu1 = nn.ReLU(inplace=True) 219 | self.conv2 = nn.Conv2d(int(out_channels/2), out_channels, kernel_size=3, 220 | stride=2, padding=1, bias=True) 221 | self.bn2 = norm_layer(out_channels) 222 | self.relu2 = nn.ReLU(inplace=True) 223 | self.conv3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, 224 | stride=2, padding=1, bias=True) 225 | self.bn3 = norm_layer(out_channels) 226 | self.relu3 = nn.ReLU(inplace=True) 227 | 228 | def forward(self, x): 229 | x = self.conv1(x) 230 | x = self.bn1(x) 231 | x = self.relu1(x) 232 | x = self.conv2(x) 233 | x = self.bn2(x) 234 | x = self.relu2(x) 235 | x = self.conv3(x) 236 | x = self.bn3(x) 237 | x = self.relu3(x) 238 | return x 239 | 240 | 241 | class SReT(nn.Module): 242 | def __init__(self, image_size, patch_size, stride, base_dims, depth, recursive_num, groups1, groups2, heads, 243 | mlp_ratio, np_mlp_ratio, num_classes=1000, in_chans=3, 244 | attn_drop_rate=.0, drop_rate=.0, drop_path_rate=.0): 245 | super(SReT, self).__init__() 246 | 247 | total_block = sum(depth) 248 | padding = 0 249 | block_idx = 0 250 | 251 | width = int(image_size/8) 252 | 253 | self.base_dims = base_dims 254 | self.heads = heads 255 | self.num_classes = num_classes 256 | 257 | self.patch_size = patch_size 258 | self.pos_embed = nn.Parameter( 259 | torch.randn(1, base_dims[0] * heads[0], width, width), 260 | requires_grad=True 261 | ) 262 | self.patch_embed = conv_embedding(in_chans, base_dims[0] * heads[0], 263 | patch_size, stride, padding) 264 | 265 | self.pos_drop = nn.Dropout(p=drop_rate) 266 | 267 | self.transformers = nn.ModuleList([]) 268 | self.pools = nn.ModuleList([]) 269 | 270 | for stage in range(len(depth)): 271 | drop_path_prob = [drop_path_rate * i / total_block 272 | for i in range(block_idx, block_idx + depth[stage])] 273 | block_idx += depth[stage] 274 | 275 | self.transformers.append( 276 | Transformer(base_dims[stage], depth[stage], recursive_num[stage], groups1[stage], groups2[stage], heads[stage], 277 | mlp_ratio, np_mlp_ratio, 278 | drop_rate, attn_drop_rate, drop_path_prob) 279 | ) 280 | if stage < len(heads) - 1: 281 | self.pools.append( 282 | conv_head_pooling(base_dims[stage] * heads[stage], 283 | base_dims[stage + 1] * heads[stage + 1], 284 | stride=2 285 | ) 286 | ) 287 | 288 | self.norm = nn.LayerNorm(base_dims[-1] * heads[-1], eps=1e-6) 289 | self.embed_dim = base_dims[-1] * heads[-1] 290 | 291 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 292 | 293 | # Classifier head 294 | if num_classes > 0: 295 | self.head = nn.Linear(base_dims[-1] * heads[-1], num_classes) 296 | else: 297 | self.head = nn.Identity() 298 | 299 | trunc_normal_(self.pos_embed, std=.02) 300 | self.apply(self._init_weights) 301 | 302 | def _init_weights(self, m): 303 | if isinstance(m, nn.LayerNorm): 304 | nn.init.constant_(m.bias, 0) 305 | nn.init.constant_(m.weight, 1.0) 306 | 307 | @torch.jit.ignore 308 | def no_weight_decay(self): 309 | return {'pos_embed', 'cls_token'} 310 | 311 | def get_classifier(self): 312 | return self.head 313 | 314 | def reset_classifier(self, num_classes, global_pool=''): 315 | self.num_classes = num_classes 316 | if num_classes > 0: 317 | self.head = nn.Linear(self.embed_dim, num_classes) 318 | else: 319 | self.head = nn.Identity() 320 | 321 | def forward_features(self, x): 322 | x = self.patch_embed(x) 323 | 324 | pos_embed = self.pos_embed 325 | x = self.pos_drop(x + pos_embed) 326 | 327 | for stage in range(len(self.pools)): 328 | x = self.transformers[stage](x) 329 | x = self.pools[stage](x) 330 | x = self.transformers[-1](x) 331 | 332 | x = self.avgpool(x) 333 | x = torch.flatten(x, 1) 334 | 335 | x = self.norm(x) 336 | 337 | return x 338 | 339 | def forward(self, x): 340 | x = self.forward_features(x) 341 | x = self.head(x) 342 | return x 343 | 344 | 345 | class Distilled_SReT(SReT): 346 | def __init__(self, *args, **kwargs): 347 | super().__init__(*args, **kwargs) 348 | 349 | def forward(self, x): 350 | x = self.forward_features(x) 351 | x_cls = self.head(x) 352 | # `x_cls, x_cls` is used to make it compatible with DeiT codebase, while SReT uses global_average pooling, and soft label only for knowledge distillation 353 | # so `x_cls` is enough 354 | if self.training: 355 | # return x_cls, x_cls 356 | return x_cls 357 | else: 358 | return x_cls 359 | 360 | 361 | @register_model 362 | def SReT_T(pretrained=False, **kwargs): 363 | model = SReT( 364 | image_size=224, 365 | patch_size=16, 366 | stride=8, 367 | base_dims=[32, 32, 32], 368 | depth=[4, 10, 6], 369 | recursive_num=[2,5,3], 370 | heads=[2, 4, 8], 371 | groups1=[8, 4, 1], 372 | groups2=[2, 1, 1], 373 | mlp_ratio=3.6, 374 | np_mlp_ratio=1, 375 | drop_path_rate=0.1, 376 | **kwargs 377 | ) 378 | if pretrained: 379 | state_dict = \ 380 | torch.load('SReT_T.pth', map_location='cpu') 381 | model.load_state_dict(state_dict['model']) 382 | return model 383 | 384 | @register_model 385 | def SReT_LT(pretrained=False, **kwargs): 386 | model = SReT( 387 | image_size=224, 388 | patch_size=16, 389 | stride=8, 390 | base_dims=[32, 32, 32], 391 | depth=[4, 10, 6], 392 | recursive_num=[2, 5, 3], 393 | heads=[2, 4, 8], 394 | groups1=[8, 4, 1], 395 | groups2=[2, 1, 1], 396 | mlp_ratio=4.0, 397 | np_mlp_ratio=1, 398 | drop_path_rate=0.1, 399 | **kwargs 400 | ) 401 | if pretrained: 402 | state_dict = \ 403 | torch.load('SReT_LT.pth', map_location='cpu') 404 | model.load_state_dict(state_dict['model']) 405 | return model 406 | 407 | def SReT_S(pretrained=False, **kwargs): 408 | model = SReT( 409 | image_size=224, 410 | patch_size=16, 411 | stride=8, 412 | base_dims=[42, 42, 42], 413 | depth=[4, 10, 6], 414 | recursive_num=[2, 5, 3], 415 | heads=[3, 6, 12], 416 | groups1=[8, 4, 1], 417 | groups2=[2, 1, 1], 418 | mlp_ratio=3.0, 419 | np_mlp_ratio=2, 420 | drop_path_rate=0.2, 421 | **kwargs 422 | ) 423 | if pretrained: 424 | state_dict = \ 425 | torch.load('SReT_S.pth', map_location='cpu') 426 | model.load_state_dict(state_dict['model']) 427 | return model 428 | 429 | # Knowledge Distillation 430 | @register_model 431 | def SReT_T_distill(pretrained=False, **kwargs): 432 | model = Distilled_SReT( 433 | image_size=224, 434 | patch_size=16, 435 | stride=8, 436 | base_dims=[32, 32, 32], 437 | depth=[4, 10, 6], 438 | recursive_num=[2, 5, 3], 439 | heads=[2, 4, 8], 440 | groups1=[8, 4, 1], 441 | groups2=[2, 1, 1], 442 | mlp_ratio=3.6, 443 | np_mlp_ratio=1, 444 | drop_path_rate=0.1, 445 | **kwargs 446 | ) 447 | if pretrained: 448 | state_dict = \ 449 | torch.load('SReT_T_distill.pth', map_location='cpu') 450 | model.load_state_dict(state_dict['model']) 451 | return model 452 | 453 | @register_model 454 | def SReT_LT_distill(pretrained=False, **kwargs): 455 | model = Distilled_SReT( 456 | image_size=224, 457 | patch_size=16, 458 | stride=8, 459 | base_dims=[32, 32, 32], 460 | depth=[4, 10, 6], 461 | recursive_num=[2, 5, 3], 462 | heads=[2, 4, 8], 463 | groups1=[8, 4, 1], 464 | groups2=[2, 1, 1], 465 | mlp_ratio=4.0, 466 | np_mlp_ratio=1, 467 | drop_path_rate=0.1, 468 | **kwargs 469 | ) 470 | if pretrained: 471 | state_dict = \ 472 | torch.load('SReT_LT_distill.pth', map_location='cpu') 473 | model.load_state_dict(state_dict['model']) 474 | return model 475 | 476 | def SReT_S_distill(pretrained=False, **kwargs): 477 | model = Distilled_SReT( 478 | image_size=224, 479 | patch_size=16, 480 | stride=8, 481 | base_dims=[42, 42, 42], 482 | depth=[4, 10, 6], 483 | recursive_num=[2, 5, 3], 484 | heads=[3, 6, 12], 485 | groups1=[8, 4, 1], 486 | groups2=[2, 1, 1], 487 | mlp_ratio=3.0, 488 | np_mlp_ratio=2, 489 | drop_path_rate=0.2, 490 | **kwargs 491 | ) 492 | if pretrained: 493 | state_dict = \ 494 | torch.load('SReT_S_distill.pth', map_location='cpu') 495 | model.load_state_dict(state_dict['model']) 496 | return model -------------------------------------------------------------------------------- /FKD/FKD_SLG/generate_soft_label.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import shutil 5 | import time 6 | import warnings 7 | import timm 8 | import numpy as np 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.parallel 13 | import torch.backends.cudnn as cudnn 14 | import torch.distributed as dist 15 | import torch.optim 16 | import torch.multiprocessing as mp 17 | import torch.utils.data 18 | import torch.utils.data.distributed 19 | import torchvision.transforms as transforms 20 | import torchvision.datasets as datasets 21 | import torchvision.models as models 22 | 23 | from utils import RandomResizedCropWithCoords 24 | from utils import RandomHorizontalFlipWithRes 25 | from utils import ImageFolder_FKD_GSL 26 | from utils import ComposeWithCoords 27 | 28 | from torchvision.transforms import InterpolationMode 29 | 30 | import torch.multiprocessing 31 | 32 | 33 | parser = argparse.ArgumentParser(description='FKD Soft Label Generation on ImageNet-1K') 34 | parser.add_argument('data', metavar='DIR', 35 | help='path to dataset') 36 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18') 37 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 38 | help='number of data loading workers (default: 4)') 39 | parser.add_argument('-b', '--batch-size', default=256, type=int, 40 | metavar='N', 41 | help='mini-batch size (default: 256), this is the total ' 42 | 'batch size of all GPUs on the current node when ' 43 | 'using Data Parallel or Distributed Data Parallel') 44 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 45 | help='evaluate model on validation set') 46 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 47 | help='use pre-trained model') 48 | parser.add_argument('--world-size', default=-1, type=int, 49 | help='number of nodes for distributed training') 50 | parser.add_argument('--rank', default=-1, type=int, 51 | help='node rank for distributed training') 52 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 53 | help='url used to set up distributed training') 54 | parser.add_argument('--dist-backend', default='nccl', type=str, 55 | help='distributed backend') 56 | parser.add_argument('--seed', default=None, type=int, 57 | help='seed for initializing training. ') 58 | parser.add_argument('--gpu', default=None, type=int, 59 | help='GPU id to use.') 60 | parser.add_argument('--multiprocessing-distributed', action='store_true', 61 | help='Use multi-processing distributed training to launch ' 62 | 'N processes per node, which has N GPUs. This is the ' 63 | 'fastest way to use PyTorch for either single node or ' 64 | 'multi node data parallel training') 65 | # FKD soft label generation args 66 | parser.add_argument('--num_crops', default=500, type=int, 67 | help='total number of crops in each image') 68 | parser.add_argument('--num_seg', default=50, type=int, 69 | help='number of crops on each GPU during generating') 70 | parser.add_argument("--min_scale_crops", type=float, default=0.08, 71 | help="argument in RandomResizedCrop") 72 | parser.add_argument("--max_scale_crops", type=float, default=0.936, 73 | help="argument in RandomResizedCrop") 74 | parser.add_argument("--temp", type=float, default=1.0, 75 | help="temperature on soft label") 76 | parser.add_argument('--save_path', default='./FKD_effL2_475_soft_label', type=str, metavar='PATH', 77 | help='path to save soft labels') 78 | parser.add_argument('--reference_path', default=None, type=str, metavar='PATH', 79 | help='path to existing soft labels files, we can use existing crop locations to generate new soft labels') 80 | parser.add_argument('--input_size', default=475, type=int, metavar='S', 81 | help='input size of teacher model') 82 | parser.add_argument('--teacher_path', default='', type=str, metavar='TEACHER', 83 | help='path of pre-trained teacher') 84 | parser.add_argument('--teacher_source', default='timm', type=str, metavar='SOURCE', 85 | help='source of pre-trained teacher models: pytorch or timm') 86 | parser.add_argument('--label_type', default='marginal_smoothing_k5', type=str, metavar='TYPE', 87 | help='type of generated soft labels') 88 | parser.add_argument('--use_fp16', dest='use_fp16', action='store_true', 89 | help='save soft labels as `fp16`') 90 | 91 | 92 | sharing_strategy = "file_system" 93 | torch.multiprocessing.set_sharing_strategy(sharing_strategy) 94 | 95 | def set_worker_sharing_strategy(worker_id: int) -> None: 96 | torch.multiprocessing.set_sharing_strategy(sharing_strategy) 97 | 98 | 99 | def main(): 100 | args = parser.parse_args() 101 | 102 | if not os.path.exists(args.save_path): 103 | os.makedirs(args.save_path, exist_ok=True) 104 | 105 | if args.seed is not None: 106 | random.seed(args.seed) 107 | torch.manual_seed(args.seed) 108 | cudnn.deterministic = True 109 | warnings.warn('You have chosen to seed training. ' 110 | 'This will turn on the CUDNN deterministic setting, ' 111 | 'which can slow down your training considerably! ' 112 | 'You may see unexpected behavior when restarting ' 113 | 'from checkpoints.') 114 | 115 | if args.gpu is not None: 116 | warnings.warn('You have chosen a specific GPU. This will completely ' 117 | 'disable data parallelism.') 118 | 119 | if args.dist_url == "env://" and args.world_size == -1: 120 | args.world_size = int(os.environ["WORLD_SIZE"]) 121 | 122 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 123 | 124 | ngpus_per_node = torch.cuda.device_count() 125 | if args.multiprocessing_distributed: 126 | # Since we have ngpus_per_node processes per node, the total world_size 127 | # needs to be adjusted accordingly 128 | args.world_size = ngpus_per_node * args.world_size 129 | # Use torch.multiprocessing.spawn to launch distributed processes: the 130 | # main_worker process function 131 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 132 | else: 133 | # Simply call main_worker function 134 | main_worker(args.gpu, ngpus_per_node, args) 135 | 136 | 137 | def main_worker(gpu, ngpus_per_node, args): 138 | args.gpu = gpu 139 | 140 | if args.gpu is not None: 141 | print("Use GPU: {} for training".format(args.gpu)) 142 | 143 | if args.distributed: 144 | if args.dist_url == "env://" and args.rank == -1: 145 | args.rank = int(os.environ["RANK"]) 146 | if args.multiprocessing_distributed: 147 | # For multiprocessing distributed training, rank needs to be the 148 | # global rank among all the processes 149 | args.rank = args.rank * ngpus_per_node + gpu 150 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 151 | world_size=args.world_size, rank=args.rank) 152 | # create model 153 | if os.path.isfile(args.teacher_path): 154 | print("=> using pre-trained model from '{}'".format(args.teacher_path)) 155 | model = models.__dict__[args.arch](pretrained=False) 156 | checkpoint = torch.load(args.teacher_path, map_location=torch.device('cpu')) 157 | model.load_state_dict(checkpoint) 158 | elif args.teacher_source == 'timm': 159 | # Timm pre-trained models 160 | print("=> using pre-trained model '{}'".format(args.arch)) 161 | model = timm.create_model(args.arch, pretrained=True, num_classes=1000) 162 | elif args.teacher_source == 'pytorch': 163 | # PyTorch pre-trained models 164 | print("=> using pre-trained model '{}'".format(args.arch)) 165 | model = models.__dict__[args.arch](pretrained=True) 166 | else: 167 | print("'{}' currently is not supported. Please use pytorch, timm or your own pre-trained models as teachers.".format(args.teacher_source)) 168 | # add your code of loading teacher here. 169 | return 170 | 171 | if not torch.cuda.is_available(): 172 | print('using CPU, this will be slow') 173 | elif args.distributed: 174 | # For multiprocessing distributed, DistributedDataParallel constructor 175 | # should always set the single device scope, otherwise, 176 | # DistributedDataParallel will use all available devices. 177 | if args.gpu is not None: 178 | torch.cuda.set_device(args.gpu) 179 | model.cuda(args.gpu) 180 | # When using a single GPU per process and per 181 | # DistributedDataParallel, we need to divide the batch size 182 | # ourselves based on the total number of GPUs we have 183 | args.batch_size = int(args.batch_size / ngpus_per_node) 184 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 185 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 186 | else: 187 | model.cuda() 188 | # DistributedDataParallel will divide and allocate batch_size to all 189 | # available GPUs if device_ids are not set 190 | model = torch.nn.parallel.DistributedDataParallel(model) 191 | elif args.gpu is not None: 192 | torch.cuda.set_device(args.gpu) 193 | model = model.cuda(args.gpu) 194 | else: 195 | # DataParallel will divide and allocate batch_size to all available GPUs 196 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 197 | model.features = torch.nn.DataParallel(model.features) 198 | model.cuda() 199 | else: 200 | model = torch.nn.DataParallel(model).cuda() 201 | 202 | # freeze all layers 203 | for name, param in model.named_parameters(): 204 | param.requires_grad = False 205 | 206 | # define loss function (criterion) and optimizer 207 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 208 | 209 | cudnn.benchmark = True 210 | 211 | # Data loading code 212 | traindir = os.path.join(args.data, 'train') 213 | valdir = os.path.join(args.data, 'val') 214 | 215 | # BEIT, etc. 216 | # normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], 217 | # std=[0.5, 0.5, 0.5]) 218 | # ResNet, efficientnet_l2_ns_475, etc. 219 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 220 | std=[0.229, 0.224, 0.225]) 221 | 222 | train_dataset = ImageFolder_FKD_GSL( 223 | num_crops=args.num_crops, 224 | save_path=args.save_path, 225 | reference_path=args.reference_path, 226 | root=traindir, 227 | transform=ComposeWithCoords(transforms=[ 228 | RandomResizedCropWithCoords(size=args.input_size, 229 | scale=(args.min_scale_crops, args.max_scale_crops), 230 | interpolation=InterpolationMode.BICUBIC), 231 | RandomHorizontalFlipWithRes(), 232 | transforms.ToTensor(), 233 | normalize, 234 | ])) 235 | 236 | if args.distributed: 237 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 238 | else: 239 | train_sampler = None 240 | 241 | #(train_sampler is None) 242 | train_loader = torch.utils.data.DataLoader( 243 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 244 | num_workers=args.workers, pin_memory=True, sampler=train_sampler, worker_init_fn=set_worker_sharing_strategy) 245 | 246 | val_loader = torch.utils.data.DataLoader( 247 | datasets.ImageFolder(valdir, transforms.Compose([ 248 | transforms.Resize(int(256/224*args.input_size)), 249 | transforms.CenterCrop(args.input_size), 250 | transforms.ToTensor(), 251 | normalize, 252 | ])), 253 | batch_size=args.batch_size, shuffle=False, 254 | num_workers=args.workers, pin_memory=True) 255 | 256 | # test the accuracy of teacher model 257 | if args.evaluate: 258 | validate(val_loader, model, criterion, args) 259 | return 260 | 261 | # generate soft labels 262 | generate_soft_labels(train_loader, model, args) 263 | 264 | 265 | def generate_soft_labels(train_loader, model, args): 266 | batch_time = AverageMeter('Time', ':6.3f') 267 | data_time = AverageMeter('Data', ':6.3f') 268 | 269 | # switch to eval mode 270 | model.eval() 271 | 272 | end = time.time() 273 | for i, (images, target, flip_status, coords, path) in enumerate(train_loader): 274 | # measure data loading time 275 | data_time.update(time.time() - end) 276 | 277 | images = torch.stack(images, dim=0).permute(1,0,2,3,4) 278 | flip_status = torch.stack(flip_status, dim=0).permute(1,0) 279 | coords = torch.stack(coords, dim=0).permute(1,0,2) 280 | 281 | for k in range(images.size()[0]): 282 | save_path = os.path.join(args.save_path,'/'.join(path[k].split('/')[-4:-1])) 283 | if not os.path.exists(save_path): 284 | os.makedirs(save_path, exist_ok=True) 285 | new_filename = os.path.join(save_path,'/'.join(path[k].split('/')[-1:]).split('.')[0] + '.tar') 286 | if not os.path.exists(new_filename): 287 | if args.num_crops <= args.num_seg: 288 | if args.gpu is not None: 289 | images[k] = images[k].cuda(args.gpu, non_blocking=True) 290 | if torch.cuda.is_available(): 291 | target = target.cuda(args.gpu, non_blocking=True) 292 | 293 | # compute output 294 | output = model(images[k]) 295 | 296 | output = nn.functional.softmax(output / args.temp, dim=1) 297 | images[k] = images[k].detach()#.cpu() 298 | output = output.detach().cpu() 299 | 300 | output = label_quantization(output, args.label_type) 301 | 302 | state = [coords[k].detach().numpy(), flip_status[k].detach().numpy(), output] 303 | torch.save(state, new_filename) 304 | else: 305 | output_all = [] 306 | for split in range(int(args.num_crops / args.num_seg)): 307 | if args.gpu is not None: 308 | images[k][split*args.num_seg:(split+1)*args.num_seg] = images[k][split*args.num_seg:(split+1)*args.num_seg].cuda(args.gpu, non_blocking=True) 309 | if torch.cuda.is_available(): 310 | target = target.cuda(args.gpu, non_blocking=True) 311 | 312 | # compute output 313 | output = model(images[k][split*args.num_seg:(split+1)*args.num_seg]) 314 | 315 | output = nn.functional.softmax(output / args.temp, dim=1) 316 | images[k][split*args.num_seg:(split+1)*args.num_seg] = images[k][split*args.num_seg:(split+1)*args.num_seg].detach()#.cpu() 317 | output = output.detach().cpu() 318 | output_all.append(output) 319 | 320 | output_all = torch.cat(output_all, dim=0) 321 | output_quan = label_quantization(output_all, args.label_type) 322 | 323 | if args.use_fp16: 324 | state = [np.float16(coords[k].detach().numpy()), flip_status[k].detach().numpy(), np.float16(output_quan)] 325 | else: 326 | state = [coords[k].detach().numpy(), flip_status[k].detach().numpy(), output_quan] 327 | 328 | torch.save(state, new_filename) 329 | 330 | print(i,'/', len(train_loader), i/len(train_loader)*100, "%") 331 | 332 | 333 | def validate(val_loader, model, criterion, args): 334 | batch_time = AverageMeter('Time', ':6.3f') 335 | losses = AverageMeter('Loss', ':.4e') 336 | top1 = AverageMeter('Acc@1', ':6.2f') 337 | top5 = AverageMeter('Acc@5', ':6.2f') 338 | progress = ProgressMeter( 339 | len(val_loader), 340 | [batch_time, losses, top1, top5], 341 | prefix='Test: ') 342 | 343 | # switch to evaluate mode 344 | model.eval() 345 | 346 | with torch.no_grad(): 347 | end = time.time() 348 | for i, (images, target) in enumerate(val_loader): 349 | if args.gpu is not None: 350 | images = images.cuda(args.gpu, non_blocking=True) 351 | if torch.cuda.is_available(): 352 | target = target.cuda(args.gpu, non_blocking=True) 353 | 354 | # compute output 355 | output = model(images) 356 | loss = criterion(output, target) 357 | 358 | # measure accuracy and record loss 359 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 360 | losses.update(loss.item(), images.size(0)) 361 | top1.update(acc1[0], images.size(0)) 362 | top5.update(acc5[0], images.size(0)) 363 | 364 | # measure elapsed time 365 | batch_time.update(time.time() - end) 366 | end = time.time() 367 | 368 | if i % args.print_freq == 0: 369 | progress.display(i) 370 | 371 | # TODO: this should also be done with the ProgressMeter 372 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 373 | .format(top1=top1, top5=top5)) 374 | 375 | return top1.avg 376 | 377 | def label_quantization(full_label, label_type): 378 | # '(1) hard; (2) smoothing; (3) marginal_smoothing_k5 and marginal_renorm_k5; (4) marginal_smoothing_k10' 379 | output_quantized = [] 380 | for kk,p in enumerate(full_label): 381 | # 1 hard 382 | if label_type == 'hard': 383 | output = torch.argmax(p) 384 | output_quantized.append(output) 385 | # 2 smoothing 386 | elif label_type == 'smoothing': 387 | output = torch.argmax(p) 388 | value = p[output]#.item() 389 | output_quantized.append(torch.stack([output, value], dim=0)) 390 | # 3 marginal_smoothing_k5 and marginal_renorm_k5 391 | elif label_type == 'marginal_smoothing_k5' or label_type == 'marginal_renorm_k5': 392 | output = torch.argsort(p, descending=True) 393 | value = p[output[:5]] 394 | output_quantized.append(torch.stack([output[:5], value], dim=0)) 395 | # 4 marginal_smoothing_k10 396 | elif label_type == 'marginal_smoothing_k10': 397 | output = torch.argsort(p, descending=True) 398 | value = p[output[:10]] 399 | output_quantized.append(torch.stack([output[:10], value], dim=0)) 400 | 401 | output_quantized = torch.stack(output_quantized, dim=0) 402 | 403 | return output_quantized.detach().numpy() 404 | 405 | 406 | class AverageMeter(object): 407 | """Computes and stores the average and current value""" 408 | def __init__(self, name, fmt=':f'): 409 | self.name = name 410 | self.fmt = fmt 411 | self.reset() 412 | 413 | def reset(self): 414 | self.val = 0 415 | self.avg = 0 416 | self.sum = 0 417 | self.count = 0 418 | 419 | def update(self, val, n=1): 420 | self.val = val 421 | self.sum += val * n 422 | self.count += n 423 | self.avg = self.sum / self.count 424 | 425 | def __str__(self): 426 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 427 | return fmtstr.format(**self.__dict__) 428 | 429 | 430 | class ProgressMeter(object): 431 | def __init__(self, num_batches, meters, prefix=""): 432 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 433 | self.meters = meters 434 | self.prefix = prefix 435 | 436 | def display(self, batch): 437 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 438 | entries += [str(meter) for meter in self.meters] 439 | print('\t'.join(entries)) 440 | 441 | def _get_batch_fmtstr(self, num_batches): 442 | num_digits = len(str(num_batches // 1)) 443 | fmt = '{:' + str(num_digits) + 'd}' 444 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 445 | 446 | 447 | def accuracy(output, target, topk=(1,)): 448 | """Computes the accuracy over the k top predictions for the specified values of k""" 449 | with torch.no_grad(): 450 | maxk = max(topk) 451 | batch_size = target.size(0) 452 | 453 | _, pred = output.topk(maxk, 1, True, True) 454 | pred = pred.t() 455 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 456 | 457 | res = [] 458 | for k in topk: 459 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 460 | res.append(correct_k.mul_(100.0 / batch_size)) 461 | return res 462 | 463 | 464 | if __name__ == '__main__': 465 | main() 466 | -------------------------------------------------------------------------------- /FKD/train_FKD.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import shutil 5 | import time 6 | import math 7 | import warnings 8 | import numpy as np 9 | import builtins 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.parallel 14 | import torch.backends.cudnn as cudnn 15 | import torch.distributed as dist 16 | import torch.optim 17 | import torch.multiprocessing as mp 18 | import torch.utils.data 19 | import torch.utils.data.distributed 20 | import torchvision.transforms as transforms 21 | import torchvision.datasets as datasets 22 | import torchvision.models as models 23 | from torchvision.transforms import InterpolationMode 24 | 25 | from utils_FKD import RandomResizedCrop_FKD, RandomHorizontalFlip_FKD 26 | from utils_FKD import ImageFolder_FKD, Compose_FKD 27 | from utils_FKD import Soft_CrossEntropy, Recover_soft_label 28 | from utils_FKD import mixup_cutmix 29 | 30 | 31 | model_names = sorted(name for name in models.__dict__ 32 | if name.islower() and not name.startswith("__") 33 | and callable(models.__dict__[name])) 34 | 35 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training with FKD Scheme') 36 | parser.add_argument('data', metavar='DIR', 37 | help='path to dataset') 38 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', 39 | choices=model_names, 40 | help='model architecture: ' + 41 | ' | '.join(model_names) + 42 | ' (default: resnet18)') 43 | parser.add_argument('-j', '--workers', default=24, type=int, metavar='N', 44 | help='number of data loading workers (default: 4)') 45 | parser.add_argument('--epochs', default=300, type=int, metavar='N', 46 | help='number of total epochs to run') 47 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 48 | help='manual epoch number (useful on restarts)') 49 | parser.add_argument('-b', '--batch-size', default=1024, type=int, 50 | metavar='N', 51 | help='mini-batch size (default: 1024), this is the total ' 52 | 'batch size of all GPUs on the current node when ' 53 | 'using Data Parallel or Distributed Data Parallel') 54 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 55 | metavar='LR', help='initial learning rate', dest='lr') 56 | parser.add_argument('--schedule', default=[120, 240], nargs='*', type=int, 57 | help='learning rate schedule (when to drop lr by 10x)') 58 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 59 | help='momentum') 60 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 61 | metavar='W', help='weight decay (default: 1e-4)', 62 | dest='weight_decay') 63 | parser.add_argument('-p', '--print-freq', default=10, type=int, 64 | metavar='N', help='print frequency (default: 10)') 65 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 66 | help='path to latest checkpoint (default: none)') 67 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 68 | help='evaluate model on validation set') 69 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 70 | help='use pre-trained model') 71 | parser.add_argument('--world-size', default=-1, type=int, 72 | help='number of nodes for distributed training') 73 | parser.add_argument('--rank', default=-1, type=int, 74 | help='node rank for distributed training') 75 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 76 | help='url used to set up distributed training') 77 | parser.add_argument('--dist-backend', default='nccl', type=str, 78 | help='distributed backend') 79 | parser.add_argument('--seed', default=None, type=int, 80 | help='seed for initializing training. ') 81 | parser.add_argument('--gpu', default=None, type=int, 82 | help='GPU id to use.') 83 | parser.add_argument('--multiprocessing-distributed', action='store_true', 84 | help='Use multi-processing distributed training to launch ' 85 | 'N processes per node, which has N GPUs. This is the ' 86 | 'fastest way to use PyTorch for either single node or ' 87 | 'multi node data parallel training') 88 | parser.add_argument('--num_crops', default=4, type=int, 89 | help='number of crops in each image, 1 is the standard training') 90 | parser.add_argument('--softlabel_path', default='./soft_label', type=str, metavar='PATH', 91 | help='path to soft label files (default: none)') 92 | parser.add_argument("--temp", type=float, default=1.0, 93 | help="temperature on student during training (defautl: 1.0)") 94 | parser.add_argument('--cos', action='store_true', 95 | help='use cosine lr schedule') 96 | parser.add_argument('--save_checkpoint_path', default='./FKD_checkpoints_output', type=str, metavar='PATH', 97 | help='path to latest checkpoint (default: none)') 98 | parser.add_argument('--soft_label_type', default='marginal_smoothing_k5', type=str, metavar='TYPE', 99 | help='(1) ori; (2) hard; (3) smoothing; (4) marginal_smoothing_k5; (5) marginal_smoothing_k10; (6) marginal_renorm_k5') 100 | parser.add_argument('--num_classes', default=1000, type=int, 101 | help='number of classes.') 102 | 103 | # mixup and cutmix parameters 104 | parser.add_argument('--mixup_cutmix', default=False, action='store_true', 105 | help='use mixup and cutmix data augmentation') 106 | parser.add_argument('--mixup', type=float, default=0.8, 107 | help='mixup alpha, mixup enabled if > 0. (default: 0.8)') 108 | parser.add_argument('--cutmix', type=float, default=1.0, 109 | help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)') 110 | parser.add_argument('--mixup_cutmix_prob', type=float, default=1.0, 111 | help='Probability of performing mixup or cutmix when either/both is enabled') 112 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5, 113 | help='Probability of switching to cutmix when both mixup and cutmix enabled. Mixup only: set to 1.0, Cutmix only: set to 0.0') 114 | 115 | best_acc1 = 0 116 | 117 | 118 | def main(): 119 | args = parser.parse_args() 120 | 121 | if not os.path.exists(args.save_checkpoint_path): 122 | os.makedirs(args.save_checkpoint_path) 123 | 124 | # convert to TRUE number of loading-images since we use multiple crops from the same image within a minbatch 125 | args.batch_size = math.ceil(args.batch_size / args.num_crops) 126 | 127 | if args.seed is not None: 128 | random.seed(args.seed) 129 | torch.manual_seed(args.seed) 130 | cudnn.deterministic = True 131 | warnings.warn('You have chosen to seed training. ' 132 | 'This will turn on the CUDNN deterministic setting, ' 133 | 'which can slow down your training considerably! ' 134 | 'You may see unexpected behavior when restarting ' 135 | 'from checkpoints.') 136 | 137 | if args.gpu is not None: 138 | warnings.warn('You have chosen a specific GPU. This will completely ' 139 | 'disable data parallelism.') 140 | 141 | if args.dist_url == "env://" and args.world_size == -1: 142 | args.world_size = int(os.environ["WORLD_SIZE"]) 143 | 144 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 145 | 146 | ngpus_per_node = torch.cuda.device_count() 147 | if args.multiprocessing_distributed: 148 | # Since we have ngpus_per_node processes per node, the total world_size 149 | # needs to be adjusted accordingly 150 | args.world_size = ngpus_per_node * args.world_size 151 | # Use torch.multiprocessing.spawn to launch distributed processes: the 152 | # main_worker process function 153 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 154 | else: 155 | # Simply call main_worker function 156 | main_worker(args.gpu, ngpus_per_node, args) 157 | 158 | 159 | def main_worker(gpu, ngpus_per_node, args): 160 | global best_acc1 161 | args.gpu = gpu 162 | 163 | # suppress printing if not master 164 | if args.multiprocessing_distributed and args.gpu != 0: 165 | def print_pass(*args): 166 | pass 167 | builtins.print = print_pass 168 | 169 | if args.gpu is not None: 170 | print("Use GPU: {} for training".format(args.gpu)) 171 | 172 | if args.distributed: 173 | if args.dist_url == "env://" and args.rank == -1: 174 | args.rank = int(os.environ["RANK"]) 175 | if args.multiprocessing_distributed: 176 | # For multiprocessing distributed training, rank needs to be the 177 | # global rank among all the processes 178 | args.rank = args.rank * ngpus_per_node + gpu 179 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 180 | world_size=args.world_size, rank=args.rank) 181 | # create model 182 | if args.pretrained: 183 | print("=> using pre-trained model '{}'".format(args.arch)) 184 | model = models.__dict__[args.arch](pretrained=True) 185 | else: 186 | print("=> creating model '{}'".format(args.arch)) 187 | model = models.__dict__[args.arch](pretrained=False, num_classes=args.num_classes) 188 | 189 | if not torch.cuda.is_available(): 190 | print('using CPU, this will be slow') 191 | elif args.distributed: 192 | # For multiprocessing distributed, DistributedDataParallel constructor 193 | # should always set the single device scope, otherwise, 194 | # DistributedDataParallel will use all available devices. 195 | if args.gpu is not None: 196 | torch.cuda.set_device(args.gpu) 197 | model.cuda(args.gpu) 198 | # When using a single GPU per process and per 199 | # DistributedDataParallel, we need to divide the batch size 200 | # ourselves based on the total number of GPUs we have 201 | args.batch_size = int(args.batch_size / ngpus_per_node) 202 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 203 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 204 | else: 205 | model.cuda() 206 | # DistributedDataParallel will divide and allocate batch_size to all 207 | # available GPUs if device_ids are not set 208 | model = torch.nn.parallel.DistributedDataParallel(model) 209 | elif args.gpu is not None: 210 | torch.cuda.set_device(args.gpu) 211 | model = model.cuda(args.gpu) 212 | else: 213 | # DataParallel will divide and allocate batch_size to all available GPUs 214 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 215 | model.features = torch.nn.DataParallel(model.features) 216 | model.cuda() 217 | else: 218 | model = torch.nn.DataParallel(model).cuda() 219 | 220 | # define loss function (criterion) and optimizer 221 | criterion_sce = Soft_CrossEntropy() 222 | criterion_ce = nn.CrossEntropyLoss().cuda(args.gpu) 223 | 224 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 225 | momentum=args.momentum, 226 | weight_decay=args.weight_decay) 227 | 228 | # optionally resume from a checkpoint 229 | if args.resume: 230 | if os.path.isfile(args.resume): 231 | print("=> loading checkpoint '{}'".format(args.resume)) 232 | if args.gpu is None: 233 | checkpoint = torch.load(args.resume) 234 | else: 235 | # Map model to be loaded to specified single gpu. 236 | loc = 'cuda:{}'.format(args.gpu) 237 | checkpoint = torch.load(args.resume, map_location=loc) 238 | args.start_epoch = checkpoint['epoch'] 239 | best_acc1 = checkpoint['best_acc1'] 240 | if args.gpu is not None: 241 | # best_acc1 may be from a checkpoint from a different GPU 242 | best_acc1 = best_acc1.to(args.gpu) 243 | model.load_state_dict(checkpoint['state_dict']) 244 | optimizer.load_state_dict(checkpoint['optimizer']) 245 | print("=> loaded checkpoint '{}' (epoch {})" 246 | .format(args.resume, checkpoint['epoch'])) 247 | else: 248 | print("=> no checkpoint found at '{}'".format(args.resume)) 249 | 250 | cudnn.benchmark = True 251 | 252 | # Data loading code 253 | traindir = os.path.join(args.data, 'train') 254 | valdir = os.path.join(args.data, 'val') 255 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 256 | std=[0.229, 0.224, 0.225]) 257 | 258 | train_dataset = ImageFolder_FKD( 259 | num_crops=args.num_crops, 260 | softlabel_path=args.softlabel_path, 261 | root=traindir, 262 | transform=Compose_FKD(transforms=[ 263 | RandomResizedCrop_FKD(size=224, 264 | interpolation='bilinear'), 265 | RandomHorizontalFlip_FKD(), 266 | transforms.ToTensor(), 267 | normalize, 268 | ])) 269 | train_dataset_single_crop = ImageFolder_FKD( 270 | num_crops=1, 271 | softlabel_path=args.softlabel_path, 272 | root=traindir, 273 | transform=Compose_FKD(transforms=[ 274 | RandomResizedCrop_FKD(size=224, 275 | interpolation='bilinear'), 276 | RandomHorizontalFlip_FKD(), 277 | transforms.ToTensor(), 278 | normalize, 279 | ])) 280 | 281 | if args.distributed: 282 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 283 | else: 284 | train_sampler = None 285 | 286 | train_loader = torch.utils.data.DataLoader( 287 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 288 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 289 | 290 | train_loader_single_crop = torch.utils.data.DataLoader( 291 | train_dataset_single_crop, batch_size=args.batch_size*args.num_crops, shuffle=(train_sampler is None), 292 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 293 | 294 | val_loader = torch.utils.data.DataLoader( 295 | datasets.ImageFolder(valdir, transforms.Compose([ 296 | transforms.Resize(256), 297 | transforms.CenterCrop(224), 298 | transforms.ToTensor(), 299 | normalize, 300 | ])), 301 | batch_size=args.batch_size, shuffle=False, 302 | num_workers=args.workers, pin_memory=True) 303 | 304 | if args.evaluate: 305 | validate(val_loader, model, criterion_ce, args) 306 | return 307 | 308 | # for resume 309 | if args.start_epoch !=0 and args.start_epoch < (args.epochs-args.num_crops): 310 | args.start_epoch = args.start_epoch + args.num_crops - 1 311 | 312 | for epoch in range(args.start_epoch, args.epochs, args.num_crops): 313 | if args.distributed: 314 | train_sampler.set_epoch(epoch) 315 | adjust_learning_rate(optimizer, epoch, args) 316 | # for fine-grained evaluation at last a few epochs 317 | if epoch >= (args.epochs-args.num_crops): 318 | start_epoch = epoch 319 | for epoch in range(start_epoch, args.epochs): 320 | if args.distributed: 321 | train_sampler.set_epoch(epoch) 322 | adjust_learning_rate(optimizer, epoch, args) 323 | 324 | # train for one epoch 325 | train(train_loader_single_crop, model, criterion_sce, optimizer, epoch, args) 326 | 327 | # evaluate on validation set 328 | acc1 = validate(val_loader, model, criterion_ce, args) 329 | 330 | # remember best acc@1 and save checkpoint 331 | is_best = acc1 > best_acc1 332 | best_acc1 = max(acc1, best_acc1) 333 | 334 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 335 | and args.rank % ngpus_per_node == 0): 336 | save_checkpoint({ 337 | 'epoch': epoch + 1, 338 | 'arch': args.arch, 339 | 'state_dict': model.state_dict(), 340 | 'best_acc1': best_acc1, 341 | 'optimizer' : optimizer.state_dict(), 342 | }, is_best, filename=args.save_checkpoint_path+'/checkpoint.pth.tar') 343 | return 344 | else: 345 | # train for one epoch 346 | train(train_loader, model, criterion_sce, optimizer, epoch, args) 347 | 348 | # evaluate on validation set 349 | acc1 = validate(val_loader, model, criterion_ce, args) 350 | 351 | # remember best acc@1 and save checkpoint 352 | is_best = acc1 > best_acc1 353 | best_acc1 = max(acc1, best_acc1) 354 | 355 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 356 | and args.rank % ngpus_per_node == 0): 357 | save_checkpoint({ 358 | 'epoch': epoch + 1, 359 | 'arch': args.arch, 360 | 'state_dict': model.state_dict(), 361 | 'best_acc1': best_acc1, 362 | 'optimizer' : optimizer.state_dict(), 363 | }, is_best, filename=args.save_checkpoint_path+'/checkpoint.pth.tar') 364 | 365 | 366 | def train(train_loader, model, criterion, optimizer, epoch, args): 367 | batch_time = AverageMeter('Time', ':6.3f') 368 | data_time = AverageMeter('Data', ':6.3f') 369 | losses = AverageMeter('Loss', ':.4e') 370 | top1 = AverageMeter('Acc@1', ':6.2f') 371 | top5 = AverageMeter('Acc@5', ':6.2f') 372 | progress = ProgressMeter( 373 | len(train_loader), 374 | [batch_time, data_time, losses, top1, top5, 'LR {lr:.5f}'.format(lr=_get_learning_rate(optimizer))], 375 | prefix="Epoch: [{}]".format(epoch)) 376 | 377 | # switch to train mode 378 | model.train() 379 | 380 | end = time.time() 381 | for i, (images, target, soft_label) in enumerate(train_loader): 382 | # measure data loading time 383 | data_time.update(time.time() - end) 384 | 385 | # reshape images and soft label 386 | images = torch.cat(images, dim=0) 387 | soft_label = torch.cat(soft_label, dim=0) 388 | target = torch.cat(target, dim=0) 389 | 390 | if args.soft_label_type != 'ori': 391 | soft_label = Recover_soft_label(soft_label, args.soft_label_type, args.num_classes) 392 | 393 | if args.gpu is not None: 394 | images = images.cuda(args.gpu, non_blocking=True) 395 | if torch.cuda.is_available(): 396 | target = target.cuda(args.gpu, non_blocking=True) 397 | soft_label = soft_label.cuda(args.gpu, non_blocking=True) 398 | 399 | if args.mixup_cutmix: 400 | images, soft_label = mixup_cutmix(images, soft_label, args) 401 | 402 | # compute output 403 | output = model(images) 404 | loss = criterion(output / args.temp, soft_label) 405 | 406 | # measure accuracy and record loss 407 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 408 | losses.update(loss.item(), images.size(0)) 409 | top1.update(acc1[0], images.size(0)) 410 | top5.update(acc5[0], images.size(0)) 411 | 412 | # compute gradient and do SGD step 413 | optimizer.zero_grad() 414 | loss.backward() 415 | optimizer.step() 416 | 417 | # measure elapsed time 418 | batch_time.update(time.time() - end) 419 | end = time.time() 420 | 421 | if i % args.print_freq == 0: 422 | progress.display(i) 423 | 424 | 425 | def validate(val_loader, model, criterion, args): 426 | batch_time = AverageMeter('Time', ':6.3f') 427 | losses = AverageMeter('Loss', ':.4e') 428 | top1 = AverageMeter('Acc@1', ':6.2f') 429 | top5 = AverageMeter('Acc@5', ':6.2f') 430 | progress = ProgressMeter( 431 | len(val_loader), 432 | [batch_time, losses, top1, top5], 433 | prefix='Test: ') 434 | 435 | # switch to evaluate mode 436 | model.eval() 437 | 438 | with torch.no_grad(): 439 | end = time.time() 440 | for i, (images, target) in enumerate(val_loader): 441 | if args.gpu is not None: 442 | images = images.cuda(args.gpu, non_blocking=True) 443 | if torch.cuda.is_available(): 444 | target = target.cuda(args.gpu, non_blocking=True) 445 | 446 | # compute output 447 | output = model(images) 448 | loss = criterion(output, target) 449 | 450 | # measure accuracy and record loss 451 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 452 | losses.update(loss.item(), images.size(0)) 453 | top1.update(acc1[0], images.size(0)) 454 | top5.update(acc5[0], images.size(0)) 455 | 456 | # measure elapsed time 457 | batch_time.update(time.time() - end) 458 | end = time.time() 459 | 460 | if i % args.print_freq == 0: 461 | progress.display(i) 462 | 463 | # TODO: this should also be done with the ProgressMeter 464 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 465 | .format(top1=top1, top5=top5)) 466 | 467 | return top1.avg 468 | 469 | 470 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 471 | torch.save(state, filename) 472 | if is_best: 473 | shutil.copyfile(filename, filename[:-19]+'/model_best.pth.tar') 474 | 475 | 476 | class AverageMeter(object): 477 | """Computes and stores the average and current value""" 478 | def __init__(self, name, fmt=':f'): 479 | self.name = name 480 | self.fmt = fmt 481 | self.reset() 482 | 483 | def reset(self): 484 | self.val = 0 485 | self.avg = 0 486 | self.sum = 0 487 | self.count = 0 488 | 489 | def update(self, val, n=1): 490 | self.val = val 491 | self.sum += val * n 492 | self.count += n 493 | self.avg = self.sum / self.count 494 | 495 | def __str__(self): 496 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 497 | return fmtstr.format(**self.__dict__) 498 | 499 | 500 | class ProgressMeter(object): 501 | def __init__(self, num_batches, meters, prefix=""): 502 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 503 | self.meters = meters 504 | self.prefix = prefix 505 | 506 | def display(self, batch): 507 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 508 | entries += [str(meter) for meter in self.meters] 509 | print('\t'.join(entries)) 510 | 511 | def _get_batch_fmtstr(self, num_batches): 512 | num_digits = len(str(num_batches // 1)) 513 | fmt = '{:' + str(num_digits) + 'd}' 514 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 515 | 516 | 517 | def adjust_learning_rate(optimizer, epoch, args): 518 | """Decay the learning rate based on schedule""" 519 | lr = args.lr 520 | if args.cos: # cosine lr schedule 521 | lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) 522 | else: # stepwise lr schedule 523 | for milestone in args.schedule: 524 | lr *= 0.1 if epoch >= milestone else 1. 525 | for param_group in optimizer.param_groups: 526 | param_group['lr'] = lr 527 | 528 | def _get_learning_rate(optimizer): 529 | return max(param_group['lr'] for param_group in optimizer.param_groups) 530 | 531 | def accuracy(output, target, topk=(1,)): 532 | """Computes the accuracy over the k top predictions for the specified values of k""" 533 | with torch.no_grad(): 534 | maxk = max(topk) 535 | batch_size = target.size(0) 536 | 537 | _, pred = output.topk(maxk, 1, True, True) 538 | pred = pred.t() 539 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 540 | 541 | res = [] 542 | for k in topk: 543 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 544 | res.append(correct_k.mul_(100.0 / batch_size)) 545 | return res 546 | 547 | 548 | if __name__ == '__main__': 549 | main() -------------------------------------------------------------------------------- /FKD/FKD_ViT/train_ViT_FKD.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import shutil 5 | import time 6 | import math 7 | import warnings 8 | import numpy as np 9 | import builtins 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.parallel 14 | import torch.backends.cudnn as cudnn 15 | import torch.distributed as dist 16 | import torch.optim 17 | import torch.multiprocessing as mp 18 | import torch.utils.data 19 | import torch.utils.data.distributed 20 | import torchvision.transforms as transforms 21 | import torchvision.datasets as datasets 22 | import torchvision.models as models 23 | from torchvision.transforms import InterpolationMode 24 | 25 | from utils_FKD import RandomResizedCrop_FKD, RandomHorizontalFlip_FKD 26 | from utils_FKD import ImageFolder_FKD, Compose_FKD 27 | from utils_FKD import Soft_CrossEntropy, Recover_soft_label 28 | from utils_FKD import mixup_cutmix 29 | 30 | import timm 31 | from timm.scheduler import create_scheduler 32 | from timm.optim import create_optimizer 33 | import SReT 34 | 35 | # timm is used to build the optimizer and learning rate scheduler (https://github.com/rwightman/pytorch-image-models) 36 | 37 | 38 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training with FKD Scheme') 39 | parser.add_argument('data', metavar='DIR', 40 | help='path to dataset') 41 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50') 42 | parser.add_argument('-j', '--workers', default=24, type=int, metavar='N', 43 | help='number of data loading workers (default: 24)') 44 | parser.add_argument('--epochs', default=300, type=int, metavar='N', 45 | help='number of total epochs to run') 46 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 47 | help='manual epoch number (useful on restarts)') 48 | parser.add_argument('-b', '--batch-size', default=256, type=int, 49 | metavar='N', 50 | help='mini-batch size (default: 256), this is the total ' 51 | 'batch size of all GPUs on the current node when ' 52 | 'using Data Parallel or Distributed Data Parallel') 53 | parser.add_argument('--lr', '--learning-rate', default=0.002, type=float, 54 | metavar='LR', help='initial learning rate', dest='lr') 55 | parser.add_argument('--schedule', default=[120, 160], nargs='*', type=int, 56 | help='learning rate schedule (when to drop lr by 10x)') 57 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 58 | help='momentum') 59 | parser.add_argument('--wd', '--weight-decay', default=0.05, type=float, 60 | metavar='W', help='weight decay (default: 1e-4)', 61 | dest='weight_decay') 62 | parser.add_argument('-p', '--print-freq', default=10, type=int, 63 | metavar='N', help='print frequency (default: 10)') 64 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 65 | help='path to latest checkpoint (default: none)') 66 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 67 | help='evaluate model on validation set') 68 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 69 | help='use pre-trained model') 70 | parser.add_argument('--world-size', default=-1, type=int, 71 | help='number of nodes for distributed training') 72 | parser.add_argument('--rank', default=-1, type=int, 73 | help='node rank for distributed training') 74 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 75 | help='url used to set up distributed training') 76 | parser.add_argument('--dist-backend', default='nccl', type=str, 77 | help='distributed backend') 78 | parser.add_argument('--seed', default=None, type=int, 79 | help='seed for initializing training. ') 80 | parser.add_argument('--gpu', default=None, type=int, 81 | help='GPU id to use.') 82 | parser.add_argument('--multiprocessing-distributed', action='store_true', 83 | help='Use multi-processing distributed training to launch ' 84 | 'N processes per node, which has N GPUs. This is the ' 85 | 'fastest way to use PyTorch for either single node or ' 86 | 'multi node data parallel training') 87 | parser.add_argument('--num_crops', default=4, type=int, 88 | help='number of crops in each image, 1 is the standard training') 89 | parser.add_argument('--softlabel_path', default='../soft_label', type=str, metavar='PATH', 90 | help='path to soft label files (default: none)') 91 | parser.add_argument("--temp", type=float, default=1.0, 92 | help="temperature on student during training (defautl: 1.0)") 93 | parser.add_argument('--save_checkpoint_path', default='./FKD_checkpoints_output', type=str, metavar='PATH', 94 | help='path to latest checkpoint (default: none)') 95 | parser.add_argument('--soft_label_type', default='marginal_smoothing_k5', type=str, metavar='TYPE', 96 | help='(1) ori; (2) hard; (3) smoothing; (4) marginal_smoothing_k5; (5) marginal_smoothing_k10; (6) marginal_renorm_k5') 97 | parser.add_argument('--num_classes', default=1000, type=int, 98 | help='number of classes.') 99 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', 100 | help='Dropout rate (default: 0.)') 101 | parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT', 102 | help='Drop path rate (default: 0.1)') 103 | 104 | # Optimizer parameters 105 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', 106 | help='Optimizer (default: "adamw"') 107 | parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', 108 | help='Optimizer Epsilon (default: 1e-8)') 109 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', 110 | help='Optimizer Betas (default: None, use opt default)') 111 | parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', 112 | help='Clip gradient norm (default: None, no clipping)') 113 | 114 | # Learning rate schedule parameters 115 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', 116 | help='LR scheduler (default: "cosine"') 117 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', 118 | help='learning rate noise on/off epoch percentages') 119 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', 120 | help='learning rate noise limit percent (default: 0.67)') 121 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', 122 | help='learning rate noise std-dev (default: 1.0)') 123 | parser.add_argument('--warmup-lr', type=float, default=1e-5, metavar='LR', 124 | help='warmup learning rate (default: 1e-5)') 125 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', 126 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 127 | parser.add_argument('--cos', action='store_true', 128 | help='use cosine lr schedule') 129 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', 130 | help='epoch interval to decay LR') 131 | parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N', 132 | help='epochs to warmup LR, if scheduler supports') 133 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', 134 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends') 135 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', 136 | help='patience epochs for Plateau LR scheduler (default: 10') 137 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', 138 | help='LR decay rate (default: 0.1)') 139 | 140 | # mixup and cutmix parameters 141 | parser.add_argument('--mixup_cutmix', default=False, action='store_true', 142 | help='use mixup and cutmix data augmentation') 143 | parser.add_argument('--mixup', type=float, default=0.8, 144 | help='mixup alpha, mixup enabled if > 0. (default: 0.8)') 145 | parser.add_argument('--cutmix', type=float, default=1.0, 146 | help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)') 147 | parser.add_argument('--mixup_cutmix_prob', type=float, default=1.0, 148 | help='Probability of performing mixup or cutmix when either/both is enabled') 149 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5, 150 | help='Probability of switching to cutmix when both mixup and cutmix enabled. Mixup only: set to 1.0, Cutmix only: set to 0.0') 151 | 152 | best_acc1 = 0 153 | 154 | 155 | def main(): 156 | args = parser.parse_args() 157 | 158 | if not os.path.exists(args.save_checkpoint_path): 159 | os.makedirs(args.save_checkpoint_path) 160 | 161 | # convert to TRUE number of loading-images and #epochs since we use multiple crops from the same image within a minbatch 162 | args.batch_size = math.ceil(args.batch_size / args.num_crops) 163 | 164 | if args.seed is not None: 165 | random.seed(args.seed) 166 | torch.manual_seed(args.seed) 167 | cudnn.deterministic = True 168 | warnings.warn('You have chosen to seed training. ' 169 | 'This will turn on the CUDNN deterministic setting, ' 170 | 'which can slow down your training considerably! ' 171 | 'You may see unexpected behavior when restarting ' 172 | 'from checkpoints.') 173 | 174 | if args.gpu is not None: 175 | warnings.warn('You have chosen a specific GPU. This will completely ' 176 | 'disable data parallelism.') 177 | 178 | if args.dist_url == "env://" and args.world_size == -1: 179 | args.world_size = int(os.environ["WORLD_SIZE"]) 180 | 181 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 182 | 183 | ngpus_per_node = torch.cuda.device_count() 184 | if args.multiprocessing_distributed: 185 | # Since we have ngpus_per_node processes per node, the total world_size 186 | # needs to be adjusted accordingly 187 | args.world_size = ngpus_per_node * args.world_size 188 | # Use torch.multiprocessing.spawn to launch distributed processes: the 189 | # main_worker process function 190 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 191 | else: 192 | # Simply call main_worker function 193 | main_worker(args.gpu, ngpus_per_node, args) 194 | 195 | 196 | def main_worker(gpu, ngpus_per_node, args): 197 | global best_acc1 198 | args.gpu = gpu 199 | 200 | # suppress printing if not master 201 | if args.multiprocessing_distributed and args.gpu != 0: 202 | def print_pass(*args): 203 | pass 204 | builtins.print = print_pass 205 | 206 | if args.gpu is not None: 207 | print("Use GPU: {} for training".format(args.gpu)) 208 | 209 | if args.distributed: 210 | if args.dist_url == "env://" and args.rank == -1: 211 | args.rank = int(os.environ["RANK"]) 212 | if args.multiprocessing_distributed: 213 | # For multiprocessing distributed training, rank needs to be the 214 | # global rank among all the processes 215 | args.rank = args.rank * ngpus_per_node + gpu 216 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 217 | world_size=args.world_size, rank=args.rank) 218 | # create model 219 | if args.pretrained: 220 | print("=> using pre-trained model '{}'".format(args.arch)) 221 | model = timm.create_model(args.arch, pretrained=True, num_classes=args.num_classes) 222 | else: 223 | print("=> creating model '{}'".format(args.arch)) 224 | if args.arch.split('_')[0] == 'SReT': 225 | model = SReT.__dict__[args.arch](pretrained=False, num_classes=args.num_classes) 226 | else: 227 | model = timm.create_model(args.arch, pretrained=False, num_classes=args.num_classes) 228 | 229 | if not torch.cuda.is_available(): 230 | print('using CPU, this will be slow') 231 | elif args.distributed: 232 | # For multiprocessing distributed, DistributedDataParallel constructor 233 | # should always set the single device scope, otherwise, 234 | # DistributedDataParallel will use all available devices. 235 | if args.gpu is not None: 236 | torch.cuda.set_device(args.gpu) 237 | model.cuda(args.gpu) 238 | # When using a single GPU per process and per 239 | # DistributedDataParallel, we need to divide the batch size 240 | # ourselves based on the total number of GPUs we have 241 | args.batch_size = int(args.batch_size / ngpus_per_node) 242 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 243 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 244 | else: 245 | model.cuda() 246 | # DistributedDataParallel will divide and allocate batch_size to all 247 | # available GPUs if device_ids are not set 248 | model = torch.nn.parallel.DistributedDataParallel(model) 249 | elif args.gpu is not None: 250 | torch.cuda.set_device(args.gpu) 251 | model = model.cuda(args.gpu) 252 | else: 253 | # DataParallel will divide and allocate batch_size to all available GPUs 254 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 255 | model.features = torch.nn.DataParallel(model.features) 256 | model.cuda() 257 | else: 258 | model = torch.nn.DataParallel(model).cuda() 259 | 260 | # define loss function (criterion) and optimizer 261 | criterion_sce = Soft_CrossEntropy() 262 | criterion_ce = nn.CrossEntropyLoss().cuda(args.gpu) 263 | 264 | optimizer = create_optimizer(args, model) 265 | 266 | lr_scheduler, _ = create_scheduler(args, optimizer) 267 | 268 | # optionally resume from a checkpoint 269 | if args.resume: 270 | if os.path.isfile(args.resume): 271 | print("=> loading checkpoint '{}'".format(args.resume)) 272 | if args.gpu is None: 273 | checkpoint = torch.load(args.resume) 274 | else: 275 | # Map model to be loaded to specified single gpu. 276 | loc = 'cuda:{}'.format(args.gpu) 277 | checkpoint = torch.load(args.resume, map_location=loc) 278 | args.start_epoch = checkpoint['epoch'] 279 | best_acc1 = checkpoint['best_acc1'] 280 | if args.gpu is not None: 281 | # best_acc1 may be from a checkpoint from a different GPU 282 | best_acc1 = best_acc1.to(args.gpu) 283 | model.load_state_dict(checkpoint['state_dict']) 284 | optimizer.load_state_dict(checkpoint['optimizer']) 285 | print("=> loaded checkpoint '{}' (epoch {})" 286 | .format(args.resume, checkpoint['epoch'])) 287 | else: 288 | print("=> no checkpoint found at '{}'".format(args.resume)) 289 | 290 | cudnn.benchmark = True 291 | 292 | # Data loading code 293 | traindir = os.path.join(args.data, 'train') 294 | valdir = os.path.join(args.data, 'val') 295 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 296 | std=[0.229, 0.224, 0.225]) 297 | 298 | train_dataset = ImageFolder_FKD( 299 | num_crops=args.num_crops, 300 | softlabel_path=args.softlabel_path, 301 | root=traindir, 302 | transform=Compose_FKD(transforms=[ 303 | RandomResizedCrop_FKD(size=224, 304 | interpolation='bilinear'), 305 | RandomHorizontalFlip_FKD(), 306 | transforms.ToTensor(), 307 | normalize, 308 | ])) 309 | train_dataset_single_crop = ImageFolder_FKD( 310 | num_crops=1, 311 | softlabel_path=args.softlabel_path, 312 | root=traindir, 313 | transform=Compose_FKD(transforms=[ 314 | RandomResizedCrop_FKD(size=224, 315 | interpolation='bilinear'), 316 | RandomHorizontalFlip_FKD(), 317 | transforms.ToTensor(), 318 | normalize, 319 | ])) 320 | 321 | if args.distributed: 322 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 323 | else: 324 | train_sampler = None 325 | 326 | train_loader = torch.utils.data.DataLoader( 327 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 328 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 329 | 330 | train_loader_single_crop = torch.utils.data.DataLoader( 331 | train_dataset_single_crop, batch_size=args.batch_size*args.num_crops, shuffle=(train_sampler is None), 332 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 333 | 334 | val_loader = torch.utils.data.DataLoader( 335 | datasets.ImageFolder(valdir, transforms.Compose([ 336 | transforms.Resize(256), 337 | transforms.CenterCrop(224), 338 | transforms.ToTensor(), 339 | normalize, 340 | ])), 341 | batch_size=args.batch_size, shuffle=False, 342 | num_workers=args.workers, pin_memory=True) 343 | 344 | if args.evaluate: 345 | validate(val_loader, model, criterion_ce, args) 346 | return 347 | 348 | # warmup with single crop, "=" is used to let start_epoch to be 0 for the corner case. 349 | if args.start_epoch <= args.warmup_epochs: 350 | for epoch in range(args.start_epoch, args.warmup_epochs): 351 | if args.distributed: 352 | train_sampler.set_epoch(epoch) 353 | # train for one epoch 354 | train(train_loader_single_crop, model, criterion_sce, optimizer, epoch, args) 355 | lr_scheduler.step(epoch + 1) 356 | 357 | # evaluate on validation set 358 | acc1 = validate(val_loader, model, criterion_ce, args) 359 | 360 | # remember best acc@1 and save checkpoint 361 | is_best = acc1 > best_acc1 362 | best_acc1 = max(acc1, best_acc1) 363 | 364 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 365 | and args.rank % ngpus_per_node == 0): 366 | save_checkpoint({ 367 | 'epoch': epoch + 1, 368 | 'arch': args.arch, 369 | 'state_dict': model.state_dict(), 370 | 'best_acc1': best_acc1, 371 | 'optimizer' : optimizer.state_dict(), 372 | }, is_best, filename=args.save_checkpoint_path+'/checkpoint.pth.tar') 373 | args.start_epoch = 0 # for resume 374 | else: 375 | args.warmup_epochs = args.num_crops - 1 # for resume 376 | 377 | for epoch in range(args.start_epoch+args.warmup_epochs, args.epochs, args.num_crops): 378 | if args.distributed: 379 | train_sampler.set_epoch(epoch) 380 | 381 | # for fine-grained evaluation at last a few epochs 382 | if epoch >= (args.epochs-args.num_crops): 383 | start_epoch = epoch 384 | for epoch in range(start_epoch, args.epochs): 385 | if args.distributed: 386 | train_sampler.set_epoch(epoch) 387 | lr_scheduler.step(epoch+1) 388 | # train for one epoch 389 | train(train_loader_single_crop, model, criterion_sce, optimizer, epoch, args) 390 | 391 | # evaluate on validation set 392 | acc1 = validate(val_loader, model, criterion_ce, args) 393 | 394 | # remember best acc@1 and save checkpoint 395 | is_best = acc1 > best_acc1 396 | best_acc1 = max(acc1, best_acc1) 397 | 398 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 399 | and args.rank % ngpus_per_node == 0): 400 | save_checkpoint({ 401 | 'epoch': epoch + 1, 402 | 'arch': args.arch, 403 | 'state_dict': model.state_dict(), 404 | 'best_acc1': best_acc1, 405 | 'optimizer' : optimizer.state_dict(), 406 | }, is_best, filename=args.save_checkpoint_path+'/checkpoint.pth.tar') 407 | return 408 | else: 409 | # train for one epoch 410 | train(train_loader, model, criterion_sce, optimizer, epoch, args) 411 | lr_scheduler.step(epoch+args.num_crops) 412 | 413 | # evaluate on validation set 414 | acc1 = validate(val_loader, model, criterion_ce, args) 415 | 416 | # remember best acc@1 and save checkpoint 417 | is_best = acc1 > best_acc1 418 | best_acc1 = max(acc1, best_acc1) 419 | 420 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 421 | and args.rank % ngpus_per_node == 0): 422 | save_checkpoint({ 423 | 'epoch': epoch + 1, 424 | 'arch': args.arch, 425 | 'state_dict': model.state_dict(), 426 | 'best_acc1': best_acc1, 427 | 'optimizer' : optimizer.state_dict(), 428 | }, is_best, filename=args.save_checkpoint_path+'/checkpoint.pth.tar') 429 | 430 | 431 | def train(train_loader, model, criterion, optimizer, epoch, args): 432 | batch_time = AverageMeter('Time', ':6.3f') 433 | data_time = AverageMeter('Data', ':6.3f') 434 | losses = AverageMeter('Loss', ':.4e') 435 | top1 = AverageMeter('Acc@1', ':6.2f') 436 | top5 = AverageMeter('Acc@5', ':6.2f') 437 | progress = ProgressMeter( 438 | len(train_loader), 439 | [batch_time, data_time, losses, top1, top5, 'LR {lr:.5f}'.format(lr=_get_learning_rate(optimizer))], 440 | prefix="Epoch: [{}]".format(epoch)) 441 | 442 | # switch to train mode 443 | model.train() 444 | 445 | end = time.time() 446 | for i, (images, target, soft_label) in enumerate(train_loader): 447 | # measure data loading time 448 | data_time.update(time.time() - end) 449 | 450 | images = torch.cat(images, dim=0) 451 | soft_label = torch.cat(soft_label, dim=0) 452 | target = torch.cat(target, dim=0) 453 | 454 | if args.soft_label_type != 'ori': 455 | soft_label = Recover_soft_label(soft_label, args.soft_label_type, args.num_classes) 456 | 457 | if args.gpu is not None: 458 | images = images.cuda(args.gpu, non_blocking=True) 459 | if torch.cuda.is_available(): 460 | target = target.cuda(args.gpu, non_blocking=True) 461 | soft_label = soft_label.cuda(args.gpu, non_blocking=True) 462 | 463 | if args.mixup_cutmix: 464 | images, soft_label = mixup_cutmix(images, soft_label, args) 465 | 466 | # compute output 467 | output = model(images) 468 | loss = criterion(output / args.temp, soft_label) 469 | 470 | # measure accuracy and record loss 471 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 472 | losses.update(loss.item(), images.size(0)) 473 | top1.update(acc1[0], images.size(0)) 474 | top5.update(acc5[0], images.size(0)) 475 | 476 | # compute gradient and update 477 | optimizer.zero_grad() 478 | loss.backward() 479 | optimizer.step() 480 | 481 | # measure elapsed time 482 | batch_time.update(time.time() - end) 483 | end = time.time() 484 | 485 | if i % args.print_freq == 0: 486 | t = time.localtime() 487 | current_time = time.strftime("%H:%M:%S", t) 488 | print(current_time) 489 | progress.display(i) 490 | 491 | 492 | def validate(val_loader, model, criterion, args): 493 | batch_time = AverageMeter('Time', ':6.3f') 494 | losses = AverageMeter('Loss', ':.4e') 495 | top1 = AverageMeter('Acc@1', ':6.2f') 496 | top5 = AverageMeter('Acc@5', ':6.2f') 497 | progress = ProgressMeter( 498 | len(val_loader), 499 | [batch_time, losses, top1, top5], 500 | prefix='Test: ') 501 | 502 | # switch to evaluate mode 503 | model.eval() 504 | 505 | with torch.no_grad(): 506 | end = time.time() 507 | for i, (images, target) in enumerate(val_loader): 508 | if args.gpu is not None: 509 | images = images.cuda(args.gpu, non_blocking=True) 510 | if torch.cuda.is_available(): 511 | target = target.cuda(args.gpu, non_blocking=True) 512 | 513 | # compute output 514 | output = model(images) 515 | loss = criterion(output, target) 516 | 517 | # measure accuracy and record loss 518 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 519 | losses.update(loss.item(), images.size(0)) 520 | top1.update(acc1[0], images.size(0)) 521 | top5.update(acc5[0], images.size(0)) 522 | 523 | # measure elapsed time 524 | batch_time.update(time.time() - end) 525 | end = time.time() 526 | 527 | if i % args.print_freq == 0: 528 | progress.display(i) 529 | 530 | # TODO: this should also be done with the ProgressMeter 531 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 532 | .format(top1=top1, top5=top5)) 533 | 534 | return top1.avg 535 | 536 | 537 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 538 | torch.save(state, filename) 539 | if is_best: 540 | shutil.copyfile(filename, filename[:-19]+'/model_best.pth.tar') 541 | 542 | 543 | class AverageMeter(object): 544 | """Computes and stores the average and current value""" 545 | def __init__(self, name, fmt=':f'): 546 | self.name = name 547 | self.fmt = fmt 548 | self.reset() 549 | 550 | def reset(self): 551 | self.val = 0 552 | self.avg = 0 553 | self.sum = 0 554 | self.count = 0 555 | 556 | def update(self, val, n=1): 557 | self.val = val 558 | self.sum += val * n 559 | self.count += n 560 | self.avg = self.sum / self.count 561 | 562 | def __str__(self): 563 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 564 | return fmtstr.format(**self.__dict__) 565 | 566 | 567 | class ProgressMeter(object): 568 | def __init__(self, num_batches, meters, prefix=""): 569 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 570 | self.meters = meters 571 | self.prefix = prefix 572 | 573 | def display(self, batch): 574 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 575 | entries += [str(meter) for meter in self.meters] 576 | print('\t'.join(entries)) 577 | 578 | def _get_batch_fmtstr(self, num_batches): 579 | num_digits = len(str(num_batches // 1)) 580 | fmt = '{:' + str(num_digits) + 'd}' 581 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 582 | 583 | 584 | def adjust_learning_rate(optimizer, epoch, args): 585 | """Decay the learning rate based on schedule""" 586 | lr = args.lr 587 | if args.cos: # cosine lr schedule 588 | lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) 589 | else: # stepwise lr schedule 590 | for milestone in args.schedule: 591 | lr *= 0.1 if epoch >= milestone else 1. 592 | for param_group in optimizer.param_groups: 593 | param_group['lr'] = lr 594 | 595 | def _get_learning_rate(optimizer): 596 | return max(param_group['lr'] for param_group in optimizer.param_groups) 597 | 598 | def accuracy(output, target, topk=(1,)): 599 | """Computes the accuracy over the k top predictions for the specified values of k""" 600 | with torch.no_grad(): 601 | maxk = max(topk) 602 | batch_size = target.size(0) 603 | 604 | _, pred = output.topk(maxk, 1, True, True) 605 | pred = pred.t() 606 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 607 | 608 | res = [] 609 | for k in topk: 610 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 611 | res.append(correct_k.mul_(100.0 / batch_size)) 612 | return res 613 | 614 | 615 | if __name__ == '__main__': 616 | main() --------------------------------------------------------------------------------