├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── TRAINING.md ├── benchmark.py ├── evaluate.py ├── presets.py ├── sampler.py ├── supermask.py ├── train.py ├── transforms.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | */*.pyc 2 | 3 | # Editor temporaries 4 | *.swa 5 | *.swb 6 | *.swc 7 | *.swd 8 | *.swe 9 | *.swf 10 | *.swg 11 | *.swh 12 | *.swi 13 | *.swj 14 | *.swk 15 | *.swl 16 | *.swm 17 | *.swn 18 | *.swo 19 | *.swp 20 | *~ 21 | .~lock.* 22 | 23 | # macOS dir files 24 | .DS_Store 25 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to superblock 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Contributor License Agreement ("CLA") 6 | In order to accept your pull request, we need you to submit a CLA. You only need 7 | to do this once to work on any of Meta's open source projects. 8 | 9 | Complete your CLA here: 10 | 11 | ## Issues 12 | We use GitHub issues to track public bugs. Please ensure your description is 13 | clear and has sufficient instructions to be able to reproduce the issue. 14 | 15 | Meta has a [bounty program](https://bugbounty.meta.com/) for the safe 16 | disclosure of security bugs. In those cases, please go through the process 17 | outlined on that page and do not file a public issue. 18 | 19 | ## License 20 | By contributing to superblock, you agree that your contributions will be licensed 21 | under the LICENSE file in the root directory of this source tree. 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 PyTorch Labs 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **_SuperBlock has now been transferred to torchao repo [here](https://github.com/pytorch/ao/tree/main/torchao/sparsity/prototype/superblock)._** 2 | 3 | # SuperBlock 4 | 5 | SuperBlock combines two techniques for efficient neural network training and inference: Supermask and Block Compressed Sparse Row (BSR). 6 | The techniques are described in this [blog post](https://pytorch.org/blog/speeding-up-vits/). 7 | 8 | ### Supermask 9 | [Supermask](https://arxiv.org/abs/2207.00670) is a technique for applying structured sparsity to neural networks using a learned mask. It works by learning a continuous mask (scores) that is applied element-wise to the weights of a neural network layer. The mask scores are learned separately from the weights and are thresholded based on a target sparsity level to obtain a binary mask. The mask determines which weigths are kept and which are pruned, and is learned during training. 10 | 11 | During inference, the binary mask is applied element-wise to the weights, pruning the weights that correspond to a 0 in the mask, resulting in a sparse network that can be efficiently computed. 12 | 13 | ### Block compressed Sparse Row Format (BSR) 14 | [The BSR format](https://pytorch.org/docs/main/sparse.html#sparse-bsr-tensor) is a sparse matrix representation that stores dense sub-blocks of non-zero elements instead of individual non-zero elements. The matrix is divided into equal-sized blocks, and only the non-zero blocks are stored. 15 | 16 | The BSR format is efficient for sparse matrices with a block structure, where non-zero elements tend to cluster in dense sub-blocks. It reduces storage requirements and enables efficient matrix operations on the non-zero blocks. 17 | 18 | Currently, the BSR format is optimized for Nvidia A100 GPU(s) only. 19 | 20 | ## Setup 21 | To use SuperBlock, you will need 22 | * [PyTorch](https://pytorch.org/get-started/locally/) 23 | 24 | To train the model or evaluate accuracy, you will need: 25 | * ImageNet2012-blurred dataset 26 | 27 | At least one GPU: 28 | * A100 or H100 29 | 30 | ## Installation 31 | * Clone this repo 32 | ``` 33 | git clone https://github.com/pytorch-labs/superblock.git 34 | cd superblock 35 | ``` 36 | * Create a new conda environment 37 | ``` 38 | conda create -n superblock 39 | conda activate superblock 40 | ``` 41 | * Install PyTorch. For best performance, we recommend `2.3.0.dev20240305+cu121` nightly 42 | ``` 43 | pip install --pre torch==2.3.0.dev20240305+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121 44 | pip install --pre torchvision==0.18.0 --no-deps 45 | ``` 46 | 47 | 48 | ## Benchmarking 49 | Baseline: 50 | ``` 51 | python benchmark.py \ 52 | --model vit_b_16 \ 53 | --batch-size 256 \ 54 | > /dev/null 55 | ``` 56 | Result: 57 | ``` 58 | 532.1160546875 ms 59 | ``` 60 | 61 | 62 | 80% sparsity, block size 64 (random weights): 63 | ``` 64 | python benchmark.py --model vit_b_16 \ 65 | --batch-size 256 \ 66 | --sparsity-linear 0.8 \ 67 | --sp-linear-tile-size 64 \ 68 | --sparsify-weights \ 69 | --bsr 64 \ 70 | > /dev/null 71 | ``` 72 | Result: 73 | ``` 74 | 393.864453125 ms 75 | ``` 76 | 77 | 78 | ## Training 79 | Please refer to [TRAINING.md](TRAINING.md) for training from scratch. We use [Torchvision](https://github.com/pytorch/vision/tree/main/references/classification) as our framework for training. Supermask can be applied during training. 80 | 81 | To apply supermask, we have the following arguments at our disposal, 82 | 83 | * Apply Supermask to linear layers: 84 | ``` 85 | --sparsity-linear 86 | --sp-linear-tile-size 87 | ``` 88 | * Apply Supermask to conv1x1 layers: 89 | ``` 90 | --sparsity-conv1x1 91 | --sp-conv1x1-tile-size 92 | ``` 93 | * Apply Supermask to all other convolutional layers: 94 | ``` 95 | --sparsity-conv 96 | --sp-conv-tile-size 97 | ``` 98 | * Skip the first transformer layer and/or last linear layer (ViT only): 99 | ``` 100 | --skip-last-layer-sparsity 101 | --skip-first-transformer-sparsity 102 | ``` 103 | 104 | For example, if you would like to train a `vit_b_16` from scratch using Supermask, you can use the respective torchvision command found in [TRAINING.md](TRAINING.md) and append the supermask arguments: 105 | ``` 106 | torchrun --nproc_per_node=8 train.py\ 107 | --model vit_b_16 --epochs 300 --batch-size 512 --opt adamw --lr 0.003 --wd 0.3\ 108 | --lr-scheduler cosineannealinglr --lr-warmup-method linear --lr-warmup-epochs 30\ 109 | --lr-warmup-decay 0.033 --amp --label-smoothing 0.11 --mixup-alpha 0.2 --auto-augment ra\ 110 | --clip-grad-norm 1 --ra-sampler --cutmix-alpha 1.0 --model-ema\ 111 | --sparsity-linear 0.9 --sp-linear-tile-size 32 112 | ``` 113 | Through this command, we are training a `vit_b_16` with 90% sparsity to linear layers using 32x32 tiles. 114 | 115 | Please run `python train.py --help` for a full list of available arguments. 116 | 117 | ## Evaluation 118 | 119 | To run an evaluation of a Supermask-trained model, you can use [evaluate.py](evaluate.py). Our current version has signficant speedup with float32 only and not float16, hence, to illustrate speedup, we don't pass `--amp` in the example commands below. 120 | 121 | ``` 122 | MODEL_PATH= 123 | IMAGENET_PATH= 124 | NGPUS=1 # put number of available GPUS here 125 | ``` 126 | 127 | * Offline sparsification with BSR: 128 | ``` 129 | torchrun --nproc_per_node=${NGPUS} evaluate.py --model vit_b_16 --batch-size 256 --sparsity-linear 0.9 --sp-linear-tile-size 32 --weights-path ${MODEL_PATH} --data-path ${IMAGENET_PATH} --sparsify-weights --bsr 32 130 | ``` 131 | This command applies 90% sparsity to linear layers using 32x32 tiles, loads the model weights from ${MODEL_PATH}, loads the ImageNet validation set located at the specified path, applies offline sparsification to the weights, and converts the sparse weights to BSR format with a block size of 32. It is recommended to set `--bsr` the same as tile size. 132 | 133 | * Online sparsification without BSR: 134 | ``` 135 | torchrun --nproc_per_node=${NGPUS} evaluate.py --model vit_b_16 --batch-size 256 --sparsity-linear 0.9 --sp-linear-tile-size 32 --weights-path ${MODEL_PATH} --data-path ${IMAGENET_PATH} 136 | ``` 137 | This is similar to the previous command, but it does not apply offline sparsification or BSR conversion. Instead, the sparsity is applied on-the-fly during evaluation. 138 | 139 | Please run `python evaluate.py --help` for a full list of available arguments. 140 | 141 | Results (1x A100): 142 | * Baseline 143 | ``` 144 | Test: Total time: 0:02:11 145 | Test: Acc@1 78.392 Acc@5 93.592 146 | ``` 147 | 148 | * Sparsity= 0.9, Tile Size = 32, Online Sparsification, BSR = None 149 | ``` 150 | Test: Total time: 0:01:52 151 | Test: Acc@1 76.092 Acc@5 92.656 152 | ``` 153 | 154 | * Sparsity= 0.9, Tile Size = 32, Offline Sparsification, BSR = None 155 | ``` 156 | Test: Total time: 0:01:54 157 | Test: Acc@1 76.092 Acc@5 92.656 158 | ``` 159 | 160 | * Sparsity= 0.9, Tile Size = 32, Offline Sparsification, BSR = 32 161 | ``` 162 | Test: Total time: 0:01:25 163 | Test: Acc@1 76.092 Acc@5 92.656 164 | ``` 165 | 166 | ## Pretrained Weights 167 | 168 | ### Download: 169 | Instead of training from scratch, if you'd like to use the Supermask weights of `vit_b_16` trained on privacy mitigated Imagenet-blurred, you can download them here: 170 | ``` 171 | SPARSITY=0.80 # Checkpoints available for: 0.70, 0.80, 0.82, 0.84, 0.86, 0.88, 0.90 172 | BLOCK_SIZE=32 # Checkpoints available for: 16, 32, 64 173 | ``` 174 | 175 | ``` 176 | mkdir checkpoints 177 | # For baseline, 178 | wget https://huggingface.co/facebook/superblock-vit-b-16/resolve/main/checkpoints/baseline.pth -P checkpoints/ 179 | # For sparsified checkpoints, 180 | wget https://huggingface.co/facebook/superblock-vit-b-16/resolve/main/checkpoints/sp${SPARSITY}-ts${BLOCK_SIZE}.pth -P checkpoints/ 181 | ``` 182 | 183 | ### Benchmark: 184 | ``` 185 | python benchmark.py --model vit_b_16 \ 186 | --batch-size 256 \ 187 | --sparsity-linear ${SPARSITY} \ 188 | --sp-linear-tile-size ${BLOCK_SIZE} \ 189 | --sparsify-weights \ 190 | --bsr ${BLOCK_SIZE} \ 191 | --weights-path ./checkpoints/sp${SPARSITY}-ts${BLOCK_SIZE}.pth \ 192 | > /dev/null 193 | ``` 194 | Result: 195 | ``` 196 | 530.342578125 ms 197 | ``` 198 | 199 | ### Evaluate: 200 | 8 x A100 GPUs: 201 | ``` 202 | torchrun --nproc_per_node=8 evaluate.py --model vit_b_16 --batch-size 256 --sparsity-linear ${SPARSITY} --sp-linear-tile-size ${BLOCK_SIZE} --bsr ${BLOCK_SIZE} --sparsify-weights --weights-path checkpoints/sp${SPARSITY}-ts${BLOCK_SIZE}.pth --data-path ${IMAGENET_PATH} 203 | ``` 204 | Result: 205 | ``` 206 | Test: Total time: 0:01:01 207 | Test: Acc@1 77.644 Acc@5 93.554 208 | ``` 209 | 210 | 1 x A100 GPUs: 211 | ``` 212 | torchrun --nproc_per_node=1 evaluate.py --model vit_b_16 --batch-size 256 --sparsity-linear ${SPARSITY} --sp-linear-tile-size ${BLOCK_SIZE} --bsr ${BLOCK_SIZE} --sparsify-weights --weights-path checkpoints/sp${SPARSITY}-ts${BLOCK_SIZE}.pth --data-path ${IMAGENET_PATH} 213 | ``` 214 | Result: 215 | ``` 216 | Test: Total time: 0:01:51 217 | Test: Acc@1 77.644 Acc@5 93.554 218 | ``` 219 | 220 | ## License 221 | SuperBlock is released under the [MIT license](https://github.com/pytorch-labs/superblock?tab=MIT-1-ov-file#readme). 222 | -------------------------------------------------------------------------------- /TRAINING.md: -------------------------------------------------------------------------------- 1 | # Image classification reference training scripts 2 | 3 | This folder contains reference training scripts for image classification. 4 | They serve as a log of how to train specific models, as provide baseline 5 | training and evaluation scripts to quickly bootstrap research. 6 | 7 | Except otherwise noted, all models have been trained on 8x V100 GPUs with 8 | the following parameters: 9 | 10 | | Parameter | value | 11 | | ------------------------ | ------ | 12 | | `--batch_size` | `32` | 13 | | `--epochs` | `90` | 14 | | `--lr` | `0.1` | 15 | | `--momentum` | `0.9` | 16 | | `--wd`, `--weight-decay` | `1e-4` | 17 | | `--lr-step-size` | `30` | 18 | | `--lr-gamma` | `0.1` | 19 | 20 | ### AlexNet and VGG 21 | 22 | Since `AlexNet` and the original `VGG` architectures do not include batch 23 | normalization, the default initial learning rate `--lr 0.1` is too high. 24 | 25 | ``` 26 | torchrun --nproc_per_node=8 train.py\ 27 | --model $MODEL --lr 1e-2 28 | ``` 29 | 30 | Here `$MODEL` is one of `alexnet`, `vgg11`, `vgg13`, `vgg16` or `vgg19`. Note 31 | that `vgg11_bn`, `vgg13_bn`, `vgg16_bn`, and `vgg19_bn` include batch 32 | normalization and thus are trained with the default parameters. 33 | 34 | ### GoogLeNet 35 | 36 | The weights of the GoogLeNet model are ported from the original paper rather than trained from scratch. 37 | 38 | ### Inception V3 39 | 40 | The weights of the Inception V3 model are ported from the original paper rather than trained from scratch. 41 | 42 | Since it expects tensors with a size of N x 3 x 299 x 299, to validate the model use the following command: 43 | 44 | ``` 45 | torchrun --nproc_per_node=8 train.py --model inception_v3\ 46 | --test-only --weights Inception_V3_Weights.IMAGENET1K_V1 47 | ``` 48 | 49 | ### ResNet 50 | ``` 51 | torchrun --nproc_per_node=8 train.py --model $MODEL 52 | ``` 53 | 54 | Here `$MODEL` is one of `resnet18`, `resnet34`, `resnet50`, `resnet101` or `resnet152`. 55 | 56 | ### ResNext 57 | ``` 58 | torchrun --nproc_per_node=8 train.py\ 59 | --model $MODEL --epochs 100 60 | ``` 61 | 62 | Here `$MODEL` is one of `resnext50_32x4d` or `resnext101_32x8d`. 63 | Note that the above command corresponds to a single node with 8 GPUs. If you use 64 | a different number of GPUs and/or a different batch size, then the learning rate 65 | should be scaled accordingly. For example, the pretrained model provided by 66 | `torchvision` was trained on 8 nodes, each with 8 GPUs (for a total of 64 GPUs), 67 | with `--batch_size 16` and `--lr 0.4`, instead of the current defaults 68 | which are respectively batch_size=32 and lr=0.1 69 | 70 | ### MobileNetV2 71 | ``` 72 | torchrun --nproc_per_node=8 train.py\ 73 | --model mobilenet_v2 --epochs 300 --lr 0.045 --wd 0.00004\ 74 | --lr-step-size 1 --lr-gamma 0.98 75 | ``` 76 | 77 | 78 | ### MobileNetV3 Large & Small 79 | ``` 80 | torchrun --nproc_per_node=8 train.py\ 81 | --model $MODEL --epochs 600 --opt rmsprop --batch-size 128 --lr 0.064\ 82 | --wd 0.00001 --lr-step-size 2 --lr-gamma 0.973 --auto-augment imagenet --random-erase 0.2 83 | ``` 84 | 85 | Here `$MODEL` is one of `mobilenet_v3_large` or `mobilenet_v3_small`. 86 | 87 | Then we averaged the parameters of the last 3 checkpoints that improved the Acc@1. See [#3182](https://github.com/pytorch/vision/pull/3182) 88 | and [#3354](https://github.com/pytorch/vision/pull/3354) for details. 89 | 90 | 91 | ### EfficientNet-V1 92 | 93 | The weights of the B0-B4 variants are ported from Ross Wightman's [timm repo](https://github.com/rwightman/pytorch-image-models/blob/01cb46a9a50e3ba4be167965b5764e9702f09b30/timm/models/efficientnet.py#L95-L108). 94 | 95 | The weights of the B5-B7 variants are ported from Luke Melas' [EfficientNet-PyTorch repo](https://github.com/lukemelas/EfficientNet-PyTorch/blob/1039e009545d9329ea026c9f7541341439712b96/efficientnet_pytorch/utils.py#L562-L564). 96 | 97 | All models were trained using Bicubic interpolation and each have custom crop and resize sizes. To validate the models use the following commands: 98 | ``` 99 | torchrun --nproc_per_node=8 train.py --model efficientnet_b0 --test-only --weights EfficientNet_B0_Weights.IMAGENET1K_V1 100 | torchrun --nproc_per_node=8 train.py --model efficientnet_b1 --test-only --weights EfficientNet_B1_Weights.IMAGENET1K_V1 101 | torchrun --nproc_per_node=8 train.py --model efficientnet_b2 --test-only --weights EfficientNet_B2_Weights.IMAGENET1K_V1 102 | torchrun --nproc_per_node=8 train.py --model efficientnet_b3 --test-only --weights EfficientNet_B3_Weights.IMAGENET1K_V1 103 | torchrun --nproc_per_node=8 train.py --model efficientnet_b4 --test-only --weights EfficientNet_B4_Weights.IMAGENET1K_V1 104 | torchrun --nproc_per_node=8 train.py --model efficientnet_b5 --test-only --weights EfficientNet_B5_Weights.IMAGENET1K_V1 105 | torchrun --nproc_per_node=8 train.py --model efficientnet_b6 --test-only --weights EfficientNet_B6_Weights.IMAGENET1K_V1 106 | torchrun --nproc_per_node=8 train.py --model efficientnet_b7 --test-only --weights EfficientNet_B7_Weights.IMAGENET1K_V1 107 | ``` 108 | 109 | 110 | ### EfficientNet-V2 111 | ``` 112 | torchrun --nproc_per_node=8 train.py \ 113 | --model $MODEL --batch-size 128 --lr 0.5 --lr-scheduler cosineannealinglr \ 114 | --lr-warmup-epochs 5 --lr-warmup-method linear --auto-augment ta_wide --epochs 600 --random-erase 0.1 \ 115 | --label-smoothing 0.1 --mixup-alpha 0.2 --cutmix-alpha 1.0 --weight-decay 0.00002 --norm-weight-decay 0.0 \ 116 | --train-crop-size $TRAIN_SIZE --model-ema --val-crop-size $EVAL_SIZE --val-resize-size $EVAL_SIZE \ 117 | --ra-sampler --ra-reps 4 118 | ``` 119 | Here `$MODEL` is one of `efficientnet_v2_s` and `efficientnet_v2_m`. 120 | Note that the Small variant had a `$TRAIN_SIZE` of `300` and a `$EVAL_SIZE` of `384`, while the Medium `384` and `480` respectively. 121 | 122 | Note that the above command corresponds to training on a single node with 8 GPUs. 123 | For generatring the pre-trained weights, we trained with 4 nodes, each with 8 GPUs (for a total of 32 GPUs), 124 | and `--batch_size 32`. 125 | 126 | The weights of the Large variant are ported from the original paper rather than trained from scratch. See the `EfficientNet_V2_L_Weights` entry for their exact preprocessing transforms. 127 | 128 | 129 | ### RegNet 130 | 131 | #### Small models 132 | ``` 133 | torchrun --nproc_per_node=8 train.py\ 134 | --model $MODEL --epochs 100 --batch-size 128 --wd 0.00005 --lr=0.8\ 135 | --lr-scheduler=cosineannealinglr --lr-warmup-method=linear\ 136 | --lr-warmup-epochs=5 --lr-warmup-decay=0.1 137 | ``` 138 | Here `$MODEL` is one of `regnet_x_400mf`, `regnet_x_800mf`, `regnet_x_1_6gf`, `regnet_y_400mf`, `regnet_y_800mf` and `regnet_y_1_6gf`. Please note we used learning rate 0.4 for `regent_y_400mf` to get the same Acc@1 as [the paper)(https://arxiv.org/abs/2003.13678). 139 | 140 | #### Medium models 141 | ``` 142 | torchrun --nproc_per_node=8 train.py\ 143 | --model $MODEL --epochs 100 --batch-size 64 --wd 0.00005 --lr=0.4\ 144 | --lr-scheduler=cosineannealinglr --lr-warmup-method=linear\ 145 | --lr-warmup-epochs=5 --lr-warmup-decay=0.1 146 | ``` 147 | Here `$MODEL` is one of `regnet_x_3_2gf`, `regnet_x_8gf`, `regnet_x_16gf`, `regnet_y_3_2gf` and `regnet_y_8gf`. 148 | 149 | #### Large models 150 | ``` 151 | torchrun --nproc_per_node=8 train.py\ 152 | --model $MODEL --epochs 100 --batch-size 32 --wd 0.00005 --lr=0.2\ 153 | --lr-scheduler=cosineannealinglr --lr-warmup-method=linear\ 154 | --lr-warmup-epochs=5 --lr-warmup-decay=0.1 155 | ``` 156 | Here `$MODEL` is one of `regnet_x_32gf`, `regnet_y_16gf` and `regnet_y_32gf`. 157 | 158 | ### Vision Transformer 159 | 160 | #### vit_b_16 161 | ``` 162 | torchrun --nproc_per_node=8 train.py\ 163 | --model vit_b_16 --epochs 300 --batch-size 512 --opt adamw --lr 0.003 --wd 0.3\ 164 | --lr-scheduler cosineannealinglr --lr-warmup-method linear --lr-warmup-epochs 30\ 165 | --lr-warmup-decay 0.033 --amp --label-smoothing 0.11 --mixup-alpha 0.2 --auto-augment ra\ 166 | --clip-grad-norm 1 --ra-sampler --cutmix-alpha 1.0 --model-ema 167 | ``` 168 | 169 | Note that the above command corresponds to training on a single node with 8 GPUs. 170 | For generatring the pre-trained weights, we trained with 8 nodes, each with 8 GPUs (for a total of 64 GPUs), 171 | and `--batch_size 64`. 172 | 173 | #### vit_b_32 174 | ``` 175 | torchrun --nproc_per_node=8 train.py\ 176 | --model vit_b_32 --epochs 300 --batch-size 512 --opt adamw --lr 0.003 --wd 0.3\ 177 | --lr-scheduler cosineannealinglr --lr-warmup-method linear --lr-warmup-epochs 30\ 178 | --lr-warmup-decay 0.033 --amp --label-smoothing 0.11 --mixup-alpha 0.2 --auto-augment imagenet\ 179 | --clip-grad-norm 1 --ra-sampler --cutmix-alpha 1.0 --model-ema 180 | ``` 181 | 182 | Note that the above command corresponds to training on a single node with 8 GPUs. 183 | For generatring the pre-trained weights, we trained with 2 nodes, each with 8 GPUs (for a total of 16 GPUs), 184 | and `--batch_size 256`. 185 | 186 | #### vit_l_16 187 | ``` 188 | torchrun --nproc_per_node=8 train.py\ 189 | --model vit_l_16 --epochs 600 --batch-size 128 --lr 0.5 --lr-scheduler cosineannealinglr\ 190 | --lr-warmup-method linear --lr-warmup-epochs 5 --label-smoothing 0.1 --mixup-alpha 0.2\ 191 | --auto-augment ta_wide --random-erase 0.1 --weight-decay 0.00002 --norm-weight-decay 0.0\ 192 | --clip-grad-norm 1 --ra-sampler --cutmix-alpha 1.0 --model-ema --val-resize-size 232 193 | ``` 194 | 195 | Note that the above command corresponds to training on a single node with 8 GPUs. 196 | For generatring the pre-trained weights, we trained with 2 nodes, each with 8 GPUs (for a total of 16 GPUs), 197 | and `--batch_size 64`. 198 | 199 | #### vit_l_32 200 | ``` 201 | torchrun --nproc_per_node=8 train.py\ 202 | --model vit_l_32 --epochs 300 --batch-size 512 --opt adamw --lr 0.003 --wd 0.3\ 203 | --lr-scheduler cosineannealinglr --lr-warmup-method linear --lr-warmup-epochs 30\ 204 | --lr-warmup-decay 0.033 --amp --label-smoothing 0.11 --mixup-alpha 0.2 --auto-augment ra\ 205 | --clip-grad-norm 1 --ra-sampler --cutmix-alpha 1.0 --model-ema 206 | ``` 207 | 208 | Note that the above command corresponds to training on a single node with 8 GPUs. 209 | For generatring the pre-trained weights, we trained with 8 nodes, each with 8 GPUs (for a total of 64 GPUs), 210 | and `--batch_size 64`. 211 | 212 | 213 | ### ConvNeXt 214 | ``` 215 | torchrun --nproc_per_node=8 train.py\ 216 | --model $MODEL --batch-size 128 --opt adamw --lr 1e-3 --lr-scheduler cosineannealinglr \ 217 | --lr-warmup-epochs 5 --lr-warmup-method linear --auto-augment ta_wide --epochs 600 --random-erase 0.1 \ 218 | --label-smoothing 0.1 --mixup-alpha 0.2 --cutmix-alpha 1.0 --weight-decay 0.05 --norm-weight-decay 0.0 \ 219 | --train-crop-size 176 --model-ema --val-resize-size 232 --ra-sampler --ra-reps 4 220 | ``` 221 | Here `$MODEL` is one of `convnext_tiny`, `convnext_small`, `convnext_base` and `convnext_large`. Note that each variant had its `--val-resize-size` optimized in a post-training step, see their `Weights` entry for their exact value. 222 | 223 | Note that the above command corresponds to training on a single node with 8 GPUs. 224 | For generatring the pre-trained weights, we trained with 2 nodes, each with 8 GPUs (for a total of 16 GPUs), 225 | and `--batch_size 64`. 226 | 227 | 228 | ### SwinTransformer 229 | ``` 230 | torchrun --nproc_per_node=8 train.py\ 231 | --model $MODEL --epochs 300 --batch-size 128 --opt adamw --lr 0.001 --weight-decay 0.05 --norm-weight-decay 0.0 --bias-weight-decay 0.0 --transformer-embedding-decay 0.0 --lr-scheduler cosineannealinglr --lr-min 0.00001 --lr-warmup-method linear --lr-warmup-epochs 20 --lr-warmup-decay 0.01 --amp --label-smoothing 0.1 --mixup-alpha 0.8 --clip-grad-norm 5.0 --cutmix-alpha 1.0 --random-erase 0.25 --interpolation bicubic --auto-augment ta_wide --model-ema --ra-sampler --ra-reps 4 --val-resize-size 224 232 | ``` 233 | Here `$MODEL` is one of `swin_t`, `swin_s` or `swin_b`. 234 | Note that `--val-resize-size` was optimized in a post-training step, see their `Weights` entry for the exact value. 235 | 236 | 237 | 238 | 239 | ### SwinTransformer V2 240 | ``` 241 | torchrun --nproc_per_node=8 train.py\ 242 | --model $MODEL --epochs 300 --batch-size 128 --opt adamw --lr 0.001 --weight-decay 0.05 --norm-weight-decay 0.0 --bias-weight-decay 0.0 --transformer-embedding-decay 0.0 --lr-scheduler cosineannealinglr --lr-min 0.00001 --lr-warmup-method linear --lr-warmup-epochs 20 --lr-warmup-decay 0.01 --amp --label-smoothing 0.1 --mixup-alpha 0.8 --clip-grad-norm 5.0 --cutmix-alpha 1.0 --random-erase 0.25 --interpolation bicubic --auto-augment ta_wide --model-ema --ra-sampler --ra-reps 4 --val-resize-size 256 --val-crop-size 256 --train-crop-size 256 243 | ``` 244 | Here `$MODEL` is one of `swin_v2_t`, `swin_v2_s` or `swin_v2_b`. 245 | Note that `--val-resize-size` was optimized in a post-training step, see their `Weights` entry for the exact value. 246 | 247 | 248 | ### MaxViT 249 | ``` 250 | torchrun --nproc_per_node=8 --n_nodes=4 train.py\ 251 | --model $MODEL --epochs 400 --batch-size 128 --opt adamw --lr 3e-3 --weight-decay 0.05 --lr-scheduler cosineannealinglr --lr-min 1e-5 --lr-warmup-method linear --lr-warmup-epochs 32 --label-smoothing 0.1 --mixup-alpha 0.8 --clip-grad-norm 1.0 --interpolation bicubic --auto-augment ta_wide --policy-magnitude 15 --model-ema --val-resize-size 224\ 252 | --val-crop-size 224 --train-crop-size 224 --amp --model-ema-steps 32 --transformer-embedding-decay 0 --sync-bn 253 | ``` 254 | Here `$MODEL` is `maxvit_t`. 255 | Note that `--val-resize-size` was not optimized in a post-training step. 256 | 257 | 258 | ### ShuffleNet V2 259 | ``` 260 | torchrun --nproc_per_node=8 train.py \ 261 | --batch-size=128 \ 262 | --lr=0.5 --lr-scheduler=cosineannealinglr --lr-warmup-epochs=5 --lr-warmup-method=linear \ 263 | --auto-augment=ta_wide --epochs=600 --random-erase=0.1 --weight-decay=0.00002 \ 264 | --norm-weight-decay=0.0 --label-smoothing=0.1 --mixup-alpha=0.2 --cutmix-alpha=1.0 \ 265 | --train-crop-size=176 --model-ema --val-resize-size=232 --ra-sampler --ra-reps=4 266 | ``` 267 | Here `$MODEL` is either `shufflenet_v2_x1_5` or `shufflenet_v2_x2_0`. 268 | 269 | The models `shufflenet_v2_x0_5` and `shufflenet_v2_x1_0` were contributed by the community. See [PR-849](https://github.com/pytorch/vision/pull/849#issuecomment-483391686) for details. 270 | 271 | 272 | ## Mixed precision training 273 | Automatic Mixed Precision (AMP) training on GPU for Pytorch can be enabled with the [torch.cuda.amp](https://pytorch.org/docs/stable/amp.html?highlight=amp#module-torch.cuda.amp). 274 | 275 | Mixed precision training makes use of both FP32 and FP16 precisions where appropriate. FP16 operations can leverage the Tensor cores on NVIDIA GPUs (Volta, Turing or newer architectures) for improved throughput, generally without loss in model accuracy. Mixed precision training also often allows larger batch sizes. GPU automatic mixed precision training for Pytorch Vision can be enabled via the flag value `--amp=True`. 276 | 277 | ``` 278 | torchrun --nproc_per_node=8 train.py\ 279 | --model resnext50_32x4d --epochs 100 --amp 280 | ``` 281 | 282 | ## Quantized 283 | 284 | ### Post training quantized models 285 | 286 | For all post training quantized models, the settings are: 287 | 288 | 1. num_calibration_batches: 32 289 | 2. num_workers: 16 290 | 3. batch_size: 32 291 | 4. eval_batch_size: 128 292 | 5. backend: 'fbgemm' 293 | 294 | ``` 295 | python train_quantization.py --device='cpu' --post-training-quantize --backend='fbgemm' --model='$MODEL' 296 | ``` 297 | Here `$MODEL` is one of `googlenet`, `inception_v3`, `resnet18`, `resnet50`, `resnext101_32x8d`, `shufflenet_v2_x0_5` and `shufflenet_v2_x1_0`. 298 | 299 | ### Quantized ShuffleNet V2 300 | 301 | Here are commands that we use to quantized the `shufflenet_v2_x1_5` and `shufflenet_v2_x2_0` models. 302 | ``` 303 | # For shufflenet_v2_x1_5 304 | python train_quantization.py --device='cpu' --post-training-quantize --backend='fbgemm' \ 305 | --model=shufflenet_v2_x1_5 --weights="ShuffleNet_V2_X1_5_Weights.IMAGENET1K_V1" \ 306 | --train-crop-size 176 --val-resize-size 232 --data-path /datasets01_ontap/imagenet_full_size/061417/ 307 | 308 | # For shufflenet_v2_x2_0 309 | python train_quantization.py --device='cpu' --post-training-quantize --backend='fbgemm' \ 310 | --model=shufflenet_v2_x2_0 --weights="ShuffleNet_V2_X2_0_Weights.IMAGENET1K_V1" \ 311 | --train-crop-size 176 --val-resize-size 232 --data-path /datasets01_ontap/imagenet_full_size/061417/ 312 | ``` 313 | 314 | ### QAT MobileNetV2 315 | 316 | For Mobilenet-v2, the model was trained with quantization aware training, the settings used are: 317 | 1. num_workers: 16 318 | 2. batch_size: 32 319 | 3. eval_batch_size: 128 320 | 4. backend: 'qnnpack' 321 | 5. learning-rate: 0.0001 322 | 6. num_epochs: 90 323 | 7. num_observer_update_epochs:4 324 | 8. num_batch_norm_update_epochs:3 325 | 9. momentum: 0.9 326 | 10. lr_step_size:30 327 | 11. lr_gamma: 0.1 328 | 12. weight-decay: 0.0001 329 | 330 | ``` 331 | torchrun --nproc_per_node=8 train_quantization.py --model='mobilenet_v2' 332 | ``` 333 | 334 | Training converges at about 10 epochs. 335 | 336 | ### QAT MobileNetV3 337 | 338 | For Mobilenet-v3 Large, the model was trained with quantization aware training, the settings used are: 339 | 1. num_workers: 16 340 | 2. batch_size: 32 341 | 3. eval_batch_size: 128 342 | 4. backend: 'qnnpack' 343 | 5. learning-rate: 0.001 344 | 6. num_epochs: 90 345 | 7. num_observer_update_epochs:4 346 | 8. num_batch_norm_update_epochs:3 347 | 9. momentum: 0.9 348 | 10. lr_step_size:30 349 | 11. lr_gamma: 0.1 350 | 12. weight-decay: 0.00001 351 | 352 | ``` 353 | torchrun --nproc_per_node=8 train_quantization.py --model='mobilenet_v3_large' \ 354 | --wd 0.00001 --lr 0.001 355 | ``` 356 | 357 | For post training quant, device is set to CPU. For training, the device is set to CUDA. 358 | 359 | ### Command to evaluate quantized models using the pre-trained weights: 360 | 361 | ``` 362 | python train_quantization.py --device='cpu' --test-only --backend='' --model='' 363 | ``` 364 | 365 | For inception_v3 you need to pass the following extra parameters: 366 | ``` 367 | --val-resize-size 342 --val-crop-size 299 --train-crop-size 299 368 | ``` -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import os 4 | import time 5 | import sys 6 | import warnings 7 | import hashlib 8 | import torchvision 9 | 10 | import presets 11 | import torch 12 | import torch.utils.data 13 | import utils 14 | from torch import nn 15 | 16 | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) 17 | from supermask import apply_supermask, SupermaskLinear 18 | 19 | 20 | def apply_sparsity(model): 21 | for name, module in model.named_modules(): 22 | if isinstance(module, SupermaskLinear) and "mlp" in name: 23 | module.sparsify_offline() 24 | 25 | 26 | def apply_bsr(model, blocksize): 27 | for name, module in model.named_modules(): 28 | if isinstance(module, torch.nn.Linear) and "mlp" in name: 29 | try: 30 | module.weight = torch.nn.Parameter(to_bsr(module.weight.data, blocksize)) 31 | print(f"Converted {name} to bsr format.") 32 | except ValueError as e: 33 | print(f"Unable to convert weight of {name} to bsr format: {e}") 34 | 35 | 36 | def to_bsr(tensor, blocksize): 37 | if tensor.ndim != 2: 38 | raise ValueError("to_bsr expects 2D tensor") 39 | if tensor.size(0) % blocksize or tensor.size(1) % blocksize: 40 | raise ValueError("Tensor dimensions must be divisible by blocksize") 41 | return tensor.to_sparse_bsr(blocksize) 42 | 43 | 44 | def verify_sparsity(model): 45 | for name, module in model.named_modules(): 46 | if isinstance(module, nn.Linear): 47 | total_weights = module.weight.numel() 48 | sparse_weights = (module.weight == 0).sum().item() 49 | sparsity_percentage = (sparse_weights / total_weights) * 100 50 | print(f"Sparsity verified in layer {name}: {sparsity_percentage:.2f}%") 51 | 52 | 53 | def benchmark_in_ms(warmup, iters, f, *args, **kwargs): 54 | for _ in range(warmup): 55 | f(*args, **kwargs) 56 | torch.cuda.synchronize() 57 | start_event = torch.cuda.Event(enable_timing=True) 58 | end_event = torch.cuda.Event(enable_timing=True) 59 | start_event.record() 60 | 61 | for _ in range(iters): 62 | f(*args, **kwargs) 63 | 64 | end_event.record() 65 | torch.cuda.synchronize() 66 | return start_event.elapsed_time(end_event) / float(iters) 67 | 68 | 69 | def main(args): 70 | print(args) 71 | device = torch.device(args.device) 72 | 73 | # We disable the cudnn benchmarking because it can noticeably affect the accuracy 74 | torch.backends.cudnn.benchmark = False 75 | torch.backends.cudnn.deterministic = True 76 | num_classes = 1000 77 | 78 | dtype = None 79 | if args.bfloat16: 80 | print("Using bfloat16") 81 | dtype = torch.bfloat16 82 | elif args.float16: 83 | print("Using float16") 84 | dtype = torch.float16 85 | 86 | # Sample input 87 | # input = torch.rand(32, 3, 224, 224, dtype=dtype).to(device) 88 | 89 | print("Creating model") 90 | model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes) 91 | 92 | apply_supermask( 93 | model, 94 | linear_sparsity=args.sparsity_linear, 95 | linear_sp_tilesize=args.sp_linear_tile_size, 96 | conv1x1_sparsity=args.sparsity_conv1x1, 97 | conv1x1_sp_tilesize=args.sp_conv1x1_tile_size, 98 | conv_sparsity=args.sparsity_conv, 99 | conv_sp_tilesize=args.sp_conv_tile_size, 100 | skip_last_layer_sparsity=args.skip_last_layer_sparsity, 101 | skip_first_transformer_sparsity=args.skip_first_transformer_sparsity, 102 | device=device, 103 | verbose=False, 104 | ) 105 | 106 | if args.weights_path: 107 | try: 108 | checkpoint = torch.load(args.weights_path, map_location="cpu") 109 | model.load_state_dict(checkpoint["model"]) 110 | print(f"Loaded checkpoint successfully from: {args.weights_path}") 111 | except FileNotFoundError: 112 | raise FileNotFoundError(f"No checkpoint found at {args.weights_path}.") 113 | 114 | model.to(device) 115 | # output0 = model(input) 116 | 117 | if args.sparsify_weights: 118 | apply_sparsity(model) 119 | verify_sparsity(model) 120 | 121 | # verify correctness 122 | # output1 = model(input) 123 | # assert torch.allclose(output0, output1), "Output of model before and after weight sparsification should be equal" 124 | 125 | if dtype: 126 | model = model.to(dtype) 127 | 128 | if args.bsr: 129 | if not args.sparsify_weights: 130 | raise ValueError("--bsr can only be used when --sparsify_weights is also specified.") 131 | apply_bsr(model, blocksize=args.bsr) 132 | 133 | # verify correctness 134 | # output2 = model(input) 135 | # assert torch.allclose(output2, output1), "Output of model before and after changing format to BSR should be equal" 136 | 137 | image = torch.empty(args.batch_size, 3, args.val_crop_size, args.val_crop_size, dtype=dtype, device=device) 138 | # model = torch.compile(model, mode='max-autotune') 139 | return benchmark_in_ms(10, 100, model, image) 140 | 141 | 142 | def get_args_parser(add_help=True): 143 | import argparse 144 | 145 | parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help) 146 | parser.add_argument("--model", default="resnet18", type=str, help="model name") 147 | parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)") 148 | parser.add_argument( 149 | "-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size" 150 | ) 151 | 152 | # Mixed precision training parameters 153 | parser.add_argument( 154 | "--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)" 155 | ) 156 | parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") 157 | parser.add_argument("--weights-path", type=str, help="path of pretrained weights to load") 158 | 159 | # NOTE: sparsity args 160 | parser.add_argument("--sparsity-linear", type=float, default=0.0) 161 | parser.add_argument("--sp-linear-tile-size", type=int, default=1) 162 | parser.add_argument("--sparsity-conv1x1", type=float, default=0.0) 163 | parser.add_argument("--sp-conv1x1-tile-size", type=int, default=1) 164 | parser.add_argument("--sparsity-conv", type=float, default=0.0) 165 | parser.add_argument("--sp-conv-tile-size", type=int, default=1) 166 | parser.add_argument("--skip-last-layer-sparsity", action="store_true", help="Skip applying sparsity to the last linear layer (for vit only)") 167 | parser.add_argument("--skip-first-transformer-sparsity", action="store_true", help="Skip applying sparsity to the first transformer layer (for vit only)") 168 | parser.add_argument('--sparsify-weights', action='store_true', help='Apply weight sparsification in evaluation mode') 169 | parser.add_argument('--bsr', type=int, nargs='?', const=256, default=None, help='Convert sparsified weights to BSR format with optional block size (default: 256)') 170 | parser.add_argument("--bfloat16", action="store_true", help="Use bfloat16") 171 | parser.add_argument("--float16", action="store_true", help="Use float16") 172 | 173 | return parser 174 | 175 | 176 | if __name__ == "__main__": 177 | args = get_args_parser().parse_args() 178 | result = main(args) 179 | print(f"{result} ms", file=sys.stderr) 180 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import os 4 | import sys 5 | import warnings 6 | import hashlib 7 | 8 | import presets 9 | import torch 10 | import torch.utils.data 11 | import torchvision 12 | import utils 13 | from torch import nn 14 | from torchvision.transforms.functional import InterpolationMode 15 | 16 | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) 17 | from supermask import apply_supermask, SupermaskLinear 18 | 19 | 20 | def apply_sparsity(model): 21 | for name, module in model.named_modules(): 22 | if isinstance(module, SupermaskLinear) and "mlp" in name: 23 | module.sparsify_offline() 24 | 25 | 26 | def apply_bsr(model): 27 | for name, module in model.named_modules(): 28 | if isinstance(module, torch.nn.Linear) and "mlp" in name: 29 | try: 30 | module.weight = torch.nn.Parameter(to_bsr(module.weight.data, args.bsr)) 31 | print(f"Converted {name} to bsr format.") 32 | except ValueError as e: 33 | print(f"Unable to convert weight of {name} to bsr format: {e}") 34 | 35 | 36 | def to_bsr(tensor, blocksize): 37 | if tensor.ndim != 2: 38 | raise ValueError("to_bsr expects 2D tensor") 39 | if tensor.size(0) % blocksize or tensor.size(1) % blocksize: 40 | raise ValueError("Tensor dimensions must be divisible by blocksize") 41 | return tensor.to_sparse_bsr(blocksize) 42 | 43 | 44 | def verify_sparsity(model): 45 | for name, module in model.named_modules(): 46 | if isinstance(module, nn.Linear): 47 | total_weights = module.weight.numel() 48 | sparse_weights = (module.weight == 0).sum().item() 49 | sparsity_percentage = (sparse_weights / total_weights) * 100 50 | print(f"Sparsity verified in layer {name}: {sparsity_percentage:.2f}%") 51 | 52 | 53 | def _get_cache_path(filepath): 54 | h = hashlib.sha1(filepath.encode()).hexdigest() 55 | cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt") 56 | cache_path = os.path.expanduser(cache_path) 57 | return cache_path 58 | 59 | 60 | def load_data(valdir, args): 61 | # Data loading code 62 | print("Loading data") 63 | val_resize_size, val_crop_size = ( 64 | args.val_resize_size, 65 | args.val_crop_size 66 | ) 67 | interpolation = InterpolationMode(args.interpolation) 68 | 69 | print("Loading validation data") 70 | cache_path = _get_cache_path(valdir) 71 | if args.cache_dataset and os.path.exists(cache_path): 72 | # Attention, as the transforms are also cached! 73 | print(f"Loading dataset_test from {cache_path}") 74 | dataset_test, _ = torch.load(cache_path) 75 | else: 76 | if args.weights: 77 | weights = torchvision.models.get_weight(args.weights) 78 | preprocessing = weights.transforms() 79 | else: 80 | preprocessing = presets.ClassificationPresetEval( 81 | crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation 82 | ) 83 | 84 | # for META internal 85 | dataset_test = torchvision.datasets.ImageFolder( 86 | valdir, 87 | preprocessing, 88 | ) 89 | # for OSS 90 | # dataset_test = torchvision.datasets.ImageNet( 91 | # valdir, 92 | # split='val', 93 | # transform=preprocessing 94 | # ) 95 | if args.cache_dataset: 96 | print(f"Saving dataset_test to {cache_path}") 97 | utils.mkdir(os.path.dirname(cache_path)) 98 | utils.save_on_master((dataset_test, valdir), cache_path) 99 | 100 | print(f"Number of validation images: {len(dataset_test)}") 101 | 102 | print("Creating data loaders") 103 | if args.distributed: 104 | test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False) 105 | else: 106 | test_sampler = torch.utils.data.SequentialSampler(dataset_test) 107 | 108 | return dataset_test, test_sampler 109 | 110 | 111 | def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix="", args=None): 112 | model.eval() 113 | metric_logger = utils.MetricLogger(delimiter=" ") 114 | header = f"Test: {log_suffix}" 115 | 116 | num_processed_samples = 0 117 | with torch.no_grad(): 118 | for image, target in metric_logger.log_every(data_loader, print_freq, header): 119 | image = image.to(device, non_blocking=True) 120 | target = target.to(device, non_blocking=True) 121 | output = model(image) 122 | loss = criterion(output, target) 123 | 124 | acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) 125 | # FIXME need to take into account that the datasets 126 | # could have been padded in distributed setup 127 | batch_size = image.shape[0] 128 | metric_logger.update(loss=loss.item()) 129 | metric_logger.meters["acc1"].update(acc1.item(), n=batch_size) 130 | metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) 131 | num_processed_samples += batch_size 132 | # gather the stats from all processes 133 | 134 | num_processed_samples = utils.reduce_across_processes(num_processed_samples) 135 | if ( 136 | hasattr(data_loader.dataset, "__len__") 137 | and len(data_loader.dataset) != num_processed_samples 138 | and torch.distributed.get_rank() == 0 139 | ): 140 | # See FIXME above 141 | warnings.warn( 142 | f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} " 143 | "samples were used for the validation, which might bias the results. " 144 | "Try adjusting the batch size and / or the world size. " 145 | "Setting the world size to 1 is always a safe bet." 146 | ) 147 | 148 | metric_logger.synchronize_between_processes() 149 | 150 | print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}") 151 | return metric_logger.acc1.global_avg 152 | 153 | 154 | def main(args): 155 | 156 | utils.init_distributed_mode(args) 157 | print(args) 158 | 159 | device = torch.device(args.device) 160 | 161 | # We disable the cudnn benchmarking because it can noticeably affect the accuracy 162 | torch.backends.cudnn.benchmark = False 163 | torch.backends.cudnn.deterministic = True 164 | 165 | val_dir = os.path.join(args.data_path, "val") 166 | dataset_test, test_sampler = load_data(val_dir, args) 167 | 168 | data_loader_test = torch.utils.data.DataLoader( 169 | dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True 170 | ) 171 | num_classes = len(dataset_test.classes) 172 | 173 | print("Creating model") 174 | model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes) 175 | 176 | apply_supermask( 177 | model, 178 | linear_sparsity=args.sparsity_linear, 179 | linear_sp_tilesize=args.sp_linear_tile_size, 180 | conv1x1_sparsity=args.sparsity_conv1x1, 181 | conv1x1_sp_tilesize=args.sp_conv1x1_tile_size, 182 | conv_sparsity=args.sparsity_conv, 183 | conv_sp_tilesize=args.sp_conv_tile_size, 184 | skip_last_layer_sparsity=args.skip_last_layer_sparsity, 185 | skip_first_transformer_sparsity=args.skip_first_transformer_sparsity, 186 | device=device, 187 | verbose=True, 188 | ) 189 | 190 | model.to(device) 191 | if args.distributed and args.sync_bn: 192 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 193 | 194 | criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) 195 | scaler = torch.cuda.amp.GradScaler() if args.amp else None 196 | 197 | model_without_ddp = model 198 | if args.distributed: 199 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False) 200 | model_without_ddp = model.module 201 | 202 | model_ema = None 203 | if args.model_ema: 204 | # Decay adjustment that aims to keep the decay independent from other hyper-parameters originally proposed at: 205 | # https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123 206 | # 207 | # total_ema_updates = (Dataset_size / n_GPUs) * epochs / (batch_size_per_gpu * EMA_steps) 208 | # We consider constant = Dataset_size for a given dataset/setup and ommit it. Thus: 209 | # adjust = 1 / total_ema_updates ~= n_GPUs * batch_size_per_gpu * EMA_steps / epochs 210 | adjust = args.world_size * args.batch_size * args.model_ema_steps / args.epochs 211 | alpha = 1.0 - args.model_ema_decay 212 | alpha = min(1.0, alpha * adjust) 213 | model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=1.0 - alpha) 214 | 215 | if args.weights_path: 216 | try: 217 | checkpoint = torch.load(args.weights_path, map_location="cpu") 218 | model_without_ddp.load_state_dict(checkpoint["model"]) 219 | if model_ema: 220 | model_ema.load_state_dict(checkpoint["model_ema"]) 221 | if scaler: 222 | scaler.load_state_dict(checkpoint["scaler"]) 223 | print(f"Loaded checkpoint successfully from: {args.weights_path}") 224 | except FileNotFoundError: 225 | raise FileNotFoundError(f"No checkpoint found at {args.weights_path}") 226 | 227 | if args.bsr and not args.sparsify_weights: 228 | raise ValueError("--bsr can only be used when --sparsify_weights is also specified.") 229 | if args.sparsify_weights: 230 | apply_sparsity(model) 231 | verify_sparsity(model) 232 | if args.bsr: 233 | apply_bsr(model) 234 | if model_ema: 235 | evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA", args=args) 236 | else: 237 | evaluate(model, criterion, data_loader_test, device=device) 238 | return 239 | 240 | 241 | def get_args_parser(add_help=True): 242 | import argparse 243 | 244 | parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help) 245 | parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417", type=str, help="dataset path") 246 | parser.add_argument("--model", default="resnet18", type=str, help="model name") 247 | parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)") 248 | parser.add_argument( 249 | "-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size" 250 | ) 251 | parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run") 252 | parser.add_argument( 253 | "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)" 254 | ) 255 | parser.add_argument( 256 | "--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing" 257 | ) 258 | parser.add_argument("--print-freq", default=10, type=int, help="print frequency") 259 | parser.add_argument( 260 | "--cache-dataset", 261 | dest="cache_dataset", 262 | help="Cache the datasets for quicker initialization. It also serializes the transforms", 263 | action="store_true", 264 | ) 265 | parser.add_argument( 266 | "--sync-bn", 267 | dest="sync_bn", 268 | help="Use sync batch norm", 269 | action="store_true", 270 | ) 271 | 272 | # Mixed precision training parameters 273 | parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training") 274 | 275 | # distributed training parameters 276 | parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") 277 | parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") 278 | parser.add_argument( 279 | "--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters" 280 | ) 281 | parser.add_argument( 282 | "--model-ema-steps", 283 | type=int, 284 | default=32, 285 | help="the number of iterations that controls how often to update the EMA model (default: 32)", 286 | ) 287 | parser.add_argument( 288 | "--model-ema-decay", 289 | type=float, 290 | default=0.99998, 291 | help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)", 292 | ) 293 | parser.add_argument( 294 | "--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)" 295 | ) 296 | parser.add_argument( 297 | "--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)" 298 | ) 299 | parser.add_argument( 300 | "--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)" 301 | ) 302 | parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") 303 | parser.add_argument("--weights-path", type=str, help="path of pretrained weights to load") 304 | 305 | # NOTE: sparsity args 306 | parser.add_argument("--sparsity-linear", type=float, default=0.0) 307 | parser.add_argument("--sp-linear-tile-size", type=int, default=1) 308 | parser.add_argument("--sparsity-conv1x1", type=float, default=0.0) 309 | parser.add_argument("--sp-conv1x1-tile-size", type=int, default=1) 310 | parser.add_argument("--sparsity-conv", type=float, default=0.0) 311 | parser.add_argument("--sp-conv-tile-size", type=int, default=1) 312 | parser.add_argument("--skip-last-layer-sparsity", action="store_true", help="Skip applying sparsity to the last linear layer (for vit only)") 313 | parser.add_argument("--skip-first-transformer-sparsity", action="store_true", help="Skip applying sparsity to the first transformer layer (for vit only)") 314 | parser.add_argument('--sparsify-weights', action='store_true', help='Apply weight sparsification in evaluation mode') 315 | parser.add_argument('--bsr', type=int, nargs='?', const=256, default=None, help='Convert sparsified weights to BSR format with optional block size (default: 256)') 316 | 317 | return parser 318 | 319 | 320 | if __name__ == "__main__": 321 | args = get_args_parser().parse_args() 322 | main(args) 323 | -------------------------------------------------------------------------------- /presets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import torch 4 | from torchvision.transforms import autoaugment, transforms 5 | from torchvision.transforms.functional import InterpolationMode 6 | 7 | 8 | class ClassificationPresetTrain: 9 | def __init__( 10 | self, 11 | *, 12 | crop_size, 13 | mean=(0.485, 0.456, 0.406), 14 | std=(0.229, 0.224, 0.225), 15 | interpolation=InterpolationMode.BILINEAR, 16 | hflip_prob=0.5, 17 | auto_augment_policy=None, 18 | ra_magnitude=9, 19 | augmix_severity=3, 20 | random_erase_prob=0.0, 21 | ): 22 | trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)] 23 | if hflip_prob > 0: 24 | trans.append(transforms.RandomHorizontalFlip(hflip_prob)) 25 | if auto_augment_policy is not None: 26 | if auto_augment_policy == "ra": 27 | trans.append(autoaugment.RandAugment(interpolation=interpolation, magnitude=ra_magnitude)) 28 | elif auto_augment_policy == "ta_wide": 29 | trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation)) 30 | elif auto_augment_policy == "augmix": 31 | trans.append(autoaugment.AugMix(interpolation=interpolation, severity=augmix_severity)) 32 | else: 33 | aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) 34 | trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation)) 35 | trans.extend( 36 | [ 37 | transforms.PILToTensor(), 38 | transforms.ConvertImageDtype(torch.float), 39 | transforms.Normalize(mean=mean, std=std), 40 | ] 41 | ) 42 | if random_erase_prob > 0: 43 | trans.append(transforms.RandomErasing(p=random_erase_prob)) 44 | 45 | self.transforms = transforms.Compose(trans) 46 | 47 | def __call__(self, img): 48 | return self.transforms(img) 49 | 50 | 51 | class ClassificationPresetEval: 52 | def __init__( 53 | self, 54 | *, 55 | crop_size, 56 | resize_size=256, 57 | mean=(0.485, 0.456, 0.406), 58 | std=(0.229, 0.224, 0.225), 59 | interpolation=InterpolationMode.BILINEAR, 60 | ): 61 | 62 | self.transforms = transforms.Compose( 63 | [ 64 | transforms.Resize(resize_size, interpolation=interpolation), 65 | transforms.CenterCrop(crop_size), 66 | transforms.PILToTensor(), 67 | transforms.ConvertImageDtype(torch.float), 68 | transforms.Normalize(mean=mean, std=std), 69 | ] 70 | ) 71 | 72 | def __call__(self, img): 73 | return self.transforms(img) 74 | -------------------------------------------------------------------------------- /sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import math 4 | 5 | import torch 6 | import torch.distributed as dist 7 | 8 | 9 | class RASampler(torch.utils.data.Sampler): 10 | """Sampler that restricts data loading to a subset of the dataset for distributed, 11 | with repeated augmentation. 12 | It ensures that different each augmented version of a sample will be visible to a 13 | different process (GPU). 14 | Heavily based on 'torch.utils.data.DistributedSampler'. 15 | 16 | This is borrowed from the DeiT Repo: 17 | https://github.com/facebookresearch/deit/blob/main/samplers.py 18 | """ 19 | 20 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, repetitions=3): 21 | if num_replicas is None: 22 | if not dist.is_available(): 23 | raise RuntimeError("Requires distributed package to be available!") 24 | num_replicas = dist.get_world_size() 25 | if rank is None: 26 | if not dist.is_available(): 27 | raise RuntimeError("Requires distributed package to be available!") 28 | rank = dist.get_rank() 29 | self.dataset = dataset 30 | self.num_replicas = num_replicas 31 | self.rank = rank 32 | self.epoch = 0 33 | self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas)) 34 | self.total_size = self.num_samples * self.num_replicas 35 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) 36 | self.shuffle = shuffle 37 | self.seed = seed 38 | self.repetitions = repetitions 39 | 40 | def __iter__(self): 41 | if self.shuffle: 42 | # Deterministically shuffle based on epoch 43 | g = torch.Generator() 44 | g.manual_seed(self.seed + self.epoch) 45 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 46 | else: 47 | indices = list(range(len(self.dataset))) 48 | 49 | # Add extra samples to make it evenly divisible 50 | indices = [ele for ele in indices for i in range(self.repetitions)] 51 | indices += indices[: (self.total_size - len(indices))] 52 | assert len(indices) == self.total_size 53 | 54 | # Subsample 55 | indices = indices[self.rank : self.total_size : self.num_replicas] 56 | assert len(indices) == self.num_samples 57 | 58 | return iter(indices[: self.num_selected_samples]) 59 | 60 | def __len__(self): 61 | return self.num_selected_samples 62 | 63 | def set_epoch(self, epoch): 64 | self.epoch = epoch 65 | -------------------------------------------------------------------------------- /supermask.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import torch.nn as nn 4 | import math 5 | import torch 6 | from torch.autograd import Variable 7 | import torch.nn.functional as F 8 | import numpy as np 9 | 10 | 11 | # original supermask 12 | scores_min=None 13 | scores_max=9e9 14 | uniform_init_01 = False 15 | 16 | # adjusted supermask, initialize scores with uniform distribution in [0,1], clamp scores in each step in [0,1] 17 | # scores_min=0. 18 | # scores_max=1. 19 | # uniform_init_01 = True 20 | 21 | def percentile(t, q): 22 | """Return the value that is larger than q% of t""" 23 | k = 1 + round(.01 * float(q) * (t.numel() - 1)) 24 | return t.view(-1).kthvalue(k).values.item() 25 | 26 | 27 | def to_bsr(tensor, blocksize=256): 28 | if tensor.ndim != 2: 29 | print("Tensor is not 2D, skipping BSR conversion.") 30 | return tensor 31 | 32 | if tensor.size(0) % blocksize or tensor.size(1) % blocksize: 33 | print("Tensor dimensions are not divisible by blocksize, skipping BSR conversion.") 34 | return tensor 35 | 36 | try: 37 | converted_tensor = tensor.to_sparse_bsr(blocksize=blocksize) 38 | print(f"Converted tensor to BSR format with blocksize: {blocksize}") 39 | return converted_tensor 40 | except ValueError as e: 41 | print(f"Unable to convert tensor to BSR format: {e}") 42 | return tensor 43 | 44 | 45 | class GetSubnet(torch.autograd.Function): 46 | """Supermask STE function""" 47 | @staticmethod 48 | def forward(ctx, scores, zeros, ones, sparsity): 49 | scores.clamp_(min=scores_min,max=scores_max) 50 | k_val = percentile(scores, sparsity*100) 51 | return torch.where(scores < k_val, zeros.to(scores.device), ones.to(scores.device)) 52 | @staticmethod 53 | def backward(ctx, g): 54 | return g, None, None, None 55 | 56 | 57 | class SupermaskLinear(nn.Linear): 58 | """Supermask class for Linear layer""" 59 | def __init__(self, sparsity, fixed_mask, fixed_weight, bitwidth, transform, fixed_transform, *args, **kwargs): 60 | tile_size = kwargs.pop("tile_size", 1) 61 | super(SupermaskLinear, self).__init__(*args, **kwargs) 62 | # initialize the scores 63 | max_sparsity = 1 - (1 / math.prod([math.ceil(k / tile_size) for k in self.weight.size()])) 64 | self.sparsity = sparsity 65 | if self.sparsity > max_sparsity: 66 | print( 67 | f"reducing sparsity from {self.sparsity} to {max_sparsity}", 68 | f"(maximum sparsity for layer with shape {self.weight.size()} and tile size {tile_size})" 69 | ) 70 | self.sparsity = max_sparsity 71 | self.tile_size = tile_size 72 | self.sparsify_weights = False 73 | self.scores = nn.Parameter( 74 | torch.empty( 75 | [max(1, int(math.ceil(wn / tile_size))) for wn in self.weight.size()] 76 | ), 77 | requires_grad=not fixed_mask, 78 | ) 79 | nn.init.uniform_(self.scores) if uniform_init_01 else nn.init.kaiming_uniform_(self.scores, a=math.sqrt(5)) 80 | 81 | # the shift and the scale are transformation parameters 82 | # the actually used weights = self.weight*self.scale+self.shift 83 | # the transformation is activated only for quantized weights 84 | self.shift=nn.Parameter(torch.Tensor(1).fill_(0.), requires_grad=False) 85 | self.scale=nn.Parameter(torch.Tensor(1).fill_(1.), requires_grad=False) 86 | 87 | with torch.no_grad(): 88 | # if bitwidth is None, then use floating point values in self.weight 89 | # if bitwidth is not None, then quantize self.weight into k-bit (k=bitwidth) 90 | # quantized values are -2^(k-1), -2^(k-1)+1, ..., 0, 1, ..., 2^(k-1)-1 91 | # these quantized values are uniformly distributed 92 | if bitwidth is not None: 93 | weights_max = torch.max(self.weight).item() 94 | weights_min = torch.min(self.weight).item() 95 | least_step = (weights_max-weights_min)/pow(2,bitwidth) 96 | left_bound = weights_min-1e-6 97 | right_bound = weights_min+least_step+1e-6 98 | # self.shift=nn.Parameter(torch.Tensor(1).fill_( (weights_min+(pow(2,bitwidth-1)+0.5)*least_step) if transform[0] is None else transform[0] ), requires_grad=not fixed_transform[0]) 99 | # self.scale=nn.Parameter(torch.Tensor(1).fill_( least_step if transform[1] is None else transform[1] ), requires_grad=not fixed_transform[1]) 100 | # for example, if using binary weights (k=1) with -a, +a, set transform = [a,2a]; if using binary weights (k=1) with a, 0, set transform = [0,-a]; 101 | self.shift=nn.Parameter(torch.Tensor(1).fill_( 0. if transform[0] is None else transform[0] ), requires_grad=not fixed_transform[0]) 102 | self.scale=nn.Parameter(torch.Tensor(1).fill_( 1. if transform[1] is None else transform[1] ), requires_grad=not fixed_transform[1]) 103 | for i in range(-int(pow(2,bitwidth-1)),int(pow(2,bitwidth-1))): 104 | self.weight[torch.logical_and(self.weight>left_bound, self.weight<=right_bound)] = i 105 | left_bound = right_bound 106 | right_bound += least_step 107 | 108 | self.weight.requires_grad = not fixed_weight 109 | 110 | def get_mask(self): 111 | subnet = GetSubnet.apply(self.scores, 112 | torch.zeros_like(self.scores), 113 | torch.ones_like(self.scores), 114 | self.sparsity) 115 | 116 | if self.tile_size != 1: 117 | for i, k in enumerate(self.weight.shape): 118 | subnet = subnet.repeat_interleave(self.tile_size, dim=i) 119 | subnet = torch.narrow(subnet, i, 0, k) 120 | 121 | return subnet 122 | 123 | def sparsify_offline(self): 124 | subnet = self.get_mask() 125 | self.weight.data = (self.weight*self.scale+self.shift) * subnet 126 | self.sparsify_weights = True 127 | 128 | def forward(self, x): 129 | if not self.sparsify_weights: 130 | subnet = self.get_mask() 131 | w = (self.weight*self.scale+self.shift) * subnet 132 | else: 133 | w = self.weight.data 134 | return F.linear(x, w, self.bias) 135 | 136 | 137 | class SupermaskConv2d(nn.Conv2d): 138 | """Supermask class for Conv2d layer""" 139 | def __init__(self, sparsity, fixed_mask, fixed_weight, bitwidth, transform, fixed_transform, *args, **kwargs): 140 | tile_size = kwargs.pop("tile_size", 1) 141 | super(SupermaskConv2d, self).__init__(*args, **kwargs) 142 | # initialize the scores 143 | max_sparsity = 1 - (1 / math.prod([math.ceil(k / tile_size) for k in self.weight.size()])) 144 | self.sparsity = sparsity 145 | if self.sparsity > max_sparsity: 146 | print( 147 | f"reducing sparsity from {self.sparsity} to {max_sparsity}", 148 | f"(maximum sparsity for layer with shape {self.weight.size()} and tile size {tile_size})" 149 | ) 150 | self.sparsity = max_sparsity 151 | self.tile_size = tile_size 152 | self.scores = nn.Parameter( 153 | torch.empty( 154 | [max(1, int(math.ceil(wn / tile_size))) for wn in self.weight.size()] 155 | ), 156 | requires_grad=not fixed_mask, 157 | ) 158 | nn.init.uniform_(self.scores) if uniform_init_01 else nn.init.kaiming_uniform_(self.scores, a=math.sqrt(5)) 159 | 160 | # the shift and the scale are transformation parameters 161 | # the actually used weights = self.weight*self.scale+self.shift 162 | # the transformation is activated only for quantized weights 163 | self.shift=nn.Parameter(torch.Tensor(1).fill_(0.), requires_grad=False) 164 | self.scale=nn.Parameter(torch.Tensor(1).fill_(1.), requires_grad=False) 165 | 166 | with torch.no_grad(): 167 | # if bitwidth is None, then use floating point values in self.weight 168 | # if bitwidth is not None, then quantize self.weight into k-bit (k=bitwidth) 169 | # quantized values are -2^(k-1), -2^(k-1)+1, ..., 0, 1, ..., 2^(k-1)-1 170 | # these quantized values are uniformly distributed 171 | if bitwidth is not None: 172 | weights_max = torch.max(self.weight).item() 173 | weights_min = torch.min(self.weight).item() 174 | least_step = (weights_max-weights_min)/pow(2,bitwidth) 175 | left_bound = weights_min-1e-6 176 | right_bound = weights_min+least_step+1e-6 177 | # self.shift=nn.Parameter(torch.Tensor(1).fill_( (weights_min+(pow(2,bitwidth-1)+0.5)*least_step) if transform[0] is None else transform[0] ), requires_grad=not fixed_transform[0]) 178 | # self.scale=nn.Parameter(torch.Tensor(1).fill_( least_step if transform[1] is None else transform[1]), requires_grad=not fixed_transform[1]) 179 | # for example, if using binary weights (k=1) with -a, +a, set transform = [a,2a]; if using binary weights (k=1) with a, 0, set transform = [0,-a]; 180 | self.shift=nn.Parameter(torch.Tensor(1).fill_( 0. if transform[0] is None else transform[0] ), requires_grad=not fixed_transform[0]) 181 | self.scale=nn.Parameter(torch.Tensor(1).fill_( 1. if transform[1] is None else transform[1] ), requires_grad=not fixed_transform[1]) 182 | for i in range(-int(pow(2,bitwidth-1)),int(pow(2,bitwidth-1))): 183 | self.weight[torch.logical_and(self.weight>left_bound, self.weight<=right_bound)] = i 184 | left_bound = right_bound 185 | right_bound += least_step 186 | 187 | self.weight.requires_grad = not fixed_weight 188 | 189 | def forward(self, x): 190 | subnet = GetSubnet.apply(self.scores, 191 | torch.zeros_like(self.scores), 192 | torch.ones_like(self.scores), 193 | self.sparsity) 194 | 195 | if self.tile_size != 1: 196 | for i, k in enumerate(self.weight.shape): 197 | # if k == 1: continue 198 | subnet = subnet.repeat_interleave(self.tile_size, dim=i) 199 | subnet = torch.narrow(subnet, i, 0, k) 200 | 201 | w = (self.weight*self.scale+self.shift) * subnet 202 | return F.conv2d(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups) 203 | 204 | @torch.no_grad() 205 | def set_sparsity(modules, sparsity): 206 | """Set the sparsity for supermask layers""" 207 | sm_idx = 0 208 | for mod in modules: 209 | if isinstance(mod, (SupermaskLinear, SupermaskConv2d)): 210 | mod.sparsity=sparsity[sm_idx] 211 | sm_idx += 1 212 | print(mod) 213 | print('Sparsity: ', mod.sparsity) 214 | 215 | 216 | def apply_supermask( 217 | model, 218 | linear_sparsity=0.0, 219 | linear_sp_tilesize=1, 220 | conv1x1_sparsity=0.0, 221 | conv1x1_sp_tilesize=1, 222 | conv_sparsity=0.0, 223 | conv_sp_tilesize=1, 224 | skip_last_layer_sparsity=False, 225 | skip_first_transformer_sparsity=False, 226 | device="cuda", 227 | verbose=False, 228 | ): 229 | sparsified_modules = {} 230 | 231 | for n, m in model.named_modules(): 232 | # check conditions for skipping sparsity 233 | if skip_last_layer_sparsity and n == "heads.head": 234 | continue 235 | if skip_first_transformer_sparsity and "encoder.layers.encoder_layer_0" in n: 236 | continue 237 | 238 | # convert 1x1 convolutions 239 | if conv1x1_sparsity != 0.0 and isinstance(m, torch.nn.Conv2d) and m.kernel_size == (1, 1): 240 | new_m = SupermaskConv2d( 241 | conv1x1_sparsity, False, False, None, None, None, 242 | m.in_channels, 243 | m.out_channels, 244 | m.kernel_size, 245 | stride=m.stride, 246 | padding=m.padding, 247 | dilation=m.dilation, 248 | groups=m.groups, 249 | bias=m.bias is not None, 250 | padding_mode=m.padding_mode, 251 | device=device, 252 | tile_size=conv1x1_sp_tilesize, 253 | ) 254 | new_m.weight.data.copy_(m.weight.data) 255 | if m.bias is not None: 256 | new_m.bias.data.copy_(m.bias.data) 257 | sparsified_modules[n] = new_m 258 | continue 259 | 260 | # convert all other convolutions (not tested!) 261 | if conv_sparsity != 0.0 and isinstance(m, torch.nn.Conv2d): 262 | new_m = SupermaskConv2d( 263 | conv_sparsity, False, False, None, None, None, 264 | m.in_channels, 265 | m.out_channels, 266 | m.kernel_size, 267 | stride=m.stride, 268 | padding=m.padding, 269 | dilation=m.dilation, 270 | groups=m.groups, 271 | bias=m.bias is not None, 272 | padding_mode=m.padding_mode, 273 | device=device, 274 | tile_size=conv_sp_tilesize, 275 | ) 276 | new_m.weight.data.copy_(m.weight.data) 277 | if m.bias is not None: 278 | new_m.bias.data.copy_(m.bias.data) 279 | sparsified_modules[n] = new_m 280 | continue 281 | 282 | if linear_sparsity != 0.0 and isinstance(m, torch.nn.Linear): 283 | new_m = SupermaskLinear( 284 | linear_sparsity, False, False, None, None, None, 285 | m.in_features, 286 | m.out_features, 287 | bias=m.bias is not None, 288 | device=device, 289 | tile_size=linear_sp_tilesize, 290 | ) 291 | new_m.weight.data.copy_(m.weight.data) 292 | if m.bias is not None: 293 | new_m.bias.data.copy_(m.bias.data) 294 | sparsified_modules[n] = new_m 295 | continue 296 | 297 | # add modules to model 298 | for k, v in sparsified_modules.items(): 299 | sm_name, ch_name = k.rsplit(".", 1) 300 | sm = model.get_submodule(sm_name) 301 | sm.add_module(ch_name, v) 302 | 303 | if verbose: 304 | print(f'sparsified module "{k}" with sparsity={v.sparsity}, tile size={v.tile_size}') 305 | 306 | return model 307 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import datetime 4 | import os 5 | import glob 6 | import sys 7 | import time 8 | import warnings 9 | 10 | import presets 11 | import torch 12 | import torch.utils.data 13 | import torchvision 14 | import transforms 15 | import utils 16 | from sampler import RASampler 17 | from torch import nn 18 | from torch.utils.data.dataloader import default_collate 19 | from torchvision.transforms.functional import InterpolationMode 20 | 21 | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) 22 | from supermask import apply_supermask, SupermaskLinear 23 | 24 | 25 | def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None): 26 | model.train() 27 | metric_logger = utils.MetricLogger(delimiter=" ") 28 | metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}")) 29 | metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}")) 30 | 31 | header = f"Epoch: [{epoch}]" 32 | accumulation_counter = 0 # Counter for tracking accumulated gradients 33 | 34 | for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)): 35 | start_time = time.time() 36 | image, target = image.to(device), target.to(device) 37 | 38 | with torch.cuda.amp.autocast(enabled=scaler is not None): 39 | output = model(image) 40 | loss = criterion(output, target) / args.accumulation_steps # Scale loss 41 | 42 | if scaler is not None: 43 | scaler.scale(loss).backward() 44 | else: 45 | loss.backward() 46 | 47 | accumulation_counter += 1 48 | 49 | if accumulation_counter % args.accumulation_steps == 0: 50 | if scaler is not None: 51 | if args.clip_grad_norm is not None: 52 | scaler.unscale_(optimizer) # Unscale gradients before clipping 53 | nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm) 54 | scaler.step(optimizer) 55 | scaler.update() 56 | else: 57 | if args.clip_grad_norm is not None: 58 | nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm) 59 | optimizer.step() 60 | 61 | optimizer.zero_grad() # Zero out gradients after optimization step 62 | 63 | if model_ema and i % args.model_ema_steps == 0: 64 | model_ema.update_parameters(model) 65 | if epoch < args.lr_warmup_epochs: 66 | model_ema.n_averaged.fill_(0) 67 | 68 | acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) 69 | batch_size = image.shape[0] 70 | metric_logger.update(loss=loss.item() * args.accumulation_steps, lr=optimizer.param_groups[0]["lr"]) # Scale back up for logging 71 | metric_logger.meters["acc1"].update(acc1.item(), n=batch_size) 72 | metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) 73 | metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time)) 74 | 75 | 76 | def apply_sparsity(model): 77 | for module in model.modules(): 78 | if isinstance(module, SupermaskLinear): 79 | module.sparsify_offline() 80 | 81 | 82 | def apply_bsr(model): 83 | for name, module in model.named_modules(): 84 | if isinstance(module, torch.nn.Linear): 85 | try: 86 | module.weight = torch.nn.Parameter(to_bsr(module.weight.data, args.bsr)) 87 | print(f"Converted {name} to bsr format.") 88 | except ValueError as e: 89 | print(f"Unable to convert weight of {name} to bsr format: {e}") 90 | 91 | 92 | def to_bsr(tensor, blocksize): 93 | if tensor.ndim != 2: 94 | raise ValueError("to_bsr expects 2D tensor") 95 | if tensor.size(0) % blocksize or tensor.size(1) % blocksize: 96 | raise ValueError("Tensor dimensions must be divisible by blocksize") 97 | return tensor.to_sparse_bsr(blocksize) 98 | 99 | 100 | def verify_sparsity(model): 101 | for name, module in model.named_modules(): 102 | if isinstance(module, nn.Linear): 103 | total_weights = module.weight.numel() 104 | sparse_weights = (module.weight == 0).sum().item() 105 | sparsity_percentage = (sparse_weights / total_weights) * 100 106 | print(f"Sparsity verified in layer {name}: {sparsity_percentage:.2f}%") 107 | 108 | 109 | def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=""): 110 | model.eval() 111 | metric_logger = utils.MetricLogger(delimiter=" ") 112 | header = f"Test: {log_suffix}" 113 | 114 | num_processed_samples = 0 115 | with torch.inference_mode(): 116 | for image, target in metric_logger.log_every(data_loader, print_freq, header): 117 | image = image.to(device, non_blocking=True) 118 | target = target.to(device, non_blocking=True) 119 | output = model(image) 120 | loss = criterion(output, target) 121 | 122 | acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) 123 | # FIXME need to take into account that the datasets 124 | # could have been padded in distributed setup 125 | batch_size = image.shape[0] 126 | metric_logger.update(loss=loss.item()) 127 | metric_logger.meters["acc1"].update(acc1.item(), n=batch_size) 128 | metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) 129 | num_processed_samples += batch_size 130 | # gather the stats from all processes 131 | 132 | num_processed_samples = utils.reduce_across_processes(num_processed_samples) 133 | if ( 134 | hasattr(data_loader.dataset, "__len__") 135 | and len(data_loader.dataset) != num_processed_samples 136 | and torch.distributed.get_rank() == 0 137 | ): 138 | # See FIXME above 139 | warnings.warn( 140 | f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} " 141 | "samples were used for the validation, which might bias the results. " 142 | "Try adjusting the batch size and / or the world size. " 143 | "Setting the world size to 1 is always a safe bet." 144 | ) 145 | 146 | metric_logger.synchronize_between_processes() 147 | 148 | print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}") 149 | return metric_logger.acc1.global_avg 150 | 151 | 152 | def _get_cache_path(filepath): 153 | import hashlib 154 | 155 | h = hashlib.sha1(filepath.encode()).hexdigest() 156 | cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt") 157 | cache_path = os.path.expanduser(cache_path) 158 | return cache_path 159 | 160 | 161 | def load_data(traindir, valdir, args): 162 | # Data loading code 163 | print("Loading data") 164 | val_resize_size, val_crop_size, train_crop_size = ( 165 | args.val_resize_size, 166 | args.val_crop_size, 167 | args.train_crop_size, 168 | ) 169 | interpolation = InterpolationMode(args.interpolation) 170 | 171 | print("Loading training data") 172 | st = time.time() 173 | cache_path = _get_cache_path(traindir) 174 | if args.cache_dataset and os.path.exists(cache_path): 175 | # Attention, as the transforms are also cached! 176 | print(f"Loading dataset_train from {cache_path}") 177 | dataset, _ = torch.load(cache_path) 178 | else: 179 | auto_augment_policy = getattr(args, "auto_augment", None) 180 | random_erase_prob = getattr(args, "random_erase", 0.0) 181 | ra_magnitude = args.ra_magnitude 182 | augmix_severity = args.augmix_severity 183 | dataset = torchvision.datasets.ImageFolder( 184 | traindir, 185 | presets.ClassificationPresetTrain( 186 | crop_size=train_crop_size, 187 | interpolation=interpolation, 188 | auto_augment_policy=auto_augment_policy, 189 | random_erase_prob=random_erase_prob, 190 | ra_magnitude=ra_magnitude, 191 | augmix_severity=augmix_severity, 192 | ), 193 | ) 194 | if args.cache_dataset: 195 | print(f"Saving dataset_train to {cache_path}") 196 | utils.mkdir(os.path.dirname(cache_path)) 197 | utils.save_on_master((dataset, traindir), cache_path) 198 | print("Took", time.time() - st) 199 | 200 | print("Loading validation data") 201 | cache_path = _get_cache_path(valdir) 202 | if args.cache_dataset and os.path.exists(cache_path): 203 | # Attention, as the transforms are also cached! 204 | print(f"Loading dataset_test from {cache_path}") 205 | dataset_test, _ = torch.load(cache_path) 206 | else: 207 | if args.weights and args.test_only: 208 | weights = torchvision.models.get_weight(args.weights) 209 | preprocessing = weights.transforms() 210 | else: 211 | preprocessing = presets.ClassificationPresetEval( 212 | crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation 213 | ) 214 | 215 | dataset_test = torchvision.datasets.ImageFolder( 216 | valdir, 217 | preprocessing, 218 | ) 219 | if args.cache_dataset: 220 | print(f"Saving dataset_test to {cache_path}") 221 | utils.mkdir(os.path.dirname(cache_path)) 222 | utils.save_on_master((dataset_test, valdir), cache_path) 223 | 224 | print(f"Number of training images: {len(dataset)}") 225 | print(f"Number of validation images: {len(dataset_test)}") 226 | 227 | print("Creating data loaders") 228 | if args.distributed: 229 | if hasattr(args, "ra_sampler") and args.ra_sampler: 230 | train_sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps) 231 | else: 232 | train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) 233 | test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False) 234 | else: 235 | train_sampler = torch.utils.data.RandomSampler(dataset) 236 | test_sampler = torch.utils.data.SequentialSampler(dataset_test) 237 | 238 | return dataset, dataset_test, train_sampler, test_sampler 239 | 240 | 241 | def main(args): 242 | if args.output_dir: 243 | utils.mkdir(args.output_dir) 244 | 245 | utils.init_distributed_mode(args) 246 | print(args) 247 | 248 | device = torch.device(args.device) 249 | 250 | if args.use_deterministic_algorithms: 251 | torch.backends.cudnn.benchmark = False 252 | torch.use_deterministic_algorithms(True) 253 | else: 254 | torch.backends.cudnn.benchmark = True 255 | 256 | train_dir = os.path.join(args.data_path, "train") 257 | val_dir = os.path.join(args.data_path, "val") 258 | dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args) 259 | 260 | collate_fn = None 261 | num_classes = len(dataset.classes) 262 | mixup_transforms = [] 263 | if args.mixup_alpha > 0.0: 264 | mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha)) 265 | if args.cutmix_alpha > 0.0: 266 | mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha)) 267 | if mixup_transforms: 268 | mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms) 269 | 270 | def collate_fn(batch): 271 | return mixupcutmix(*default_collate(batch)) 272 | 273 | data_loader = torch.utils.data.DataLoader( 274 | dataset, 275 | batch_size=args.batch_size, 276 | sampler=train_sampler, 277 | num_workers=args.workers, 278 | pin_memory=True, 279 | collate_fn=collate_fn, 280 | ) 281 | data_loader_test = torch.utils.data.DataLoader( 282 | dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True 283 | ) 284 | 285 | print("Creating model") 286 | model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes) 287 | 288 | if args.weights_path is not None: 289 | sd = torch.load(args.weights_path, map_location="cpu") 290 | model.load_state_dict(sd) 291 | 292 | if args.sparsify_weights and not args.test_only: 293 | raise ValueError("--sparsify-weights can only be used when --test-only is also specified.") 294 | 295 | apply_supermask( 296 | model, 297 | linear_sparsity=args.sparsity_linear, 298 | linear_sp_tilesize=args.sp_linear_tile_size, 299 | conv1x1_sparsity=args.sparsity_conv1x1, 300 | conv1x1_sp_tilesize=args.sp_conv1x1_tile_size, 301 | conv_sparsity=args.sparsity_conv, 302 | conv_sp_tilesize=args.sp_conv_tile_size, 303 | skip_last_layer_sparsity=args.skip_last_layer_sparsity, 304 | skip_first_transformer_sparsity=args.skip_first_transformer_sparsity, 305 | device=device, 306 | verbose=True, 307 | ) 308 | 309 | model.to(device) 310 | if args.distributed and args.sync_bn: 311 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 312 | 313 | criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) 314 | 315 | custom_keys_weight_decay = [] 316 | if args.bias_weight_decay is not None: 317 | custom_keys_weight_decay.append(("bias", args.bias_weight_decay)) 318 | if args.transformer_embedding_decay is not None: 319 | for key in ["class_token", "position_embedding", "relative_position_bias_table"]: 320 | custom_keys_weight_decay.append((key, args.transformer_embedding_decay)) 321 | parameters = utils.set_weight_decay( 322 | model, 323 | args.weight_decay, 324 | norm_weight_decay=args.norm_weight_decay, 325 | custom_keys_weight_decay=custom_keys_weight_decay if len(custom_keys_weight_decay) > 0 else None, 326 | ) 327 | 328 | opt_name = args.opt.lower() 329 | if opt_name.startswith("sgd"): 330 | optimizer = torch.optim.SGD( 331 | parameters, 332 | lr=args.lr, 333 | momentum=args.momentum, 334 | weight_decay=args.weight_decay, 335 | nesterov="nesterov" in opt_name, 336 | ) 337 | elif opt_name == "rmsprop": 338 | optimizer = torch.optim.RMSprop( 339 | parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9 340 | ) 341 | elif opt_name == "adamw": 342 | optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay) 343 | else: 344 | raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.") 345 | 346 | scaler = torch.cuda.amp.GradScaler() if args.amp else None 347 | 348 | args.lr_scheduler = args.lr_scheduler.lower() 349 | if args.lr_scheduler == "steplr": 350 | main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) 351 | elif args.lr_scheduler == "cosineannealinglr": 352 | main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 353 | optimizer, T_max=args.epochs - args.lr_warmup_epochs, eta_min=args.lr_min 354 | ) 355 | elif args.lr_scheduler == "exponentiallr": 356 | main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma) 357 | else: 358 | raise RuntimeError( 359 | f"Invalid lr scheduler '{args.lr_scheduler}'. Only StepLR, CosineAnnealingLR and ExponentialLR " 360 | "are supported." 361 | ) 362 | 363 | if args.lr_warmup_epochs > 0: 364 | if args.lr_warmup_method == "linear": 365 | warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR( 366 | optimizer, start_factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs 367 | ) 368 | elif args.lr_warmup_method == "constant": 369 | warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR( 370 | optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs 371 | ) 372 | else: 373 | raise RuntimeError( 374 | f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported." 375 | ) 376 | lr_scheduler = torch.optim.lr_scheduler.SequentialLR( 377 | optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs] 378 | ) 379 | else: 380 | lr_scheduler = main_lr_scheduler 381 | 382 | model_without_ddp = model 383 | if args.distributed: 384 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) 385 | model_without_ddp = model.module 386 | 387 | model_ema = None 388 | if args.model_ema: 389 | # Decay adjustment that aims to keep the decay independent from other hyper-parameters originally proposed at: 390 | # https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123 391 | # 392 | # total_ema_updates = (Dataset_size / n_GPUs) * epochs / (batch_size_per_gpu * EMA_steps) 393 | # We consider constant = Dataset_size for a given dataset/setup and ommit it. Thus: 394 | # adjust = 1 / total_ema_updates ~= n_GPUs * batch_size_per_gpu * EMA_steps / epochs 395 | adjust = args.world_size * args.batch_size * args.model_ema_steps / args.epochs 396 | alpha = 1.0 - args.model_ema_decay 397 | alpha = min(1.0, alpha * adjust) 398 | model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=1.0 - alpha) 399 | 400 | #TODO: need to test resume functionality 401 | if args.resume: 402 | checkpoint_pattern = os.path.join(args.output_dir, "model_*.pth") 403 | checkpoint_files = glob.glob(checkpoint_pattern) 404 | epochs = [int(f.split('_')[-1].split('.')[0]) for f in checkpoint_files] 405 | if epochs: 406 | latest_epoch = max(epochs) 407 | latest_checkpoint = os.path.join(args.output_dir, f"model_{latest_epoch}.pth") 408 | try: 409 | checkpoint = torch.load(latest_checkpoint, map_location="cpu") 410 | model_without_ddp.load_state_dict(checkpoint["model"]) 411 | if not args.test_only: 412 | optimizer.load_state_dict(checkpoint["optimizer"]) 413 | lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) 414 | args.start_epoch = checkpoint["epoch"] + 1 415 | if model_ema: 416 | model_ema.load_state_dict(checkpoint["model_ema"]) 417 | if scaler: 418 | scaler.load_state_dict(checkpoint["scaler"]) 419 | print(f"Resumed training from epoch {args.start_epoch}.") 420 | except FileNotFoundError: 421 | print(f"No checkpoint found at {latest_checkpoint}. Starting training from scratch.") 422 | args.start_epoch = 0 423 | else: 424 | print("No checkpoint found. Starting training from scratch.") 425 | args.start_epoch = 0 426 | else: 427 | args.start_epoch = 0 428 | 429 | if args.test_only: 430 | # We disable the cudnn benchmarking because it can noticeably affect the accuracy 431 | torch.backends.cudnn.benchmark = False 432 | torch.backends.cudnn.deterministic = True 433 | if args.bsr and not args.sparsify_weights: 434 | raise ValueError("--bsr can only be used when --sparsify_weights is also specified.") 435 | if args.sparsify_weights: 436 | apply_sparsity(model) 437 | verify_sparsity(model) 438 | if args.bsr: 439 | apply_bsr(model) 440 | 441 | if model_ema: 442 | evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA") 443 | else: 444 | evaluate(model, criterion, data_loader_test, device=device) 445 | return 446 | 447 | print("Start training") 448 | start_time = time.time() 449 | for epoch in range(args.start_epoch, args.epochs): 450 | if args.distributed: 451 | train_sampler.set_epoch(epoch) 452 | train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler) 453 | lr_scheduler.step() 454 | evaluate(model, criterion, data_loader_test, device=device) 455 | if model_ema: 456 | evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA") 457 | if args.output_dir: 458 | checkpoint = { 459 | "model": model_without_ddp.state_dict(), 460 | "optimizer": optimizer.state_dict(), 461 | "lr_scheduler": lr_scheduler.state_dict(), 462 | "epoch": epoch, 463 | "args": args, 464 | } 465 | if model_ema: 466 | checkpoint["model_ema"] = model_ema.state_dict() 467 | if scaler: 468 | checkpoint["scaler"] = scaler.state_dict() 469 | utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth")) 470 | utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth")) 471 | 472 | total_time = time.time() - start_time 473 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 474 | print(f"Training time {total_time_str}") 475 | 476 | 477 | def get_args_parser(add_help=True): 478 | import argparse 479 | 480 | parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help) 481 | parser.add_argument("--data-path", type=str, help="dataset path") 482 | parser.add_argument("--model", default="resnet18", type=str, help="model name") 483 | parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)") 484 | parser.add_argument( 485 | "-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size" 486 | ) 487 | parser.add_argument("--accumulation-steps", default=1, type=int, help="Number of steps to accumulate gradients over") 488 | parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run") 489 | parser.add_argument( 490 | "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)" 491 | ) 492 | parser.add_argument("--opt", default="sgd", type=str, help="optimizer") 493 | parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate") 494 | parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum") 495 | parser.add_argument( 496 | "--wd", 497 | "--weight-decay", 498 | default=1e-4, 499 | type=float, 500 | metavar="W", 501 | help="weight decay (default: 1e-4)", 502 | dest="weight_decay", 503 | ) 504 | parser.add_argument( 505 | "--norm-weight-decay", 506 | default=None, 507 | type=float, 508 | help="weight decay for Normalization layers (default: None, same value as --wd)", 509 | ) 510 | parser.add_argument( 511 | "--bias-weight-decay", 512 | default=None, 513 | type=float, 514 | help="weight decay for bias parameters of all layers (default: None, same value as --wd)", 515 | ) 516 | parser.add_argument( 517 | "--transformer-embedding-decay", 518 | default=None, 519 | type=float, 520 | help="weight decay for embedding parameters for vision transformer models (default: None, same value as --wd)", 521 | ) 522 | parser.add_argument( 523 | "--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing" 524 | ) 525 | parser.add_argument("--mixup-alpha", default=0.0, type=float, help="mixup alpha (default: 0.0)") 526 | parser.add_argument("--cutmix-alpha", default=0.0, type=float, help="cutmix alpha (default: 0.0)") 527 | parser.add_argument("--lr-scheduler", default="steplr", type=str, help="the lr scheduler (default: steplr)") 528 | parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)") 529 | parser.add_argument( 530 | "--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)" 531 | ) 532 | parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr") 533 | parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs") 534 | parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma") 535 | parser.add_argument("--lr-min", default=0.0, type=float, help="minimum lr of lr schedule (default: 0.0)") 536 | parser.add_argument("--print-freq", default=10, type=int, help="print frequency") 537 | parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs") 538 | parser.add_argument('--resume', action='store_true', help='Resumes training from latest available checkpoint ("model_.pth")') 539 | parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch") 540 | parser.add_argument( 541 | "--cache-dataset", 542 | dest="cache_dataset", 543 | help="Cache the datasets for quicker initialization. It also serializes the transforms", 544 | action="store_true", 545 | ) 546 | parser.add_argument( 547 | "--sync-bn", 548 | dest="sync_bn", 549 | help="Use sync batch norm", 550 | action="store_true", 551 | ) 552 | parser.add_argument( 553 | "--test-only", 554 | dest="test_only", 555 | help="Only test the model", 556 | action="store_true", 557 | ) 558 | parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)") 559 | parser.add_argument("--ra-magnitude", default=9, type=int, help="magnitude of auto augment policy") 560 | parser.add_argument("--augmix-severity", default=3, type=int, help="severity of augmix policy") 561 | parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)") 562 | 563 | # Mixed precision training parameters 564 | parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training") 565 | 566 | # distributed training parameters 567 | parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") 568 | parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") 569 | parser.add_argument( 570 | "--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters" 571 | ) 572 | parser.add_argument( 573 | "--model-ema-steps", 574 | type=int, 575 | default=32, 576 | help="the number of iterations that controls how often to update the EMA model (default: 32)", 577 | ) 578 | parser.add_argument( 579 | "--model-ema-decay", 580 | type=float, 581 | default=0.99998, 582 | help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)", 583 | ) 584 | parser.add_argument( 585 | "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only." 586 | ) 587 | parser.add_argument( 588 | "--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)" 589 | ) 590 | parser.add_argument( 591 | "--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)" 592 | ) 593 | parser.add_argument( 594 | "--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)" 595 | ) 596 | parser.add_argument( 597 | "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)" 598 | ) 599 | parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)") 600 | parser.add_argument("--ra-sampler", action="store_true", help="whether to use Repeated Augmentation in training") 601 | parser.add_argument( 602 | "--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)" 603 | ) 604 | parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") 605 | parser.add_argument("--weights-path", type=str) 606 | 607 | # NOTE: sparsity args 608 | parser.add_argument("--sparsity-linear", type=float, default=0.0) 609 | parser.add_argument("--sp-linear-tile-size", type=int, default=1) 610 | parser.add_argument("--sparsity-conv1x1", type=float, default=0.0) 611 | parser.add_argument("--sp-conv1x1-tile-size", type=int, default=1) 612 | parser.add_argument("--sparsity-conv", type=float, default=0.0) 613 | parser.add_argument("--sp-conv-tile-size", type=int, default=1) 614 | parser.add_argument("--skip-last-layer-sparsity", action="store_true", help="Skip applying sparsity to the last linear layer (for vit only)") 615 | parser.add_argument("--skip-first-transformer-sparsity", action="store_true", help="Skip applying sparsity to the first transformer layer (for vit only)") 616 | parser.add_argument('--sparsify-weights', action='store_true', help='Apply weight sparsification in evaluation mode') 617 | parser.add_argument('--bsr', type=int, nargs='?', const=256, default=None, help='Convert sparsified weights to BSR format with optional block size (default: 256)') 618 | 619 | 620 | return parser 621 | 622 | 623 | if __name__ == "__main__": 624 | args = get_args_parser().parse_args() 625 | main(args) 626 | -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import math 4 | from typing import Tuple 5 | 6 | import torch 7 | from torch import Tensor 8 | from torchvision.transforms import functional as F 9 | 10 | 11 | class RandomMixup(torch.nn.Module): 12 | """Randomly apply Mixup to the provided batch and targets. 13 | The class implements the data augmentations as described in the paper 14 | `"mixup: Beyond Empirical Risk Minimization" `_. 15 | 16 | Args: 17 | num_classes (int): number of classes used for one-hot encoding. 18 | p (float): probability of the batch being transformed. Default value is 0.5. 19 | alpha (float): hyperparameter of the Beta distribution used for mixup. 20 | Default value is 1.0. 21 | inplace (bool): boolean to make this transform inplace. Default set to False. 22 | """ 23 | 24 | def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None: 25 | super().__init__() 26 | 27 | if num_classes < 1: 28 | raise ValueError( 29 | f"Please provide a valid positive value for the num_classes. Got num_classes={num_classes}" 30 | ) 31 | 32 | if alpha <= 0: 33 | raise ValueError("Alpha param can't be zero.") 34 | 35 | self.num_classes = num_classes 36 | self.p = p 37 | self.alpha = alpha 38 | self.inplace = inplace 39 | 40 | def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: 41 | """ 42 | Args: 43 | batch (Tensor): Float tensor of size (B, C, H, W) 44 | target (Tensor): Integer tensor of size (B, ) 45 | 46 | Returns: 47 | Tensor: Randomly transformed batch. 48 | """ 49 | if batch.ndim != 4: 50 | raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}") 51 | if target.ndim != 1: 52 | raise ValueError(f"Target ndim should be 1. Got {target.ndim}") 53 | if not batch.is_floating_point(): 54 | raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.") 55 | if target.dtype != torch.int64: 56 | raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}") 57 | 58 | if not self.inplace: 59 | batch = batch.clone() 60 | target = target.clone() 61 | 62 | if target.ndim == 1: 63 | target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype) 64 | 65 | if torch.rand(1).item() >= self.p: 66 | return batch, target 67 | 68 | # It's faster to roll the batch by one instead of shuffling it to create image pairs 69 | batch_rolled = batch.roll(1, 0) 70 | target_rolled = target.roll(1, 0) 71 | 72 | # Implemented as on mixup paper, page 3. 73 | lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0]) 74 | batch_rolled.mul_(1.0 - lambda_param) 75 | batch.mul_(lambda_param).add_(batch_rolled) 76 | 77 | target_rolled.mul_(1.0 - lambda_param) 78 | target.mul_(lambda_param).add_(target_rolled) 79 | 80 | return batch, target 81 | 82 | def __repr__(self) -> str: 83 | s = ( 84 | f"{self.__class__.__name__}(" 85 | f"num_classes={self.num_classes}" 86 | f", p={self.p}" 87 | f", alpha={self.alpha}" 88 | f", inplace={self.inplace}" 89 | f")" 90 | ) 91 | return s 92 | 93 | 94 | class RandomCutmix(torch.nn.Module): 95 | """Randomly apply Cutmix to the provided batch and targets. 96 | The class implements the data augmentations as described in the paper 97 | `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" 98 | `_. 99 | 100 | Args: 101 | num_classes (int): number of classes used for one-hot encoding. 102 | p (float): probability of the batch being transformed. Default value is 0.5. 103 | alpha (float): hyperparameter of the Beta distribution used for cutmix. 104 | Default value is 1.0. 105 | inplace (bool): boolean to make this transform inplace. Default set to False. 106 | """ 107 | 108 | def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None: 109 | super().__init__() 110 | if num_classes < 1: 111 | raise ValueError("Please provide a valid positive value for the num_classes.") 112 | if alpha <= 0: 113 | raise ValueError("Alpha param can't be zero.") 114 | 115 | self.num_classes = num_classes 116 | self.p = p 117 | self.alpha = alpha 118 | self.inplace = inplace 119 | 120 | def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: 121 | """ 122 | Args: 123 | batch (Tensor): Float tensor of size (B, C, H, W) 124 | target (Tensor): Integer tensor of size (B, ) 125 | 126 | Returns: 127 | Tensor: Randomly transformed batch. 128 | """ 129 | if batch.ndim != 4: 130 | raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}") 131 | if target.ndim != 1: 132 | raise ValueError(f"Target ndim should be 1. Got {target.ndim}") 133 | if not batch.is_floating_point(): 134 | raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.") 135 | if target.dtype != torch.int64: 136 | raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}") 137 | 138 | if not self.inplace: 139 | batch = batch.clone() 140 | target = target.clone() 141 | 142 | if target.ndim == 1: 143 | target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype) 144 | 145 | if torch.rand(1).item() >= self.p: 146 | return batch, target 147 | 148 | # It's faster to roll the batch by one instead of shuffling it to create image pairs 149 | batch_rolled = batch.roll(1, 0) 150 | target_rolled = target.roll(1, 0) 151 | 152 | # Implemented as on cutmix paper, page 12 (with minor corrections on typos). 153 | lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0]) 154 | _, H, W = F.get_dimensions(batch) 155 | 156 | r_x = torch.randint(W, (1,)) 157 | r_y = torch.randint(H, (1,)) 158 | 159 | r = 0.5 * math.sqrt(1.0 - lambda_param) 160 | r_w_half = int(r * W) 161 | r_h_half = int(r * H) 162 | 163 | x1 = int(torch.clamp(r_x - r_w_half, min=0)) 164 | y1 = int(torch.clamp(r_y - r_h_half, min=0)) 165 | x2 = int(torch.clamp(r_x + r_w_half, max=W)) 166 | y2 = int(torch.clamp(r_y + r_h_half, max=H)) 167 | 168 | batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2] 169 | lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) 170 | 171 | target_rolled.mul_(1.0 - lambda_param) 172 | target.mul_(lambda_param).add_(target_rolled) 173 | 174 | return batch, target 175 | 176 | def __repr__(self) -> str: 177 | s = ( 178 | f"{self.__class__.__name__}(" 179 | f"num_classes={self.num_classes}" 180 | f", p={self.p}" 181 | f", alpha={self.alpha}" 182 | f", inplace={self.inplace}" 183 | f")" 184 | ) 185 | return s 186 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import copy 4 | import datetime 5 | import errno 6 | import hashlib 7 | import os 8 | import time 9 | from collections import defaultdict, deque, OrderedDict 10 | from typing import List, Optional, Tuple 11 | 12 | import torch 13 | import torch.distributed as dist 14 | 15 | 16 | class SmoothedValue: 17 | """Track a series of values and provide access to smoothed values over a 18 | window or the global series average. 19 | """ 20 | 21 | def __init__(self, window_size=20, fmt=None): 22 | if fmt is None: 23 | fmt = "{median:.4f} ({global_avg:.4f})" 24 | self.deque = deque(maxlen=window_size) 25 | self.total = 0.0 26 | self.count = 0 27 | self.fmt = fmt 28 | 29 | def update(self, value, n=1): 30 | self.deque.append(value) 31 | self.count += n 32 | self.total += value * n 33 | 34 | def synchronize_between_processes(self): 35 | """ 36 | Warning: does not synchronize the deque! 37 | """ 38 | t = reduce_across_processes([self.count, self.total]) 39 | t = t.tolist() 40 | self.count = int(t[0]) 41 | self.total = t[1] 42 | 43 | @property 44 | def median(self): 45 | d = torch.tensor(list(self.deque)) 46 | return d.median().item() 47 | 48 | @property 49 | def avg(self): 50 | d = torch.tensor(list(self.deque), dtype=torch.float32) 51 | return d.mean().item() 52 | 53 | @property 54 | def global_avg(self): 55 | return self.total / self.count 56 | 57 | @property 58 | def max(self): 59 | return max(self.deque) 60 | 61 | @property 62 | def value(self): 63 | return self.deque[-1] 64 | 65 | def __str__(self): 66 | return self.fmt.format( 67 | median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value 68 | ) 69 | 70 | 71 | class MetricLogger: 72 | def __init__(self, delimiter="\t"): 73 | self.meters = defaultdict(SmoothedValue) 74 | self.delimiter = delimiter 75 | 76 | def update(self, **kwargs): 77 | for k, v in kwargs.items(): 78 | if isinstance(v, torch.Tensor): 79 | v = v.item() 80 | assert isinstance(v, (float, int)) 81 | self.meters[k].update(v) 82 | 83 | def __getattr__(self, attr): 84 | if attr in self.meters: 85 | return self.meters[attr] 86 | if attr in self.__dict__: 87 | return self.__dict__[attr] 88 | raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'") 89 | 90 | def __str__(self): 91 | loss_str = [] 92 | for name, meter in self.meters.items(): 93 | loss_str.append(f"{name}: {str(meter)}") 94 | return self.delimiter.join(loss_str) 95 | 96 | def synchronize_between_processes(self): 97 | for meter in self.meters.values(): 98 | meter.synchronize_between_processes() 99 | 100 | def add_meter(self, name, meter): 101 | self.meters[name] = meter 102 | 103 | def log_every(self, iterable, print_freq, header=None): 104 | i = 0 105 | if not header: 106 | header = "" 107 | start_time = time.time() 108 | end = time.time() 109 | iter_time = SmoothedValue(fmt="{avg:.4f}") 110 | data_time = SmoothedValue(fmt="{avg:.4f}") 111 | space_fmt = ":" + str(len(str(len(iterable)))) + "d" 112 | if torch.cuda.is_available(): 113 | log_msg = self.delimiter.join( 114 | [ 115 | header, 116 | "[{0" + space_fmt + "}/{1}]", 117 | "eta: {eta}", 118 | "{meters}", 119 | "time: {time}", 120 | "data: {data}", 121 | "max mem: {memory:.0f}", 122 | ] 123 | ) 124 | else: 125 | log_msg = self.delimiter.join( 126 | [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"] 127 | ) 128 | MB = 1024.0 * 1024.0 129 | for obj in iterable: 130 | data_time.update(time.time() - end) 131 | yield obj 132 | iter_time.update(time.time() - end) 133 | if i % print_freq == 0: 134 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 135 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 136 | if torch.cuda.is_available(): 137 | print( 138 | log_msg.format( 139 | i, 140 | len(iterable), 141 | eta=eta_string, 142 | meters=str(self), 143 | time=str(iter_time), 144 | data=str(data_time), 145 | memory=torch.cuda.max_memory_allocated() / MB, 146 | ) 147 | ) 148 | else: 149 | print( 150 | log_msg.format( 151 | i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time) 152 | ) 153 | ) 154 | i += 1 155 | end = time.time() 156 | total_time = time.time() - start_time 157 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 158 | print(f"{header} Total time: {total_time_str}") 159 | 160 | 161 | class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel): 162 | """Maintains moving averages of model parameters using an exponential decay. 163 | ``ema_avg = decay * avg_model_param + (1 - decay) * model_param`` 164 | `torch.optim.swa_utils.AveragedModel `_ 165 | is used to compute the EMA. 166 | """ 167 | 168 | def __init__(self, model, decay, device="cpu"): 169 | def ema_avg(avg_model_param, model_param, num_averaged): 170 | return decay * avg_model_param + (1 - decay) * model_param 171 | 172 | super().__init__(model, device, ema_avg, use_buffers=True) 173 | 174 | 175 | def accuracy(output, target, topk=(1,)): 176 | """Computes the accuracy over the k top predictions for the specified values of k""" 177 | with torch.inference_mode(): 178 | maxk = max(topk) 179 | batch_size = target.size(0) 180 | if target.ndim == 2: 181 | target = target.max(dim=1)[1] 182 | 183 | _, pred = output.topk(maxk, 1, True, True) 184 | pred = pred.t() 185 | correct = pred.eq(target[None]) 186 | 187 | res = [] 188 | for k in topk: 189 | correct_k = correct[:k].flatten().sum(dtype=torch.float32) 190 | res.append(correct_k * (100.0 / batch_size)) 191 | return res 192 | 193 | 194 | def mkdir(path): 195 | try: 196 | os.makedirs(path) 197 | except OSError as e: 198 | if e.errno != errno.EEXIST: 199 | raise 200 | 201 | 202 | def setup_for_distributed(is_master): 203 | """ 204 | This function disables printing when not in master process 205 | """ 206 | import builtins as __builtin__ 207 | 208 | builtin_print = __builtin__.print 209 | 210 | def print(*args, **kwargs): 211 | force = kwargs.pop("force", False) 212 | if is_master or force: 213 | builtin_print(*args, **kwargs) 214 | 215 | __builtin__.print = print 216 | 217 | 218 | def is_dist_avail_and_initialized(): 219 | if not dist.is_available(): 220 | return False 221 | if not dist.is_initialized(): 222 | return False 223 | return True 224 | 225 | 226 | def get_world_size(): 227 | if not is_dist_avail_and_initialized(): 228 | return 1 229 | return dist.get_world_size() 230 | 231 | 232 | def get_rank(): 233 | if not is_dist_avail_and_initialized(): 234 | return 0 235 | return dist.get_rank() 236 | 237 | 238 | def is_main_process(): 239 | return get_rank() == 0 240 | 241 | 242 | def save_on_master(*args, **kwargs): 243 | if is_main_process(): 244 | torch.save(*args, **kwargs) 245 | 246 | 247 | def init_distributed_mode(args): 248 | if "RANK" in os.environ and "WORLD_SIZE" in os.environ: 249 | args.rank = int(os.environ["RANK"]) 250 | args.world_size = int(os.environ["WORLD_SIZE"]) 251 | args.gpu = int(os.environ["LOCAL_RANK"]) 252 | elif "SLURM_PROCID" in os.environ: 253 | args.rank = int(os.environ["SLURM_PROCID"]) 254 | args.gpu = args.rank % torch.cuda.device_count() 255 | elif hasattr(args, "rank"): 256 | pass 257 | else: 258 | print("Not using distributed mode") 259 | args.distributed = False 260 | return 261 | 262 | args.distributed = True 263 | 264 | torch.cuda.set_device(args.gpu) 265 | args.dist_backend = "nccl" 266 | print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True) 267 | torch.distributed.init_process_group( 268 | backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank 269 | ) 270 | torch.distributed.barrier() 271 | setup_for_distributed(args.rank == 0) 272 | 273 | 274 | def average_checkpoints(inputs): 275 | """Loads checkpoints from inputs and returns a model with averaged weights. Original implementation taken from: 276 | https://github.com/pytorch/fairseq/blob/a48f235636557b8d3bc4922a6fa90f3a0fa57955/scripts/average_checkpoints.py#L16 277 | 278 | Args: 279 | inputs (List[str]): An iterable of string paths of checkpoints to load from. 280 | Returns: 281 | A dict of string keys mapping to various values. The 'model' key 282 | from the returned dict should correspond to an OrderedDict mapping 283 | string parameter names to torch Tensors. 284 | """ 285 | params_dict = OrderedDict() 286 | params_keys = None 287 | new_state = None 288 | num_models = len(inputs) 289 | for fpath in inputs: 290 | with open(fpath, "rb") as f: 291 | state = torch.load( 292 | f, 293 | map_location=(lambda s, _: torch.serialization.default_restore_location(s, "cpu")), 294 | ) 295 | # Copies over the settings from the first checkpoint 296 | if new_state is None: 297 | new_state = state 298 | model_params = state["model"] 299 | model_params_keys = list(model_params.keys()) 300 | if params_keys is None: 301 | params_keys = model_params_keys 302 | elif params_keys != model_params_keys: 303 | raise KeyError( 304 | f"For checkpoint {f}, expected list of params: {params_keys}, but found: {model_params_keys}" 305 | ) 306 | for k in params_keys: 307 | p = model_params[k] 308 | if isinstance(p, torch.HalfTensor): 309 | p = p.float() 310 | if k not in params_dict: 311 | params_dict[k] = p.clone() 312 | # NOTE: clone() is needed in case of p is a shared parameter 313 | else: 314 | params_dict[k] += p 315 | averaged_params = OrderedDict() 316 | for k, v in params_dict.items(): 317 | averaged_params[k] = v 318 | if averaged_params[k].is_floating_point(): 319 | averaged_params[k].div_(num_models) 320 | else: 321 | averaged_params[k] //= num_models 322 | new_state["model"] = averaged_params 323 | return new_state 324 | 325 | 326 | def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=True): 327 | """ 328 | This method can be used to prepare weights files for new models. It receives as 329 | input a model architecture and a checkpoint from the training script and produces 330 | a file with the weights ready for release. 331 | 332 | Examples: 333 | from torchvision import models as M 334 | 335 | # Classification 336 | model = M.mobilenet_v3_large(weights=None) 337 | print(store_model_weights(model, './class.pth')) 338 | 339 | # Quantized Classification 340 | model = M.quantization.mobilenet_v3_large(weights=None, quantize=False) 341 | model.fuse_model(is_qat=True) 342 | model.qconfig = torch.ao.quantization.get_default_qat_qconfig('qnnpack') 343 | _ = torch.ao.quantization.prepare_qat(model, inplace=True) 344 | print(store_model_weights(model, './qat.pth')) 345 | 346 | # Object Detection 347 | model = M.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=None, weights_backbone=None) 348 | print(store_model_weights(model, './obj.pth')) 349 | 350 | # Segmentation 351 | model = M.segmentation.deeplabv3_mobilenet_v3_large(weights=None, weights_backbone=None, aux_loss=True) 352 | print(store_model_weights(model, './segm.pth', strict=False)) 353 | 354 | Args: 355 | model (pytorch.nn.Module): The model on which the weights will be loaded for validation purposes. 356 | checkpoint_path (str): The path of the checkpoint we will load. 357 | checkpoint_key (str, optional): The key of the checkpoint where the model weights are stored. 358 | Default: "model". 359 | strict (bool): whether to strictly enforce that the keys 360 | in :attr:`state_dict` match the keys returned by this module's 361 | :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` 362 | 363 | Returns: 364 | output_path (str): The location where the weights are saved. 365 | """ 366 | # Store the new model next to the checkpoint_path 367 | checkpoint_path = os.path.abspath(checkpoint_path) 368 | output_dir = os.path.dirname(checkpoint_path) 369 | 370 | # Deep copy to avoid side-effects on the model object. 371 | model = copy.deepcopy(model) 372 | checkpoint = torch.load(checkpoint_path, map_location="cpu") 373 | 374 | # Load the weights to the model to validate that everything works 375 | # and remove unnecessary weights (such as auxiliaries, etc) 376 | if checkpoint_key == "model_ema": 377 | del checkpoint[checkpoint_key]["n_averaged"] 378 | torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(checkpoint[checkpoint_key], "module.") 379 | model.load_state_dict(checkpoint[checkpoint_key], strict=strict) 380 | 381 | tmp_path = os.path.join(output_dir, str(model.__hash__())) 382 | torch.save(model.state_dict(), tmp_path) 383 | 384 | sha256_hash = hashlib.sha256() 385 | with open(tmp_path, "rb") as f: 386 | # Read and update hash string value in blocks of 4K 387 | for byte_block in iter(lambda: f.read(4096), b""): 388 | sha256_hash.update(byte_block) 389 | hh = sha256_hash.hexdigest() 390 | 391 | output_path = os.path.join(output_dir, "weights-" + str(hh[:8]) + ".pth") 392 | os.replace(tmp_path, output_path) 393 | 394 | return output_path 395 | 396 | 397 | def reduce_across_processes(val): 398 | if not is_dist_avail_and_initialized(): 399 | # nothing to sync, but we still convert to tensor for consistency with the distributed case. 400 | return torch.tensor(val) 401 | 402 | t = torch.tensor(val, device="cuda") 403 | dist.barrier() 404 | dist.all_reduce(t) 405 | return t 406 | 407 | 408 | def set_weight_decay( 409 | model: torch.nn.Module, 410 | weight_decay: float, 411 | norm_weight_decay: Optional[float] = None, 412 | norm_classes: Optional[List[type]] = None, 413 | custom_keys_weight_decay: Optional[List[Tuple[str, float]]] = None, 414 | ): 415 | if not norm_classes: 416 | norm_classes = [ 417 | torch.nn.modules.batchnorm._BatchNorm, 418 | torch.nn.LayerNorm, 419 | torch.nn.GroupNorm, 420 | torch.nn.modules.instancenorm._InstanceNorm, 421 | torch.nn.LocalResponseNorm, 422 | ] 423 | norm_classes = tuple(norm_classes) 424 | 425 | params = { 426 | "other": [], 427 | "norm": [], 428 | } 429 | params_weight_decay = { 430 | "other": weight_decay, 431 | "norm": norm_weight_decay, 432 | } 433 | custom_keys = [] 434 | if custom_keys_weight_decay is not None: 435 | for key, weight_decay in custom_keys_weight_decay: 436 | params[key] = [] 437 | params_weight_decay[key] = weight_decay 438 | custom_keys.append(key) 439 | 440 | def _add_params(module, prefix=""): 441 | for name, p in module.named_parameters(recurse=False): 442 | if not p.requires_grad: 443 | continue 444 | is_custom_key = False 445 | for key in custom_keys: 446 | target_name = f"{prefix}.{name}" if prefix != "" and "." in key else name 447 | if key == target_name: 448 | params[key].append(p) 449 | is_custom_key = True 450 | break 451 | if not is_custom_key: 452 | if norm_weight_decay is not None and isinstance(module, norm_classes): 453 | params["norm"].append(p) 454 | else: 455 | params["other"].append(p) 456 | 457 | for child_name, child_module in module.named_children(): 458 | child_prefix = f"{prefix}.{child_name}" if prefix != "" else child_name 459 | _add_params(child_module, prefix=child_prefix) 460 | 461 | _add_params(model) 462 | 463 | param_groups = [] 464 | for key in params: 465 | if len(params[key]) > 0: 466 | param_groups.append({"params": params[key], "weight_decay": params_weight_decay[key]}) 467 | return param_groups 468 | --------------------------------------------------------------------------------